Commit 3aa106dc authored by Guolin Ke's avatar Guolin Ke
Browse files

support layernorm with dim 192

parent 561c2132
...@@ -108,7 +108,7 @@ std::vector<at::Tensor> layer_norm( ...@@ -108,7 +108,7 @@ std::vector<at::Tensor> layer_norm(
CHECK_INPUT(beta); CHECK_INPUT(beta);
int n1,n2; int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input,normalized_shape,gamma,beta,n1,n2);
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 320 || n2 == 384 || n2 == 512 || n2 == 640 || n2 == 768 || n2 == 1024 || n2 == 1280 || TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 192 || n2 == 256 || n2 == 320 || n2 == 384 || n2 == 512 || n2 == 640 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048 || n2 == 2560 || n2 == 5120, "dimension is not supported"); n2 == 1536 || n2 == 1792 || n2 == 2048 || n2 == 2560 || n2 == 5120, "dimension is not supported");
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype((input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type())); at::Tensor mean = at::empty({n1}, input.options().dtype((input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type()));
...@@ -149,7 +149,7 @@ at::Tensor layer_norm_gradient( ...@@ -149,7 +149,7 @@ at::Tensor layer_norm_gradient(
CHECK_INPUT(beta); CHECK_INPUT(beta);
int n1,n2; int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input,normalized_shape,gamma,beta,n1,n2);
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 320 || n2 == 384 || n2 == 512 || n2 == 640 || n2 == 768 || n2 == 1024 || n2 == 1280 || TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 192 || n2 == 256 || n2 == 320 || n2 == 384 || n2 == 512 || n2 == 640 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048 || n2 == 2560 || n2 == 5120, "dimension is not supported"); n2 == 1536 || n2 == 1792 || n2 == 2048 || n2 == 2560 || n2 == 5120, "dimension is not supported");
at::Tensor grad_input = at::empty_like(input); at::Tensor grad_input = at::empty_like(input);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2, cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
......
...@@ -117,7 +117,7 @@ std::vector<at::Tensor> layer_norm_gradient( ...@@ -117,7 +117,7 @@ std::vector<at::Tensor> layer_norm_gradient(
CHECK_INPUT(beta); CHECK_INPUT(beta);
int n1,n2; int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input,normalized_shape,gamma,beta,n1,n2);
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 320 || n2 == 384 || n2 == 512 || n2 == 640 || n2 == 768 || n2 == 1024 || n2 == 1280 || TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 192 || n2 == 256 || n2 == 320 || n2 == 384 || n2 == 512 || n2 == 640 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048 || n2 == 2560 || n2 == 5120, "dimension is not supported"); n2 == 1536 || n2 == 1792 || n2 == 2048 || n2 == 2560 || n2 == 5120, "dimension is not supported");
at::Tensor grad_gamma = at::empty_like(gamma); at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta); at::Tensor grad_beta = at::empty_like(beta);
......
...@@ -186,6 +186,7 @@ void cuda_layer_norm( ...@@ -186,6 +186,7 @@ void cuda_layer_norm(
switch (n2) { switch (n2) {
case 64: LAUNCH_FORWARD_KERNEL(64, 2, 4, nv_bfloat16) case 64: LAUNCH_FORWARD_KERNEL(64, 2, 4, nv_bfloat16)
case 128: LAUNCH_FORWARD_KERNEL(128, 2, 4, nv_bfloat16) case 128: LAUNCH_FORWARD_KERNEL(128, 2, 4, nv_bfloat16)
case 192: LAUNCH_FORWARD_KERNEL(192, 2, 4, nv_bfloat16)
case 256: LAUNCH_FORWARD_KERNEL(256, 2, 4, nv_bfloat16) case 256: LAUNCH_FORWARD_KERNEL(256, 2, 4, nv_bfloat16)
case 320: LAUNCH_FORWARD_KERNEL(320, 2, 4, nv_bfloat16) case 320: LAUNCH_FORWARD_KERNEL(320, 2, 4, nv_bfloat16)
case 384: LAUNCH_FORWARD_KERNEL(384, 2, 4, nv_bfloat16) case 384: LAUNCH_FORWARD_KERNEL(384, 2, 4, nv_bfloat16)
...@@ -204,6 +205,7 @@ void cuda_layer_norm( ...@@ -204,6 +205,7 @@ void cuda_layer_norm(
switch (n2) { switch (n2) {
case 64: LAUNCH_FORWARD_KERNEL(64, 2, 4, half) case 64: LAUNCH_FORWARD_KERNEL(64, 2, 4, half)
case 128: LAUNCH_FORWARD_KERNEL(128, 2, 4, half) case 128: LAUNCH_FORWARD_KERNEL(128, 2, 4, half)
case 192: LAUNCH_FORWARD_KERNEL(192, 2, 4, half)
case 256: LAUNCH_FORWARD_KERNEL(256, 2, 4, half) case 256: LAUNCH_FORWARD_KERNEL(256, 2, 4, half)
case 320: LAUNCH_FORWARD_KERNEL(320, 2, 4, half) case 320: LAUNCH_FORWARD_KERNEL(320, 2, 4, half)
case 384: LAUNCH_FORWARD_KERNEL(384, 2, 4, half) case 384: LAUNCH_FORWARD_KERNEL(384, 2, 4, half)
...@@ -222,6 +224,7 @@ void cuda_layer_norm( ...@@ -222,6 +224,7 @@ void cuda_layer_norm(
switch (n2) { switch (n2) {
case 64: LAUNCH_FORWARD_KERNEL(64, 1, 4, float) case 64: LAUNCH_FORWARD_KERNEL(64, 1, 4, float)
case 128: LAUNCH_FORWARD_KERNEL(128, 1, 4, float) case 128: LAUNCH_FORWARD_KERNEL(128, 1, 4, float)
case 192: LAUNCH_FORWARD_KERNEL(192, 1, 4, float)
case 256: LAUNCH_FORWARD_KERNEL(256, 1, 4, float) case 256: LAUNCH_FORWARD_KERNEL(256, 1, 4, float)
case 320: LAUNCH_FORWARD_KERNEL(320, 1, 4, float) case 320: LAUNCH_FORWARD_KERNEL(320, 1, 4, float)
case 384: LAUNCH_FORWARD_KERNEL(384, 1, 4, float) case 384: LAUNCH_FORWARD_KERNEL(384, 1, 4, float)
...@@ -259,6 +262,7 @@ void cuda_layer_norm_gradient( ...@@ -259,6 +262,7 @@ void cuda_layer_norm_gradient(
switch (n2) { switch (n2) {
case 64: LAUNCH_BACKWARD_KERNEL(64, 2, 4, nv_bfloat16) case 64: LAUNCH_BACKWARD_KERNEL(64, 2, 4, nv_bfloat16)
case 128: LAUNCH_BACKWARD_KERNEL(128, 2, 4, nv_bfloat16) case 128: LAUNCH_BACKWARD_KERNEL(128, 2, 4, nv_bfloat16)
case 192: LAUNCH_BACKWARD_KERNEL(192, 2, 4, nv_bfloat16)
case 256: LAUNCH_BACKWARD_KERNEL(256, 2, 4, nv_bfloat16) case 256: LAUNCH_BACKWARD_KERNEL(256, 2, 4, nv_bfloat16)
case 320: LAUNCH_BACKWARD_KERNEL(320, 2, 4, nv_bfloat16) case 320: LAUNCH_BACKWARD_KERNEL(320, 2, 4, nv_bfloat16)
case 384: LAUNCH_BACKWARD_KERNEL(384, 2, 4, nv_bfloat16) case 384: LAUNCH_BACKWARD_KERNEL(384, 2, 4, nv_bfloat16)
...@@ -277,6 +281,7 @@ void cuda_layer_norm_gradient( ...@@ -277,6 +281,7 @@ void cuda_layer_norm_gradient(
switch (n2) { switch (n2) {
case 64: LAUNCH_BACKWARD_KERNEL(64, 2, 4, half) case 64: LAUNCH_BACKWARD_KERNEL(64, 2, 4, half)
case 128: LAUNCH_BACKWARD_KERNEL(128, 2, 4, half) case 128: LAUNCH_BACKWARD_KERNEL(128, 2, 4, half)
case 192: LAUNCH_BACKWARD_KERNEL(192, 2, 4, half)
case 256: LAUNCH_BACKWARD_KERNEL(256, 2, 4, half) case 256: LAUNCH_BACKWARD_KERNEL(256, 2, 4, half)
case 320: LAUNCH_BACKWARD_KERNEL(320, 2, 4, half) case 320: LAUNCH_BACKWARD_KERNEL(320, 2, 4, half)
case 384: LAUNCH_BACKWARD_KERNEL(384, 2, 4, half) case 384: LAUNCH_BACKWARD_KERNEL(384, 2, 4, half)
...@@ -295,6 +300,7 @@ void cuda_layer_norm_gradient( ...@@ -295,6 +300,7 @@ void cuda_layer_norm_gradient(
switch (n2) { switch (n2) {
case 64: LAUNCH_BACKWARD_KERNEL(64, 1, 4, float) case 64: LAUNCH_BACKWARD_KERNEL(64, 1, 4, float)
case 128: LAUNCH_BACKWARD_KERNEL(128, 1, 4, float) case 128: LAUNCH_BACKWARD_KERNEL(128, 1, 4, float)
case 192: LAUNCH_BACKWARD_KERNEL(192, 2, 4, float)
case 256: LAUNCH_BACKWARD_KERNEL(256, 1, 4, float) case 256: LAUNCH_BACKWARD_KERNEL(256, 1, 4, float)
case 320: LAUNCH_BACKWARD_KERNEL(320, 1, 4, float) case 320: LAUNCH_BACKWARD_KERNEL(320, 1, 4, float)
case 384: LAUNCH_BACKWARD_KERNEL(384, 1, 4, float) case 384: LAUNCH_BACKWARD_KERNEL(384, 1, 4, float)
......
...@@ -45,7 +45,7 @@ class FusedLayerNormFastFunction(torch.autograd.Function): ...@@ -45,7 +45,7 @@ class FusedLayerNormFastFunction(torch.autograd.Function):
weight_, bias_, ctx.eps) weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None return grad_input, grad_weight, grad_bias, None, None
FUSED_LAYER_NORM_SUPPORT_DIM = set([64, 128, 256, 320, 384, 512, 640, 768, 1024, 1280, 1536, 1792, 2048, 2560, 5120]) FUSED_LAYER_NORM_SUPPORT_DIM = set([64, 128, 192, 256, 320, 384, 512, 640, 768, 1024, 1280, 1536, 1792, 2048, 2560, 5120])
class LayerNorm(torch.nn.Module): class LayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment