Unverified Commit c78c29f9 authored by Reza Yazdani's avatar Reza Yazdani Committed by GitHub
Browse files

supporting different hidden dimensions (#559)



* supporting different hidden dimensions

* add support for larger hidden dimensions (greater than 8K)

* remove empty line

* add loop unrolling factor for dropout kernels

* update different kernels based on the reviews
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 17f36f1b
......@@ -22,6 +22,8 @@
#define MAX_THREAD_ITERATIONS 8 // Maximum 8K
#define MAX_WARP_NUM 32
#define MAX_REGISTERS 256
// Fused bias add with gelu activation
template <typename T>
void launch_bias_gelu(const T* input,
......
......@@ -133,7 +133,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
......
This diff is collapsed.
......@@ -86,40 +86,29 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
__global__ void fused_add2_kernel(float* out,
const float* inp1,
const float* inp2,
int size,
int row_stride)
__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 val;
float4 inp1_reg = inp1_4[j];
float4 inp2_reg = inp2_4[j];
val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w;
val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w;
out_4[row * row_stride + id] = val;
out_4[j] = val;
}
}
__global__ void fused_add2_kernel(__half* out,
const __half* inp1,
const __half* inp2,
int size,
int row_stride)
__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
{
int row = blockIdx.x;
int id = threadIdx.x;
float2 inp1_4;
float2 inp2_4;
......@@ -129,28 +118,31 @@ __global__ void fused_add2_kernel(__half* out,
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
inp1_4 = inp1_arr[row * row_stride + id];
inp2_4 = inp2_arr[row * row_stride + id];
CUDA_1D_KERNEL_LOOP(j, N)
{
inp1_4 = inp1_arr[j];
inp2_4 = inp2_arr[j];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y;
inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y;
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[j] = val_f;
}
}
template <>
......@@ -162,12 +154,12 @@ void launch_fused_add2<float>(float* out,
int hidden_dim,
cudaStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim(hidden_dim / 4);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(
out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
}
template <>
......@@ -179,12 +171,12 @@ void launch_fused_add2<__half>(__half* out,
int hidden_dim,
cudaStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim(hidden_dim / 4);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(
out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
}
__global__ void fused_add3_kernel(float* out,
......
This diff is collapsed.
This diff is collapsed.
......@@ -261,6 +261,8 @@ def run_backward(ds_config, atol=1e-2, verbose=False):
(3,1024,120,16,24,True,True, 0.05),
(3,1024,56,16,24,False,False, 0.1),
(3,1024,56,16,24,False,True, 0.2),
(3,128,56,2,24,False,False, 0.1),
(3,128,56,2,24,False,True, 0.2),
]) # yapf: disable
def test_backward(batch_size,
hidden_size,
......
......@@ -213,6 +213,8 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
# FP16 test cases can only run on the devices support FP16.
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[
(8,256,128,4,3,True,False),
(8,256,128,4,3,True,True),
(64,1024,128,16,3,True,False),
(64,1024,128,16,3,True,True),
(8,1024,384,16,3,True,False),
......@@ -236,6 +238,10 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
(8,2048,128,32,3,False,True),
(8,2560,128,40,3,False,False),
(8,2560,128,40,3,False,True),
(8,128,128,2,3,True,False),
(8,128,128,2,3,True,True),
(8,4096,128,64,3,True,True),
(8,8192,128,64,3,False,True),
]) # yapf: disable
def test_forward(batch_size,
hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment