Unverified Commit c78c29f9 authored by Reza Yazdani's avatar Reza Yazdani Committed by GitHub
Browse files

supporting different hidden dimensions (#559)



* supporting different hidden dimensions

* add support for larger hidden dimensions (greater than 8K)

* remove empty line

* add loop unrolling factor for dropout kernels

* update different kernels based on the reviews
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 17f36f1b
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#define MAX_THREAD_ITERATIONS 8 // Maximum 8K #define MAX_THREAD_ITERATIONS 8 // Maximum 8K
#define MAX_WARP_NUM 32 #define MAX_WARP_NUM 32
#define MAX_REGISTERS 256
// Fused bias add with gelu activation // Fused bias add with gelu activation
template <typename T> template <typename T>
void launch_bias_gelu(const T* input, void launch_bias_gelu(const T* input,
......
...@@ -133,7 +133,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, ...@@ -133,7 +133,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m, m,
n, n,
k, k,
......
This diff is collapsed.
...@@ -86,40 +86,29 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp, ...@@ -86,40 +86,29 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols); column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
} }
__global__ void fused_add2_kernel(float* out, __global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
const float* inp1,
const float* inp2,
int size,
int row_stride)
{ {
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1); const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2); const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out); float4* out_4 = reinterpret_cast<float4*>(out);
float4 val; CUDA_1D_KERNEL_LOOP(j, N)
float4 inp1_reg = inp1_4[row * row_stride + id]; {
float4 inp2_reg = inp2_4[row * row_stride + id]; float4 val;
float4 inp1_reg = inp1_4[j];
float4 inp2_reg = inp2_4[j];
val.x = inp1_reg.x + inp2_reg.x; val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y; val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z; val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w; val.w = inp1_reg.w + inp2_reg.w;
out_4[row * row_stride + id] = val; out_4[j] = val;
}
} }
__global__ void fused_add2_kernel(__half* out, __global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
const __half* inp1,
const __half* inp2,
int size,
int row_stride)
{ {
int row = blockIdx.x;
int id = threadIdx.x;
float2 inp1_4; float2 inp1_4;
float2 inp2_4; float2 inp2_4;
...@@ -129,28 +118,31 @@ __global__ void fused_add2_kernel(__half* out, ...@@ -129,28 +118,31 @@ __global__ void fused_add2_kernel(__half* out,
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1); const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2); const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
inp1_4 = inp1_arr[row * row_stride + id]; CUDA_1D_KERNEL_LOOP(j, N)
inp2_4 = inp2_arr[row * row_stride + id]; {
inp1_4 = inp1_arr[j];
inp2_4 = inp2_arr[j];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]); float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]); float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]); float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]); float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
inp1_h_f_0.x += inp2_h_f_0.x; inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y; inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x; inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y; inp1_h_f_1.y += inp2_h_f_1.y;
float2 val_f; float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f); __half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0); val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1); val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out); float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f; out_4[j] = val_f;
}
} }
template <> template <>
...@@ -162,12 +154,12 @@ void launch_fused_add2<float>(float* out, ...@@ -162,12 +154,12 @@ void launch_fused_add2<float>(float* out,
int hidden_dim, int hidden_dim,
cudaStream_t& stream) cudaStream_t& stream)
{ {
dim3 grid_dim(batch_size * seq_length); int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim(hidden_dim / 4); dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>( fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
} }
template <> template <>
...@@ -179,12 +171,12 @@ void launch_fused_add2<__half>(__half* out, ...@@ -179,12 +171,12 @@ void launch_fused_add2<__half>(__half* out,
int hidden_dim, int hidden_dim,
cudaStream_t& stream) cudaStream_t& stream)
{ {
dim3 grid_dim(batch_size * seq_length); int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim(hidden_dim / 4); dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>( fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
} }
__global__ void fused_add3_kernel(float* out, __global__ void fused_add3_kernel(float* out,
......
This diff is collapsed.
This diff is collapsed.
...@@ -261,6 +261,8 @@ def run_backward(ds_config, atol=1e-2, verbose=False): ...@@ -261,6 +261,8 @@ def run_backward(ds_config, atol=1e-2, verbose=False):
(3,1024,120,16,24,True,True, 0.05), (3,1024,120,16,24,True,True, 0.05),
(3,1024,56,16,24,False,False, 0.1), (3,1024,56,16,24,False,False, 0.1),
(3,1024,56,16,24,False,True, 0.2), (3,1024,56,16,24,False,True, 0.2),
(3,128,56,2,24,False,False, 0.1),
(3,128,56,2,24,False,True, 0.2),
]) # yapf: disable ]) # yapf: disable
def test_backward(batch_size, def test_backward(batch_size,
hidden_size, hidden_size,
......
...@@ -213,6 +213,8 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): ...@@ -213,6 +213,8 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
# FP16 test cases can only run on the devices support FP16. # FP16 test cases can only run on the devices support FP16.
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16', @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[ [
(8,256,128,4,3,True,False),
(8,256,128,4,3,True,True),
(64,1024,128,16,3,True,False), (64,1024,128,16,3,True,False),
(64,1024,128,16,3,True,True), (64,1024,128,16,3,True,True),
(8,1024,384,16,3,True,False), (8,1024,384,16,3,True,False),
...@@ -236,6 +238,10 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): ...@@ -236,6 +238,10 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
(8,2048,128,32,3,False,True), (8,2048,128,32,3,False,True),
(8,2560,128,40,3,False,False), (8,2560,128,40,3,False,False),
(8,2560,128,40,3,False,True), (8,2560,128,40,3,False,True),
(8,128,128,2,3,True,False),
(8,128,128,2,3,True,True),
(8,4096,128,64,3,True,True),
(8,8192,128,64,3,False,True),
]) # yapf: disable ]) # yapf: disable
def test_forward(batch_size, def test_forward(batch_size,
hidden_size, hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment