Unverified Commit c78c29f9 authored by Reza Yazdani's avatar Reza Yazdani Committed by GitHub
Browse files

supporting different hidden dimensions (#559)



* supporting different hidden dimensions

* add support for larger hidden dimensions (greater than 8K)

* remove empty line

* add loop unrolling factor for dropout kernels

* update different kernels based on the reviews
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 17f36f1b
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#define MAX_THREAD_ITERATIONS 8 // Maximum 8K #define MAX_THREAD_ITERATIONS 8 // Maximum 8K
#define MAX_WARP_NUM 32 #define MAX_WARP_NUM 32
#define MAX_REGISTERS 256
// Fused bias add with gelu activation // Fused bias add with gelu activation
template <typename T> template <typename T>
void launch_bias_gelu(const T* input, void launch_bias_gelu(const T* input,
......
...@@ -133,7 +133,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, ...@@ -133,7 +133,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m, m,
n, n,
k, k,
......
This diff is collapsed.
...@@ -86,40 +86,29 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp, ...@@ -86,40 +86,29 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols); column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
} }
__global__ void fused_add2_kernel(float* out, __global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
const float* inp1,
const float* inp2,
int size,
int row_stride)
{ {
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1); const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2); const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out); float4* out_4 = reinterpret_cast<float4*>(out);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 val; float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id]; float4 inp1_reg = inp1_4[j];
float4 inp2_reg = inp2_4[row * row_stride + id]; float4 inp2_reg = inp2_4[j];
val.x = inp1_reg.x + inp2_reg.x; val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y; val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z; val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w; val.w = inp1_reg.w + inp2_reg.w;
out_4[row * row_stride + id] = val; out_4[j] = val;
}
} }
__global__ void fused_add2_kernel(__half* out, __global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
const __half* inp1,
const __half* inp2,
int size,
int row_stride)
{ {
int row = blockIdx.x;
int id = threadIdx.x;
float2 inp1_4; float2 inp1_4;
float2 inp2_4; float2 inp2_4;
...@@ -129,8 +118,10 @@ __global__ void fused_add2_kernel(__half* out, ...@@ -129,8 +118,10 @@ __global__ void fused_add2_kernel(__half* out,
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1); const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2); const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
inp1_4 = inp1_arr[row * row_stride + id]; CUDA_1D_KERNEL_LOOP(j, N)
inp2_4 = inp2_arr[row * row_stride + id]; {
inp1_4 = inp1_arr[j];
inp2_4 = inp2_arr[j];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]); float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]); float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
...@@ -150,7 +141,8 @@ __global__ void fused_add2_kernel(__half* out, ...@@ -150,7 +141,8 @@ __global__ void fused_add2_kernel(__half* out,
val_h[1] = __float22half2_rn(inp1_h_f_1); val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out); float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f; out_4[j] = val_f;
}
} }
template <> template <>
...@@ -162,12 +154,12 @@ void launch_fused_add2<float>(float* out, ...@@ -162,12 +154,12 @@ void launch_fused_add2<float>(float* out,
int hidden_dim, int hidden_dim,
cudaStream_t& stream) cudaStream_t& stream)
{ {
dim3 grid_dim(batch_size * seq_length); int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim(hidden_dim / 4); dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>( fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
} }
template <> template <>
...@@ -179,12 +171,12 @@ void launch_fused_add2<__half>(__half* out, ...@@ -179,12 +171,12 @@ void launch_fused_add2<__half>(__half* out,
int hidden_dim, int hidden_dim,
cudaStream_t& stream) cudaStream_t& stream)
{ {
dim3 grid_dim(batch_size * seq_length); int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim(hidden_dim / 4); dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>( fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
} }
__global__ void fused_add3_kernel(float* out, __global__ void fused_add3_kernel(float* out,
......
This diff is collapsed.
...@@ -53,26 +53,32 @@ void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols, ...@@ -53,26 +53,32 @@ void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols,
} }
template <typename T> template <typename T>
__global__ void transform_0213(T* output, const T* vals, int hidden_dim, int seq_length, int heads); __global__ void transform_0213(T* output,
const T* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <> template <>
__global__ void transform_0213<float>(float* output, __global__ void transform_0213<float>(float* output,
const float* vals, const float* vals,
int hidden_dim, int hidden_dim,
int seq_length, int seq_length,
int heads) int heads,
int head_ext)
{ {
int d0_stride = hidden_dim * seq_length / 4; int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim / 4; int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads / 4; int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride; int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride; int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length; int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127) int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y; // Head (0-11) int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4) int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals); const float4* vals_vec = reinterpret_cast<const float4*>(vals);
...@@ -87,21 +93,22 @@ __global__ void transform_0213<__half>(__half* output, ...@@ -87,21 +93,22 @@ __global__ void transform_0213<__half>(__half* output,
const __half* vals, const __half* vals,
int hidden_dim, int hidden_dim,
int seq_length, int seq_length,
int heads) int heads,
int head_ext)
{ {
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * seq_length / 8; int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim / 8; int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads / 8; int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride; int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride; int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length; int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127) int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y; // Head (0-11) int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4) int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1]; float4 vals_arr[1];
...@@ -123,10 +130,13 @@ void launch_transform_0213<float>(float* output, ...@@ -123,10 +130,13 @@ void launch_transform_0213<float>(float* output,
int heads, int heads,
cudaStream_t stream) cudaStream_t stream)
{ {
dim3 block_dim(hidden_dim / heads / 4, heads); hidden_dim >>= 2;
dim3 grid_dim(batch_size, seq_length); int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
transform_0213<float> transform_0213<float>
<<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads); <<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads, head_ext);
} }
template <> template <>
...@@ -138,10 +148,12 @@ void launch_transform_0213<__half>(__half* output, ...@@ -138,10 +148,12 @@ void launch_transform_0213<__half>(__half* output,
int heads, int heads,
cudaStream_t stream) cudaStream_t stream)
{ {
dim3 block_dim(hidden_dim / heads / 8, heads); hidden_dim >>= 3;
dim3 grid_dim(batch_size, seq_length); int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
transform_0213<__half> transform_0213<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads); <<<grid_dim, block_dim, 0, stream>>>(output, vals, hidden_dim, seq_length, heads, head_ext);
} }
// Bias add // Bias add
...@@ -151,7 +163,8 @@ __global__ void bias_add_transform_0213(T* output, ...@@ -151,7 +163,8 @@ __global__ void bias_add_transform_0213(T* output,
const T* bias, const T* bias,
int hidden_dim, int hidden_dim,
int seq_length, int seq_length,
int heads); int heads,
int head_ext);
template <> template <>
__global__ void bias_add_transform_0213<float>(float* output, __global__ void bias_add_transform_0213<float>(float* output,
...@@ -159,11 +172,12 @@ __global__ void bias_add_transform_0213<float>(float* output, ...@@ -159,11 +172,12 @@ __global__ void bias_add_transform_0213<float>(float* output,
const float* bias, const float* bias,
int hidden_dim, int hidden_dim,
int seq_length, int seq_length,
int heads) int heads,
int head_ext)
{ {
int d0_stride = hidden_dim * seq_length / 4; int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim / 4; int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads / 4; int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride; int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride; int d1_out_stride = d2_stride;
...@@ -171,16 +185,16 @@ __global__ void bias_add_transform_0213<float>(float* output, ...@@ -171,16 +185,16 @@ __global__ void bias_add_transform_0213<float>(float* output,
int d0 = blockIdx.x; // Batch int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127) int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z; // Hidden count int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y; // Head (0-11) int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4) int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals); const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias); const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output); float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride * gridDim.z + cnt * d1_stride + float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride +
d1 * d1_stride * gridDim.z + d2 * d2_stride + d3]; d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3];
float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3]; float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];
float4 outputs; float4 outputs;
...@@ -198,6 +212,65 @@ __global__ void bias_add_transform_0213<float>(float* output, ...@@ -198,6 +212,65 @@ __global__ void bias_add_transform_0213<float>(float* output,
template <> template <>
__global__ void bias_add_transform_0213<__half>(__half* output, __global__ void bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr;
float4 bias_arr;
float4 output_arr;
__half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(&output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
bias_vec += (cnt * d1_stride);
bias_vec += (d2 * d2_stride);
output_vec += (cnt * d0_stride * gridDim.x);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_stride);
output_vec += (d2 * d2_out_stride);
bias_arr = bias_vec[d3];
vals_arr = vals_vec[d3];
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
output_vec[d3] = output_arr;
#endif
}
__global__ void bias_add_transform_0213_v2(__half* output,
const __half* vals, const __half* vals,
const __half* bias, const __half* bias,
int hidden_dim, int hidden_dim,
...@@ -207,9 +280,9 @@ __global__ void bias_add_transform_0213<__half>(__half* output, ...@@ -207,9 +280,9 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
__shared__ float4 in_data[3072]; __shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length / 8; int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim / 8; int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads / 8; int d2_stride = hidden_dim / heads;
int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8 int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8
int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8 int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8
...@@ -237,6 +310,8 @@ __global__ void bias_add_transform_0213<__half>(__half* output, ...@@ -237,6 +310,8 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
int iter_index = cnt * d1_stride + d2 * d2_stride + d3; int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1); int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
bias_arr[0] = bias_vec[iter_index]; bias_arr[0] = bias_vec[iter_index];
#pragma unroll
for (int iter = 0; iter < 2; iter++) { for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index; int iter_id = iter * iteration_stride + iter_index;
vals_arr[0] = vals_vec[input_offset + iter_id]; vals_arr[0] = vals_vec[input_offset + iter_id];
...@@ -255,6 +330,8 @@ __global__ void bias_add_transform_0213<__half>(__half* output, ...@@ -255,6 +330,8 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1); int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);
int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride; int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
#pragma unroll
for (int iter = 0; iter < 2; iter++) { for (int iter = 0; iter < 2; iter++) {
int iter_row = (iter * iteration_stride) + head_count; int iter_row = (iter * iteration_stride) + head_count;
int iter_offset = int iter_offset =
...@@ -277,10 +354,14 @@ void launch_bias_add_transform_0213<float>(float* output, ...@@ -277,10 +354,14 @@ void launch_bias_add_transform_0213<float>(float* output,
cudaStream_t stream, cudaStream_t stream,
int trans_count) int trans_count)
{ {
dim3 block_dim(hidden_dim / heads / 4, heads); hidden_dim >>= 2;
dim3 grid_dim(batch_size, seq_length, trans_count); int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
bias_add_transform_0213<float>
<<<grid_dim, block_dim, 0, stream>>>(output, vals, bias, hidden_dim, seq_length, heads); dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
bias_add_transform_0213<float><<<grid_dim, block_dim, 0, stream>>>(
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
} }
template <> template <>
...@@ -294,32 +375,47 @@ void launch_bias_add_transform_0213<__half>(__half* output, ...@@ -294,32 +375,47 @@ void launch_bias_add_transform_0213<__half>(__half* output,
cudaStream_t stream, cudaStream_t stream,
int trans_count) int trans_count)
{ {
dim3 block_dim(hidden_dim / heads / 8, heads, trans_count); hidden_dim >>= 3;
if (hidden_dim > 128 || hidden_dim < 16) {
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
bias_add_transform_0213<__half><<<grid_dim, block_dim, 0, stream>>>(
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
} else {
dim3 block_dim(hidden_dim / heads, heads, trans_count);
dim3 grid_dim(batch_size, seq_length / 2); dim3 grid_dim(batch_size, seq_length / 2);
bias_add_transform_0213<__half> bias_add_transform_0213_v2<<<grid_dim, block_dim, 0, stream>>>(
<<<grid_dim, block_dim, 0, stream>>>(output, vals, bias, hidden_dim, seq_length, heads); output, vals, bias, hidden_dim, seq_length, heads);
}
} }
template <typename T> template <typename T>
__global__ void transform4d_0213(T* out, const T* in, int heads, int seq_length, int hidden_dim); __global__ void transform4d_0213(T* out,
const T* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext);
template <> template <>
__global__ void transform4d_0213<float>(float* out, __global__ void transform4d_0213<float>(float* out,
const float* in, const float* in,
int heads, int heads,
int seq_length, int seq_length,
int hidden_dim) int hidden_dim,
int head_ext)
{ {
int d0_stride = hidden_dim * seq_length / 4; int d0_stride = hidden_dim * seq_length;
int d1_stride = d0_stride / heads; int d1_stride = d0_stride / heads;
int d2_stride = hidden_dim / heads / 4; int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride; int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride; int d1_out_stride = d2_stride;
int d2_out_stride = hidden_dim / 4; int d2_out_stride = hidden_dim;
int d0 = blockIdx.x; // Batch int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / ((seq_length + blockDim.y - 1) / blockDim.y); // Head int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head
int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length; int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
int cnt = blockIdx.z; int cnt = blockIdx.z;
int d3 = threadIdx.x; // Values (groups of 8) int d3 = threadIdx.x; // Values (groups of 8)
...@@ -337,6 +433,43 @@ __global__ void transform4d_0213<float>(float* out, ...@@ -337,6 +433,43 @@ __global__ void transform4d_0213<float>(float* out,
template <> template <>
__global__ void transform4d_0213<__half>(__half* out, __global__ void transform4d_0213<__half>(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * (seq_length / head_ext);
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head
int d2 = blockIdx.z / head_ext; // Sequence
int cnt = blockIdx.y; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
in_vec += (cnt * d0_stride * gridDim.x);
in_vec += (d0 * d0_stride);
in_vec += (d2 * d2_stride);
in_vec += (d1 * d2_stride * seq_length);
out_vec += (cnt * d1_stride);
out_vec += (d1 * d2_stride);
out_vec += (d0 * d0_stride * gridDim.y);
out_vec += (d2 * d1_stride * gridDim.y);
out_vec[d3] = in_vec[d3];
#endif
}
__global__ void transform4d_0213_v2(__half* out,
const __half* in, const __half* in,
int heads, int heads,
int seq_length, int seq_length,
...@@ -345,9 +478,9 @@ __global__ void transform4d_0213<__half>(__half* out, ...@@ -345,9 +478,9 @@ __global__ void transform4d_0213<__half>(__half* out,
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
__shared__ float4 in_data[3072]; __shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length / 8; int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim / 8; int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads / 8; int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y; // Head int d1 = threadIdx.y; // Head
...@@ -358,11 +491,12 @@ __global__ void transform4d_0213<__half>(__half* out, ...@@ -358,11 +491,12 @@ __global__ void transform4d_0213<__half>(__half* out,
const float4* in_vec = reinterpret_cast<const float4*>(in); const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out); float4* out_vec = reinterpret_cast<float4*>(out);
int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + d1 % 2 * d2_stride; int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride;
int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1); int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
int iteration_stride = blockDim.z * (blockDim.y >> 1); int iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_stride * gridDim.x); int matrix_stride = (d0_stride * gridDim.x);
#pragma unroll
for (int iter = 0; iter < 2; iter++) { for (int iter = 0; iter < 2; iter++) {
int iter_row = iter * iteration_stride + head_count; int iter_row = iter * iteration_stride + head_count;
int iter_offset = (iter_row % blockDim.y) * d2_stride; int iter_offset = (iter_row % blockDim.y) * d2_stride;
...@@ -377,6 +511,7 @@ __global__ void transform4d_0213<__half>(__half* out, ...@@ -377,6 +511,7 @@ __global__ void transform4d_0213<__half>(__half* out,
int iter_index = cnt * d1_stride + d1 * d2_stride + d3; int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1); int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);
#pragma unroll
for (int iter = 0; iter < 2; iter++) { for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index; int iter_id = iter * iteration_stride + iter_index;
out_vec[output_offset + iter_id] = in_data[iter_id]; out_vec[output_offset + iter_id] = in_data[iter_id];
...@@ -395,10 +530,11 @@ void launch_transform4d_0213<float>(float* out, ...@@ -395,10 +530,11 @@ void launch_transform4d_0213<float>(float* out,
cudaStream_t stream, cudaStream_t stream,
int trans_count) int trans_count)
{ {
dim3 grid_dims(batch_size, heads * ((seq_length + 7) / 8), trans_count); hidden_dim >>= 2;
dim3 block_dims(hidden_dim / heads / 4, 8); dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count);
dim3 block_dims(hidden_dim / heads, 8);
transform4d_0213<float> transform4d_0213<float>
<<<grid_dims, block_dims, 0, stream>>>(out, in, heads, seq_length, hidden_dim); <<<grid_dims, block_dims, 0, stream>>>(out, in, heads, seq_length, hidden_dim, 1);
} }
template <> template <>
...@@ -411,8 +547,17 @@ void launch_transform4d_0213<__half>(__half* out, ...@@ -411,8 +547,17 @@ void launch_transform4d_0213<__half>(__half* out,
cudaStream_t stream, cudaStream_t stream,
int trans_count) int trans_count)
{ {
hidden_dim >>= 3;
if (hidden_dim > 128 || hidden_dim < 16) {
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext));
dim3 block_dims(hidden_dim / heads, (heads / head_ext));
transform4d_0213<__half><<<grid_dims, block_dims, 0, stream>>>(
out, in, heads, seq_length, hidden_dim, head_ext);
} else {
dim3 grid_dims(batch_size, seq_length / 2); dim3 grid_dims(batch_size, seq_length / 2);
dim3 block_dims(hidden_dim / heads / 8, heads, trans_count); dim3 block_dims(hidden_dim / heads, heads, trans_count);
transform4d_0213<__half> transform4d_0213_v2<<<grid_dims, block_dims, 0, stream>>>(
<<<grid_dims, block_dims, 0, stream>>>(out, in, heads, seq_length, hidden_dim); out, in, heads, seq_length, hidden_dim);
}
} }
...@@ -261,6 +261,8 @@ def run_backward(ds_config, atol=1e-2, verbose=False): ...@@ -261,6 +261,8 @@ def run_backward(ds_config, atol=1e-2, verbose=False):
(3,1024,120,16,24,True,True, 0.05), (3,1024,120,16,24,True,True, 0.05),
(3,1024,56,16,24,False,False, 0.1), (3,1024,56,16,24,False,False, 0.1),
(3,1024,56,16,24,False,True, 0.2), (3,1024,56,16,24,False,True, 0.2),
(3,128,56,2,24,False,False, 0.1),
(3,128,56,2,24,False,True, 0.2),
]) # yapf: disable ]) # yapf: disable
def test_backward(batch_size, def test_backward(batch_size,
hidden_size, hidden_size,
......
...@@ -213,6 +213,8 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): ...@@ -213,6 +213,8 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
# FP16 test cases can only run on the devices support FP16. # FP16 test cases can only run on the devices support FP16.
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16', @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[ [
(8,256,128,4,3,True,False),
(8,256,128,4,3,True,True),
(64,1024,128,16,3,True,False), (64,1024,128,16,3,True,False),
(64,1024,128,16,3,True,True), (64,1024,128,16,3,True,True),
(8,1024,384,16,3,True,False), (8,1024,384,16,3,True,False),
...@@ -236,6 +238,10 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): ...@@ -236,6 +238,10 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
(8,2048,128,32,3,False,True), (8,2048,128,32,3,False,True),
(8,2560,128,40,3,False,False), (8,2560,128,40,3,False,False),
(8,2560,128,40,3,False,True), (8,2560,128,40,3,False,True),
(8,128,128,2,3,True,False),
(8,128,128,2,3,True,True),
(8,4096,128,64,3,True,True),
(8,8192,128,64,3,False,True),
]) # yapf: disable ]) # yapf: disable
def test_forward(batch_size, def test_forward(batch_size,
hidden_size, hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment