Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
4acf0e01
Commit
4acf0e01
authored
Apr 26, 2023
by
aiss
Browse files
delete hip file
parent
7dd68788
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
11099 deletions
+0
-11099
csrc/transformer_bak/inference/csrc/apply_rotary_pos_emb.hip
csrc/transformer_bak/inference/csrc/apply_rotary_pos_emb.hip
+0
-374
csrc/transformer_bak/inference/csrc/dequantize.cu
csrc/transformer_bak/inference/csrc/dequantize.cu
+0
-110
csrc/transformer_bak/inference/csrc/dequantize.hip
csrc/transformer_bak/inference/csrc/dequantize.hip
+0
-112
csrc/transformer_bak/inference/csrc/gelu.cu
csrc/transformer_bak/inference/csrc/gelu.cu
+0
-525
csrc/transformer_bak/inference/csrc/gelu.hip
csrc/transformer_bak/inference/csrc/gelu.hip
+0
-527
csrc/transformer_bak/inference/csrc/normalize.cu
csrc/transformer_bak/inference/csrc/normalize.cu
+0
-451
csrc/transformer_bak/inference/csrc/normalize.hip
csrc/transformer_bak/inference/csrc/normalize.hip
+0
-453
csrc/transformer_bak/inference/csrc/pt_binding.cpp
csrc/transformer_bak/inference/csrc/pt_binding.cpp
+0
-911
csrc/transformer_bak/inference/csrc/pt_binding_hip.cpp
csrc/transformer_bak/inference/csrc/pt_binding_hip.cpp
+0
-912
csrc/transformer_bak/inference/csrc/softmax.cu
csrc/transformer_bak/inference/csrc/softmax.cu
+0
-432
csrc/transformer_bak/inference/csrc/softmax.hip
csrc/transformer_bak/inference/csrc/softmax.hip
+0
-434
csrc/transformer_bak/inference/includes/context.h
csrc/transformer_bak/inference/includes/context.h
+0
-177
csrc/transformer_bak/inference/includes/context_hip.h
csrc/transformer_bak/inference/includes/context_hip.h
+0
-178
csrc/transformer_bak/inference/includes/cublas_wrappers.h
csrc/transformer_bak/inference/includes/cublas_wrappers.h
+0
-207
csrc/transformer_bak/inference/includes/cublas_wrappers_hip.h
.../transformer_bak/inference/includes/cublas_wrappers_hip.h
+0
-208
csrc/transformer_bak/inference/includes/custom_cuda_layers.h
csrc/transformer_bak/inference/includes/custom_cuda_layers.h
+0
-124
csrc/transformer_bak/inference/includes/custom_hip_layers.h
csrc/transformer_bak/inference/includes/custom_hip_layers.h
+0
-125
csrc/transformer_bak/normalize_kernels.cu
csrc/transformer_bak/normalize_kernels.cu
+0
-2121
csrc/transformer_bak/normalize_kernels.hip
csrc/transformer_bak/normalize_kernels.hip
+0
-2123
csrc/transformer_bak/softmax_kernels.cu
csrc/transformer_bak/softmax_kernels.cu
+0
-595
No files found.
csrc/transformer_bak/inference/csrc/apply_rotary_pos_emb.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
//#include <cuda_profiler_api.h>
namespace cg = cooperative_groups;
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
__global__ void apply_rotary_pos_emb1(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb1(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {
0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000,
0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000, 0x100 | 0x100000, 0x200 | 0x200000,
0x400 | 0x400000, 0x800 | 0x800000, 0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4,
0x8000 | 0x8, 0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800, 0x1000000,
0x2000000, 0x4000000, 0x8000000, 0x10000000, 0x20000000,
0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned half_dim = rotary_dim >> 1;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], q_rot, lane + half_dim)
: __shfl_sync(mask[lane], q_rot, lane - half_dim);
auto k_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], k_rot, lane + half_dim)
: __shfl_sync(mask[lane], k_rot, lane - half_dim);
q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
bool rotate_half,
bool rotate_every_two,
hipStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
if (rotate_every_two)
hipLaunchKernelGGL(( apply_rotary_pos_emb), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
else if (rotate_half)
hipLaunchKernelGGL(( apply_rotary_pos_emb1), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
hipStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
hipStream_t);
/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000,
0x10 | 0x10000, 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000,
0x100 | 0x100000, 0x200 | 0x200000, 0x400 | 0x400000, 0x800 | 0x800000,
0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, 0x8000 | 0x8,
0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800,
0x1000000, 0x2000000, 0x4000000, 0x8000000,
0x10000000, 0x20000000, 0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
//float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
float inv_freq = (float)((lane % (rotary_dim >> 1)) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > 11 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], q_rot, lane + 12) : __shfl_sync(mask[lane],
q_rot, lane - 12);//g.shfl_xor(q_rot, 12); auto k_rot_tmp = lane < 12 ? __shfl_sync(mask[lane],
k_rot, lane + 12) : __shfl_sync(mask[lane], k_rot, lane - 12);//g.shfl_xor(k_rot, 12); q = q *
cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
hipStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
hipLaunchKernelGGL((
apply_rotary_pos_emb), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
*/
csrc/transformer_bak/inference/csrc/dequantize.cu
deleted
100644 → 0
View file @
7dd68788
#include "custom_cuda_layers.h"
#define MAX_QUANTIZE_GROUPING 1024
#define loop_unroll 1
#define loop_unroll_bits 1
__global__
void
dequantize_kernel
(
float
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
int
output_size
,
int
hidden_dim
,
int
groups
,
int
merge_count
)
{
unsigned
merge_hidden
=
hidden_dim
>>
merge_count
;
unsigned
quantization_stride
=
(
merge_hidden
*
output_size
)
/
groups
;
unsigned
bid
=
blockIdx
.
x
;
unsigned
tid
=
threadIdx
.
x
;
while
(
tid
<
output_size
)
{
unsigned
w_index
=
bid
/
merge_hidden
;
unsigned
q_index
=
tid
+
bid
*
output_size
;
auto
q
=
input
[
q_index
];
unsigned
merge_hidden_total
=
w_index
*
merge_hidden
;
unsigned
scale_index
=
((((
bid
-
merge_hidden_total
)
+
tid
*
merge_hidden
)
/
quantization_stride
)
<<
merge_count
)
+
w_index
;
float
scale_data
=
qscale
[
scale_index
];
output
[
q_index
]
=
(
scale_data
*
(
float
)
q
);
tid
+=
blockDim
.
x
;
}
}
__global__
void
dequantize_kernel
(
__half
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
unsigned
output_size
,
unsigned
hidden_dim
,
unsigned
groups
,
unsigned
merge_count
)
{
#ifdef HALF_PRECISION_AVAILABLE
unsigned
merge_hidden
=
hidden_dim
>>
merge_count
;
unsigned
quantization_stride
=
(
merge_hidden
*
output_size
)
/
groups
;
unsigned
bid
=
blockIdx
.
x
;
unsigned
tid
=
threadIdx
.
x
;
while
(
tid
<
output_size
)
{
unsigned
w_index
=
bid
/
merge_hidden
;
unsigned
q_index
=
tid
+
bid
*
output_size
;
auto
q
=
input
[
q_index
];
unsigned
merge_hidden_total
=
w_index
*
merge_hidden
;
unsigned
scale_index
=
((((
bid
-
merge_hidden_total
)
+
tid
*
merge_hidden
)
/
quantization_stride
)
<<
merge_count
)
+
w_index
;
float
scale_data
=
qscale
[
scale_index
];
output
[
q_index
]
=
__float2half
(
scale_data
*
(
float
)
q
);
tid
+=
blockDim
.
x
;
}
#endif
}
template
<
typename
T
>
void
launch_dequantize
(
T
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
unsigned
output_size
,
unsigned
hidden_dim
,
unsigned
groups
,
unsigned
merge_count
,
cudaStream_t
stream
)
{
unsigned
threads
=
1024
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
hidden_dim
);
dequantize_kernel
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
output
,
input
,
qscale
,
output_size
,
hidden_dim
,
groups
,
merge_count
);
}
template
void
launch_dequantize
<
float
>(
float
*
,
const
int8_t
*
,
const
float
*
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
cudaStream_t
);
template
void
launch_dequantize
<
__half
>(
__half
*
,
const
int8_t
*
,
const
float
*
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
cudaStream_t
);
csrc/transformer_bak/inference/csrc/dequantize.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
#define MAX_QUANTIZE_GROUPING 1024
#define loop_unroll 1
#define loop_unroll_bits 1
__global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count)
{
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = (scale_data * (float)q);
tid += blockDim.x;
}
}
__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count)
{
#ifdef HALF_PRECISION_AVAILABLE
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = __float2half(scale_data * (float)q);
tid += blockDim.x;
}
#endif
}
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count,
hipStream_t stream)
{
unsigned threads = 1024;
dim3 block_dims(threads);
dim3 grid_dims(hidden_dim);
hipLaunchKernelGGL(( dequantize_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
output, input, qscale, output_size, hidden_dim, groups, merge_count);
}
template void launch_dequantize<float>(float*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
csrc/transformer_bak/inference/csrc/gelu.cu
deleted
100644 → 0
View file @
7dd68788
#include "custom_cuda_layers.h"
#define MAX_CAP 4
#define MAX_SEQ 2048
inline
__device__
float
gelu
(
const
float
x
)
{
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
return
x
*
0.5
f
*
(
1.0
f
+
tanhf
(
sqrt_param
*
(
x
+
mul_param
*
x
*
x
*
x
)));
}
__global__
void
fused_bias_gelu
(
float
*
input
,
const
float
*
bias
,
int
total_count
,
int
intermediate_size
)
{
float4
*
input_cast
=
reinterpret_cast
<
float4
*>
(
input
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float4
data
=
input_cast
[
offset
];
float4
bias_data
=
bias_cast
[
offset
%
intermediate_size
];
data
.
x
+=
bias_data
.
x
;
data
.
y
+=
bias_data
.
y
;
data
.
z
+=
bias_data
.
z
;
data
.
w
+=
bias_data
.
w
;
data
.
x
=
gelu
(
data
.
x
);
data
.
y
=
gelu
(
data
.
y
);
data
.
z
=
gelu
(
data
.
z
);
data
.
w
=
gelu
(
data
.
w
);
input_cast
[
offset
]
=
data
;
}
}
__global__
void
fused_bias_gelu
(
__half
*
input
,
const
__half
*
bias
,
int
total_count
,
int
intermediate_size
)
{
#ifdef HALF_PRECISION_AVAILABLE
float2
*
input_cast
=
reinterpret_cast
<
float2
*>
(
input
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float2
vals_vec
=
input_cast
[
offset
];
float2
bias_vec
=
bias_cast
[
offset
%
intermediate_size
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
low_data
.
x
+=
low_bias
.
x
;
low_data
.
y
+=
low_bias
.
y
;
high_data
.
x
+=
high_bias
.
x
;
high_data
.
y
+=
high_bias
.
y
;
low_data
.
x
=
gelu
(
low_data
.
x
);
low_data
.
y
=
gelu
(
low_data
.
y
);
high_data
.
x
=
gelu
(
high_data
.
x
);
high_data
.
y
=
gelu
(
high_data
.
y
);
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
input_cast
[
offset
]
=
vals_vec
;
}
#endif
}
template
<
typename
T
>
void
launch_bias_gelu
(
T
*
input
,
const
T
*
bias
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
total_count
=
batch_size
*
(
intermediate_size
/
4
);
int
threads
=
1024
;
// intermediate_size / iterations / 4;
dim3
block_dims
(
threads
);
dim3
grid_dims
(((
total_count
-
1
)
/
1024
+
1
));
// (batch_size);
fused_bias_gelu
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
bias
,
total_count
,
intermediate_size
/
4
);
}
template
void
launch_bias_gelu
<
float
>(
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_gelu
<
__half
>(
__half
*
,
const
__half
*
,
int
,
int
,
cudaStream_t
);
__global__
void
fused_bias_add
(
float
*
input
,
const
float
*
bias
,
int
total_count
,
int
hidden_size
)
{
float4
*
input_cast
=
reinterpret_cast
<
float4
*>
(
input
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float4
data
=
input_cast
[
offset
];
float4
bias_data
=
bias_cast
[
offset
%
hidden_size
];
data
.
x
+=
bias_data
.
x
;
data
.
y
+=
bias_data
.
y
;
data
.
z
+=
bias_data
.
z
;
data
.
w
+=
bias_data
.
w
;
input_cast
[
offset
]
=
data
;
}
}
__global__
void
fused_bias_add
(
__half
*
input
,
const
__half
*
bias
,
int
total_count
,
int
hidden_size
)
{
#ifdef HALF_PRECISION_AVAILABLE
float2
*
input_cast
=
reinterpret_cast
<
float2
*>
(
input
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float2
vals_vec
=
input_cast
[
offset
];
float2
bias_vec
=
bias_cast
[
offset
%
hidden_size
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
low_data
.
x
+=
low_bias
.
x
;
low_data
.
y
+=
low_bias
.
y
;
high_data
.
x
+=
high_bias
.
x
;
high_data
.
y
+=
high_bias
.
y
;
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
input_cast
[
offset
]
=
vals_vec
;
}
#endif
}
template
<
typename
T
>
void
launch_bias_add
(
T
*
input
,
const
T
*
bias
,
int
hidden_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
total_count
=
batch_size
*
(
hidden_size
/
4
);
int
threads
=
1024
;
// hidden_size / iterations / 4;
dim3
block_dims
(
threads
);
dim3
grid_dims
(((
total_count
-
1
)
/
threads
+
1
));
// (batch_size);
fused_bias_add
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
bias
,
total_count
,
hidden_size
/
4
);
}
template
void
launch_bias_add
<
float
>(
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_add
<
__half
>(
__half
*
,
const
__half
*
,
int
,
int
,
cudaStream_t
);
__global__
void
fused_bias_residual
(
float
*
input
,
float
*
output
,
float
*
attn
,
float
*
bias
,
float
*
attnbias
,
int
total_count
,
int
intermediate_size
,
int
mp_size
)
{
float4
*
input_cast
=
reinterpret_cast
<
float4
*>
(
input
);
float4
*
output_cast
=
reinterpret_cast
<
float4
*>
(
output
);
float4
*
attn_cast
=
reinterpret_cast
<
float4
*>
(
attn
);
float4
*
bias_cast
=
reinterpret_cast
<
float4
*>
(
bias
);
float4
*
attnbias_cast
=
reinterpret_cast
<
float4
*>
(
attnbias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float4
data
=
input_cast
[
offset
];
float4
out
=
output_cast
[
offset
];
float4
res_vec
=
attn_cast
[
offset
];
float4
bias_data
=
bias_cast
[
offset
%
intermediate_size
];
float4
attn_bias
=
attnbias_cast
[
offset
%
intermediate_size
];
data
.
x
=
(
data
.
x
+
res_vec
.
x
)
*
mp_size
+
(
out
.
x
+
bias_data
.
x
+
attn_bias
.
x
);
data
.
y
=
(
data
.
y
+
res_vec
.
y
)
*
mp_size
+
(
out
.
y
+
bias_data
.
y
+
attn_bias
.
y
);
data
.
z
=
(
data
.
z
+
res_vec
.
z
)
*
mp_size
+
(
out
.
z
+
bias_data
.
z
+
attn_bias
.
z
);
data
.
w
=
(
data
.
w
+
res_vec
.
w
)
*
mp_size
+
(
out
.
w
+
bias_data
.
w
+
attn_bias
.
w
);
output_cast
[
offset
]
=
data
;
}
}
__global__
void
fused_bias_residual
(
__half
*
input
,
__half
*
output
,
__half
*
attn
,
__half
*
bias
,
__half
*
attn_bias
,
int
total_count
,
int
intermediate_size
,
int
mp_size
)
{
#ifdef HALF_PRECISION_AVAILABLE
float2
*
input_cast
=
reinterpret_cast
<
float2
*>
(
input
);
float2
*
output_cast
=
reinterpret_cast
<
float2
*>
(
output
);
float2
*
attn_cast
=
reinterpret_cast
<
float2
*>
(
attn
);
float2
*
bias_cast
=
reinterpret_cast
<
float2
*>
(
bias
);
float2
*
attnbias_cast
=
reinterpret_cast
<
float2
*>
(
attn_bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float2
vals_vec
=
input_cast
[
offset
];
float2
out_vec
=
output_cast
[
offset
];
float2
res_vec
=
attn_cast
[
offset
];
float2
bias_vec
=
bias_cast
[
offset
%
intermediate_size
];
float2
attn_bias_vec
=
attnbias_cast
[
offset
%
intermediate_size
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
out_half
=
reinterpret_cast
<
__half2
*>
(
&
out_vec
);
__half2
*
res_half
=
reinterpret_cast
<
__half2
*>
(
&
res_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
__half2
*
attnbias_half
=
reinterpret_cast
<
__half2
*>
(
&
attn_bias_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
low_out
=
__half22float2
(
out_half
[
0
]);
float2
high_out
=
__half22float2
(
out_half
[
1
]);
float2
low_res
=
__half22float2
(
res_half
[
0
]);
float2
high_res
=
__half22float2
(
res_half
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
float2
attn_low_bias
=
__half22float2
(
attnbias_half
[
0
]);
float2
attn_high_bias
=
__half22float2
(
attnbias_half
[
1
]);
low_data
.
x
=
(
low_data
.
x
+
low_res
.
x
)
*
mp_size
+
(
low_out
.
x
+
(
low_bias
.
x
+
attn_low_bias
.
x
));
low_data
.
y
=
(
low_data
.
y
+
low_res
.
y
)
*
mp_size
+
(
low_out
.
y
+
(
low_bias
.
y
+
attn_low_bias
.
y
));
high_data
.
x
=
(
high_data
.
x
+
high_res
.
x
)
*
mp_size
+
(
high_out
.
x
+
(
high_bias
.
x
+
attn_high_bias
.
x
));
high_data
.
y
=
(
high_data
.
y
+
high_res
.
y
)
*
mp_size
+
(
high_out
.
y
+
(
high_bias
.
y
+
attn_high_bias
.
y
));
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
output_cast
[
offset
]
=
vals_vec
;
}
#endif
}
template
<
typename
T
>
void
launch_bias_residual
(
T
*
input
,
T
*
output
,
T
*
attn
,
T
*
bias
,
T
*
attn_bias
,
int
batch
,
int
hidden_dim
,
int
mp_size
,
cudaStream_t
stream
)
{
int
total_count
=
batch
*
hidden_dim
/
4
;
dim3
block_dims
(
1024
);
dim3
grid_dims
((
total_count
-
1
)
/
1024
+
1
);
// (batch_size);
fused_bias_residual
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
output
,
attn
,
bias
,
attn_bias
,
total_count
,
hidden_dim
/
4
,
1.0
/
mp_size
);
}
template
void
launch_bias_residual
<
float
>(
float
*
,
float
*
,
float
*
,
float
*
,
float
*
,
int
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_residual
<
__half
>(
__half
*
,
__half
*
,
__half
*
,
__half
*
,
__half
*
,
int
,
int
,
int
,
cudaStream_t
);
__global__
void
gptj_residual_add
(
float
*
input
,
float
*
output
,
float
*
attn
,
float
*
bias
,
float
*
attnbias
,
int
total_count
,
int
intermediate_size
,
float
mp_size
)
{
float4
*
input_cast
=
reinterpret_cast
<
float4
*>
(
input
);
float4
*
output_cast
=
reinterpret_cast
<
float4
*>
(
output
);
float4
*
attn_cast
=
reinterpret_cast
<
float4
*>
(
attn
);
float4
*
bias_cast
=
reinterpret_cast
<
float4
*>
(
bias
);
float4
*
attnbias_cast
=
reinterpret_cast
<
float4
*>
(
attnbias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float4
data
=
input_cast
[
offset
];
float4
out
=
output_cast
[
offset
];
float4
res_vec
=
attn_cast
[
offset
];
float4
bias_data
=
bias_cast
[
offset
%
intermediate_size
];
float4
attn_bias
=
attnbias_cast
[
offset
%
intermediate_size
];
data
.
x
=
data
.
x
*
mp_size
+
(
out
.
x
+
res_vec
.
x
+
bias_data
.
x
+
attn_bias
.
x
);
data
.
y
=
data
.
y
*
mp_size
+
(
out
.
y
+
res_vec
.
y
+
bias_data
.
y
+
attn_bias
.
y
);
data
.
z
=
data
.
z
*
mp_size
+
(
out
.
z
+
res_vec
.
z
+
bias_data
.
z
+
attn_bias
.
z
);
data
.
w
=
data
.
w
*
mp_size
+
(
out
.
w
+
res_vec
.
w
+
bias_data
.
w
+
attn_bias
.
w
);
output_cast
[
offset
]
=
data
;
}
}
__global__
void
gptj_residual_add
(
__half
*
input
,
__half
*
output
,
__half
*
attn
,
__half
*
bias
,
__half
*
attn_bias
,
int
total_count
,
int
intermediate_size
,
float
mp_size
)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
float2
*
input_cast
=
reinterpret_cast
<
float2
*>
(
input
);
float2
*
output_cast
=
reinterpret_cast
<
float2
*>
(
output
);
float2
*
attn_cast
=
reinterpret_cast
<
float2
*>
(
attn
);
float2
*
bias_cast
=
reinterpret_cast
<
float2
*>
(
bias
);
float2
*
attnbias_cast
=
reinterpret_cast
<
float2
*>
(
attn_bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float2
vals_vec
=
input_cast
[
offset
];
float2
out_vec
=
output_cast
[
offset
];
float2
res_vec
=
attn_cast
[
offset
];
float2
bias_vec
=
bias_cast
[
offset
%
intermediate_size
];
float2
attn_bias_vec
=
attnbias_cast
[
offset
%
intermediate_size
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
out_half
=
reinterpret_cast
<
__half2
*>
(
&
out_vec
);
__half2
*
res_half
=
reinterpret_cast
<
__half2
*>
(
&
res_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
__half2
*
attnbias_half
=
reinterpret_cast
<
__half2
*>
(
&
attn_bias_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
low_out
=
__half22float2
(
out_half
[
0
]);
float2
high_out
=
__half22float2
(
out_half
[
1
]);
float2
low_res
=
__half22float2
(
res_half
[
0
]);
float2
high_res
=
__half22float2
(
res_half
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
float2
attn_low_bias
=
__half22float2
(
attnbias_half
[
0
]);
float2
attn_high_bias
=
__half22float2
(
attnbias_half
[
1
]);
low_data
.
x
=
low_data
.
x
*
mp_size
+
(
low_out
.
x
+
low_res
.
x
+
(
low_bias
.
x
+
attn_low_bias
.
x
));
low_data
.
y
=
low_data
.
y
*
mp_size
+
(
low_out
.
y
+
low_res
.
y
+
(
low_bias
.
y
+
attn_low_bias
.
y
));
high_data
.
x
=
high_data
.
x
*
mp_size
+
(
high_out
.
x
+
high_res
.
x
+
(
high_bias
.
x
+
attn_high_bias
.
x
));
high_data
.
y
=
high_data
.
y
*
mp_size
+
(
high_out
.
y
+
high_res
.
y
+
(
high_bias
.
y
+
attn_high_bias
.
y
));
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
output_cast
[
offset
]
=
vals_vec
;
}
#endif
}
template
<
typename
T
>
void
launch_gptj_residual_add
(
T
*
input
,
T
*
output
,
T
*
attn
,
T
*
bias
,
T
*
attn_bias
,
int
hidden_dim
,
int
batch
,
int
mp_size
,
cudaStream_t
stream
)
{
int
total_count
=
batch
*
hidden_dim
/
4
;
dim3
block_dims
(
1024
);
dim3
grid_dims
((
total_count
-
1
)
/
1024
+
1
);
// (batch_size);
gptj_residual_add
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
output
,
attn
,
bias
,
attn_bias
,
total_count
,
hidden_dim
/
4
,
1.0
/
mp_size
);
}
template
void
launch_gptj_residual_add
<
float
>(
float
*
,
float
*
,
float
*
,
float
*
,
float
*
,
int
,
int
,
int
,
cudaStream_t
);
template
void
launch_gptj_residual_add
<
__half
>(
__half
*
,
__half
*
,
__half
*
,
__half
*
,
__half
*
,
int
,
int
,
int
,
cudaStream_t
);
__global__
void
moe_res_matmul
(
float
*
residual
,
float
*
coef
,
float
*
mlp_out
,
int
seq_len
,
int
hidden_dim
)
{
unsigned
tid
=
threadIdx
.
x
;
float4
*
residual_cast
=
reinterpret_cast
<
float4
*>
(
residual
);
float4
*
coef_cast
=
reinterpret_cast
<
float4
*>
(
coef
);
float4
*
mlp_out_cast
=
reinterpret_cast
<
float4
*>
(
mlp_out
);
residual_cast
+=
blockIdx
.
x
*
hidden_dim
;
mlp_out_cast
+=
blockIdx
.
x
*
hidden_dim
;
float4
*
coef_cast2
=
coef_cast
+
hidden_dim
;
while
(
tid
<
hidden_dim
)
{
float4
res
=
residual_cast
[
tid
];
float4
mlp
=
mlp_out_cast
[
tid
];
float4
coef1
=
coef_cast
[
tid
];
float4
coef2
=
coef_cast2
[
tid
];
mlp
.
x
=
mlp
.
x
*
coef2
.
x
+
res
.
x
*
coef1
.
x
;
mlp
.
y
=
mlp
.
y
*
coef2
.
y
+
res
.
y
*
coef1
.
y
;
mlp
.
z
=
mlp
.
z
*
coef2
.
z
+
res
.
z
*
coef1
.
z
;
mlp
.
w
=
mlp
.
w
*
coef2
.
w
+
res
.
w
*
coef1
.
w
;
mlp_out_cast
[
tid
]
=
mlp
;
tid
+=
blockDim
.
x
;
}
}
__global__
void
moe_res_matmul
(
__half
*
residual
,
__half
*
coef
,
__half
*
mlp_out
,
int
seq_len
,
int
hidden_dim
)
{
unsigned
tid
=
threadIdx
.
x
;
float2
*
residual_cast
=
reinterpret_cast
<
float2
*>
(
residual
);
float2
*
mlp_out_cast
=
reinterpret_cast
<
float2
*>
(
mlp_out
);
float2
*
coef_cast
=
reinterpret_cast
<
float2
*>
(
coef
);
float2
*
coef_cast2
=
coef_cast
+
hidden_dim
;
residual_cast
+=
blockIdx
.
x
*
hidden_dim
;
mlp_out_cast
+=
blockIdx
.
x
*
hidden_dim
;
while
(
tid
<
hidden_dim
)
{
float2
res
=
residual_cast
[
tid
];
float2
coef1
=
coef_cast
[
tid
];
float2
coef2
=
coef_cast
[
tid
];
float2
data
=
mlp_out_cast
[
tid
];
__half
*
data_h
=
reinterpret_cast
<
__half
*>
(
&
data
);
__half
*
coef1_h
=
reinterpret_cast
<
__half
*>
(
&
coef1
);
__half
*
coef2_h
=
reinterpret_cast
<
__half
*>
(
&
coef2
);
__half
*
res_h
=
reinterpret_cast
<
__half
*>
(
&
res
);
data_h
[
0
]
=
res_h
[
0
]
*
coef1_h
[
0
]
+
data_h
[
0
]
*
coef2_h
[
0
];
data_h
[
1
]
=
res_h
[
1
]
*
coef1_h
[
1
]
+
data_h
[
1
]
*
coef2_h
[
1
];
data_h
[
2
]
=
res_h
[
2
]
*
coef1_h
[
2
]
+
data_h
[
2
]
*
coef2_h
[
2
];
data_h
[
3
]
=
res_h
[
3
]
*
coef1_h
[
3
]
+
data_h
[
3
]
*
coef2_h
[
3
];
mlp_out_cast
[
tid
]
=
data
;
tid
+=
blockDim
.
x
;
}
}
template
<
typename
T
>
void
launch_moe_res_matmul
(
T
*
residual
,
T
*
coef
,
T
*
mlp_out
,
int
seq_len
,
int
hidden_dim
,
cudaStream_t
stream
)
{
dim3
grid_dim
(
seq_len
);
dim3
block_dim
(
1024
);
moe_res_matmul
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
residual
,
coef
,
mlp_out
,
seq_len
,
hidden_dim
/
4
);
}
template
void
launch_moe_res_matmul
(
float
*
residual
,
float
*
coef
,
float
*
mlp_out
,
int
seq_len
,
int
hidden_dim
,
cudaStream_t
stream
);
template
void
launch_moe_res_matmul
(
__half
*
residual
,
__half
*
coef
,
__half
*
mlp_out
,
int
seq_len
,
int
hidden_dim
,
cudaStream_t
stream
);
csrc/transformer_bak/inference/csrc/gelu.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
#define MAX_CAP 4
#define MAX_SEQ 2048
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
__global__ void fused_bias_gelu(float* input,
const float* bias,
int total_count,
int intermediate_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
input_cast[offset] = data;
}
}
__global__ void fused_bias_gelu(__half* input,
const __half* bias,
int total_count,
int intermediate_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
input_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_gelu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int total_count = batch_size * (intermediate_size / 4);
int threads = 1024; // intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, total_count, intermediate_size / 4);
}
template void launch_bias_gelu<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(__half*, const __half*, int, int, hipStream_t);
__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % hidden_size];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
input_cast[offset] = data;
}
}
__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % hidden_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
input_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, hipStream_t stream)
{
int total_count = batch_size * (hidden_size / 4);
int threads = 1024; // hidden_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / threads + 1)); // (batch_size);
hipLaunchKernelGGL(( fused_bias_add), dim3(grid_dims), dim3(block_dims), 0, stream, input, bias, total_count, hidden_size / 4);
}
template void launch_bias_add<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_add<__half>(__half*, const __half*, int, int, hipStream_t);
__global__ void fused_bias_residual(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
int mp_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x);
data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y);
data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z);
data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w);
output_cast[offset] = data;
}
}
__global__ void fused_bias_residual(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
int mp_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x =
(low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x));
low_data.y =
(low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y));
high_data.x =
(high_data.x + high_res.x) * mp_size + (high_out.x + (high_bias.x + attn_high_bias.x));
high_data.y =
(high_data.y + high_res.y) * mp_size + (high_out.y + (high_bias.y + attn_high_bias.y));
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_residual(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int batch,
int hidden_dim,
int mp_size,
hipStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
hipLaunchKernelGGL(( fused_bias_residual), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void
launch_bias_residual<float>(float*, float*, float*, float*, float*, int, int, int, hipStream_t);
template void launch_bias_residual<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
hipStream_t);
__global__ void gptj_residual_add(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
float mp_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x + attn_bias.x);
data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y + attn_bias.y);
data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z + attn_bias.z);
data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w + attn_bias.w);
output_cast[offset] = data;
}
}
__global__ void gptj_residual_add(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
float mp_size)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x =
low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x + attn_low_bias.x));
low_data.y =
low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y + attn_low_bias.y));
high_data.x =
high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x + attn_high_bias.x));
high_data.y =
high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y + attn_high_bias.y));
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_gptj_residual_add(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int hidden_dim,
int batch,
int mp_size,
hipStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
hipLaunchKernelGGL(( gptj_residual_add), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void launch_gptj_residual_add<float>(float*,
float*,
float*,
float*,
float*,
int,
int,
int,
hipStream_t);
template void launch_gptj_residual_add<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
hipStream_t);
__global__ void moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim)
{
unsigned tid = threadIdx.x;
float4* residual_cast = reinterpret_cast<float4*>(residual);
float4* coef_cast = reinterpret_cast<float4*>(coef);
float4* mlp_out_cast = reinterpret_cast<float4*>(mlp_out);
residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;
float4* coef_cast2 = coef_cast + hidden_dim;
while (tid < hidden_dim) {
float4 res = residual_cast[tid];
float4 mlp = mlp_out_cast[tid];
float4 coef1 = coef_cast[tid];
float4 coef2 = coef_cast2[tid];
mlp.x = mlp.x * coef2.x + res.x * coef1.x;
mlp.y = mlp.y * coef2.y + res.y * coef1.y;
mlp.z = mlp.z * coef2.z + res.z * coef1.z;
mlp.w = mlp.w * coef2.w + res.w * coef1.w;
mlp_out_cast[tid] = mlp;
tid += blockDim.x;
}
}
__global__ void moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim)
{
unsigned tid = threadIdx.x;
float2* residual_cast = reinterpret_cast<float2*>(residual);
float2* mlp_out_cast = reinterpret_cast<float2*>(mlp_out);
float2* coef_cast = reinterpret_cast<float2*>(coef);
float2* coef_cast2 = coef_cast + hidden_dim;
residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;
while (tid < hidden_dim) {
float2 res = residual_cast[tid];
float2 coef1 = coef_cast[tid];
float2 coef2 = coef_cast[tid];
float2 data = mlp_out_cast[tid];
__half* data_h = reinterpret_cast<__half*>(&data);
__half* coef1_h = reinterpret_cast<__half*>(&coef1);
__half* coef2_h = reinterpret_cast<__half*>(&coef2);
__half* res_h = reinterpret_cast<__half*>(&res);
data_h[0] = res_h[0] * coef1_h[0] + data_h[0] * coef2_h[0];
data_h[1] = res_h[1] * coef1_h[1] + data_h[1] * coef2_h[1];
data_h[2] = res_h[2] * coef1_h[2] + data_h[2] * coef2_h[2];
data_h[3] = res_h[3] * coef1_h[3] + data_h[3] * coef2_h[3];
mlp_out_cast[tid] = data;
tid += blockDim.x;
}
}
template <typename T>
void launch_moe_res_matmul(T* residual,
T* coef,
T* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream)
{
dim3 grid_dim(seq_len);
dim3 block_dim(1024);
hipLaunchKernelGGL(( moe_res_matmul), dim3(grid_dim), dim3(block_dim), 0, stream,
residual, coef, mlp_out, seq_len, hidden_dim / 4);
}
template void launch_moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream);
template void launch_moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream);
csrc/transformer_bak/inference/csrc/normalize.cu
deleted
100644 → 0
View file @
7dd68788
#include <limits>
#include "custom_cuda_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define NORM_REG (MAX_REGISTERS)
namespace
cg
=
cooperative_groups
;
__global__
void
fused_bias_residual_layer_norm
(
float
*
output
,
const
float
*
vals
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
warp_num
=
iteration_stride
>>
5
;
float
inp_reg
[
NORM_REG
];
int
k
=
0
;
float
sum
=
0
;
int
input_id
=
id
;
while
(
input_id
<
row_stride
)
{
inp_reg
[
k
]
=
vals
[
input_id
+
row
*
row_stride
];
sum
+=
inp_reg
[
k
++
];
input_id
+=
iteration_stride
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
__shared__
float
shr
[
MAX_WARP_NUM
];
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
(
row_stride
);
sum
=
0.
f
;
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
inp_reg
[
f
]
-=
mean
;
sum
+=
inp_reg
[
f
]
*
inp_reg
[
f
];
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
row_stride
);
sum
+=
epsilon
;
sum
=
__frsqrt_rn
(
sum
);
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
int
out_id
=
f
*
iteration_stride
+
id
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
sum
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
gamma
[
out_id
]
+
beta
[
out_id
];
output
[
out_id
+
row
*
row_stride
]
=
inp_reg
[
f
];
}
}
__global__
void
fused_bias_residual_layer_norm
(
__half
*
output
,
const
__half
*
vals
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
int
row_stride
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
warp_num
=
iteration_stride
>>
5
;
__half2
inp_reg
[
NORM_REG
];
const
__half2
*
vals_cast
=
reinterpret_cast
<
const
__half2
*>
(
vals
);
__half2
*
out_cast
=
reinterpret_cast
<
__half2
*>
(
output
);
int
k
=
0
;
int
input_id
=
id
;
while
(
input_id
<
row_stride
)
{
inp_reg
[
k
++
]
=
vals_cast
[
input_id
+
row
*
row_stride
];
input_id
+=
iteration_stride
;
}
float
sum
=
0
;
for
(
int
f
=
k
-
1
;
f
>=
0
;
f
--
)
{
float2
inp_f
=
__half22float2
(
inp_reg
[
f
]);
sum
+=
inp_f
.
x
+
inp_f
.
y
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
__shared__
float
shr
[
MAX_WARP_NUM
];
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
(
row_stride
<<
1
);
sum
=
0.
f
;
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
float2
inp_f
=
__half22float2
(
inp_reg
[
f
]);
inp_f
.
x
-=
mean
;
inp_f
.
y
-=
mean
;
inp_reg
[
f
]
=
__float22half2_rn
(
inp_f
);
sum
+=
inp_f
.
x
*
inp_f
.
x
;
sum
+=
inp_f
.
y
*
inp_f
.
y
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
row_stride
<<
1
);
sum
+=
epsilon
;
sum
=
__frsqrt_rn
(
sum
);
__half2
variance_h
=
__float2half2_rn
(
sum
);
const
__half2
*
gamma_cast
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
const
__half2
*
beta_cast
=
reinterpret_cast
<
const
__half2
*>
(
beta
);
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
int
out_id
=
f
*
iteration_stride
+
id
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
variance_h
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
gamma_cast
[
out_id
]
+
beta_cast
[
out_id
];
out_cast
[
out_id
+
row
*
row_stride
]
=
inp_reg
[
f
];
}
#endif
}
template
<
typename
T
>
void
launch_layer_norm
(
T
*
out
,
T
*
vals
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
);
template
<
>
void
launch_layer_norm
<
float
>
(
float
*
out
,
float
*
vals
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
)
{
constexpr
int
threads
=
1024
;
dim3
grid_dim
(
batch_size
);
dim3
block_dim
(
threads
);
fused_bias_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
vals
,
gamma
,
beta
,
epsilon
,
hidden_dim
);
}
template
<
>
void
launch_layer_norm
<
__half
>
(
__half
*
out
,
__half
*
vals
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
)
{
constexpr
int
threads
=
1024
;
dim3
grid_dim
(
batch_size
);
dim3
block_dim
(
threads
);
fused_bias_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
vals
,
gamma
,
beta
,
epsilon
,
hidden_dim
/
2
);
}
__global__
void
fused_residual_layer_norm
(
float
*
norm
,
float
*
res_add
,
float
*
vals
,
float
*
residual
,
const
float
*
bias
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
int
row_stride
,
bool
preLN
,
bool
mlp_after_attn
)
{
int
iteration_stride
=
blockDim
.
x
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
warp_num
=
iteration_stride
>>
5
;
float
inp_reg
[
NORM_REG
];
int
k
=
0
;
int
input_id
=
id
;
float
sum
=
0
;
while
(
input_id
<
row_stride
)
{
inp_reg
[
k
]
=
vals
[
input_id
+
row
*
row_stride
];
float
res_f
=
(
residual
[
input_id
+
row
*
row_stride
]);
float
bias_f
=
(
bias
[
input_id
]);
if
(
mlp_after_attn
)
inp_reg
[
k
]
+=
res_f
+
bias_f
;
// if (preLN) res_add[input_id + row * row_stride] = inp_reg[k];
sum
+=
inp_reg
[
k
++
];
input_id
+=
iteration_stride
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
__shared__
float
shr
[
MAX_WARP_NUM
];
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
(
row_stride
);
sum
=
0.
f
;
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
inp_reg
[
f
]
-=
mean
;
sum
+=
inp_reg
[
f
]
*
inp_reg
[
f
];
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
row_stride
);
sum
+=
epsilon
;
sum
=
__frsqrt_rn
(
sum
);
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
int
out_id
=
f
*
iteration_stride
+
id
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
sum
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
gamma
[
out_id
]
+
beta
[
out_id
];
norm
[
out_id
+
row
*
row_stride
]
=
inp_reg
[
f
];
}
}
__global__
void
fused_residual_layer_norm
(
__half
*
norm
,
__half
*
res_add
,
__half
*
vals
,
__half
*
residual
,
const
__half
*
bias
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
int
row_stride
,
bool
preLN
,
bool
mlp_after_attn
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
iteration_stride
=
blockDim
.
x
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
warp_num
=
iteration_stride
>>
5
;
__half2
inp_reg
[
NORM_REG
];
__half2
*
vals_cast
=
reinterpret_cast
<
__half2
*>
(
vals
);
__half2
*
norm_cast
=
reinterpret_cast
<
__half2
*>
(
norm
);
__half2
*
res_add_cast
=
reinterpret_cast
<
__half2
*>
(
res_add
);
__half2
*
residual_cast
=
reinterpret_cast
<
__half2
*>
(
residual
);
const
__half2
*
bias_cast
=
reinterpret_cast
<
const
__half2
*>
(
bias
);
int
k
=
0
;
int
input_id
=
id
;
float
sum
=
0
;
while
(
input_id
<
row_stride
)
{
inp_reg
[
k
]
=
vals_cast
[
input_id
+
row
*
row_stride
];
float2
inp_f
=
__half22float2
(
inp_reg
[
k
]);
float2
res_f
=
__half22float2
(
residual_cast
[
input_id
+
row
*
row_stride
]);
float2
bias_f
=
__half22float2
(
bias_cast
[
input_id
]);
if
(
mlp_after_attn
)
{
inp_f
.
x
+=
res_f
.
x
+
bias_f
.
x
;
inp_f
.
y
+=
res_f
.
y
+
bias_f
.
y
;
}
inp_reg
[
k
]
=
__float22half2_rn
(
inp_f
);
// if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f);
// //inp_reg[k];
sum
+=
inp_f
.
x
+
inp_f
.
y
;
input_id
+=
iteration_stride
;
k
++
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
__shared__
float
shr
[
MAX_WARP_NUM
];
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
(
row_stride
<<
1
);
sum
=
0.
f
;
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
float2
inp_f
=
__half22float2
(
inp_reg
[
f
]);
inp_f
.
x
-=
mean
;
inp_f
.
y
-=
mean
;
inp_reg
[
f
]
=
__float22half2_rn
(
inp_f
);
sum
+=
inp_f
.
x
*
inp_f
.
x
;
sum
+=
inp_f
.
y
*
inp_f
.
y
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
row_stride
<<
1
);
sum
+=
epsilon
;
sum
=
__frsqrt_rn
(
sum
);
__half2
variance_h
=
__float2half2_rn
(
sum
);
const
__half2
*
gamma_cast
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
const
__half2
*
beta_cast
=
reinterpret_cast
<
const
__half2
*>
(
beta
);
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
int
out_id
=
f
*
iteration_stride
+
id
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
variance_h
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
gamma_cast
[
out_id
]
+
beta_cast
[
out_id
];
norm_cast
[
out_id
+
row
*
row_stride
]
=
inp_reg
[
f
];
}
#endif
}
template
<
typename
T
>
void
launch_residual_layer_norm
(
T
*
norm
,
T
*
res_add
,
T
*
vals
,
T
*
residual
,
const
T
*
bias
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
bool
preLN
,
bool
mlp_after_attn
,
cudaStream_t
stream
);
template
<
>
void
launch_residual_layer_norm
<
float
>
(
float
*
norm
,
float
*
res_add
,
float
*
vals
,
float
*
residual
,
const
float
*
bias
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
bool
preLN
,
bool
mlp_after_attn
,
cudaStream_t
stream
)
{
constexpr
int
threads
=
1024
;
dim3
grid_dim
(
batch_size
);
dim3
block_dim
(
threads
);
fused_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
norm
,
res_add
,
vals
,
residual
,
bias
,
gamma
,
beta
,
epsilon
,
hidden_dim
,
preLN
,
mlp_after_attn
);
}
template
<
>
void
launch_residual_layer_norm
<
__half
>
(
__half
*
norm
,
__half
*
res_add
,
__half
*
vals
,
__half
*
residual
,
const
__half
*
bias
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
bool
preLN
,
bool
mlp_after_attn
,
cudaStream_t
stream
)
{
constexpr
int
threads
=
1024
;
dim3
grid_dim
(
batch_size
);
dim3
block_dim
(
threads
);
fused_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
norm
,
res_add
,
vals
,
residual
,
bias
,
gamma
,
beta
,
epsilon
,
hidden_dim
/
2
,
preLN
,
mlp_after_attn
);
}
csrc/transformer_bak/inference/csrc/normalize.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <limits>
#include "custom_hip_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define NORM_REG (MAX_REGISTERS)
namespace cg = cooperative_groups;
__global__ void fused_bias_residual_layer_norm(float* output,
const float* vals,
const float* gamma,
const float* beta,
float epsilon,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
float sum = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
output[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_bias_residual_layer_norm(__half* output,
const __half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
const __half2* vals_cast = reinterpret_cast<const __half2*>(vals);
__half2* out_cast = reinterpret_cast<__half2*>(output);
int k = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k++] = vals_cast[input_id + row * row_stride];
input_id += iteration_stride;
}
float sum = 0;
for (int f = k - 1; f >= 0; f--) {
float2 inp_f = __half22float2(inp_reg[f]);
sum += inp_f.x + inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
out_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream);
template <>
void launch_layer_norm<float>(float* out,
float* vals,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
out, vals, gamma, beta, epsilon, hidden_dim);
}
template <>
void launch_layer_norm<__half>(__half* out,
__half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
out, vals, gamma, beta, epsilon, hidden_dim / 2);
}
__global__ void fused_residual_layer_norm(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
float res_f = (residual[input_id + row * row_stride]);
float bias_f = (bias[input_id]);
if (mlp_after_attn) inp_reg[k] += res_f + bias_f;
// if (preLN) res_add[input_id + row * row_stride] = inp_reg[k];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
norm[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_residual_layer_norm(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
__half2* norm_cast = reinterpret_cast<__half2*>(norm);
__half2* res_add_cast = reinterpret_cast<__half2*>(res_add);
__half2* residual_cast = reinterpret_cast<__half2*>(residual);
const __half2* bias_cast = reinterpret_cast<const __half2*>(bias);
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals_cast[input_id + row * row_stride];
float2 inp_f = __half22float2(inp_reg[k]);
float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]);
float2 bias_f = __half22float2(bias_cast[input_id]);
if (mlp_after_attn) {
inp_f.x += res_f.x + bias_f.x;
inp_f.y += res_f.y + bias_f.y;
}
inp_reg[k] = __float22half2_rn(inp_f);
// if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f);
// //inp_reg[k];
sum += inp_f.x + inp_f.y;
input_id += iteration_stride;
k++;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
norm_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream);
template <>
void launch_residual_layer_norm<float>(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream, norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim,
preLN,
mlp_after_attn);
}
template <>
void launch_residual_layer_norm<__half>(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream, norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim / 2,
preLN,
mlp_after_attn);
}
csrc/transformer_bak/inference/csrc/pt_binding.cpp
deleted
100644 → 0
View file @
7dd68788
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <vector>
#include "context.h"
#include "cublas_wrappers.h"
#include "custom_cuda_layers.h"
std
::
array
<
int
,
3
>
gemm_algos
=
std
::
array
<
int
,
3
>
({
99
,
99
,
99
});
#define MAX_OUT_TOKES 10
template
<
typename
T
>
at
::
Tensor
ds_softmax
(
at
::
Tensor
&
attn_scores
,
at
::
Tensor
&
attn_mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
bool
async_op
)
{
auto
attn_scores_c
=
attn_scores
.
contiguous
();
int
bsz
=
attn_scores_c
.
size
(
0
);
int
seq_len
=
attn_scores_c
.
size
(
1
);
int
len
=
attn_scores_c
.
sizes
().
size
();
if
(
len
>
3
)
seq_len
=
attn_scores_c
.
size
(
2
);
int
soft_len
=
attn_scores_c
.
size
(
2
);
if
(
len
>
3
)
soft_len
=
attn_scores_c
.
size
(
3
);
int
heads
=
1
;
if
(
len
>
3
)
heads
=
attn_scores_c
.
size
(
1
);
launch_attn_softmax_v2
((
T
*
)
attn_scores_c
.
data_ptr
(),
(
attn_mask
.
sizes
().
size
()
>
1
?
(
T
*
)
attn_mask
.
data_ptr
()
:
nullptr
),
triangular
,
recompute
,
local_attention
,
window_size
,
bsz
,
heads
,
seq_len
,
soft_len
,
1.0
,
Context
::
Instance
().
GetCurrentStream
(
async_op
));
return
attn_scores_c
;
}
template
<
typename
T
>
void
allocate_workspace
(
size_t
hidden_dim
,
size_t
max_seq_len
,
size_t
batch_size
,
size_t
head_size
=
128
)
{
size_t
_workSpaceSize
=
(
hidden_dim
*
batch_size
*
max_seq_len
);
Context
::
Instance
().
GenWorkSpace
(
_workSpaceSize
*
sizeof
(
T
));
}
template
<
typename
T
>
at
::
Tensor
einsum_sec_sm_ecm
(
at
::
Tensor
&
Q
,
at
::
Tensor
&
W
)
{
auto
options
=
at
::
TensorOptions
()
.
dtype
(
Q
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
float
alpha
=
1
;
float
gemm_beta
=
0.0
;
if
(
!
workspace
)
{
allocate_workspace
<
T
>
(
W
.
size
(
1
),
MAX_OUT_TOKES
,
Q
.
size
(
0
));
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
}
auto
O
=
at
::
from_blob
(
workspace
,
{
Q
.
size
(
1
),
Q
.
size
(
2
),
W
.
size
(
1
)},
options
);
unsigned
m
=
W
.
size
(
1
);
unsigned
n
=
Q
.
size
(
1
)
*
Q
.
size
(
2
);
unsigned
k
=
Q
.
size
(
0
);
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_T
,
m
,
n
,
k
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
W
.
data_ptr
(),
(
T
*
)
Q
.
data_ptr
(),
(
T
*
)
O
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
return
O
;
}
template
<
typename
T
>
void
attention_unfused
(
at
::
Tensor
&
prev_key_cont
,
at
::
Tensor
&
query_cont
,
at
::
Tensor
&
attn_mask
,
at
::
Tensor
&
prev_value_cont
,
at
::
Tensor
&
output
,
int
&
bsz
,
int
&
seq_len
,
int
&
soft_len
,
int
&
heads
,
float
&
norm_factor
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
)
{
auto
options
=
at
::
TensorOptions
()
.
dtype
(
query_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
float
alpha
=
norm_factor
;
float
gemm_beta
=
0.0
;
auto
attn_score
=
at
::
empty
({
bsz
,
heads
,
seq_len
,
soft_len
},
options
);
int
k
=
prev_value_cont
.
size
(
2
)
/
heads
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_strided_batched_gemm
(
Context
::
Instance
().
GetCublasHandle
(),
soft_len
,
seq_len
,
k
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
prev_key_cont
.
data_ptr
(),
(
T
*
)
query_cont
.
data_ptr
(),
(
T
*
)
attn_score
.
data_ptr
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
soft_len
*
k
,
seq_len
*
k
,
seq_len
*
soft_len
,
bsz
*
heads
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
attn_score
=
ds_softmax
<
T
>
(
attn_score
,
attn_mask
,
triangular
,
recompute
,
local_attention
,
window_size
,
false
);
alpha
=
1.0
;
cublas_strided_batched_gemm
(
Context
::
Instance
().
GetCublasHandle
(),
k
,
seq_len
,
soft_len
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
prev_value_cont
.
data_ptr
(),
(
T
*
)
attn_score
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
soft_len
*
k
,
seq_len
*
soft_len
,
seq_len
*
k
,
bsz
*
heads
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_softmax_context
(
at
::
Tensor
&
query
,
at
::
Tensor
&
prev_key
,
at
::
Tensor
&
new_key
,
at
::
Tensor
&
attn_mask
,
at
::
Tensor
&
prev_value
,
at
::
Tensor
&
new_value
,
int
heads
,
float
norm_factor
,
bool
merging
,
bool
triangular
,
bool
local_attention
,
int
window_size
,
bool
no_masking
)
{
auto
query_cont
=
query
.
contiguous
();
auto
prev_key_cont
=
prev_key
.
contiguous
();
auto
prev_value_cont
=
prev_value
.
contiguous
();
int
new_size
=
(
new_value
.
sizes
().
size
()
>
1
?
new_value
.
size
(
1
)
:
0
);
// Attn_Score [ batch Head Sequence-length Softmax-length]
int
bsz
=
query_cont
.
size
(
0
);
int
seq_len
=
query_cont
.
size
(
1
);
int
soft_len
=
prev_value
.
size
(
1
);
auto
options
=
at
::
TensorOptions
()
.
dtype
(
query_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
prev_value
.
size
(
0
),
heads
,
seq_len
,
prev_value
.
size
(
2
)
/
heads
},
options
);
attention_unfused
<
T
>
(
prev_key_cont
,
query_cont
,
attn_mask
,
//(no_masking ? nullptr : (T*)attn_mask.data_ptr()),
prev_value_cont
,
output
,
bsz
,
seq_len
,
soft_len
,
heads
,
norm_factor
,
(
triangular
&&
(
new_size
==
0
)),
(
new_size
==
0
),
local_attention
,
window_size
);
return
{
output
,
prev_key
,
prev_value
};
}
template
<
typename
T
>
at
::
Tensor
ds_bias_gelu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
int
intermediate_size
=
input_cont
.
size
(
2
);
launch_bias_gelu
((
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
intermediate_size
,
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
input_cont
;
}
template
<
typename
T
>
at
::
Tensor
ds_bias_residual
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
auto
residual_cont
=
residual
.
contiguous
();
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
// launch_bias_residual((T*)input_cont.data_ptr(),
// (T*)residual_cont.data_ptr(),
// (T*)bias.data_ptr(),
// bsz,
// input_cont.size(2),
// (bias.size(0) > 1),
// Context::Instance().GetCurrentStream());
return
input_cont
;
}
template
<
typename
T
>
at
::
Tensor
ds_layernorm
(
at
::
Tensor
&
input_cont
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
betta
,
float
epsilon
)
{
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
inp_norm
=
at
::
empty_like
(
input_cont
);
launch_layer_norm
((
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
gamma
.
data_ptr
(),
(
T
*
)
betta
.
data_ptr
(),
epsilon
,
bsz
,
input_cont
.
size
(
2
),
Context
::
Instance
().
GetCurrentStream
());
return
inp_norm
;
}
template
<
typename
T
>
at
::
Tensor
qkv_unfused_cublas
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
add_bias
)
{
auto
inp_norm
=
ds_layernorm
<
T
>
(
input
,
gamma
,
beta
,
epsilon
);
// cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
if
(
add_bias
)
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
inp_norm
;
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_qkv_gemm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
add_bias
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
inp_norm
=
qkv_unfused_cublas
<
T
>
(
output
,
input_cont
,
weight
,
bias
,
gamma
,
beta
,
epsilon
,
add_bias
);
return
{
output
,
inp_norm
};
}
template
<
typename
T
>
void
quantized_gemm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
qscale
,
int
groups
,
int
merge_count
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
weight16
=
at
::
empty
({
weight
.
size
(
0
),
weight
.
size
(
1
)},
options
);
launch_dequantize
((
T
*
)
weight16
.
data_ptr
(),
(
int8_t
*
)
weight
.
data_ptr
(),
(
float
*
)
qscale
.
data_ptr
(),
weight
.
size
(
1
),
weight
.
size
(
0
),
groups
,
merge_count
,
Context
::
Instance
().
GetCurrentStream
());
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight16
.
data_ptr
(),
(
T
*
)
input
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
at
::
Tensor
ds_qkv_gemm_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
add_bias
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
auto
inp_norm
=
ds_layernorm
<
T
>
(
input_cont
,
gamma
,
beta
,
epsilon
);
quantized_gemm
<
T
>
(
output
,
inp_norm
,
weight
,
q_scale
,
groups
,
0
);
if
(
add_bias
)
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_linear_layer
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input_cont
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_linear_layer_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
q_scale
,
int
groups
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_vector_matmul
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
bool
async_op
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
(
async_op
));
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input_cont
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_vector_matmul_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
q_scale
,
int
groups
,
int
merge_count
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
merge_count
);
return
output
;
}
template
<
typename
T
>
void
mlp_unfused_cublas
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
mlp_after_attn
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
inp_norm
=
at
::
empty_like
(
input
);
launch_residual_layer_norm
((
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
nullptr
,
(
T
*
)
input
.
data_ptr
(),
(
T
*
)
residual
.
data_ptr
(),
(
T
*
)
input_bias
.
data_ptr
(),
(
T
*
)
gamma
.
data_ptr
(),
(
T
*
)
beta
.
data_ptr
(),
epsilon
,
bsz
,
input
.
size
(
2
),
preLayerNorm
,
mlp_after_attn
,
Context
::
Instance
().
GetCurrentStream
());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
}
template
<
typename
T
>
at
::
Tensor
ds_mlp_gemm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
mlp_after_attn
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
mlp_unfused_cublas
<
T
>
(
output
,
mlp_after_attn
?
input
:
residual
,
residual
,
input_bias
,
weight
,
bias
,
gamma
,
beta
,
epsilon
,
preLayerNorm
,
mlp_after_attn
);
return
output
;
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_mlp_gemm_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
preLayerNorm
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
inp_norm
=
at
::
empty_like
(
input_cont
);
auto
residual_add
=
(
preLayerNorm
?
at
::
empty_like
(
input_cont
)
:
inp_norm
);
// computing the blocking across K dimension
// launch_residual_layer_norm((T*)inp_norm.data_ptr(),
// (T*)residual_add.data_ptr(),
// (T*)input_cont.data_ptr(),
// (T*)residual.data_ptr(),
// (T*)input_bias.data_ptr(),
// (T*)gamma.data_ptr(),
// (T*)beta.data_ptr(),
// epsilon,
// bsz,
// input_cont.size(2),
// preLayerNorm,
// Context::Instance().GetCurrentStream());
quantized_gemm
<
T
>
(
output
,
inp_norm
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
{
output
,
residual_add
};
}
template
<
typename
T
>
at
::
Tensor
fused_gemm_gelu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
weight_out
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
async_op
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
intermediate
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight_out
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
intermediate
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
launch_bias_gelu
((
T
*
)
intermediate
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight_out
.
size
(
1
),
bsz
,
intermediate
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight_out
.
data_ptr
(),
(
T
*
)
intermediate
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
// cudaEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return
output
;
}
void
residual_add_bias
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
attention_output
,
at
::
Tensor
&
output_b
,
at
::
Tensor
&
attention_b
,
int
mp_size
,
bool
mlp_after_attn
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
int
hidden_size
=
input
.
size
(
2
);
// cudaStreamWaitEvent(
// Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0);
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
if
(
mlp_after_attn
)
launch_bias_residual
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
attention_output
.
data_ptr
(),
(
float
*
)
output_b
.
data_ptr
(),
(
float
*
)
attention_b
.
data_ptr
(),
bsz
,
hidden_size
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
else
launch_gptj_residual_add
<
float
>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
attention_output
.
data_ptr
(),
(
float
*
)
output_b
.
data_ptr
(),
(
float
*
)
attention_b
.
data_ptr
(),
hidden_size
,
bsz
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
else
if
(
mlp_after_attn
)
launch_bias_residual
((
__half
*
)
input
.
data_ptr
(),
(
__half
*
)
output
.
data_ptr
(),
(
__half
*
)
attention_output
.
data_ptr
(),
(
__half
*
)
output_b
.
data_ptr
(),
(
__half
*
)
attention_b
.
data_ptr
(),
bsz
,
hidden_size
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
else
launch_gptj_residual_add
<
__half
>
((
__half
*
)
input
.
data_ptr
(),
(
__half
*
)
output
.
data_ptr
(),
(
__half
*
)
attention_output
.
data_ptr
(),
(
__half
*
)
output_b
.
data_ptr
(),
(
__half
*
)
attention_b
.
data_ptr
(),
hidden_size
,
bsz
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
}
std
::
vector
<
at
::
Tensor
>
apply_rotary_pos_emb
(
at
::
Tensor
&
mixed_query
,
at
::
Tensor
&
key_layer
,
unsigned
rotary_dim
,
unsigned
offset
,
unsigned
num_heads
,
bool
rotate_half
,
bool
rotate_every_two
)
{
auto
query_cont
=
mixed_query
.
contiguous
();
auto
key_cont
=
key_layer
.
contiguous
();
unsigned
bsz
=
mixed_query
.
size
(
0
);
unsigned
head_size
=
mixed_query
.
size
(
2
)
/
num_heads
;
unsigned
seq_len
=
mixed_query
.
size
(
1
);
if
(
mixed_query
.
scalar_type
()
==
at
::
kFloat
)
launch_apply_rotary_pos_emb
<
float
>
((
float
*
)
query_cont
.
data_ptr
(),
(
float
*
)
key_cont
.
data_ptr
(),
head_size
,
seq_len
,
rotary_dim
,
offset
,
num_heads
,
bsz
,
rotate_half
,
rotate_every_two
,
Context
::
Instance
().
GetCurrentStream
());
else
launch_apply_rotary_pos_emb
<
__half
>
((
__half
*
)
query_cont
.
data_ptr
(),
(
__half
*
)
key_cont
.
data_ptr
(),
head_size
,
seq_len
,
rotary_dim
,
offset
,
num_heads
,
bsz
,
rotate_half
,
rotate_every_two
,
Context
::
Instance
().
GetCurrentStream
());
return
{
query_cont
,
key_cont
};
}
template
<
typename
T
>
at
::
Tensor
fused_gemm_gelu_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
preLayerNorm
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
at
::
Tensor
moe_res_matmul
(
at
::
Tensor
&
moe_res
,
at
::
Tensor
&
coef
,
at
::
Tensor
&
output
)
{
int
M
=
moe_res
.
size
(
0
)
*
moe_res
.
size
(
1
);
int
N
=
moe_res
.
size
(
2
);
Context
::
Instance
().
SynchComm
();
if
(
moe_res
.
scalar_type
()
==
at
::
kFloat
)
{
launch_moe_res_matmul
<
float
>
((
float
*
)
moe_res
.
data_ptr
(),
(
float
*
)
coef
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
M
,
N
,
at
::
cuda
::
getCurrentCUDAStream
());
}
else
{
launch_moe_res_matmul
<
__half
>
((
__half
*
)
moe_res
.
data_ptr
(),
(
__half
*
)
coef
.
data_ptr
(),
(
__half
*
)
output
.
data_ptr
(),
M
,
N
,
at
::
cuda
::
getCurrentCUDAStream
());
}
return
output
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"softmax_fp32"
,
&
ds_softmax
<
float
>
,
"DeepSpeed SoftMax with fp32 (CUDA)"
);
m
.
def
(
"softmax_fp16"
,
&
ds_softmax
<
__half
>
,
"DeepSpeed SoftMax with fp32 (CUDA)"
);
m
.
def
(
"softmax_context_fp32"
,
&
ds_softmax_context
<
float
>
,
"DeepSpeed attention with fp32 (CUDA)"
);
m
.
def
(
"softmax_context_fp16"
,
&
ds_softmax_context
<
__half
>
,
"DeepSpeed attention with fp32 (CUDA)"
);
m
.
def
(
"bias_gelu_fp32"
,
&
ds_bias_gelu
<
float
>
,
"DeepSpeed Gelu with fp32 (CUDA)"
);
m
.
def
(
"bias_gelu_fp16"
,
&
ds_bias_gelu
<
__half
>
,
"DeepSpeed Gelu with fp32 (CUDA)"
);
m
.
def
(
"bias_residual_fp32"
,
&
ds_bias_residual
<
float
>
,
"DeepSpeed residual-bias add with fp32 (CUDA)"
);
m
.
def
(
"bias_residual_fp16"
,
&
ds_bias_residual
<
__half
>
,
"DeepSpeed residual-bias add with fp32 (CUDA)"
);
m
.
def
(
"layer_norm_fp32"
,
&
ds_layernorm
<
float
>
,
"DeepSpeed layer-norm with fp32 (CUDA)"
);
m
.
def
(
"layer_norm_fp16"
,
&
ds_layernorm
<
__half
>
,
"DeepSpeed layer-norm with fp16 (CUDA)"
);
m
.
def
(
"qkv_gemm_fp32"
,
&
ds_qkv_gemm
<
float
>
,
"DeepSpeed qkv gemm with fp32 (CUDA)"
);
m
.
def
(
"qkv_gemm_fp16"
,
&
ds_qkv_gemm
<
__half
>
,
"DeepSpeed qkv gemm with fp16 (CUDA)"
);
m
.
def
(
"qkv_gemm_int8"
,
&
ds_qkv_gemm_int8
<
__half
>
,
"DeepSpeed qkv gemm with int8 (CUDA)"
);
m
.
def
(
"mlp_gemm_fp32"
,
&
ds_mlp_gemm
<
float
>
,
"DeepSpeed mlp with fp32 (CUDA)"
);
m
.
def
(
"mlp_gemm_fp16"
,
&
ds_mlp_gemm
<
__half
>
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"mlp_gemm_int8"
,
&
ds_mlp_gemm_int8
<
__half
>
,
"DeepSpeed mlp with int8 (CUDA)"
);
m
.
def
(
"vector_matmul_fp32"
,
&
ds_vector_matmul
<
float
>
,
"DeepSpeed vector-MM with fp32 (CUDA)"
);
m
.
def
(
"vector_matmul_fp16"
,
&
ds_vector_matmul
<
__half
>
,
"DeepSpeed vector-MM with fp16 (CUDA)"
);
m
.
def
(
"vector_matmul_int8"
,
&
ds_vector_matmul_int8
<
__half
>
,
"DeepSpeed vector-MM with int8 (CUDA)"
);
m
.
def
(
"linear_layer_fp32"
,
&
ds_linear_layer
<
float
>
,
"DeepSpeed linear_layer with fp32 (CUDA)"
);
m
.
def
(
"linear_layer_fp16"
,
&
ds_linear_layer
<
__half
>
,
"DeepSpeed linear_layer with fp16 (CUDA)"
);
m
.
def
(
"linear_layer_int8"
,
&
ds_linear_layer_int8
<
__half
>
,
"DeepSpeed linear_layer with int8 (CUDA)"
);
m
.
def
(
"fused_gemm_gelu_fp32"
,
&
fused_gemm_gelu
<
float
>
,
"DeepSpeed mlp with fp32 (CUDA)"
);
m
.
def
(
"fused_gemm_gelu_fp16"
,
&
fused_gemm_gelu
<
__half
>
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"residual_add"
,
&
residual_add_bias
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"apply_rotary_pos_emb"
,
&
apply_rotary_pos_emb
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"einsum_sec_sm_ecm_fp32"
,
&
einsum_sec_sm_ecm
<
float
>
,
"DeepSpeed vector-MM with fp32 (CUDA)"
);
m
.
def
(
"einsum_sec_sm_ecm_fp16"
,
&
einsum_sec_sm_ecm
<
__half
>
,
"DeepSpeed vector-MM with fp16 (CUDA)"
);
m
.
def
(
"moe_res_matmul"
,
&
moe_res_matmul
,
"DeepSpeed moe residual matmul (CUDA)"
);
}
csrc/transformer_bak/inference/csrc/pt_binding_hip.cpp
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include <vector>
#include "context_hip.h"
#include "cublas_wrappers_hip.h"
#include "custom_hip_layers.h"
std
::
array
<
int
,
3
>
gemm_algos
=
std
::
array
<
int
,
3
>
({
99
,
99
,
99
});
#define MAX_OUT_TOKES 10
template
<
typename
T
>
at
::
Tensor
ds_softmax
(
at
::
Tensor
&
attn_scores
,
at
::
Tensor
&
attn_mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
bool
async_op
)
{
auto
attn_scores_c
=
attn_scores
.
contiguous
();
int
bsz
=
attn_scores_c
.
size
(
0
);
int
seq_len
=
attn_scores_c
.
size
(
1
);
int
len
=
attn_scores_c
.
sizes
().
size
();
if
(
len
>
3
)
seq_len
=
attn_scores_c
.
size
(
2
);
int
soft_len
=
attn_scores_c
.
size
(
2
);
if
(
len
>
3
)
soft_len
=
attn_scores_c
.
size
(
3
);
int
heads
=
1
;
if
(
len
>
3
)
heads
=
attn_scores_c
.
size
(
1
);
launch_attn_softmax_v2
((
T
*
)
attn_scores_c
.
data_ptr
(),
(
attn_mask
.
sizes
().
size
()
>
1
?
(
T
*
)
attn_mask
.
data_ptr
()
:
nullptr
),
triangular
,
recompute
,
local_attention
,
window_size
,
bsz
,
heads
,
seq_len
,
soft_len
,
1.0
,
Context
::
Instance
().
GetCurrentStream
(
async_op
));
return
attn_scores_c
;
}
template
<
typename
T
>
void
allocate_workspace
(
size_t
hidden_dim
,
size_t
max_seq_len
,
size_t
batch_size
,
size_t
head_size
=
128
)
{
size_t
_workSpaceSize
=
(
hidden_dim
*
batch_size
*
max_seq_len
);
Context
::
Instance
().
GenWorkSpace
(
_workSpaceSize
*
sizeof
(
T
));
}
template
<
typename
T
>
at
::
Tensor
einsum_sec_sm_ecm
(
at
::
Tensor
&
Q
,
at
::
Tensor
&
W
)
{
auto
options
=
at
::
TensorOptions
()
.
dtype
(
Q
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
float
alpha
=
1
;
float
gemm_beta
=
0.0
;
if
(
!
workspace
)
{
allocate_workspace
<
T
>
(
W
.
size
(
1
),
MAX_OUT_TOKES
,
Q
.
size
(
0
));
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
}
auto
O
=
at
::
from_blob
(
workspace
,
{
Q
.
size
(
1
),
Q
.
size
(
2
),
W
.
size
(
1
)},
options
);
unsigned
m
=
W
.
size
(
1
);
unsigned
n
=
Q
.
size
(
1
)
*
Q
.
size
(
2
);
unsigned
k
=
Q
.
size
(
0
);
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_transpose
,
m
,
n
,
k
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
W
.
data_ptr
(),
(
T
*
)
Q
.
data_ptr
(),
(
T
*
)
O
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
return
O
;
}
template
<
typename
T
>
void
attention_unfused
(
at
::
Tensor
&
prev_key_cont
,
at
::
Tensor
&
query_cont
,
at
::
Tensor
&
attn_mask
,
at
::
Tensor
&
prev_value_cont
,
at
::
Tensor
&
output
,
int
&
bsz
,
int
&
seq_len
,
int
&
soft_len
,
int
&
heads
,
float
&
norm_factor
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
)
{
auto
options
=
at
::
TensorOptions
()
.
dtype
(
query_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
float
alpha
=
norm_factor
;
float
gemm_beta
=
0.0
;
auto
attn_score
=
at
::
empty
({
bsz
,
heads
,
seq_len
,
soft_len
},
options
);
int
k
=
prev_value_cont
.
size
(
2
)
/
heads
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_strided_batched_gemm
(
Context
::
Instance
().
GetCublasHandle
(),
soft_len
,
seq_len
,
k
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
prev_key_cont
.
data_ptr
(),
(
T
*
)
query_cont
.
data_ptr
(),
(
T
*
)
attn_score
.
data_ptr
(),
rocblas_operation_none
,
rocblas_operation_none
,
soft_len
*
k
,
seq_len
*
k
,
seq_len
*
soft_len
,
bsz
*
heads
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
attn_score
=
ds_softmax
<
T
>
(
attn_score
,
attn_mask
,
triangular
,
recompute
,
local_attention
,
window_size
,
false
);
alpha
=
1.0
;
cublas_strided_batched_gemm
(
Context
::
Instance
().
GetCublasHandle
(),
k
,
seq_len
,
soft_len
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
prev_value_cont
.
data_ptr
(),
(
T
*
)
attn_score
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
rocblas_operation_none
,
rocblas_operation_none
,
soft_len
*
k
,
seq_len
*
soft_len
,
seq_len
*
k
,
bsz
*
heads
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_softmax_context
(
at
::
Tensor
&
query
,
at
::
Tensor
&
prev_key
,
at
::
Tensor
&
new_key
,
at
::
Tensor
&
attn_mask
,
at
::
Tensor
&
prev_value
,
at
::
Tensor
&
new_value
,
int
heads
,
float
norm_factor
,
bool
merging
,
bool
triangular
,
bool
local_attention
,
int
window_size
,
bool
no_masking
)
{
auto
query_cont
=
query
.
contiguous
();
auto
prev_key_cont
=
prev_key
.
contiguous
();
auto
prev_value_cont
=
prev_value
.
contiguous
();
int
new_size
=
(
new_value
.
sizes
().
size
()
>
1
?
new_value
.
size
(
1
)
:
0
);
// Attn_Score [ batch Head Sequence-length Softmax-length]
int
bsz
=
query_cont
.
size
(
0
);
int
seq_len
=
query_cont
.
size
(
1
);
int
soft_len
=
prev_value
.
size
(
1
);
auto
options
=
at
::
TensorOptions
()
.
dtype
(
query_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
prev_value
.
size
(
0
),
heads
,
seq_len
,
prev_value
.
size
(
2
)
/
heads
},
options
);
attention_unfused
<
T
>
(
prev_key_cont
,
query_cont
,
attn_mask
,
//(no_masking ? nullptr : (T*)attn_mask.data_ptr()),
prev_value_cont
,
output
,
bsz
,
seq_len
,
soft_len
,
heads
,
norm_factor
,
(
triangular
&&
(
new_size
==
0
)),
(
new_size
==
0
),
local_attention
,
window_size
);
return
{
output
,
prev_key
,
prev_value
};
}
template
<
typename
T
>
at
::
Tensor
ds_bias_gelu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
int
intermediate_size
=
input_cont
.
size
(
2
);
launch_bias_gelu
((
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
intermediate_size
,
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
input_cont
;
}
template
<
typename
T
>
at
::
Tensor
ds_bias_residual
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
auto
residual_cont
=
residual
.
contiguous
();
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
// launch_bias_residual((T*)input_cont.data_ptr(),
// (T*)residual_cont.data_ptr(),
// (T*)bias.data_ptr(),
// bsz,
// input_cont.size(2),
// (bias.size(0) > 1),
// Context::Instance().GetCurrentStream());
return
input_cont
;
}
template
<
typename
T
>
at
::
Tensor
ds_layernorm
(
at
::
Tensor
&
input_cont
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
betta
,
float
epsilon
)
{
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
inp_norm
=
at
::
empty_like
(
input_cont
);
launch_layer_norm
((
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
gamma
.
data_ptr
(),
(
T
*
)
betta
.
data_ptr
(),
epsilon
,
bsz
,
input_cont
.
size
(
2
),
Context
::
Instance
().
GetCurrentStream
());
return
inp_norm
;
}
template
<
typename
T
>
at
::
Tensor
qkv_unfused_cublas
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
add_bias
)
{
auto
inp_norm
=
ds_layernorm
<
T
>
(
input
,
gamma
,
beta
,
epsilon
);
// hipEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
if
(
add_bias
)
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
inp_norm
;
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_qkv_gemm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
add_bias
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
inp_norm
=
qkv_unfused_cublas
<
T
>
(
output
,
input_cont
,
weight
,
bias
,
gamma
,
beta
,
epsilon
,
add_bias
);
return
{
output
,
inp_norm
};
}
template
<
typename
T
>
void
quantized_gemm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
qscale
,
int
groups
,
int
merge_count
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
weight16
=
at
::
empty
({
weight
.
size
(
0
),
weight
.
size
(
1
)},
options
);
launch_dequantize
((
T
*
)
weight16
.
data_ptr
(),
(
int8_t
*
)
weight
.
data_ptr
(),
(
float
*
)
qscale
.
data_ptr
(),
weight
.
size
(
1
),
weight
.
size
(
0
),
groups
,
merge_count
,
Context
::
Instance
().
GetCurrentStream
());
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight16
.
data_ptr
(),
(
T
*
)
input
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
at
::
Tensor
ds_qkv_gemm_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
add_bias
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
auto
inp_norm
=
ds_layernorm
<
T
>
(
input_cont
,
gamma
,
beta
,
epsilon
);
quantized_gemm
<
T
>
(
output
,
inp_norm
,
weight
,
q_scale
,
groups
,
0
);
if
(
add_bias
)
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_linear_layer
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input_cont
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_linear_layer_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
q_scale
,
int
groups
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_vector_matmul
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
bool
async_op
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
(
async_op
));
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input_cont
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_vector_matmul_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
q_scale
,
int
groups
,
int
merge_count
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
merge_count
);
return
output
;
}
template
<
typename
T
>
void
mlp_unfused_cublas
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
mlp_after_attn
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
inp_norm
=
at
::
empty_like
(
input
);
launch_residual_layer_norm
((
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
nullptr
,
(
T
*
)
input
.
data_ptr
(),
(
T
*
)
residual
.
data_ptr
(),
(
T
*
)
input_bias
.
data_ptr
(),
(
T
*
)
gamma
.
data_ptr
(),
(
T
*
)
beta
.
data_ptr
(),
epsilon
,
bsz
,
input
.
size
(
2
),
preLayerNorm
,
mlp_after_attn
,
Context
::
Instance
().
GetCurrentStream
());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
}
template
<
typename
T
>
at
::
Tensor
ds_mlp_gemm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
mlp_after_attn
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
mlp_unfused_cublas
<
T
>
(
output
,
mlp_after_attn
?
input
:
residual
,
residual
,
input_bias
,
weight
,
bias
,
gamma
,
beta
,
epsilon
,
preLayerNorm
,
mlp_after_attn
);
return
output
;
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_mlp_gemm_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
preLayerNorm
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
inp_norm
=
at
::
empty_like
(
input_cont
);
auto
residual_add
=
(
preLayerNorm
?
at
::
empty_like
(
input_cont
)
:
inp_norm
);
// computing the blocking across K dimension
// launch_residual_layer_norm((T*)inp_norm.data_ptr(),
// (T*)residual_add.data_ptr(),
// (T*)input_cont.data_ptr(),
// (T*)residual.data_ptr(),
// (T*)input_bias.data_ptr(),
// (T*)gamma.data_ptr(),
// (T*)beta.data_ptr(),
// epsilon,
// bsz,
// input_cont.size(2),
// preLayerNorm,
// Context::Instance().GetCurrentStream());
quantized_gemm
<
T
>
(
output
,
inp_norm
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
{
output
,
residual_add
};
}
template
<
typename
T
>
at
::
Tensor
fused_gemm_gelu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
weight_out
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
async_op
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
intermediate
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight_out
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
intermediate
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
launch_bias_gelu
((
T
*
)
intermediate
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight_out
.
size
(
1
),
bsz
,
intermediate
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight_out
.
data_ptr
(),
(
T
*
)
intermediate
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
// hipEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return
output
;
}
void
residual_add_bias
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
attention_output
,
at
::
Tensor
&
output_b
,
at
::
Tensor
&
attention_b
,
int
mp_size
,
bool
mlp_after_attn
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
int
hidden_size
=
input
.
size
(
2
);
// hipStreamWaitEvent(
// Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0);
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
if
(
mlp_after_attn
)
launch_bias_residual
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
attention_output
.
data_ptr
(),
(
float
*
)
output_b
.
data_ptr
(),
(
float
*
)
attention_b
.
data_ptr
(),
bsz
,
hidden_size
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
else
launch_gptj_residual_add
<
float
>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
attention_output
.
data_ptr
(),
(
float
*
)
output_b
.
data_ptr
(),
(
float
*
)
attention_b
.
data_ptr
(),
hidden_size
,
bsz
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
else
if
(
mlp_after_attn
)
launch_bias_residual
((
__half
*
)
input
.
data_ptr
(),
(
__half
*
)
output
.
data_ptr
(),
(
__half
*
)
attention_output
.
data_ptr
(),
(
__half
*
)
output_b
.
data_ptr
(),
(
__half
*
)
attention_b
.
data_ptr
(),
bsz
,
hidden_size
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
else
launch_gptj_residual_add
<
__half
>
((
__half
*
)
input
.
data_ptr
(),
(
__half
*
)
output
.
data_ptr
(),
(
__half
*
)
attention_output
.
data_ptr
(),
(
__half
*
)
output_b
.
data_ptr
(),
(
__half
*
)
attention_b
.
data_ptr
(),
hidden_size
,
bsz
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
}
std
::
vector
<
at
::
Tensor
>
apply_rotary_pos_emb
(
at
::
Tensor
&
mixed_query
,
at
::
Tensor
&
key_layer
,
unsigned
rotary_dim
,
unsigned
offset
,
unsigned
num_heads
,
bool
rotate_half
,
bool
rotate_every_two
)
{
auto
query_cont
=
mixed_query
.
contiguous
();
auto
key_cont
=
key_layer
.
contiguous
();
unsigned
bsz
=
mixed_query
.
size
(
0
);
unsigned
head_size
=
mixed_query
.
size
(
2
)
/
num_heads
;
unsigned
seq_len
=
mixed_query
.
size
(
1
);
if
(
mixed_query
.
scalar_type
()
==
at
::
kFloat
)
launch_apply_rotary_pos_emb
<
float
>
((
float
*
)
query_cont
.
data_ptr
(),
(
float
*
)
key_cont
.
data_ptr
(),
head_size
,
seq_len
,
rotary_dim
,
offset
,
num_heads
,
bsz
,
rotate_half
,
rotate_every_two
,
Context
::
Instance
().
GetCurrentStream
());
else
launch_apply_rotary_pos_emb
<
__half
>
((
__half
*
)
query_cont
.
data_ptr
(),
(
__half
*
)
key_cont
.
data_ptr
(),
head_size
,
seq_len
,
rotary_dim
,
offset
,
num_heads
,
bsz
,
rotate_half
,
rotate_every_two
,
Context
::
Instance
().
GetCurrentStream
());
return
{
query_cont
,
key_cont
};
}
template
<
typename
T
>
at
::
Tensor
fused_gemm_gelu_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
preLayerNorm
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
at
::
Tensor
moe_res_matmul
(
at
::
Tensor
&
moe_res
,
at
::
Tensor
&
coef
,
at
::
Tensor
&
output
)
{
int
M
=
moe_res
.
size
(
0
)
*
moe_res
.
size
(
1
);
int
N
=
moe_res
.
size
(
2
);
Context
::
Instance
().
SynchComm
();
if
(
moe_res
.
scalar_type
()
==
at
::
kFloat
)
{
launch_moe_res_matmul
<
float
>
((
float
*
)
moe_res
.
data_ptr
(),
(
float
*
)
coef
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
M
,
N
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
else
{
launch_moe_res_matmul
<
__half
>
((
__half
*
)
moe_res
.
data_ptr
(),
(
__half
*
)
coef
.
data_ptr
(),
(
__half
*
)
output
.
data_ptr
(),
M
,
N
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
return
output
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"softmax_fp32"
,
&
ds_softmax
<
float
>
,
"DeepSpeed SoftMax with fp32 (CUDA)"
);
m
.
def
(
"softmax_fp16"
,
&
ds_softmax
<
__half
>
,
"DeepSpeed SoftMax with fp32 (CUDA)"
);
m
.
def
(
"softmax_context_fp32"
,
&
ds_softmax_context
<
float
>
,
"DeepSpeed attention with fp32 (CUDA)"
);
m
.
def
(
"softmax_context_fp16"
,
&
ds_softmax_context
<
__half
>
,
"DeepSpeed attention with fp32 (CUDA)"
);
m
.
def
(
"bias_gelu_fp32"
,
&
ds_bias_gelu
<
float
>
,
"DeepSpeed Gelu with fp32 (CUDA)"
);
m
.
def
(
"bias_gelu_fp16"
,
&
ds_bias_gelu
<
__half
>
,
"DeepSpeed Gelu with fp32 (CUDA)"
);
m
.
def
(
"bias_residual_fp32"
,
&
ds_bias_residual
<
float
>
,
"DeepSpeed residual-bias add with fp32 (CUDA)"
);
m
.
def
(
"bias_residual_fp16"
,
&
ds_bias_residual
<
__half
>
,
"DeepSpeed residual-bias add with fp32 (CUDA)"
);
m
.
def
(
"layer_norm_fp32"
,
&
ds_layernorm
<
float
>
,
"DeepSpeed layer-norm with fp32 (CUDA)"
);
m
.
def
(
"layer_norm_fp16"
,
&
ds_layernorm
<
__half
>
,
"DeepSpeed layer-norm with fp16 (CUDA)"
);
m
.
def
(
"qkv_gemm_fp32"
,
&
ds_qkv_gemm
<
float
>
,
"DeepSpeed qkv gemm with fp32 (CUDA)"
);
m
.
def
(
"qkv_gemm_fp16"
,
&
ds_qkv_gemm
<
__half
>
,
"DeepSpeed qkv gemm with fp16 (CUDA)"
);
m
.
def
(
"qkv_gemm_int8"
,
&
ds_qkv_gemm_int8
<
__half
>
,
"DeepSpeed qkv gemm with int8 (CUDA)"
);
m
.
def
(
"mlp_gemm_fp32"
,
&
ds_mlp_gemm
<
float
>
,
"DeepSpeed mlp with fp32 (CUDA)"
);
m
.
def
(
"mlp_gemm_fp16"
,
&
ds_mlp_gemm
<
__half
>
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"mlp_gemm_int8"
,
&
ds_mlp_gemm_int8
<
__half
>
,
"DeepSpeed mlp with int8 (CUDA)"
);
m
.
def
(
"vector_matmul_fp32"
,
&
ds_vector_matmul
<
float
>
,
"DeepSpeed vector-MM with fp32 (CUDA)"
);
m
.
def
(
"vector_matmul_fp16"
,
&
ds_vector_matmul
<
__half
>
,
"DeepSpeed vector-MM with fp16 (CUDA)"
);
m
.
def
(
"vector_matmul_int8"
,
&
ds_vector_matmul_int8
<
__half
>
,
"DeepSpeed vector-MM with int8 (CUDA)"
);
m
.
def
(
"linear_layer_fp32"
,
&
ds_linear_layer
<
float
>
,
"DeepSpeed linear_layer with fp32 (CUDA)"
);
m
.
def
(
"linear_layer_fp16"
,
&
ds_linear_layer
<
__half
>
,
"DeepSpeed linear_layer with fp16 (CUDA)"
);
m
.
def
(
"linear_layer_int8"
,
&
ds_linear_layer_int8
<
__half
>
,
"DeepSpeed linear_layer with int8 (CUDA)"
);
m
.
def
(
"fused_gemm_gelu_fp32"
,
&
fused_gemm_gelu
<
float
>
,
"DeepSpeed mlp with fp32 (CUDA)"
);
m
.
def
(
"fused_gemm_gelu_fp16"
,
&
fused_gemm_gelu
<
__half
>
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"residual_add"
,
&
residual_add_bias
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"apply_rotary_pos_emb"
,
&
apply_rotary_pos_emb
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"einsum_sec_sm_ecm_fp32"
,
&
einsum_sec_sm_ecm
<
float
>
,
"DeepSpeed vector-MM with fp32 (CUDA)"
);
m
.
def
(
"einsum_sec_sm_ecm_fp16"
,
&
einsum_sec_sm_ecm
<
__half
>
,
"DeepSpeed vector-MM with fp16 (CUDA)"
);
m
.
def
(
"moe_res_matmul"
,
&
moe_res_matmul
,
"DeepSpeed moe residual matmul (CUDA)"
);
}
csrc/transformer_bak/inference/csrc/softmax.cu
deleted
100644 → 0
View file @
7dd68788
#include <limits>
#include "custom_cuda_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define ATTN_THREADS 1024
#define MAX_REG_SIZE 8
#define minus_infinity -10000.0
void
CheckCudaErrorAux
(
const
char
*
file
,
unsigned
line
)
{
cudaError_t
err
=
cudaGetLastError
();
if
(
err
==
cudaSuccess
)
return
;
std
::
cerr
<<
cudaGetErrorString
(
err
)
<<
"("
<<
err
<<
") at "
<<
file
<<
":"
<<
line
<<
std
::
endl
;
throw
std
::
runtime_error
(
"CUDA ERROR!!!
\n
"
);
}
#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__)
namespace
cg
=
cooperative_groups
;
__global__
void
attn_softmax_v2
(
__half
*
vals
,
__half
*
mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
total_count
,
int
heads
,
int
sequence_length
,
int
num_seq
,
float
scale
,
int
iterations
,
int
reduceWidth
)
{
#ifdef HALF_PRECISION_AVAILABLE
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
float2
low_data
[
MAX_REG_SIZE
];
float2
high_data
[
MAX_REG_SIZE
];
__half2
h_scale
=
__float2half2_rn
(
scale
);
int
wid
=
threadIdx
.
x
>>
5
;
int
lane
=
threadIdx
.
x
&
0x1f
;
int
warp_num
=
blockDim
.
x
>>
5
;
int
reduce_blocks
=
reduceWidth
>>
5
;
int
seq_lane
=
threadIdx
.
x
%
reduceWidth
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
int
iter_offset
=
blockIdx
.
x
*
(
warp_num
/
reduce_blocks
)
+
(
wid
/
reduce_blocks
);
if
(
iter_offset
<
total_count
)
{
vals
+=
(
iter_offset
*
sequence_length
);
int
mask_offset
=
(
iter_offset
/
(
heads
*
num_seq
))
*
(
sequence_length
);
int
seq_id
=
iter_offset
%
num_seq
;
int
seq_id4
=
seq_id
>>
2
;
int
real_seq_id
=
seq_id
+
(
num_seq
==
sequence_length
?
0
:
sequence_length
);
int
window_stride4
=
(
local_attention
&&
(
real_seq_id
>>
2
)
>
(
window_size
>>
2
))
?
(
real_seq_id
>>
2
)
-
(
window_size
>>
2
)
:
0
;
int
window_stride
=
(
local_attention
&&
real_seq_id
>=
window_size
)
?
real_seq_id
-
window_size
:
-
1
;
float
max_val
=
minus_infinity
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
<<
2
);
if
((
!
triangular
||
((
data_id
>>
2
)
<=
seq_id4
))
&&
(
data_id
>>
2
)
>=
window_stride4
&&
data_id
<
sequence_length
)
{
if
((
sequence_length
-
data_id
)
>=
4
)
{
low_data
[
i
].
x
=
data_id
>
window_stride
?
__half2float
(
vals
[
data_id
])
:
minus_infinity
;
low_data
[
i
].
y
=
((
!
triangular
||
((
data_id
+
1
)
<=
seq_id
))
&&
(
data_id
+
1
)
>
window_stride
)
?
__half2float
(
vals
[
data_id
+
1
])
:
minus_infinity
;
high_data
[
i
].
x
=
((
!
triangular
||
((
data_id
+
2
)
<=
seq_id
))
&&
(
data_id
+
2
)
>
window_stride
)
?
__half2float
(
vals
[
data_id
+
2
])
:
minus_infinity
;
high_data
[
i
].
y
=
((
!
triangular
||
((
data_id
+
3
)
<=
seq_id
))
&&
(
data_id
+
3
)
>
window_stride
)
?
__half2float
(
vals
[
data_id
+
3
])
:
minus_infinity
;
if
(
mask
&&
recompute
)
{
low_data
[
i
].
x
+=
__half2float
(
mask
[
data_id
+
mask_offset
]);
low_data
[
i
].
y
+=
__half2float
(
mask
[
data_id
+
mask_offset
+
1
]);
high_data
[
i
].
x
+=
__half2float
(
mask
[
data_id
+
mask_offset
+
2
]);
high_data
[
i
].
y
+=
__half2float
(
mask
[
data_id
+
mask_offset
+
3
]);
}
}
else
{
low_data
[
i
].
x
=
data_id
>
window_stride
?
__half2float
(
vals
[
data_id
])
:
minus_infinity
;
low_data
[
i
].
y
=
(((
!
triangular
||
(
data_id
+
1
)
<=
seq_id
)
&&
(
data_id
+
1
)
>
window_stride
)
&&
(
data_id
+
1
)
<
sequence_length
)
?
__half2float
(
vals
[
data_id
+
1
])
:
minus_infinity
;
high_data
[
i
].
x
=
(((
!
triangular
||
(
data_id
+
2
)
<=
seq_id
)
&&
(
data_id
+
2
)
>
window_stride
)
&&
(
data_id
+
2
)
<
sequence_length
)
?
__half2float
(
vals
[
data_id
+
2
])
:
minus_infinity
;
high_data
[
i
].
y
=
minus_infinity
;
if
(
mask
&&
recompute
)
{
low_data
[
i
].
x
+=
__half2float
(
mask
[
data_id
+
mask_offset
]);
if
((
data_id
+
1
)
<
sequence_length
)
low_data
[
i
].
y
+=
__half2float
(
mask
[
data_id
+
mask_offset
+
1
]);
if
((
data_id
+
2
)
<
sequence_length
)
high_data
[
i
].
x
+=
__half2float
(
mask
[
data_id
+
mask_offset
+
2
]);
}
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val
=
(
low_data
[
i
].
x
>
max_val
?
low_data
[
i
].
x
:
max_val
);
max_val
=
(
low_data
[
i
].
y
>
max_val
?
low_data
[
i
].
y
:
max_val
);
max_val
=
(
high_data
[
i
].
x
>
max_val
?
high_data
[
i
].
x
:
max_val
);
max_val
=
(
high_data
[
i
].
y
>
max_val
?
high_data
[
i
].
y
:
max_val
);
}
else
{
low_data
[
i
].
x
=
minus_infinity
;
low_data
[
i
].
y
=
minus_infinity
;
high_data
[
i
].
x
=
minus_infinity
;
high_data
[
i
].
y
=
minus_infinity
;
}
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
if
(
reduceWidth
>
WARP_SIZE
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
max_val
;
b
.
sync
();
if
(
lane
<
warp_num
)
max_val
=
partialSum
[
lane
];
b
.
sync
();
for
(
int
i
=
1
;
i
<
reduce_blocks
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
max_val
=
g
.
shfl
(
max_val
,
threadIdx
.
x
/
WARP_SIZE
);
}
float
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
low_data
[
i
].
x
=
__expf
(
low_data
[
i
].
x
-
max_val
);
low_data
[
i
].
y
=
__expf
(
low_data
[
i
].
y
-
max_val
);
high_data
[
i
].
x
=
__expf
(
high_data
[
i
].
x
-
max_val
);
high_data
[
i
].
y
=
__expf
(
high_data
[
i
].
y
-
max_val
);
sum
+=
(
low_data
[
i
].
x
+
low_data
[
i
].
y
+
high_data
[
i
].
x
+
high_data
[
i
].
y
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
sum
+=
g
.
shfl_xor
(
sum
,
i
);
if
(
reduceWidth
>
WARP_SIZE
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
sum
;
b
.
sync
();
if
(
lane
<
warp_num
)
sum
=
partialSum
[
lane
];
b
.
sync
();
for
(
int
i
=
1
;
i
<
reduce_blocks
;
i
*=
2
)
{
sum
+=
g
.
shfl_xor
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
threadIdx
.
x
/
WARP_SIZE
);
}
sum
+=
1e-6
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
<<
2
);
if
(
data_id
<
sequence_length
)
{
if
((
sequence_length
-
data_id
)
>=
4
)
{
vals
[
data_id
]
=
low_data
[
i
].
x
/
sum
;
vals
[
data_id
+
1
]
=
low_data
[
i
].
y
/
sum
;
vals
[
data_id
+
2
]
=
high_data
[
i
].
x
/
sum
;
vals
[
data_id
+
3
]
=
high_data
[
i
].
y
/
sum
;
}
else
{
vals
[
data_id
]
=
low_data
[
i
].
x
/
sum
;
if
((
data_id
+
1
)
<
sequence_length
)
vals
[
data_id
+
1
]
=
low_data
[
i
].
y
/
sum
;
if
((
data_id
+
2
)
<
sequence_length
)
vals
[
data_id
+
2
]
=
high_data
[
i
].
x
/
sum
;
}
}
}
}
#endif
}
__global__
void
attn_softmax_v2
(
float
*
vals
,
float
*
attn_mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
total_count
,
int
heads
,
int
sequence_length
,
int
num_seq
,
float
scale
,
int
iterations
,
int
reduceWidth
)
{
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
float4
data
[
MAX_REG_SIZE
];
int
wid
=
threadIdx
.
x
>>
5
;
int
lane
=
threadIdx
.
x
&
0x1f
;
int
warp_num
=
blockDim
.
x
>>
5
;
int
reduce_blocks
=
reduceWidth
>>
5
;
int
seq_lane
=
threadIdx
.
x
%
reduceWidth
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
int
iter_offset
=
blockIdx
.
x
*
(
warp_num
/
reduce_blocks
)
+
(
wid
/
reduce_blocks
);
if
(
iter_offset
<
total_count
)
{
vals
+=
(
iter_offset
*
sequence_length
);
int
mask_offset
=
(
iter_offset
/
(
heads
*
num_seq
))
*
(
sequence_length
);
int
seq_id
=
iter_offset
%
num_seq
;
int
seq_id4
=
seq_id
>>
2
;
int
real_seq_id
=
seq_id
+
(
num_seq
==
sequence_length
?
0
:
sequence_length
);
int
window_stride4
=
(
local_attention
&&
(
real_seq_id
>>
2
)
>
(
window_size
>>
2
))
?
(
real_seq_id
>>
2
)
-
(
window_size
>>
2
)
:
0
;
int
window_stride
=
(
local_attention
&&
real_seq_id
>=
window_size
)
?
real_seq_id
-
window_size
:
-
1
;
float
max_val
=
minus_infinity
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
<<
2
);
if
((
!
triangular
||
((
data_id
>>
2
)
<=
seq_id4
))
&&
(
data_id
>>
2
)
>=
window_stride4
&&
data_id
<
sequence_length
)
{
if
((
sequence_length
-
data_id
)
>=
4
)
{
data
[
i
].
x
=
(
data_id
>
window_stride
?
vals
[
data_id
]
:
minus_infinity
);
data
[
i
].
y
=
((
!
triangular
||
((
data_id
+
1
)
<=
seq_id
))
&&
(
data_id
+
1
)
>
window_stride
)
?
vals
[
data_id
+
1
]
:
minus_infinity
;
data
[
i
].
z
=
((
!
triangular
||
((
data_id
+
2
)
<=
seq_id
))
&&
(
data_id
+
2
)
>
window_stride
)
?
vals
[
data_id
+
2
]
:
minus_infinity
;
data
[
i
].
w
=
((
!
triangular
||
((
data_id
+
3
)
<=
seq_id
))
&&
(
data_id
+
3
)
>
window_stride
)
?
vals
[
data_id
+
3
]
:
minus_infinity
;
if
(
attn_mask
&&
recompute
)
{
data
[
i
].
x
+=
attn_mask
[
data_id
+
mask_offset
];
data
[
i
].
y
+=
attn_mask
[
data_id
+
mask_offset
+
1
];
data
[
i
].
z
+=
attn_mask
[
data_id
+
mask_offset
+
2
];
data
[
i
].
w
+=
attn_mask
[
data_id
+
mask_offset
+
3
];
}
}
else
{
data
[
i
].
x
=
data_id
>
window_stride
?
vals
[
data_id
]
:
minus_infinity
;
data
[
i
].
y
=
(((
!
triangular
||
(
data_id
+
1
)
<=
seq_id
))
&&
(
data_id
+
1
)
>
window_stride
&&
(
data_id
+
1
)
<
sequence_length
)
?
(
vals
[
data_id
+
1
])
:
minus_infinity
;
data
[
i
].
z
=
(((
!
triangular
||
(
data_id
+
2
)
<=
seq_id
))
&&
(
data_id
+
2
)
>
window_stride
&&
(
data_id
+
2
)
<
sequence_length
)
?
(
vals
[
data_id
+
2
])
:
minus_infinity
;
data
[
i
].
w
=
minus_infinity
;
if
(
attn_mask
&&
recompute
)
{
data
[
i
].
x
+=
attn_mask
[
data_id
+
mask_offset
];
if
((
data_id
+
1
)
<
sequence_length
)
data
[
i
].
y
+=
attn_mask
[
data_id
+
mask_offset
+
1
];
if
((
data_id
+
2
)
<
sequence_length
)
data
[
i
].
z
+=
attn_mask
[
data_id
+
mask_offset
+
2
];
}
}
max_val
=
(
data
[
i
].
x
>
max_val
?
data
[
i
].
x
:
max_val
);
max_val
=
(
data
[
i
].
y
>
max_val
?
data
[
i
].
y
:
max_val
);
max_val
=
(
data
[
i
].
z
>
max_val
?
data
[
i
].
z
:
max_val
);
max_val
=
(
data
[
i
].
w
>
max_val
?
data
[
i
].
w
:
max_val
);
}
else
{
data
[
i
].
x
=
minus_infinity
;
data
[
i
].
y
=
minus_infinity
;
data
[
i
].
z
=
minus_infinity
;
data
[
i
].
w
=
minus_infinity
;
}
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
if
(
reduceWidth
>
WARP_SIZE
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
max_val
;
b
.
sync
();
if
(
lane
<
warp_num
)
max_val
=
partialSum
[
lane
];
b
.
sync
();
for
(
int
i
=
1
;
i
<
reduce_blocks
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
max_val
=
g
.
shfl
(
max_val
,
threadIdx
.
x
/
WARP_SIZE
);
}
float
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
data
[
i
].
x
=
__expf
(
data
[
i
].
x
-
max_val
);
data
[
i
].
y
=
__expf
(
data
[
i
].
y
-
max_val
);
data
[
i
].
z
=
__expf
(
data
[
i
].
z
-
max_val
);
data
[
i
].
w
=
__expf
(
data
[
i
].
w
-
max_val
);
sum
+=
(
data
[
i
].
x
+
data
[
i
].
y
+
data
[
i
].
z
+
data
[
i
].
w
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
sum
+=
g
.
shfl_xor
(
sum
,
i
);
if
(
reduceWidth
>
WARP_SIZE
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
sum
;
b
.
sync
();
if
(
lane
<
warp_num
)
sum
=
partialSum
[
lane
];
b
.
sync
();
for
(
int
i
=
1
;
i
<
reduce_blocks
;
i
*=
2
)
{
sum
+=
g
.
shfl_xor
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
threadIdx
.
x
/
WARP_SIZE
);
}
sum
+=
1e-6
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
<<
2
);
if
(
data_id
<
sequence_length
)
{
if
((
sequence_length
-
data_id
)
>=
4
)
{
vals
[
data_id
]
=
data
[
i
].
x
/
sum
;
vals
[
data_id
+
1
]
=
data
[
i
].
y
/
sum
;
vals
[
data_id
+
2
]
=
data
[
i
].
z
/
sum
;
vals
[
data_id
+
3
]
=
data
[
i
].
w
/
sum
;
}
else
{
vals
[
data_id
]
=
data
[
i
].
x
/
sum
;
if
((
data_id
+
1
)
<
sequence_length
)
vals
[
data_id
+
1
]
=
data
[
i
].
y
/
sum
;
if
((
data_id
+
2
)
<
sequence_length
)
vals
[
data_id
+
2
]
=
data
[
i
].
z
/
sum
;
}
}
}
}
}
template
<
typename
T
>
void
launch_attn_softmax_v2
(
T
*
vals
,
T
*
mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
batch_size
,
int
heads
,
int
num_seq
,
int
sequence_length
,
float
scale
,
cudaStream_t
stream
)
{
int
total_count
=
batch_size
*
heads
*
num_seq
;
dim3
grid_dim
((
total_count
-
1
)
/
(
WARP_SIZE
/
((
sequence_length
-
1
)
/
ATTN_THREADS
+
1
))
+
1
);
dim3
block_dim
(
ATTN_THREADS
);
const
int
reduce_width
=
((
sequence_length
-
1
)
/
ATTN_THREADS
+
1
)
*
WARP_SIZE
;
const
int
iterations
=
(
sequence_length
-
1
)
/
(
reduce_width
<<
2
)
+
1
;
if
(
sequence_length
<=
32768
)
attn_softmax_v2
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
mask
,
triangular
,
recompute
,
local_attention
,
window_size
,
total_count
,
(
triangular
?
(
heads
*
batch_size
)
:
heads
),
sequence_length
,
num_seq
,
scale
,
iterations
,
reduce_width
);
else
throw
std
::
runtime_error
(
"Unsupport Seq_Length!"
);
}
template
void
launch_attn_softmax_v2
(
float
*
vals
,
float
*
mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
batch_size
,
int
heads
,
int
num_seq
,
int
sequence_length
,
float
scale
,
cudaStream_t
stream
);
template
void
launch_attn_softmax_v2
(
__half
*
vals
,
__half
*
mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
batch_size
,
int
heads
,
int
num_seq
,
int
sequence_length
,
float
scale
,
cudaStream_t
stream
);
csrc/transformer_bak/inference/csrc/softmax.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <limits>
#include "custom_hip_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define ATTN_THREADS 1024
#define MAX_REG_SIZE 8
#define minus_infinity -10000.0
void CheckCudaErrorAux(const char* file, unsigned line)
{
hipError_t err = hipGetLastError();
if (err == hipSuccess) return;
std::cerr << hipGetErrorString(err) << "(" << err << ") at " << file << ":" << line
<< std::endl;
throw std::runtime_error("CUDA ERROR!!!\n");
}
#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__)
namespace cg = cooperative_groups;
__global__ void attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale,
int iterations,
int reduceWidth)
{
#ifdef HALF_PRECISION_AVAILABLE
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float2 low_data[MAX_REG_SIZE];
float2 high_data[MAX_REG_SIZE];
__half2 h_scale = __float2half2_rn(scale);
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? __half2float(vals[data_id + 3])
: minus_infinity;
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
}
} else {
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) &&
(data_id + 1) > window_stride) &&
(data_id + 1) < sequence_length)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) &&
(data_id + 2) > window_stride) &&
(data_id + 2) < sequence_length)
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = minus_infinity;
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
if ((data_id + 1) < sequence_length)
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
if ((data_id + 2) < sequence_length)
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
}
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
} else {
low_data[i].x = minus_infinity;
low_data[i].y = minus_infinity;
high_data[i].x = minus_infinity;
high_data[i].y = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = low_data[i].x / sum;
vals[data_id + 1] = low_data[i].y / sum;
vals[data_id + 2] = high_data[i].x / sum;
vals[data_id + 3] = high_data[i].y / sum;
} else {
vals[data_id] = low_data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = low_data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = high_data[i].x / sum;
}
}
}
}
#endif
}
__global__ void attn_softmax_v2(float* vals,
float* attn_mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale,
int iterations,
int reduceWidth)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float4 data[MAX_REG_SIZE];
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity);
data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? vals[data_id + 1]
: minus_infinity;
data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? vals[data_id + 2]
: minus_infinity;
data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? vals[data_id + 3]
: minus_infinity;
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
data[i].y += attn_mask[data_id + mask_offset + 1];
data[i].z += attn_mask[data_id + mask_offset + 2];
data[i].w += attn_mask[data_id + mask_offset + 3];
}
} else {
data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity;
data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride && (data_id + 1) < sequence_length)
? (vals[data_id + 1])
: minus_infinity;
data[i].z = (((!triangular || (data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride && (data_id + 2) < sequence_length)
? (vals[data_id + 2])
: minus_infinity;
data[i].w = minus_infinity;
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
if ((data_id + 1) < sequence_length)
data[i].y += attn_mask[data_id + mask_offset + 1];
if ((data_id + 2) < sequence_length)
data[i].z += attn_mask[data_id + mask_offset + 2];
}
}
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = data[i].x / sum;
vals[data_id + 1] = data[i].y / sum;
vals[data_id + 2] = data[i].z / sum;
vals[data_id + 3] = data[i].w / sum;
} else {
vals[data_id] = data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum;
}
}
}
}
}
template <typename T>
void launch_attn_softmax_v2(T* vals,
T* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
hipStream_t stream)
{
int total_count = batch_size * heads * num_seq;
dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1);
dim3 block_dim(ATTN_THREADS);
const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE;
const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1;
if (sequence_length <= 32768)
hipLaunchKernelGGL(( attn_softmax_v2), dim3(grid_dim), dim3(block_dim), 0, stream,
vals,
mask,
triangular,
recompute,
local_attention,
window_size,
total_count,
(triangular ? (heads * batch_size) : heads),
sequence_length,
num_seq,
scale,
iterations,
reduce_width);
else
throw std::runtime_error("Unsupport Seq_Length!");
}
template void launch_attn_softmax_v2(float* vals,
float* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
hipStream_t stream);
template void launch_attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
hipStream_t stream);
csrc/transformer_bak/inference/includes/context.h
deleted
100644 → 0
View file @
7dd68788
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#define WARP_SIZE 32
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline
int
DS_GET_BLOCKS
(
const
int
N
)
{
return
std
::
max
(
std
::
min
((
N
+
DS_CUDA_NUM_THREADS
-
1
)
/
DS_CUDA_NUM_THREADS
,
DS_MAXIMUM_NUM_BLOCKS
),
// Use at least 1 block, since CUDA does not allow empty block
1
);
}
class
Context
{
public:
Context
()
:
_workspace
(
nullptr
),
_seed
(
42
),
_curr_offset
(
0
),
_stream
(
0
)
{
curandCreateGenerator
(
&
_gen
,
CURAND_RNG_PSEUDO_DEFAULT
);
curandSetPseudoRandomGeneratorSeed
(
_gen
,
123
);
if
(
cublasCreate
(
&
_cublasHandle
)
!=
CUBLAS_STATUS_SUCCESS
)
{
auto
message
=
std
::
string
(
"Fail to create cublas handle."
);
std
::
cerr
<<
message
<<
std
::
endl
;
throw
std
::
runtime_error
(
message
);
}
cublasSetMathMode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
cudaEventCreate
(
&
_comp1_event
,
(
cudaEventDisableTiming
|
cudaEventBlockingSync
));
cudaEventCreate
(
&
_comp2_event
,
(
cudaEventDisableTiming
|
cudaEventBlockingSync
));
cudaEventCreate
(
&
_comp_event
,
(
cudaEventDisableTiming
|
cudaEventBlockingSync
));
cudaEventCreate
(
&
_comm_event
,
(
cudaEventDisableTiming
|
cudaEventBlockingSync
));
}
virtual
~
Context
()
{
cublasDestroy
(
_cublasHandle
);
cudaFree
(
_workspace
);
cudaEventDestroy
(
_comp1_event
);
cudaEventDestroy
(
_comp2_event
);
cudaEventDestroy
(
_comp_event
);
cudaEventDestroy
(
_comm_event
);
}
static
Context
&
Instance
()
{
static
Context
_ctx
;
return
_ctx
;
}
void
GenWorkSpace
(
size_t
size
)
{
if
(
!
_workspace
)
{
assert
(
_workspace
==
nullptr
);
cudaMalloc
(
&
_workspace
,
size
);
}
else
if
(
_workSpaceSize
<
size
)
{
cudaFree
(
_workspace
);
cudaMalloc
(
&
_workspace
,
size
);
}
_workSpaceSize
=
size
;
}
cudaEvent_t
GetCompEvent
(
int
id
)
{
return
id
==
1
?
_comp1_event
:
_comp2_event
;
}
size_t
get_workspace_size
()
const
{
return
_workSpaceSize
;
}
void
*
GetWorkSpace
()
{
return
_workspace
;
}
inline
unsigned
new_token
(
unsigned
layer_id
)
{
if
(
layer_id
==
0
)
_token_length
++
;
return
_token_length
;
}
inline
void
reset_tokens
(
unsigned
initial_tokens
=
0
)
{
_num_tokens
=
initial_tokens
;
}
//_token_length = 0; }
inline
unsigned
current_tokens
()
const
{
return
_num_tokens
;
}
inline
void
advance_tokens
()
{
_num_tokens
++
;
}
curandGenerator_t
&
GetRandGenerator
()
{
return
_gen
;
}
cudaStream_t
GetCommStream
(
bool
async_op
=
false
)
{
if
(
!
_comm_stream
)
_comm_stream
=
async_op
?
at
::
cuda
::
getStreamFromPool
(
true
)
:
at
::
cuda
::
getCurrentCUDAStream
();
return
_comm_stream
;
}
cudaStream_t
GetCurrentStream
(
bool
other_stream
=
false
)
{
// get current pytorch stream.
if
(
other_stream
)
{
if
(
!
_stream
)
_stream
=
at
::
cuda
::
getStreamFromPool
(
true
);
return
_stream
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
return
stream
;
}
cublasHandle_t
GetCublasHandle
()
{
return
_cublasHandle
;
}
std
::
pair
<
uint64_t
,
uint64_t
>
IncrementOffset
(
uint64_t
offset_inc
)
{
uint64_t
offset
=
_curr_offset
;
_curr_offset
+=
offset_inc
;
return
std
::
pair
<
uint64_t
,
uint64_t
>
(
_seed
,
offset
);
}
void
SetSeed
(
uint64_t
new_seed
)
{
_seed
=
new_seed
;
}
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
GetGemmAlgos
()
const
{
return
_gemm_algos
;
}
inline
void
SynchComp
()
{
cudaEventRecord
(
_comp_event
,
_comp_stream
);
cudaStreamWaitEvent
(
_comm_stream
,
_comp_event
,
0
);
}
inline
void
SynchComm
()
{
cudaEventRecord
(
_comm_event
,
_comm_stream
);
cudaStreamWaitEvent
(
_comp_stream
,
_comm_event
,
0
);
}
private:
curandGenerator_t
_gen
;
cublasHandle_t
_cublasHandle
;
cudaEvent_t
_comp_event
;
cudaEvent_t
_comm_event
;
void
*
_workspace
;
uint64_t
_seed
;
uint64_t
_curr_offset
;
size_t
_workSpaceSize
;
cudaEvent_t
_comp1_event
;
cudaEvent_t
_comp2_event
;
cudaStream_t
_stream
;
unsigned
_token_length
;
unsigned
_num_tokens
;
std
::
vector
<
std
::
array
<
int
,
3
>>
_gemm_algos
;
cudaStream_t
_comp_stream
;
cudaStream_t
_comm_stream
;
std
::
unordered_map
<
int
,
int
>
_world_sizes
;
};
csrc/transformer_bak/inference/includes/context_hip.h
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "rocblas.h"
#include "hip/hip_runtime.h"
#include "hiprand/hiprand.h"
#define WARP_SIZE 32
#define CUDA_CHECK(callstr) \
{ \
hipError_t error_code = callstr; \
if (error_code != hipSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline
int
DS_GET_BLOCKS
(
const
int
N
)
{
return
std
::
max
(
std
::
min
((
N
+
DS_CUDA_NUM_THREADS
-
1
)
/
DS_CUDA_NUM_THREADS
,
DS_MAXIMUM_NUM_BLOCKS
),
// Use at least 1 block, since CUDA does not allow empty block
1
);
}
class
Context
{
public:
Context
()
:
_workspace
(
nullptr
),
_seed
(
42
),
_curr_offset
(
0
),
_stream
(
0
)
{
hiprandCreateGenerator
(
&
_gen
,
HIPRAND_RNG_PSEUDO_DEFAULT
);
hiprandSetPseudoRandomGeneratorSeed
(
_gen
,
123
);
if
(
rocblas_create_handle
(
&
_cublasHandle
)
!=
rocblas_status_success
)
{
auto
message
=
std
::
string
(
"Fail to create cublas handle."
);
std
::
cerr
<<
message
<<
std
::
endl
;
throw
std
::
runtime_error
(
message
);
}
rocblas_set_math_mode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
hipEventCreate
(
&
_comp1_event
,
(
hipEventDisableTiming
|
hipEventBlockingSync
));
hipEventCreate
(
&
_comp2_event
,
(
hipEventDisableTiming
|
hipEventBlockingSync
));
hipEventCreate
(
&
_comp_event
,
(
hipEventDisableTiming
|
hipEventBlockingSync
));
hipEventCreate
(
&
_comm_event
,
(
hipEventDisableTiming
|
hipEventBlockingSync
));
}
virtual
~
Context
()
{
rocblas_destroy_handle
(
_cublasHandle
);
hipFree
(
_workspace
);
hipEventDestroy
(
_comp1_event
);
hipEventDestroy
(
_comp2_event
);
hipEventDestroy
(
_comp_event
);
hipEventDestroy
(
_comm_event
);
}
static
Context
&
Instance
()
{
static
Context
_ctx
;
return
_ctx
;
}
void
GenWorkSpace
(
size_t
size
)
{
if
(
!
_workspace
)
{
assert
(
_workspace
==
nullptr
);
hipMalloc
(
&
_workspace
,
size
);
}
else
if
(
_workSpaceSize
<
size
)
{
hipFree
(
_workspace
);
hipMalloc
(
&
_workspace
,
size
);
}
_workSpaceSize
=
size
;
}
hipEvent_t
GetCompEvent
(
int
id
)
{
return
id
==
1
?
_comp1_event
:
_comp2_event
;
}
size_t
get_workspace_size
()
const
{
return
_workSpaceSize
;
}
void
*
GetWorkSpace
()
{
return
_workspace
;
}
inline
unsigned
new_token
(
unsigned
layer_id
)
{
if
(
layer_id
==
0
)
_token_length
++
;
return
_token_length
;
}
inline
void
reset_tokens
(
unsigned
initial_tokens
=
0
)
{
_num_tokens
=
initial_tokens
;
}
//_token_length = 0; }
inline
unsigned
current_tokens
()
const
{
return
_num_tokens
;
}
inline
void
advance_tokens
()
{
_num_tokens
++
;
}
hiprandGenerator_t
&
GetRandGenerator
()
{
return
_gen
;
}
hipStream_t
GetCommStream
(
bool
async_op
=
false
)
{
if
(
!
_comm_stream
)
_comm_stream
=
async_op
?
at
::
hip
::
getStreamFromPoolMasqueradingAsCUDA
(
true
)
:
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
return
_comm_stream
;
}
hipStream_t
GetCurrentStream
(
bool
other_stream
=
false
)
{
// get current pytorch stream.
if
(
other_stream
)
{
if
(
!
_stream
)
_stream
=
at
::
hip
::
getStreamFromPoolMasqueradingAsCUDA
(
true
);
return
_stream
;
}
hipStream_t
stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
return
stream
;
}
rocblas_handle
GetCublasHandle
()
{
return
_cublasHandle
;
}
std
::
pair
<
uint64_t
,
uint64_t
>
IncrementOffset
(
uint64_t
offset_inc
)
{
uint64_t
offset
=
_curr_offset
;
_curr_offset
+=
offset_inc
;
return
std
::
pair
<
uint64_t
,
uint64_t
>
(
_seed
,
offset
);
}
void
SetSeed
(
uint64_t
new_seed
)
{
_seed
=
new_seed
;
}
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
GetGemmAlgos
()
const
{
return
_gemm_algos
;
}
inline
void
SynchComp
()
{
hipEventRecord
(
_comp_event
,
_comp_stream
);
hipStreamWaitEvent
(
_comm_stream
,
_comp_event
,
0
);
}
inline
void
SynchComm
()
{
hipEventRecord
(
_comm_event
,
_comm_stream
);
hipStreamWaitEvent
(
_comp_stream
,
_comm_event
,
0
);
}
private:
hiprandGenerator_t
_gen
;
rocblas_handle
_cublasHandle
;
hipEvent_t
_comp_event
;
hipEvent_t
_comm_event
;
void
*
_workspace
;
uint64_t
_seed
;
uint64_t
_curr_offset
;
size_t
_workSpaceSize
;
hipEvent_t
_comp1_event
;
hipEvent_t
_comp2_event
;
hipStream_t
_stream
;
unsigned
_token_length
;
unsigned
_num_tokens
;
std
::
vector
<
std
::
array
<
int
,
3
>>
_gemm_algos
;
hipStream_t
_comp_stream
;
hipStream_t
_comm_stream
;
std
::
unordered_map
<
int
,
int
>
_world_sizes
;
};
csrc/transformer_bak/inference/includes/cublas_wrappers.h
deleted
100644 → 0
View file @
7dd68788
#pragma once
#include <assert.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <mma.h>
#include <stdio.h>
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_32F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_32F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
CUDA_R_32F
,
m
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_16F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_16F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
CUDA_R_16F
,
m
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_32F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_32F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d)
\n
"
,
batch
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_16F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_16F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
csrc/transformer_bak/inference/includes/cublas_wrappers_hip.h
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <assert.h>
#include <rocblas.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <mma.h>
#include <stdio.h>
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasGemmAlgo_t
algo
)
{
rocblas_status
status
=
rocblas_gemmex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
hipR32F
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
hipR32F
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
hipR32F
,
m
,
hipR32F
,
algo
);
if
(
status
!=
rocblas_status_success
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasGemmAlgo_t
algo
)
{
rocblas_status
status
=
rocblas_gemmex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
hipR16F
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
hipR16F
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
hipR16F
,
m
,
hipR32F
,
algo
);
if
(
status
!=
rocblas_status_success
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
{
rocblas_status
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
hipR32F
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
hipR32F
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
hipR32F
,
m
,
stride_C
,
batch
,
hipR32F
,
algo
);
if
(
status
!=
rocblas_status_success
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d)
\n
"
,
batch
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
{
rocblas_status
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
hipR16F
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
hipR16F
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
hipR16F
,
m
,
stride_C
,
batch
,
hipR32F
,
algo
);
if
(
status
!=
rocblas_status_success
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
csrc/transformer_bak/inference/includes/custom_cuda_layers.h
deleted
100644 → 0
View file @
7dd68788
#pragma once
#ifdef __HIP_PLATFORM_HCC__
#define HALF_PRECISION_AVAILABLE = 1
#include <hip/hip_cooperative_groups.h>
#else
#if __CUDA_ARCH__ >= 700
#define HALF_PRECISION_AVAILABLE = 1
#endif
#include <cooperative_groups.h>
#endif
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#include <cassert>
#include <iostream>
#define MAX_WARP_NUM 32
#define WARP_SIZE 32
#define SMs 80
#define MAX_REGISTERS 256
template
<
typename
T
>
void
launch_attn_softmax_v2
(
T
*
vals
,
T
*
mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
batch_size
,
int
heads
,
int
num_seq
,
int
sequence_length
,
float
scale
,
cudaStream_t
stream
);
// Fused bias add with gelu activation
template
<
typename
T
>
void
launch_bias_gelu
(
T
*
input
,
const
T
*
bias
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
);
template
<
typename
T
>
void
launch_bias_add
(
T
*
input
,
const
T
*
bias
,
int
hidden_size
,
int
batch_size
,
cudaStream_t
stream
);
template
<
typename
T
>
void
launch_bias_residual
(
T
*
input
,
T
*
output
,
T
*
attn
,
T
*
bias
,
T
*
attn_bias
,
int
batch
,
int
hidden_dim
,
int
mp_size
,
cudaStream_t
stream
);
template
<
typename
T
>
void
launch_layer_norm
(
T
*
out
,
T
*
vals
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
);
template
<
typename
T
>
void
launch_residual_layer_norm
(
T
*
norm
,
T
*
res_add
,
T
*
vals
,
T
*
residual
,
const
T
*
bias
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
bool
preLN
,
bool
mlp_after_attn
,
cudaStream_t
stream
);
template
<
typename
T
>
void
launch_dequantize
(
T
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
unsigned
output_size
,
unsigned
hidden_dim
,
unsigned
groups
,
unsigned
merge_count
,
cudaStream_t
stream
);
template
<
typename
T
>
void
launch_gptj_residual_add
(
T
*
input
,
T
*
output
,
T
*
attn
,
T
*
bias
,
T
*
attn_bias
,
int
batch
,
int
head_size
,
int
mp_size
,
cudaStream_t
stream
);
template
<
typename
T
>
void
launch_apply_rotary_pos_emb
(
T
*
mixed_query
,
T
*
key_layer
,
unsigned
head_size
,
unsigned
seq_len
,
unsigned
rotary_dim
,
unsigned
offset
,
unsigned
num_heads
,
unsigned
batch
,
bool
rotate_half
,
bool
rotate_every_two
,
cudaStream_t
stream
);
template
<
typename
T
>
void
launch_moe_res_matmul
(
T
*
residual
,
T
*
coef
,
T
*
mlp_out
,
int
seq_len
,
int
hidden_dim
,
cudaStream_t
stream
);
csrc/transformer_bak/inference/includes/custom_hip_layers.h
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#pragma once
#ifdef __HIP_PLATFORM_HCC__
#define HALF_PRECISION_AVAILABLE = 1
#include <hip/hip_cooperative_groups.h>
#else
#if __CUDA_ARCH__ >= 700
#define HALF_PRECISION_AVAILABLE = 1
#endif
#include <cooperative_groups.h>
#endif
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#include <cassert>
#include <iostream>
#define MAX_WARP_NUM 32
#define WARP_SIZE 32
#define SMs 80
#define MAX_REGISTERS 256
template
<
typename
T
>
void
launch_attn_softmax_v2
(
T
*
vals
,
T
*
mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
batch_size
,
int
heads
,
int
num_seq
,
int
sequence_length
,
float
scale
,
hipStream_t
stream
);
// Fused bias add with gelu activation
template
<
typename
T
>
void
launch_bias_gelu
(
T
*
input
,
const
T
*
bias
,
int
intermediate_size
,
int
batch_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_bias_add
(
T
*
input
,
const
T
*
bias
,
int
hidden_size
,
int
batch_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_bias_residual
(
T
*
input
,
T
*
output
,
T
*
attn
,
T
*
bias
,
T
*
attn_bias
,
int
batch
,
int
hidden_dim
,
int
mp_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_layer_norm
(
T
*
out
,
T
*
vals
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_residual_layer_norm
(
T
*
norm
,
T
*
res_add
,
T
*
vals
,
T
*
residual
,
const
T
*
bias
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
bool
preLN
,
bool
mlp_after_attn
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_dequantize
(
T
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
unsigned
output_size
,
unsigned
hidden_dim
,
unsigned
groups
,
unsigned
merge_count
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_gptj_residual_add
(
T
*
input
,
T
*
output
,
T
*
attn
,
T
*
bias
,
T
*
attn_bias
,
int
batch
,
int
head_size
,
int
mp_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_apply_rotary_pos_emb
(
T
*
mixed_query
,
T
*
key_layer
,
unsigned
head_size
,
unsigned
seq_len
,
unsigned
rotary_dim
,
unsigned
offset
,
unsigned
num_heads
,
unsigned
batch
,
bool
rotate_half
,
bool
rotate_every_two
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_moe_res_matmul
(
T
*
residual
,
T
*
coef
,
T
*
mlp_out
,
int
seq_len
,
int
hidden_dim
,
hipStream_t
stream
);
csrc/transformer_bak/normalize_kernels.cu
deleted
100644 → 0
View file @
7dd68788
#include "custom_cuda_layers.h"
namespace
cg
=
cooperative_groups
;
/*
Fused bias add, residual (elementwise) add, and normalization layer.
For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
For specific launch constraints, see the launch functions.
*/
#define NORM_REG (MAX_REGISTERS / 4)
__global__
void
fused_bias_residual_layer_norm
(
float
*
vals
,
const
float
*
residual
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
bool
preLayerNorm
,
bool
training
,
float
*
vars
,
float
*
means
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
/
WARP_SIZE
;
float
vals_arr
[
NORM_REG
];
__shared__
float
shr
[
MAX_WARP_NUM
];
residual
+=
(
row
*
row_stride
);
vals
+=
(
row
*
row_stride
);
float
sum
=
0.
f
;
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
=
residual
[
i
*
iteration_stride
+
id
];
sum
+=
vals_arr
[
i
];
}
if
(
high_index
<
row_stride
)
{
vals_arr
[
iterations
]
=
residual
[
high_index
];
sum
+=
vals_arr
[
iterations
];
iterations
++
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
WARP_SIZE_BITS
))
sum
=
shr
[
g
.
thread_rank
()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
WARP_SIZE_BITS
);
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
row_stride
;
if
(
training
)
if
(
threadIdx
.
x
==
0
)
means
[
row
]
=
mean
;
float
variance
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
-=
mean
;
variance
+=
vals_arr
[
i
]
*
vals_arr
[
i
];
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
variance
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
WARP_SIZE_BITS
))
variance
=
shr
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
WARP_SIZE_BITS
);
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
variance
=
g
.
shfl
(
variance
,
0
);
variance
/=
row_stride
;
variance
+=
epsilon
;
if
(
training
)
if
(
threadIdx
.
x
==
0
)
vars
[
row
]
=
variance
;
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
=
vals_arr
[
i
]
*
rsqrtf
(
variance
);
vals_arr
[
i
]
=
vals_arr
[
i
]
*
gamma
[
i
*
iteration_stride
+
id
]
+
beta
[
i
*
iteration_stride
+
id
];
vals
[
i
*
iteration_stride
+
id
]
=
vals_arr
[
i
];
}
if
((
high_index
)
<
row_stride
)
{
vals_arr
[
iterations
]
=
vals_arr
[
iterations
]
*
rsqrtf
(
variance
);
vals_arr
[
iterations
]
=
vals_arr
[
iterations
]
*
gamma
[
high_index
]
+
beta
[
high_index
];
vals
[
high_index
]
=
vals_arr
[
iterations
];
}
}
__global__
void
fused_bias_residual_layer_norm
(
__half
*
vals
,
const
__half
*
residual
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
bool
preLayerNorm
,
bool
training
,
__half
*
vars
,
__half
*
means
,
int
row_stride
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
WARP_SIZE_BITS
;
float2
vals_f
[
NORM_REG
];
__shared__
float
shr
[
MAX_WARP_NUM
];
__half2
*
vals_cast
=
reinterpret_cast
<
__half2
*>
(
vals
);
const
__half2
*
residual_cast
=
reinterpret_cast
<
const
__half2
*>
(
residual
);
residual_cast
+=
(
row
*
row_stride
);
vals_cast
+=
(
row
*
row_stride
);
float
sum
=
0.
f
;
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_f
[
i
]
=
__half22float2
(
residual_cast
[
i
*
iteration_stride
+
id
]);
sum
+=
vals_f
[
i
].
x
;
sum
+=
vals_f
[
i
].
y
;
}
if
((
high_index
)
<
row_stride
)
{
vals_f
[
iterations
]
=
__half22float2
(
residual_cast
[
high_index
]);
sum
+=
vals_f
[
iterations
].
x
;
sum
+=
vals_f
[
iterations
].
y
;
iterations
++
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
WARP_SIZE_BITS
))
sum
=
shr
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
WARP_SIZE_BITS
);
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
(
row_stride
*
2
);
float
variance
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_f
[
i
].
x
-=
mean
;
vals_f
[
i
].
y
-=
mean
;
variance
+=
vals_f
[
i
].
x
*
vals_f
[
i
].
x
;
variance
+=
vals_f
[
i
].
y
*
vals_f
[
i
].
y
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
variance
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
WARP_SIZE_BITS
))
variance
=
shr
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
WARP_SIZE_BITS
);
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
variance
=
g
.
shfl
(
variance
,
0
);
variance
/=
(
row_stride
*
2
);
variance
+=
epsilon
;
__half2
variance_h
=
__float2half2_rn
(
variance
);
const
__half2
*
gamma_cast
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
const
__half2
*
beta_cast
=
reinterpret_cast
<
const
__half2
*>
(
beta
);
if
(
training
&&
threadIdx
.
x
==
0
)
{
vars
[
row
]
=
__float2half
(
variance
);
means
[
row
]
=
__float2half
(
mean
);
}
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
vals_arr
=
__float22half2_rn
(
vals_f
[
i
]);
vals_arr
=
vals_arr
*
h2rsqrt
(
variance_h
);
vals_arr
=
vals_arr
*
gamma_cast
[
i
*
iteration_stride
+
id
]
+
beta_cast
[
i
*
iteration_stride
+
id
];
vals_cast
[
i
*
iteration_stride
+
id
]
=
vals_arr
;
}
if
((
high_index
)
<
row_stride
)
{
__half2
vals_arr
=
__float22half2_rn
(
vals_f
[
iterations
]);
vals_arr
=
vals_arr
*
h2rsqrt
(
variance_h
);
vals_arr
=
vals_arr
*
gamma_cast
[
high_index
]
+
beta_cast
[
high_index
];
vals_cast
[
high_index
]
=
vals_arr
;
}
#endif
}
template
<
typename
T
>
void
launch_bias_residual_layer_norm
(
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
,
bool
preLayerNorm
,
bool
training
,
T
*
vars
,
T
*
means
);
template
<
>
void
launch_bias_residual_layer_norm
<
float
>
(
float
*
vals
,
const
float
*
residual
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
,
bool
preLayerNorm
,
bool
training
,
float
*
vars
,
float
*
means
)
{
int
threads
=
THREADS
;
dim3
grid_dim
(
batch_size
);
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
1
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
2
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim
(
threads
);
fused_bias_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
residual
,
gamma
,
beta
,
epsilon
,
preLayerNorm
,
training
,
vars
,
means
,
hidden_dim
);
}
template
<
>
void
launch_bias_residual_layer_norm
<
__half
>
(
__half
*
vals
,
const
__half
*
residual
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
,
bool
preLayerNorm
,
bool
training
,
__half
*
vars
,
__half
*
means
)
{
int
threads
=
128
;
dim3
grid_dim
(
batch_size
);
if
(
hidden_dim
>
8192
&&
hidden_dim
<=
16384
)
threads
<<=
1
;
else
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
2
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
3
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim
(
threads
);
fused_bias_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
residual
,
gamma
,
beta
,
epsilon
,
preLayerNorm
,
training
,
vars
,
means
,
hidden_dim
/
2
);
}
__global__
void
fused_bias_residual_layer_norm
(
float
*
vals
,
const
float
*
residual
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
bool
preLayerNorm
,
bool
training
,
float
*
vars
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
/
32
;
float
vals_arr
[
NORM_REG
];
__shared__
float
shr
[
MAX_WARP_NUM
];
residual
+=
(
row
*
row_stride
);
vals
+=
(
row
*
row_stride
);
float
sum
=
0.
f
;
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
=
residual
[
i
*
iteration_stride
+
id
];
sum
+=
vals_arr
[
i
];
}
if
((
high_index
)
<
row_stride
)
{
vals_arr
[
iterations
]
=
residual
[
high_index
];
sum
+=
vals_arr
[
iterations
];
iterations
++
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
WARP_SIZE_BITS
))
sum
=
shr
[
g
.
thread_rank
()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
WARP_SIZE_BITS
);
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
row_stride
;
float
variance
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
-=
mean
;
variance
+=
vals_arr
[
i
]
*
vals_arr
[
i
];
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
variance
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
WARP_SIZE_BITS
))
variance
=
shr
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
WARP_SIZE_BITS
);
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
variance
=
g
.
shfl
(
variance
,
0
);
variance
/=
row_stride
;
variance
+=
epsilon
;
if
(
training
)
if
(
threadIdx
.
x
==
0
)
vars
[
row
]
=
variance
;
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
=
vals_arr
[
i
]
*
rsqrtf
(
variance
);
vals_arr
[
i
]
=
vals_arr
[
i
]
*
gamma
[
i
*
iteration_stride
+
id
]
+
beta
[
i
*
iteration_stride
+
id
];
vals
[
i
*
iteration_stride
+
id
]
=
vals_arr
[
i
];
}
if
((
high_index
)
<
row_stride
)
{
vals_arr
[
iterations
]
=
vals_arr
[
iterations
]
*
rsqrtf
(
variance
);
vals_arr
[
iterations
]
=
vals_arr
[
iterations
]
*
gamma
[
high_index
]
+
beta
[
high_index
];
vals
[
high_index
]
=
vals_arr
[
iterations
];
}
}
__global__
void
fused_bias_residual_layer_norm
(
__half
*
vals
,
const
__half
*
residual
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
bool
preLayerNorm
,
bool
training
,
__half
*
vars
,
int
row_stride
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
WARP_SIZE_BITS
;
float2
vals_f
[
NORM_REG
];
__shared__
float
shr
[
MAX_WARP_NUM
];
__half2
*
vals_cast
=
reinterpret_cast
<
__half2
*>
(
vals
);
const
__half2
*
residual_cast
=
reinterpret_cast
<
const
__half2
*>
(
residual
);
residual_cast
+=
(
row
*
row_stride
);
vals_cast
+=
(
row
*
row_stride
);
float
sum
=
0.
f
;
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_f
[
i
]
=
__half22float2
(
residual_cast
[
i
*
iteration_stride
+
id
]);
sum
+=
vals_f
[
i
].
x
;
sum
+=
vals_f
[
i
].
y
;
}
if
((
high_index
)
<
row_stride
)
{
vals_f
[
iterations
]
=
__half22float2
(
residual_cast
[
high_index
]);
sum
+=
vals_f
[
iterations
].
x
;
sum
+=
vals_f
[
iterations
].
y
;
iterations
++
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
WARP_SIZE_BITS
))
sum
=
shr
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
WARP_SIZE_BITS
);
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
(
row_stride
*
2
);
float
variance
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_f
[
i
].
x
-=
mean
;
vals_f
[
i
].
y
-=
mean
;
variance
+=
vals_f
[
i
].
x
*
vals_f
[
i
].
x
;
variance
+=
vals_f
[
i
].
y
*
vals_f
[
i
].
y
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
variance
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
iteration_stride
>>
WARP_SIZE_BITS
))
variance
=
shr
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
for
(
int
i
=
1
;
i
<
(
iteration_stride
>>
WARP_SIZE_BITS
);
i
*=
2
)
{
variance
+=
g
.
shfl_down
(
variance
,
i
);
}
variance
=
g
.
shfl
(
variance
,
0
);
variance
/=
(
row_stride
*
2
);
variance
+=
epsilon
;
__half2
variance_h
=
__float2half2_rn
(
variance
);
const
__half2
*
gamma_cast
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
const
__half2
*
beta_cast
=
reinterpret_cast
<
const
__half2
*>
(
beta
);
if
(
training
&&
threadIdx
.
x
==
0
)
vars
[
row
]
=
__float2half
(
variance
);
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
vals_arr
=
__float22half2_rn
(
vals_f
[
i
]);
vals_arr
=
vals_arr
*
h2rsqrt
(
variance_h
);
vals_arr
=
vals_arr
*
gamma_cast
[
i
*
iteration_stride
+
id
]
+
beta_cast
[
i
*
iteration_stride
+
id
];
vals_cast
[
i
*
iteration_stride
+
id
]
=
vals_arr
;
}
if
((
high_index
)
<
row_stride
)
{
__half2
vals_arr
=
__float22half2_rn
(
vals_f
[
iterations
]);
vals_arr
=
vals_arr
*
h2rsqrt
(
variance_h
);
vals_arr
=
vals_arr
*
gamma_cast
[
high_index
]
+
beta_cast
[
high_index
];
vals_cast
[
high_index
]
=
vals_arr
;
}
#endif
}
template
<
typename
T
>
void
launch_bias_residual_layer_norm
(
T
*
vals
,
const
T
*
residual
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
,
bool
preLayerNorm
,
bool
training
,
T
*
vars
);
/*
To tune this launch the following restrictions must be met:
For float:
row_stride == hidden_size
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
For half:
row_stride == hidden_size / 2
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
*/
template
<
>
void
launch_bias_residual_layer_norm
<
float
>
(
float
*
vals
,
const
float
*
residual
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
,
bool
preLayerNorm
,
bool
training
,
float
*
vars
)
{
int
threads
=
THREADS
;
dim3
grid_dim
(
batch_size
);
// There are some limitations to call below functions, now just enumerate the situations.
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
1
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
2
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim
(
threads
);
fused_bias_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
residual
,
gamma
,
beta
,
epsilon
,
preLayerNorm
,
training
,
vars
,
hidden_dim
);
}
template
<
>
void
launch_bias_residual_layer_norm
<
__half
>
(
__half
*
vals
,
const
__half
*
residual
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
,
bool
preLayerNorm
,
bool
training
,
__half
*
vars
)
{
int
threads
=
128
;
dim3
grid_dim
(
batch_size
);
// There are some limitations to call below functions, now just enumerate the situations.
if
(
hidden_dim
>
8192
&&
hidden_dim
<=
16384
)
threads
<<=
1
;
else
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
2
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
3
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim
(
threads
);
fused_bias_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
residual
,
gamma
,
beta
,
epsilon
,
preLayerNorm
,
training
,
vars
,
hidden_dim
/
2
);
}
/* Normalize Gamma & Betta gradients
* Compute gradients using either X_hat or
* normalize input (invertible).
* Combine transpose with gradients computation.
*/
template
<
typename
T
>
__global__
void
LayerNormBackward1
(
const
T
*
__restrict__
out_grad
,
const
T
*
__restrict__
vals_hat
,
const
T
*
__restrict__
gamma
,
const
T
*
__restrict__
betta
,
T
*
__restrict__
gamma_grad
,
T
*
__restrict__
betta_grad
,
int
rows
,
int
width
,
bool
invertible
)
{
__shared__
float
betta_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
__shared__
float
gamma_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
offset
=
threadIdx
.
y
*
width
+
idx
;
int
y_stride
=
width
*
TILE_DIM
;
float
betta_reg
=
(
invertible
?
(
float
)
betta
[
idx
]
:
0.0
f
);
float
gamma_reg
=
(
float
)
gamma
[
idx
];
// Loop across matrix height
float
betta_tmp
=
0
;
float
gamma_tmp
=
0
;
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
float
grad
=
(
float
)
out_grad
[
offset
];
float
val
=
(
invertible
?
((
float
)
vals_hat
[
offset
]
-
betta_reg
)
/
gamma_reg
:
(
float
)
vals_hat
[
offset
]);
betta_tmp
+=
grad
;
gamma_tmp
+=
(
val
*
grad
);
offset
+=
y_stride
;
}
betta_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
betta_tmp
;
gamma_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
gamma_tmp
;
__syncthreads
();
// Sum the shared buffer.
float
s1
=
betta_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
float
s2
=
gamma_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
{
s1
+=
g
.
shfl_down
(
s1
,
i
);
s2
+=
g
.
shfl_down
(
s2
,
i
);
}
if
(
threadIdx
.
x
==
0
)
{
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
betta_grad
[
pos
]
=
s1
;
gamma_grad
[
pos
]
=
s2
;
}
}
/* Normalize Gamma & Betta gradients
* Compute gradients using the input to
* the normalize.
* Combine transpose with gradients computation.
*/
template
<
typename
T
>
__global__
void
LayerNormBackward1
(
const
T
*
__restrict__
out_grad
,
const
T
*
__restrict__
X_data
,
const
T
*
__restrict__
vars
,
const
T
*
__restrict__
means
,
T
*
__restrict__
gamma_grad
,
T
*
__restrict__
betta_grad
,
int
rows
,
int
width
)
{
__shared__
float
betta_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
__shared__
float
gamma_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
offset
=
threadIdx
.
y
*
width
+
idx
;
int
y_stride
=
width
*
TILE_DIM
;
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
// Loop across matrix height
float
betta_tmp
=
0
;
float
gamma_tmp
=
0
;
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
float
grad
=
(
float
)
out_grad
[
offset
];
float
val
=
(
float
)
X_data
[
offset
];
val
=
(
val
-
(
float
)
means
[
r
])
*
rsqrtf
((
float
)
vars
[
r
]);
betta_tmp
+=
grad
;
gamma_tmp
+=
(
val
*
grad
);
offset
+=
y_stride
;
}
betta_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
betta_tmp
;
gamma_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
gamma_tmp
;
__syncthreads
();
// Sum the shared buffer.
float
s1
=
betta_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
float
s2
=
gamma_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
{
s1
+=
g
.
shfl_down
(
s1
,
i
);
s2
+=
g
.
shfl_down
(
s2
,
i
);
}
if
(
threadIdx
.
x
==
0
)
{
betta_grad
[
pos
]
=
s1
;
gamma_grad
[
pos
]
=
s2
;
}
}
/*
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is invertible!
* We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
*/
__global__
void
LayerNormBackward2
(
const
float
*
out_grad
,
const
float
*
vals_hat
,
const
float
*
gamma
,
const
float
*
betta
,
const
float
*
vars
,
float
*
inp_grad
,
bool
invertible
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
iteration_stride
>>
WARP_SIZE_BITS
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
out_grad
+=
(
row
*
row_stride
);
vals_hat
+=
(
row
*
row_stride
);
inp_grad
+=
(
row
*
row_stride
);
float
vals_arr
[
NORM_REG
];
float
vals_hat_arr
[
NORM_REG
];
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
float
gamma_reg
=
gamma
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
vals_hat_arr
[
i
]
=
(
invertible
?
(
vals_hat
[
i
*
iteration_stride
+
id
]
-
betta
[
i
*
iteration_stride
+
id
])
/
gamma_reg
:
vals_hat
[
i
*
iteration_stride
+
id
]);
}
if
((
high_index
)
<
row_stride
)
{
float
gamma_reg
=
gamma
[
high_index
];
vals_arr
[
iterations
]
=
out_grad
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
vals_hat_arr
[
iterations
]
=
(
invertible
?
(
vals_hat
[
high_index
]
-
betta
[
high_index
])
/
gamma_reg
:
vals_hat
[
high_index
]);
iterations
++
;
}
float
var_reg
=
vars
[
row
];
float
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
vals_hat_arr
[
i
]
*
vals_arr
[
i
]
*
sqrtf
(
var_reg
);
// dval_hat = gamma * (x - u) * out_grad
vals_arr
[
i
]
*=
rsqrtf
(
var_reg
);
// dvar_inv = gamma * out_grad / sqrt(var)
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
+=
((
-
sum
*
vals_hat_arr
[
i
])
/
var_reg
);
}
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
vals_arr
[
i
];
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
inp_grad
[
i
*
iteration_stride
+
id
]
=
(
vals_arr
[
i
]
-
sum
);
if
((
high_index
)
<
row_stride
)
inp_grad
[
high_index
]
=
(
vals_arr
[
iterations
]
-
sum
);
}
__global__
void
LayerNormBackward2
(
const
__half
*
out_grad
,
const
__half
*
vals_hat
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
vars
,
__half
*
inp_grad
,
bool
invertible
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
iteration_stride
>>
WARP_SIZE_BITS
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
__half2
vals_arr
[
NORM_REG
];
float2
vals_arr_f
[
NORM_REG
];
__half2
vals_hat_arr
[
NORM_REG
];
__half2
*
inp_grad_h
=
reinterpret_cast
<
__half2
*>
(
inp_grad
);
const
__half2
*
out_grad_h
=
reinterpret_cast
<
const
__half2
*>
(
out_grad
);
const
__half2
*
vals_hat_h
=
reinterpret_cast
<
const
__half2
*>
(
vals_hat
);
inp_grad_h
+=
(
row
*
row_stride
);
out_grad_h
+=
(
row
*
row_stride
);
vals_hat_h
+=
(
row
*
row_stride
);
const
__half2
*
gamma_h
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
const
__half2
*
betta_h
=
(
invertible
?
reinterpret_cast
<
const
__half2
*>
(
betta
)
:
nullptr
);
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
gamma_reg
=
gamma_h
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad_h
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
vals_hat_arr
[
i
]
=
(
invertible
?
(
vals_hat_h
[
i
*
iteration_stride
+
id
]
-
betta_h
[
i
*
iteration_stride
+
id
])
/
gamma_reg
:
vals_hat_h
[
i
*
iteration_stride
+
id
]);
}
if
((
high_index
)
<
row_stride
)
{
__half2
gamma_reg
=
gamma_h
[
high_index
];
vals_arr
[
iterations
]
=
out_grad_h
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
vals_hat_arr
[
iterations
]
=
(
invertible
?
(
vals_hat_h
[
high_index
]
-
betta_h
[
high_index
])
/
gamma_reg
:
vals_hat_h
[
high_index
]);
iterations
++
;
}
__half
var_h
=
vars
[
row
];
__half2
var_reg
=
__halves2half2
(
var_h
,
var_h
);
float
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
result_h
=
(
vals_hat_arr
[
i
]
*
vals_arr
[
i
]
*
h2sqrt
(
var_reg
));
float2
result_f
=
__half22float2
(
result_h
);
sum
+=
result_f
.
x
;
sum
+=
result_f
.
y
;
vals_arr
[
i
]
*=
h2rsqrt
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
__half2
sum_h
=
__float2half2_rn
(
sum
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
temp
=
((
-
sum_h
*
vals_hat_arr
[
i
])
/
(
var_reg
));
vals_arr_f
[
i
]
=
__half22float2
(
vals_arr
[
i
]);
float2
temp_f
=
__half22float2
(
temp
);
vals_arr_f
[
i
].
x
+=
temp_f
.
x
;
vals_arr_f
[
i
].
y
+=
temp_f
.
y
;
}
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
(
vals_arr_f
[
i
].
x
);
sum
+=
(
vals_arr_f
[
i
].
y
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr_f
[
i
].
x
-=
sum
;
vals_arr_f
[
i
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
i
]);
inp_grad_h
[
i
*
iteration_stride
+
id
]
=
temp
;
}
if
((
high_index
)
<
row_stride
)
{
vals_arr_f
[
iterations
].
x
-=
sum
;
vals_arr_f
[
iterations
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
iterations
]);
inp_grad_h
[
high_index
]
=
temp
;
}
}
template
<
>
void
launch_layerNorm_backward
<
float
>
(
const
float
*
out_grad
,
const
float
*
vals_hat
,
const
float
*
vars
,
const
float
*
gamma
,
float
*
gamma_grad
,
float
*
betta_grad
,
float
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
],
bool
invertible
,
const
float
*
betta
)
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad
,
vals_hat
,
gamma
,
betta
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
,
invertible
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
1
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
2
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
);
LayerNormBackward2
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad
,
vals_hat
,
gamma
,
betta
,
vars
,
inp_grad
,
invertible
,
hidden_dim
);
}
template
<
>
void
launch_layerNorm_backward
<
__half
>
(
const
__half
*
out_grad
,
const
__half
*
vals_hat
,
const
__half
*
vars
,
const
__half
*
gamma
,
__half
*
gamma_grad
,
__half
*
betta_grad
,
__half
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
],
bool
invertible
,
const
__half
*
betta
)
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
// LayerNormBackward1<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
// out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
8192
&&
hidden_dim
<=
16384
)
threads
<<=
1
;
else
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
2
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
3
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
/
2
);
LayerNormBackward2
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad
,
vals_hat
,
gamma
,
betta
,
vars
,
inp_grad
,
invertible
,
hidden_dim
/
2
);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__
void
LayerNormBackward2
(
const
float
*
out_grad
,
const
float
*
X_vals
,
const
float
*
gamma
,
const
float
*
vars
,
const
float
*
means
,
float
*
inp_grad
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
>>
WARP_SIZE_BITS
;
int
warp_num
=
iteration_stride
>>
WARP_SIZE_BITS
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
out_grad
+=
(
row
*
row_stride
);
X_vals
+=
(
row
*
row_stride
);
inp_grad
+=
(
row
*
row_stride
);
float
vals_arr
[
NORM_REG
];
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
float
gamma_reg
=
gamma
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
}
if
((
high_index
)
<
row_stride
)
{
float
gamma_reg
=
gamma
[
high_index
];
vals_arr
[
iterations
]
=
out_grad
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
iterations
++
;
}
float
var_reg
=
vars
[
row
];
float
mean_reg
=
means
[
row
];
float
sum
=
0
;
float
xu
[
NORM_REG
];
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
xu
[
i
]
=
(
X_vals
[
i
*
iteration_stride
+
id
]
-
mean_reg
);
sum
+=
vals_arr
[
i
]
*
xu
[
i
];
vals_arr
[
i
]
*=
rsqrtf
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
+=
(
-
sum
*
xu
[
i
]
*
rsqrtf
(
var_reg
)
/
(
var_reg
));
}
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
vals_arr
[
i
];
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
inp_grad
[
i
*
iteration_stride
+
id
]
=
(
vals_arr
[
i
]
-
sum
);
if
((
high_index
)
<
row_stride
)
inp_grad
[
high_index
]
=
(
vals_arr
[
iterations
]
-
sum
);
}
__global__
void
LayerNormBackward2
(
const
__half
*
out_grad
,
const
__half
*
X_vals
,
const
__half
*
gamma
,
const
__half
*
vars
,
const
__half
*
means
,
__half
*
inp_grad
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
>>
WARP_SIZE_BITS
;
int
warp_num
=
iteration_stride
>>
WARP_SIZE_BITS
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
__half2
vals_arr
[
NORM_REG
];
float2
vals_arr_f
[
NORM_REG
];
__half2
xu
[
NORM_REG
];
__half2
*
inp_grad_h
=
reinterpret_cast
<
__half2
*>
(
inp_grad
);
const
__half2
*
out_grad_h
=
reinterpret_cast
<
const
__half2
*>
(
out_grad
);
const
__half2
*
vals_hat_h
=
reinterpret_cast
<
const
__half2
*>
(
X_vals
);
inp_grad_h
+=
(
row
*
row_stride
);
out_grad_h
+=
(
row
*
row_stride
);
vals_hat_h
+=
(
row
*
row_stride
);
const
__half2
*
gamma_h
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
int
high_index
=
iterations
*
iteration_stride
+
id
;
__half
mean_h
=
means
[
row
];
__half2
mean_reg
=
__halves2half2
(
mean_h
,
mean_h
);
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
gamma_reg
=
gamma_h
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad_h
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
// out_grad * gamma
xu
[
i
]
=
(
vals_hat_h
[
i
*
iteration_stride
+
id
]
-
mean_reg
);
}
if
((
high_index
)
<
row_stride
)
{
__half2
gamma_reg
=
gamma_h
[
high_index
];
vals_arr
[
iterations
]
=
out_grad_h
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
// out_grad * gamma
xu
[
iterations
]
=
(
vals_hat_h
[
high_index
]
-
mean_reg
);
iterations
++
;
}
__half
var_h
=
vars
[
row
];
__half2
var_reg
=
__halves2half2
(
var_h
,
var_h
);
float
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
result_h
=
(
xu
[
i
]
*
vals_arr
[
i
]);
float2
result_f
=
__half22float2
(
result_h
);
sum
+=
result_f
.
x
;
sum
+=
result_f
.
y
;
vals_arr
[
i
]
*=
h2rsqrt
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
__half2
sum_h
=
__float2half2_rn
(
sum
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
xu_grad
=
((
-
sum_h
*
xu
[
i
]
*
h2rsqrt
(
var_reg
))
/
(
var_reg
));
vals_arr_f
[
i
]
=
__half22float2
(
vals_arr
[
i
]);
float2
xu_grad_f
=
__half22float2
(
xu_grad
);
vals_arr_f
[
i
].
x
+=
xu_grad_f
.
x
;
vals_arr_f
[
i
].
y
+=
xu_grad_f
.
y
;
}
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
(
vals_arr_f
[
i
].
x
);
sum
+=
(
vals_arr_f
[
i
].
y
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr_f
[
i
].
x
-=
sum
;
vals_arr_f
[
i
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
i
]);
inp_grad_h
[
i
*
iteration_stride
+
id
]
=
temp
;
}
if
((
high_index
)
<
row_stride
)
{
vals_arr_f
[
iterations
].
x
-=
sum
;
vals_arr_f
[
iterations
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
iterations
]);
inp_grad_h
[
high_index
]
=
temp
;
}
}
template
<
>
void
launch_layerNorm_backward
<
float
>
(
const
float
*
out_grad
,
const
float
*
X_data
,
const
float
*
vars
,
const
float
*
means
,
const
float
*
gamma
,
float
*
gamma_grad
,
float
*
betta_grad
,
float
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
])
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad
,
X_data
,
vars
,
means
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
1
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
2
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
);
LayerNormBackward2
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad
,
X_data
,
gamma
,
vars
,
means
,
inp_grad
,
hidden_dim
);
}
template
<
>
void
launch_layerNorm_backward
<
__half
>
(
const
__half
*
out_grad
,
const
__half
*
X_data
,
const
__half
*
vars
,
const
__half
*
means
,
const
__half
*
gamma
,
__half
*
gamma_grad
,
__half
*
betta_grad
,
__half
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
])
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad
,
X_data
,
vars
,
means
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
8192
&&
hidden_dim
<=
16384
)
threads
<<=
1
;
else
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
2
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
3
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
/
2
);
LayerNormBackward2
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad
,
X_data
,
gamma
,
vars
,
means
,
inp_grad
,
hidden_dim
/
2
);
}
template
<
typename
T
>
__global__
void
LayerNormBackward1_fused_add
(
const
T
*
__restrict__
out_grad1
,
const
T
*
__restrict__
out_grad2
,
const
T
*
__restrict__
vals_hat
,
const
T
*
__restrict__
gamma
,
const
T
*
__restrict__
betta
,
T
*
__restrict__
gamma_grad
,
T
*
__restrict__
betta_grad
,
int
rows
,
int
width
,
bool
invertible
)
{
__shared__
float
betta_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
__shared__
float
gamma_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
offset
=
threadIdx
.
y
*
width
+
idx
;
int
y_stride
=
width
*
TILE_DIM
;
float
betta_reg
=
(
invertible
?
(
float
)
betta
[
idx
]
:
0.0
f
);
float
gamma_reg
=
(
float
)
gamma
[
idx
];
// Loop across matrix height
float
betta_tmp
=
0
;
float
gamma_tmp
=
0
;
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
float
grad
=
(
float
)
out_grad1
[
offset
]
+
(
float
)
out_grad2
[
offset
];
float
val
=
(
invertible
?
((
float
)
vals_hat
[
offset
]
-
betta_reg
)
/
gamma_reg
:
(
float
)
vals_hat
[
offset
]);
betta_tmp
+=
grad
;
gamma_tmp
+=
(
val
*
grad
);
offset
+=
y_stride
;
}
betta_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
betta_tmp
;
gamma_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
gamma_tmp
;
__syncthreads
();
// Sum the shared buffer.
float
s1
=
betta_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
float
s2
=
gamma_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
{
s1
+=
g
.
shfl_down
(
s1
,
i
);
s2
+=
g
.
shfl_down
(
s2
,
i
);
}
if
(
threadIdx
.
x
==
0
)
{
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
betta_grad
[
pos
]
=
s1
;
gamma_grad
[
pos
]
=
s2
;
}
}
template
<
typename
T
>
__global__
void
LayerNormBackward1_fused_add
(
const
T
*
__restrict__
out_grad1
,
const
T
*
__restrict__
out_grad2
,
const
T
*
__restrict__
X_data
,
const
T
*
__restrict__
vars
,
const
T
*
__restrict__
means
,
T
*
__restrict__
gamma_grad
,
T
*
__restrict__
betta_grad
,
int
rows
,
int
width
)
{
__shared__
float
betta_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
__shared__
float
gamma_buffer
[
TILE_DIM
][
TILE_DIM
+
1
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
offset
=
threadIdx
.
y
*
width
+
idx
;
int
y_stride
=
width
*
TILE_DIM
;
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
// Loop across matrix height
float
betta_tmp
=
0
;
float
gamma_tmp
=
0
;
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
float
grad
=
(
float
)
out_grad1
[
offset
]
+
(
float
)
out_grad2
[
offset
];
float
val
=
(
float
)
X_data
[
offset
];
val
=
(
val
-
(
float
)
means
[
r
])
*
rsqrtf
((
float
)
vars
[
r
]);
betta_tmp
+=
grad
;
gamma_tmp
+=
(
val
*
grad
);
offset
+=
y_stride
;
}
betta_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
betta_tmp
;
gamma_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
gamma_tmp
;
__syncthreads
();
// Sum the shared buffer.
float
s1
=
betta_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
float
s2
=
gamma_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
{
s1
+=
g
.
shfl_down
(
s1
,
i
);
s2
+=
g
.
shfl_down
(
s2
,
i
);
}
if
(
threadIdx
.
x
==
0
)
{
betta_grad
[
pos
]
=
s1
;
gamma_grad
[
pos
]
=
s2
;
}
}
__global__
void
LayerNormBackward2_fused_add
(
const
float
*
out_grad1
,
const
float
*
out_grad2
,
const
float
*
vals_hat
,
const
float
*
gamma
,
const
float
*
betta
,
const
float
*
vars
,
float
*
inp_grad
,
bool
invertible
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
iteration_stride
>>
WARP_SIZE_BITS
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
out_grad1
+=
(
row
*
row_stride
);
out_grad2
+=
(
row
*
row_stride
);
vals_hat
+=
(
row
*
row_stride
);
inp_grad
+=
(
row
*
row_stride
);
float
vals_arr
[
NORM_REG
];
float
vals_hat_arr
[
NORM_REG
];
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
float
gamma_reg
=
gamma
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad1
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
vals_hat_arr
[
i
]
=
(
invertible
?
(
vals_hat
[
i
*
iteration_stride
+
id
]
-
betta
[
i
*
iteration_stride
+
id
])
/
gamma_reg
:
vals_hat
[
i
*
iteration_stride
+
id
]);
}
if
((
high_index
)
<
row_stride
)
{
float
gamma_reg
=
gamma
[
high_index
];
vals_arr
[
iterations
]
=
out_grad1
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
vals_hat_arr
[
iterations
]
=
(
invertible
?
(
vals_hat
[
high_index
]
-
betta
[
high_index
])
/
gamma_reg
:
vals_hat
[
high_index
]);
iterations
++
;
}
float
var_reg
=
vars
[
row
];
float
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
vals_hat_arr
[
i
]
*
vals_arr
[
i
]
*
sqrtf
(
var_reg
);
vals_arr
[
i
]
*=
rsqrtf
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
+=
((
-
sum
*
vals_hat_arr
[
i
])
/
var_reg
);
}
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
vals_arr
[
i
];
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
inp_grad
[
i
*
iteration_stride
+
id
]
=
(
vals_arr
[
i
]
-
sum
)
+
out_grad2
[
i
*
iteration_stride
+
id
];
if
((
high_index
)
<
row_stride
)
inp_grad
[
high_index
]
=
(
vals_arr
[
iterations
]
-
sum
)
+
out_grad2
[
high_index
];
}
__global__
void
LayerNormBackward2_fused_add
(
const
__half
*
out_grad1
,
const
__half
*
out_grad2
,
const
__half
*
vals_hat
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
vars
,
__half
*
inp_grad
,
bool
invertible
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
iteration_stride
>>
WARP_SIZE_BITS
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
__half2
vals_arr
[
NORM_REG
];
float2
vals_arr_f
[
NORM_REG
];
__half2
vals_hat_arr
[
NORM_REG
];
// float2 result[iterations];
__half2
*
inp_grad_h
=
reinterpret_cast
<
__half2
*>
(
inp_grad
);
const
__half2
*
out_grad_h1
=
reinterpret_cast
<
const
__half2
*>
(
out_grad1
);
const
__half2
*
out_grad_h2
=
reinterpret_cast
<
const
__half2
*>
(
out_grad2
);
const
__half2
*
vals_hat_h
=
reinterpret_cast
<
const
__half2
*>
(
vals_hat
);
inp_grad_h
+=
(
row
*
row_stride
);
out_grad_h1
+=
(
row
*
row_stride
);
out_grad_h2
+=
(
row
*
row_stride
);
vals_hat_h
+=
(
row
*
row_stride
);
const
__half2
*
gamma_h
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
const
__half2
*
betta_h
=
(
invertible
?
reinterpret_cast
<
const
__half2
*>
(
betta
)
:
nullptr
);
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
gamma_reg
=
gamma_h
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad_h1
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
// out_grad * gamma
vals_hat_arr
[
i
]
=
(
invertible
?
(
vals_hat_h
[
i
*
iteration_stride
+
id
]
-
betta_h
[
i
*
iteration_stride
+
id
])
/
gamma_reg
:
vals_hat_h
[
i
*
iteration_stride
+
id
]);
}
if
((
high_index
)
<
row_stride
)
{
__half2
gamma_reg
=
gamma_h
[
high_index
];
vals_arr
[
iterations
]
=
out_grad_h1
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
// out_grad * gamma
vals_hat_arr
[
iterations
]
=
(
invertible
?
(
vals_hat_h
[
high_index
]
-
betta_h
[
high_index
])
/
gamma_reg
:
vals_hat_h
[
high_index
]);
iterations
++
;
}
__half
var_h
=
vars
[
row
];
__half2
var_reg
=
__halves2half2
(
var_h
,
var_h
);
float
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
result_h
=
(
vals_hat_arr
[
i
]
*
vals_arr
[
i
]
*
h2sqrt
(
var_reg
));
float2
result_f
=
__half22float2
(
result_h
);
sum
+=
result_f
.
x
;
sum
+=
result_f
.
y
;
vals_arr
[
i
]
*=
h2rsqrt
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
__half2
sum_h
=
__float2half2_rn
(
sum
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
temp
=
((
-
sum_h
*
vals_hat_arr
[
i
])
/
(
var_reg
));
vals_arr_f
[
i
]
=
__half22float2
(
vals_arr
[
i
]);
float2
temp_f
=
__half22float2
(
temp
);
vals_arr_f
[
i
].
x
+=
temp_f
.
x
;
vals_arr_f
[
i
].
y
+=
temp_f
.
y
;
}
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
(
vals_arr_f
[
i
].
x
);
sum
+=
(
vals_arr_f
[
i
].
y
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr_f
[
i
].
x
-=
sum
;
vals_arr_f
[
i
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
i
]);
inp_grad_h
[
i
*
iteration_stride
+
id
]
=
temp
+
out_grad_h2
[
i
*
iteration_stride
+
id
];
}
if
((
high_index
)
<
row_stride
)
{
vals_arr_f
[
iterations
].
x
-=
sum
;
vals_arr_f
[
iterations
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
iterations
]);
inp_grad_h
[
high_index
]
=
temp
+
out_grad_h2
[
high_index
];
}
}
template
<
>
void
launch_layerNorm_backward_fused_add
<
float
>
(
const
float
*
out_grad1
,
const
float
*
out_grad2
,
const
float
*
vals_hat
,
const
float
*
vars
,
const
float
*
gamma
,
float
*
gamma_grad
,
float
*
betta_grad
,
float
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
],
bool
invertible
,
const
float
*
betta
)
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad1
,
vals_hat
,
gamma
,
betta
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
,
invertible
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
1
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
2
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
);
LayerNormBackward2_fused_add
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad1
,
out_grad2
,
vals_hat
,
gamma
,
betta
,
vars
,
inp_grad
,
invertible
,
hidden_dim
);
}
template
<
>
void
launch_layerNorm_backward_fused_add
<
__half
>
(
const
__half
*
out_grad1
,
const
__half
*
out_grad2
,
const
__half
*
vals_hat
,
const
__half
*
vars
,
const
__half
*
gamma
,
__half
*
gamma_grad
,
__half
*
betta_grad
,
__half
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
],
bool
invertible
,
const
__half
*
betta
)
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad1
,
vals_hat
,
gamma
,
betta
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
,
invertible
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
8192
&&
hidden_dim
<=
16384
)
threads
<<=
1
;
else
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
2
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
3
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
/
2
);
LayerNormBackward2_fused_add
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad1
,
out_grad2
,
vals_hat
,
gamma
,
betta
,
vars
,
inp_grad
,
invertible
,
hidden_dim
/
2
);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__
void
LayerNormBackward2_fused_add
(
const
float
*
out_grad1
,
const
float
*
out_grad2
,
const
float
*
X_vals
,
const
float
*
gamma
,
const
float
*
vars
,
const
float
*
means
,
float
*
inp_grad
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
iteration_stride
>>
WARP_SIZE_BITS
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
float
vals_arr
[
NORM_REG
];
float
vals_hat_arr
[
NORM_REG
];
out_grad1
+=
(
row
*
row_stride
);
out_grad2
+=
(
row
*
row_stride
);
X_vals
+=
(
row
*
row_stride
);
inp_grad
+=
(
row
*
row_stride
);
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
float
gamma_reg
=
gamma
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad1
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
vals_hat_arr
[
i
]
=
X_vals
[
i
*
iteration_stride
+
id
];
}
if
((
high_index
)
<
row_stride
)
{
float
gamma_reg
=
gamma
[
high_index
];
vals_arr
[
iterations
]
=
out_grad1
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
vals_hat_arr
[
iterations
]
=
X_vals
[
high_index
];
iterations
++
;
}
float
var_reg
=
vars
[
row
];
float
mean_reg
=
means
[
row
];
float
sum
=
0
;
float
xu
[
NORM_REG
];
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
xu
[
i
]
=
(
vals_hat_arr
[
i
]
-
mean_reg
);
sum
+=
vals_arr
[
i
]
*
xu
[
i
];
vals_arr
[
i
]
*=
rsqrtf
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr
[
i
]
+=
(
-
sum
*
xu
[
i
]
*
rsqrtf
(
var_reg
)
/
(
var_reg
));
}
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
vals_arr
[
i
];
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
row_stride
;
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
inp_grad
[
i
*
iteration_stride
+
id
]
=
(
vals_arr
[
i
]
-
sum
)
+
out_grad2
[
i
*
iteration_stride
+
id
];
if
((
high_index
)
<
row_stride
)
inp_grad
[
high_index
]
=
(
vals_arr
[
iterations
]
-
sum
)
+
out_grad2
[
high_index
];
}
__global__
void
LayerNormBackward2_fused_add
(
const
__half
*
out_grad1
,
const
__half
*
out_grad2
,
const
__half
*
X_vals
,
const
__half
*
gamma
,
const
__half
*
vars
,
const
__half
*
means
,
__half
*
inp_grad
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
/
WARP_SIZE
;
int
warp_num
=
iteration_stride
>>
WARP_SIZE_BITS
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
__half2
vals_arr
[
NORM_REG
];
float2
vals_arr_f
[
NORM_REG
];
__half2
vals_hat_arr
[
NORM_REG
];
__half2
*
inp_grad_h
=
reinterpret_cast
<
__half2
*>
(
inp_grad
);
const
__half2
*
out_grad_h1
=
reinterpret_cast
<
const
__half2
*>
(
out_grad1
);
const
__half2
*
out_grad_h2
=
reinterpret_cast
<
const
__half2
*>
(
out_grad2
);
const
__half2
*
vals_hat_h
=
reinterpret_cast
<
const
__half2
*>
(
X_vals
);
out_grad_h1
+=
(
row
*
row_stride
);
out_grad_h2
+=
(
row
*
row_stride
);
inp_grad_h
+=
(
row
*
row_stride
);
vals_hat_h
+=
(
row
*
row_stride
);
const
__half2
*
gamma_h
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
int
high_index
=
iterations
*
iteration_stride
+
id
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
gamma_reg
=
gamma_h
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
=
out_grad_h1
[
i
*
iteration_stride
+
id
];
vals_arr
[
i
]
*=
gamma_reg
;
// out_grad * gamma
vals_hat_arr
[
i
]
=
vals_hat_h
[
i
*
iteration_stride
+
id
];
}
if
((
high_index
)
<
row_stride
)
{
__half2
gamma_reg
=
gamma_h
[
high_index
];
vals_arr
[
iterations
]
=
out_grad_h1
[
high_index
];
vals_arr
[
iterations
]
*=
gamma_reg
;
// out_grad * gamma
vals_hat_arr
[
iterations
]
=
vals_hat_h
[
high_index
];
iterations
++
;
}
__half
mean_h
=
means
[
row
];
__half
var_h
=
vars
[
row
];
__half2
var_reg
=
__halves2half2
(
var_h
,
var_h
);
__half2
mean_reg
=
__halves2half2
(
mean_h
,
mean_h
);
__half2
xu
[
NORM_REG
];
float
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
xu
[
i
]
=
(
vals_hat_arr
[
i
]
-
mean_reg
);
__half2
result_h
=
(
xu
[
i
]
*
vals_arr
[
i
]);
float2
result_f
=
__half22float2
(
result_h
);
sum
+=
result_f
.
x
;
sum
+=
result_f
.
y
;
vals_arr
[
i
]
*=
h2rsqrt
(
var_reg
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
__half2
sum_h
=
__float2half2_rn
(
sum
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
__half2
xu_grad
=
((
-
sum_h
*
xu
[
i
]
*
h2rsqrt
(
var_reg
))
/
(
var_reg
));
vals_arr_f
[
i
]
=
__half22float2
(
vals_arr
[
i
]);
float2
xu_grad_f
=
__half22float2
(
xu_grad
);
vals_arr_f
[
i
].
x
+=
xu_grad_f
.
x
;
vals_arr_f
[
i
].
y
+=
xu_grad_f
.
y
;
}
sum
=
0.
f
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
sum
+=
(
vals_arr_f
[
i
].
x
);
sum
+=
(
vals_arr_f
[
i
].
y
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
i
);
}
if
(
g
.
thread_rank
()
==
0
)
partialSum
[
wid
]
=
sum
;
__syncthreads
();
if
(
g
.
thread_rank
()
<
warp_num
)
sum
=
partialSum
[
g
.
thread_rank
()];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
warp_num
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
2
*
row_stride
);
iterations
=
row_stride
/
iteration_stride
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
vals_arr_f
[
i
].
x
-=
sum
;
vals_arr_f
[
i
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
i
]);
inp_grad_h
[
i
*
iteration_stride
+
id
]
=
temp
+
out_grad_h2
[
i
*
iteration_stride
+
id
];
}
if
((
high_index
)
<
row_stride
)
{
vals_arr_f
[
iterations
].
x
-=
sum
;
vals_arr_f
[
iterations
].
y
-=
sum
;
__half2
temp
=
__float22half2_rn
(
vals_arr_f
[
iterations
]);
inp_grad_h
[
high_index
]
=
temp
+
out_grad_h2
[
high_index
];
}
}
template
<
>
void
launch_layerNorm_backward_fused_add
<
float
>
(
const
float
*
out_grad1
,
const
float
*
out_grad2
,
const
float
*
X_data
,
const
float
*
vars
,
const
float
*
means
,
const
float
*
gamma
,
float
*
gamma_grad
,
float
*
betta_grad
,
float
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
])
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad1
,
X_data
,
vars
,
means
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
1
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
2
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
);
LayerNormBackward2_fused_add
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad1
,
out_grad2
,
X_data
,
gamma
,
vars
,
means
,
inp_grad
,
hidden_dim
);
}
template
<
>
void
launch_layerNorm_backward_fused_add
<
__half
>
(
const
__half
*
out_grad1
,
const
__half
*
out_grad2
,
const
__half
*
X_data
,
const
__half
*
vars
,
const
__half
*
means
,
const
__half
*
gamma
,
__half
*
gamma_grad
,
__half
*
betta_grad
,
__half
*
inp_grad
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
])
{
int
threads
=
THREADS
;
dim3
grid_dim
(
hidden_dim
/
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
LayerNormBackward1
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
out_grad1
,
X_data
,
vars
,
means
,
gamma_grad
,
betta_grad
,
batch
,
hidden_dim
);
dim3
grid_dim2
(
batch
);
if
(
hidden_dim
>
8192
&&
hidden_dim
<=
16384
)
threads
<<=
1
;
else
if
(
hidden_dim
>
16384
&&
hidden_dim
<=
32768
)
threads
<<=
2
;
else
if
(
hidden_dim
>
32768
&&
hidden_dim
<=
65536
)
threads
<<=
3
;
else
if
(
hidden_dim
>
65536
)
throw
std
::
runtime_error
(
"Unsupport hidden_dim."
);
dim3
block_dim2
(
threads
/
2
);
LayerNormBackward2_fused_add
<<<
grid_dim2
,
block_dim2
,
0
,
stream
[
1
]
>>>
(
out_grad1
,
out_grad2
,
X_data
,
gamma
,
vars
,
means
,
inp_grad
,
hidden_dim
/
2
);
}
csrc/transformer_bak/normalize_kernels.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
namespace cg = cooperative_groups;
/*
Fused bias add, residual (elementwise) add, and normalization layer.
For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
For specific launch constraints, see the launch functions.
*/
#define NORM_REG (MAX_REGISTERS / 4)
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
float* means,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / WARP_SIZE;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if (high_index < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
if (training)
if (threadIdx.x == 0) means[row] = mean;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (threadIdx.x == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
__half* means,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> WARP_SIZE_BITS;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && threadIdx.x == 0) {
vars[row] = __float2half(variance);
means[row] = __float2half(mean);
}
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
T* vars,
T* means);
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
float* vars,
float* means)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
__half* vars,
__half* means)
{
int threads = 128;
dim3 grid_dim(batch_size);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
}
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / 32;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (threadIdx.x == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> WARP_SIZE_BITS;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && threadIdx.x == 0) vars[row] = __float2half(variance);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
T* vars);
/*
To tune this launch the following restrictions must be met:
For float:
row_stride == hidden_size
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
For half:
row_stride == hidden_size / 2
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
*/
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
float* vars)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
__half* vars)
{
int threads = 128;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
}
/* Normalize Gamma & Betta gradients
* Compute gradients using either X_hat or
* normalize input (invertible).
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/* Normalize Gamma & Betta gradients
* Compute gradients using the input to
* the normalize.
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/*
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is invertible!
* We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] *
sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
//hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
// out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> WARP_SIZE_BITS;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> WARP_SIZE_BITS;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 xu[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
__half mean_h = means[row];
__half2 mean_reg = __halves2half2(mean_h, mean_h);
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
xu[iterations] = (vals_hat_h[high_index] - mean_reg);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
// float2 result[iterations];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = X_vals[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] = X_vals[high_index];
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
inp_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] = vals_hat_h[high_index];
iterations++;
}
__half mean_h = means[row];
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[NORM_REG];
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
csrc/transformer_bak/softmax_kernels.cu
deleted
100644 → 0
View file @
7dd68788
#include <math.h>
#include "custom_cuda_layers.h"
#include "general_kernels.h"
namespace
cg
=
cooperative_groups
;
dim3
get_attn_softmax_grid
(
int
batch_size
,
int
heads
,
int
sequence_length
,
int
threads
)
{
int
seq_length4
=
sequence_length
/
4
;
int
block_compute_size
=
(
seq_length4
<
threads
?
(
int
)
pow
(
2.0
,
floor
(
log2
((
float
)(
threads
/
seq_length4
))))
:
1
);
// Note that the Y and Z dimensions are limited to 65535, while X is basically unlimited:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
// The batch size is typically relatively small, while the sequence length could potentially be
// arbitrarily large. We therefore place the batch size second to avoid hitting the Y limit.
unsigned
x
=
heads
*
sequence_length
/
block_compute_size
;
unsigned
y
=
batch_size
;
return
{
x
,
y
};
}
// Fused attention + softmax
template
<
int
tbSize
,
int
blockStride
,
int
tbSeq
>
__global__
void
attn_softmax
(
float
*
vals
,
const
float
*
attn_mask
,
int
heads
,
int
seq_length
,
int
iterations
)
{
__shared__
float
partialSum
[
MAX_WARP_NUM
];
int
warp_num
=
blockDim
.
x
>>
WARP_SIZE_BITS
;
int
iteration_stride
=
blockDim
.
x
;
int
block_width
=
blockStride
*
seq_length
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
tbSize
>
g
=
cg
::
tiled_partition
<
tbSize
>
(
b
);
int
batch
=
blockIdx
.
y
;
int
row
=
blockIdx
.
x
;
int
max_threads_in_sequence
=
std
::
max
(
seq_length
,
tbSeq
);
int
seq_lane
=
threadIdx
.
x
%
max_threads_in_sequence
;
int
data_offset
=
batch
*
(
gridDim
.
x
*
block_width
)
+
row
*
block_width
+
(
threadIdx
.
x
/
max_threads_in_sequence
)
*
seq_length
;
int
mask_offset
=
batch
*
seq_length
;
int
wid
=
threadIdx
.
x
>>
WARP_SIZE_BITS
;
int
lane
=
threadIdx
.
x
&
0x1f
;
float4
*
val_cast
=
reinterpret_cast
<
float4
*>
(
vals
);
const
float4
*
attn_mask_cast
=
reinterpret_cast
<
const
float4
*>
(
attn_mask
);
float4
data
[
MAX_THREAD_ITERATIONS
];
float
max_val
=
minus_infinity
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
iteration_stride
+
seq_lane
;
if
(
data_id
<
seq_length
)
{
float4
mask
=
attn_mask_cast
[
mask_offset
+
data_id
];
data
[
i
]
=
val_cast
[
data_offset
+
data_id
];
data
[
i
].
x
+=
mask
.
x
;
data
[
i
].
y
+=
mask
.
y
;
data
[
i
].
z
+=
mask
.
z
;
data
[
i
].
w
+=
mask
.
w
;
max_val
=
(
data
[
i
].
x
>
max_val
?
data
[
i
].
x
:
max_val
);
max_val
=
(
data
[
i
].
y
>
max_val
?
data
[
i
].
y
:
max_val
);
max_val
=
(
data
[
i
].
z
>
max_val
?
data
[
i
].
z
:
max_val
);
max_val
=
(
data
[
i
].
w
>
max_val
?
data
[
i
].
w
:
max_val
);
}
else
{
data
[
i
].
x
=
minus_infinity
;
data
[
i
].
y
=
minus_infinity
;
data
[
i
].
z
=
minus_infinity
;
data
[
i
].
w
=
minus_infinity
;
}
}
for
(
int
i
=
1
;
i
<
tbSize
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
if
(
seq_length
>
tbSize
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
max_val
;
b
.
sync
();
if
(
lane
<
warp_num
)
max_val
=
partialSum
[
lane
];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
int
iters
=
warp_num
;
if
(
seq_length
<
iteration_stride
)
iters
=
warp_num
/
(
iteration_stride
/
max_threads_in_sequence
);
for
(
int
i
=
1
;
i
<
iters
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
max_val
=
g
.
shfl
(
max_val
,
threadIdx
.
x
/
tbSize
);
}
float
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
data
[
i
].
x
=
__expf
(
data
[
i
].
x
-
max_val
);
data
[
i
].
y
=
__expf
(
data
[
i
].
y
-
max_val
);
data
[
i
].
z
=
__expf
(
data
[
i
].
z
-
max_val
);
data
[
i
].
w
=
__expf
(
data
[
i
].
w
-
max_val
);
sum
+=
(
data
[
i
].
x
+
data
[
i
].
y
+
data
[
i
].
z
+
data
[
i
].
w
);
}
for
(
int
i
=
1
;
i
<
tbSize
;
i
*=
2
)
{
sum
+=
g
.
shfl_xor
(
sum
,
i
);
}
if
(
seq_length
>
tbSize
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
sum
;
b
.
sync
();
if
(
lane
<
warp_num
)
sum
=
partialSum
[
lane
];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
int
iters
=
warp_num
;
if
(
seq_length
<
iteration_stride
)
iters
=
warp_num
/
(
iteration_stride
/
max_threads_in_sequence
);
for
(
int
i
=
1
;
i
<
iters
;
i
*=
2
)
{
sum
+=
g
.
shfl_xor
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
threadIdx
.
x
/
tbSize
);
}
sum
+=
1e-6
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
data
[
i
].
x
/=
sum
;
data
[
i
].
y
/=
sum
;
data
[
i
].
z
/=
sum
;
data
[
i
].
w
/=
sum
;
int
data_id
=
i
*
iteration_stride
+
seq_lane
;
if
(
data_id
<
seq_length
)
val_cast
[
data_offset
+
data_id
]
=
data
[
i
];
}
}
template
<
int
tbSize
,
int
blockStride
,
int
tbSeq
>
__global__
void
attn_softmax
(
__half
*
vals
,
const
__half
*
attn_mask
,
int
heads
,
int
seq_length
,
int
iterations
)
{
#ifdef HALF_PRECISION_AVAILABLE
__shared__
float
partialSum
[
MAX_WARP_NUM
];
int
warp_num
=
blockDim
.
x
>>
WARP_SIZE_BITS
;
int
iteration_stride
=
blockDim
.
x
;
int
block_width
=
blockStride
*
seq_length
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
tbSize
>
g
=
cg
::
tiled_partition
<
tbSize
>
(
b
);
int
batch
=
blockIdx
.
y
;
int
row
=
blockIdx
.
x
;
int
max_threads_in_sequence
=
std
::
max
(
seq_length
,
tbSeq
);
int
seq_lane
=
threadIdx
.
x
%
max_threads_in_sequence
;
int
data_offset
=
batch
*
(
gridDim
.
x
*
block_width
)
+
row
*
block_width
+
(
threadIdx
.
x
/
max_threads_in_sequence
)
*
seq_length
;
int
mask_offset
=
batch
*
seq_length
;
int
wid
=
threadIdx
.
x
>>
WARP_SIZE_BITS
;
int
lane
=
threadIdx
.
x
&
0x1f
;
float2
*
val_cast
=
reinterpret_cast
<
float2
*>
(
vals
);
const
float2
*
attn_mask_cast
=
reinterpret_cast
<
const
float2
*>
(
attn_mask
);
val_cast
+=
data_offset
;
attn_mask_cast
+=
mask_offset
;
float2
low_data
[
MAX_THREAD_ITERATIONS
];
float2
high_data
[
MAX_THREAD_ITERATIONS
];
float
max_val
=
minus_infinity
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
iteration_stride
+
seq_lane
;
if
(
data_id
<
seq_length
)
{
float2
data
=
val_cast
[
data_id
];
float2
mask
=
attn_mask_cast
[
data_id
];
__half2
*
data_arr
=
reinterpret_cast
<
__half2
*>
(
&
data
);
__half2
*
mask_arr
=
reinterpret_cast
<
__half2
*>
(
&
mask
);
low_data
[
i
]
=
__half22float2
(
data_arr
[
0
]);
high_data
[
i
]
=
__half22float2
(
data_arr
[
1
]);
float2
low_mask
=
__half22float2
(
mask_arr
[
0
]);
float2
high_mask
=
__half22float2
(
mask_arr
[
1
]);
low_data
[
i
].
x
+=
low_mask
.
x
;
low_data
[
i
].
y
+=
low_mask
.
y
;
high_data
[
i
].
x
+=
high_mask
.
x
;
high_data
[
i
].
y
+=
high_mask
.
y
;
max_val
=
(
low_data
[
i
].
x
>
max_val
?
low_data
[
i
].
x
:
max_val
);
max_val
=
(
low_data
[
i
].
y
>
max_val
?
low_data
[
i
].
y
:
max_val
);
max_val
=
(
high_data
[
i
].
x
>
max_val
?
high_data
[
i
].
x
:
max_val
);
max_val
=
(
high_data
[
i
].
y
>
max_val
?
high_data
[
i
].
y
:
max_val
);
}
}
for
(
int
i
=
1
;
i
<
tbSize
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
if
(
seq_length
>
tbSize
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
max_val
;
b
.
sync
();
if
(
lane
<
warp_num
)
max_val
=
partialSum
[
lane
];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
int
iters
=
warp_num
;
if
(
seq_length
<
iteration_stride
)
iters
=
warp_num
/
(
iteration_stride
/
max_threads_in_sequence
);
for
(
int
i
=
1
;
i
<
iters
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
max_val
=
g
.
shfl
(
max_val
,
threadIdx
.
x
/
tbSize
);
}
float
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
iteration_stride
+
seq_lane
;
if
(
data_id
<
seq_length
)
{
low_data
[
i
].
x
=
__expf
(
low_data
[
i
].
x
-
max_val
);
low_data
[
i
].
y
=
__expf
(
low_data
[
i
].
y
-
max_val
);
high_data
[
i
].
x
=
__expf
(
high_data
[
i
].
x
-
max_val
);
high_data
[
i
].
y
=
__expf
(
high_data
[
i
].
y
-
max_val
);
sum
+=
(
low_data
[
i
].
x
+
low_data
[
i
].
y
+
high_data
[
i
].
x
+
high_data
[
i
].
y
);
}
}
for
(
int
i
=
1
;
i
<
tbSize
;
i
*=
2
)
{
sum
+=
g
.
shfl_xor
(
sum
,
i
);
}
if
(
seq_length
>
tbSize
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
sum
;
b
.
sync
();
if
(
lane
<
warp_num
)
sum
=
partialSum
[
lane
];
#ifndef __STOCHASTIC_MODE__
b
.
sync
();
#endif
int
iters
=
warp_num
;
if
(
seq_length
<
iteration_stride
)
iters
=
warp_num
/
(
iteration_stride
/
max_threads_in_sequence
);
for
(
int
i
=
1
;
i
<
iters
;
i
*=
2
)
{
sum
+=
g
.
shfl_xor
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
threadIdx
.
x
/
tbSize
);
}
sum
+=
1e-6
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
iteration_stride
+
seq_lane
;
if
(
data_id
<
seq_length
)
{
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
low_data
[
i
].
x
/=
sum
;
low_data
[
i
].
y
/=
sum
;
high_data
[
i
].
x
/=
sum
;
high_data
[
i
].
y
/=
sum
;
result_h
[
0
]
=
__float22half2_rn
(
low_data
[
i
]);
result_h
[
1
]
=
__float22half2_rn
(
high_data
[
i
]);
val_cast
[
data_id
]
=
result_f
;
}
}
#endif
}
template
<
typename
T
>
void
launch_attn_softmax
(
T
*
,
const
T
*
,
int
,
int
,
int
,
cudaStream_t
);
template
<
>
void
launch_attn_softmax
<
float
>
(
float
*
vals
,
const
float
*
attn_mask
,
int
batch_size
,
int
heads
,
int
sequence_length
,
cudaStream_t
stream
)
{
const
int
threads
=
128
;
int
seq_length4
=
sequence_length
/
4
;
dim3
grid_dim
=
get_attn_softmax_grid
(
batch_size
,
heads
,
sequence_length
,
threads
);
int
subblock_max_workload
=
MAX_THREAD_ITERATIONS
*
4
*
threads
;
dim3
block_dim
(
seq_length4
>
threads
?
((
sequence_length
+
subblock_max_workload
-
1
)
/
subblock_max_workload
*
threads
)
:
threads
);
int
iterations
=
(
sequence_length
<
subblock_max_workload
?
(
seq_length4
+
threads
-
1
)
/
threads
:
MAX_THREAD_ITERATIONS
);
if
(
sequence_length
<=
8
)
attn_softmax
<
2
,
(
threads
/
2
),
2
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
16
)
attn_softmax
<
4
,
(
threads
/
4
),
4
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
32
)
attn_softmax
<
8
,
(
threads
/
8
),
8
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
64
)
attn_softmax
<
16
,
(
threads
/
16
),
16
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
128
)
attn_softmax
<
32
,
(
threads
/
32
),
32
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
256
)
attn_softmax
<
32
,
(
threads
/
64
),
64
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
{
const
int
threads
=
256
;
dim3
grid_dim
=
get_attn_softmax_grid
(
batch_size
,
heads
,
sequence_length
,
threads
);
int
subblock_max_workload
=
MAX_THREAD_ITERATIONS
*
4
*
threads
;
dim3
block_dim
(
seq_length4
>
threads
?
((
sequence_length
+
subblock_max_workload
-
1
)
/
subblock_max_workload
*
threads
)
:
threads
);
iterations
=
(
sequence_length
<
subblock_max_workload
?
(
seq_length4
+
threads
-
1
)
/
threads
:
MAX_THREAD_ITERATIONS
);
if
(
sequence_length
<=
512
)
attn_softmax
<
32
,
(
threads
/
128
),
128
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<
(
MAX_THREADS
*
MAX_THREAD_ITERATIONS
*
4
))
attn_softmax
<
32
,
1
,
128
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
throw
std
::
runtime_error
(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!"
);
}
}
template
<
>
void
launch_attn_softmax
<
__half
>
(
__half
*
vals
,
const
__half
*
attn_mask
,
int
batch_size
,
int
heads
,
int
sequence_length
,
cudaStream_t
stream
)
{
const
int
threads
=
128
;
int
seq_length4
=
sequence_length
/
4
;
dim3
grid_dim
=
get_attn_softmax_grid
(
batch_size
,
heads
,
sequence_length
,
threads
);
int
subblock_max_workload
=
MAX_THREAD_ITERATIONS
*
4
*
threads
;
dim3
block_dim
(
seq_length4
>
threads
?
((
sequence_length
+
subblock_max_workload
-
1
)
/
subblock_max_workload
*
threads
)
:
threads
);
int
iterations
=
(
sequence_length
<
subblock_max_workload
?
(
seq_length4
+
threads
-
1
)
/
threads
:
MAX_THREAD_ITERATIONS
);
if
(
sequence_length
<=
8
)
attn_softmax
<
2
,
(
threads
/
2
),
2
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
16
)
attn_softmax
<
4
,
(
threads
/
4
),
4
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
32
)
attn_softmax
<
8
,
(
threads
/
8
),
8
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
64
)
attn_softmax
<
16
,
(
threads
/
16
),
16
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
128
)
attn_softmax
<
32
,
(
threads
/
32
),
32
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<=
256
)
attn_softmax
<
32
,
(
threads
/
64
),
64
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
{
const
int
threads
=
256
;
dim3
grid_dim
=
get_attn_softmax_grid
(
batch_size
,
heads
,
sequence_length
,
threads
);
int
subblock_max_workload
=
MAX_THREAD_ITERATIONS
*
4
*
threads
;
dim3
block_dim
(
seq_length4
>
threads
?
((
sequence_length
+
subblock_max_workload
-
1
)
/
subblock_max_workload
*
threads
)
:
threads
);
iterations
=
(
sequence_length
<
subblock_max_workload
?
(
seq_length4
+
threads
-
1
)
/
threads
:
MAX_THREAD_ITERATIONS
);
if
(
sequence_length
<=
512
)
attn_softmax
<
32
,
(
threads
/
128
),
128
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
if
(
sequence_length
<
(
MAX_THREADS
*
MAX_THREAD_ITERATIONS
*
4
))
attn_softmax
<
32
,
1
,
128
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
attn_mask
,
heads
,
seq_length4
,
iterations
);
else
throw
std
::
runtime_error
(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!"
);
}
}
template
<
typename
T
,
int
tbSize
,
int
blockStride
>
__global__
void
softmax_backward_kernel
(
T
*
out_grad
,
const
T
*
soft_inp
,
int
seq_length
)
{
__shared__
float
partialSum
[
MAX_WARP_NUM
];
int
warp_num
=
blockDim
.
x
>>
WARP_SIZE_BITS
;
// warp-count = num_threads / WARP_SIZE (32)
int
iteration_stride
=
blockDim
.
x
;
int
block_width
=
blockStride
*
seq_length
;
int
iterations
=
(
seq_length
<
(
MAX_THREAD_ITERATIONS
*
iteration_stride
)
?
(
seq_length
+
iteration_stride
-
1
)
/
iteration_stride
:
MAX_THREAD_ITERATIONS
);
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
tbSize
>
g
=
cg
::
tiled_partition
<
tbSize
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
wid
=
id
>>
WARP_SIZE_BITS
;
int
lane
=
id
&
0x1f
;
T
val_reg
[
MAX_THREAD_ITERATIONS
];
T
soft_reg
[
MAX_THREAD_ITERATIONS
];
float
grad_reg
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
iteration_stride
+
id
;
if
(
data_id
<
block_width
)
{
val_reg
[
i
]
=
out_grad
[
row
*
block_width
+
data_id
];
soft_reg
[
i
]
=
soft_inp
[
row
*
block_width
+
data_id
];
grad_reg
+=
((
float
)
val_reg
[
i
]
*
(
float
)
soft_reg
[
i
]);
// if done in half, the multiplication, we may lose
// 2% of accuracy in computation!!
}
}
for
(
int
i
=
1
;
i
<
tbSize
;
i
*=
2
)
grad_reg
+=
g
.
shfl_xor
(
grad_reg
,
i
);
if
(
seq_length
>
tbSize
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
grad_reg
;
b
.
sync
();
if
(
lane
<
warp_num
)
grad_reg
=
partialSum
[
lane
];
int
iters
=
warp_num
;
if
(
seq_length
<
iteration_stride
)
iters
=
warp_num
/
(
iteration_stride
/
seq_length
);
for
(
int
i
=
1
;
i
<
iters
;
i
*=
2
)
grad_reg
+=
g
.
shfl_xor
(
grad_reg
,
i
);
grad_reg
=
g
.
shfl
(
grad_reg
,
id
/
tbSize
);
}
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
iteration_stride
+
id
;
if
(
data_id
<
block_width
)
{
float
temp
=
(
float
)
soft_reg
[
i
]
*
((
float
)
val_reg
[
i
]
-
grad_reg
);
out_grad
[
row
*
block_width
+
data_id
]
=
(
T
)
temp
;
}
}
}
template
<
typename
T
,
int
ITERATIONS
>
__global__
void
softmax_backward_kernel_v2
(
T
*
grad
/* input & output*/
,
const
T
*
output
,
int
softmax_length
)
{
int
batch_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
offset
=
batch_idx
*
softmax_length
+
threadIdx
.
x
;
grad
+=
offset
;
output
+=
offset
;
T
grad_reg
[
ITERATIONS
];
T
output_reg
[
ITERATIONS
];
float
sum
=
0.0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ITERATIONS
;
++
i
)
{
int
curr_idx
=
threadIdx
.
x
+
i
*
WARP_SIZE
;
if
(
curr_idx
<
softmax_length
)
{
grad_reg
[
i
]
=
grad
[
i
*
WARP_SIZE
];
output_reg
[
i
]
=
output
[
i
*
WARP_SIZE
];
sum
+=
(
float
)
grad_reg
[
i
]
*
(
float
)
output_reg
[
i
];
}
}
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_xor
(
sum
,
i
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ITERATIONS
;
++
i
)
{
int
curr_idx
=
threadIdx
.
x
+
i
*
WARP_SIZE
;
if
(
curr_idx
<
softmax_length
)
grad
[
i
*
WARP_SIZE
]
=
(
float
)
output_reg
[
i
]
*
((
float
)
grad_reg
[
i
]
-
sum
);
}
}
template
<
typename
T
>
void
launch_attn_softmax_backward_v2
(
T
*
out_grad
,
const
T
*
soft_inp
,
int
batch_size
,
int
heads
,
int
seq_length
,
cudaStream_t
stream
)
{
const
int
warps_per_block
=
4
;
dim3
grid_dim
(
batch_size
*
heads
*
seq_length
/
warps_per_block
);
dim3
block_dim
(
WARP_SIZE
,
warps_per_block
);
if
(
seq_length
<=
32
)
softmax_backward_kernel_v2
<
T
,
1
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
64
)
softmax_backward_kernel_v2
<
T
,
2
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
128
)
softmax_backward_kernel_v2
<
T
,
4
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
256
)
softmax_backward_kernel_v2
<
T
,
8
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
384
)
softmax_backward_kernel_v2
<
T
,
12
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
512
)
softmax_backward_kernel_v2
<
T
,
16
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
768
)
softmax_backward_kernel_v2
<
T
,
24
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
1024
)
softmax_backward_kernel_v2
<
T
,
32
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
if
(
seq_length
<=
2048
)
softmax_backward_kernel_v2
<
T
,
64
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
seq_length
);
else
throw
std
::
runtime_error
(
std
::
string
(
"Special sequence length found in softmax backward, seq_length: "
)
+
std
::
to_string
(
seq_length
));
}
template
void
launch_attn_softmax_backward_v2
<
__half
>(
__half
*
out_grad
,
const
__half
*
soft_inp
,
int
batch_size
,
int
heads
,
int
seq_length
,
cudaStream_t
stream
);
template
void
launch_attn_softmax_backward_v2
<
float
>(
float
*
out_grad
,
const
float
*
soft_inp
,
int
batch_size
,
int
heads
,
int
seq_length
,
cudaStream_t
stream
);
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment