Commit cf503760 authored by Guolin Ke's avatar Guolin Ke
Browse files

more layer_norm kernels

parent e578ae25
name: Build and Publish Docker
on:
push:
branches:
- main
jobs:
docker:
runs-on: ubuntu-latest
steps:
-
name: Checkout
uses: actions/checkout@v3
-
name: Set up QEMU
uses: docker/setup-qemu-action@v2
-
name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
-
name: Login to DockerHub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
-
name: Build and push cu113
uses: docker/build-push-action@v3
with:
context: ./docker/cu113/
push: true
tags: dptechnology/unicore:latest-pytorch1.11.0-cuda11.3
-
name: Build and push cu116
uses: docker/build-push-action@v3
with:
context: ./docker/cu116/
push: true
tags: dptechnology/unicore:latest-pytorch1.12.1-cuda11.6
......@@ -108,8 +108,8 @@ std::vector<at::Tensor> layer_norm(
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 384 || n2 == 512 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048, "dimension is not supported");
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 320 || n2 == 384 || n2 == 512 || n2 == 640 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048 || n2 == 2560 || n2 == 5120, "dimension is not supported");
at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype((input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean);
......@@ -149,8 +149,8 @@ at::Tensor layer_norm_gradient(
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 384 || n2 == 512 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048, "dimension is not supported");
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 320 || n2 == 384 || n2 == 512 || n2 == 640 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048 || n2 == 2560 || n2 == 5120, "dimension is not supported");
at::Tensor grad_input = at::empty_like(input);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon,
......
......@@ -117,8 +117,8 @@ std::vector<at::Tensor> layer_norm_gradient(
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 384 || n2 == 512 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048, "dimension is not supported");
TORCH_CHECK(n2 == 64 || n2 == 128 || n2 == 256 || n2 == 320 || n2 == 384 || n2 == 512 || n2 == 640 || n2 == 768 || n2 == 1024 || n2 == 1280 ||
n2 == 1536 || n2 == 1792 || n2 == 2048 || n2 == 2560 || n2 == 5120, "dimension is not supported");
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
......
......@@ -187,42 +187,54 @@ void cuda_layer_norm(
case 64: LAUNCH_FORWARD_KERNEL(64, 2, 4, nv_bfloat16)
case 128: LAUNCH_FORWARD_KERNEL(128, 2, 4, nv_bfloat16)
case 256: LAUNCH_FORWARD_KERNEL(256, 2, 4, nv_bfloat16)
case 320: LAUNCH_FORWARD_KERNEL(320, 2, 4, nv_bfloat16)
case 384: LAUNCH_FORWARD_KERNEL(384, 2, 4, nv_bfloat16)
case 512: LAUNCH_FORWARD_KERNEL(512, 2, 4, nv_bfloat16)
case 640: LAUNCH_FORWARD_KERNEL(640, 2, 4, nv_bfloat16)
case 768: LAUNCH_FORWARD_KERNEL(768, 2, 4, nv_bfloat16)
case 1024: LAUNCH_FORWARD_KERNEL(1024, 2, 4, nv_bfloat16)
case 1280: LAUNCH_FORWARD_KERNEL(1280, 2, 4, nv_bfloat16)
case 1536: LAUNCH_FORWARD_KERNEL(1536, 2, 4, nv_bfloat16)
case 1792: LAUNCH_FORWARD_KERNEL(1792, 2, 4, nv_bfloat16)
case 2048: LAUNCH_FORWARD_KERNEL(2048, 2, 4, nv_bfloat16)
case 2560: LAUNCH_FORWARD_KERNEL(2560, 2, 4, nv_bfloat16)
case 5120: LAUNCH_FORWARD_KERNEL(5120, 2, 4, nv_bfloat16)
}
} else if (type == at::ScalarType::Half) {
switch (n2) {
case 64: LAUNCH_FORWARD_KERNEL(64, 2, 4, half)
case 128: LAUNCH_FORWARD_KERNEL(128, 2, 4, half)
case 256: LAUNCH_FORWARD_KERNEL(256, 2, 4, half)
case 320: LAUNCH_FORWARD_KERNEL(320, 2, 4, half)
case 384: LAUNCH_FORWARD_KERNEL(384, 2, 4, half)
case 512: LAUNCH_FORWARD_KERNEL(512, 2, 4, half)
case 640: LAUNCH_FORWARD_KERNEL(640, 2, 4, half)
case 768: LAUNCH_FORWARD_KERNEL(768, 2, 4, half)
case 1024: LAUNCH_FORWARD_KERNEL(1024, 2, 4, half)
case 1280: LAUNCH_FORWARD_KERNEL(1280, 2, 4, half)
case 1536: LAUNCH_FORWARD_KERNEL(1536, 2, 4, half)
case 1792: LAUNCH_FORWARD_KERNEL(1792, 2, 4, half)
case 2048: LAUNCH_FORWARD_KERNEL(2048, 2, 4, half)
case 2560: LAUNCH_FORWARD_KERNEL(2560, 2, 4, half)
case 5120: LAUNCH_FORWARD_KERNEL(5120, 2, 4, half)
}
} else if (type == at::ScalarType::Float) {
switch (n2) {
case 64: LAUNCH_FORWARD_KERNEL(64, 1, 4, float)
case 128: LAUNCH_FORWARD_KERNEL(128, 1, 4, float)
case 256: LAUNCH_FORWARD_KERNEL(256, 1, 4, float)
case 320: LAUNCH_FORWARD_KERNEL(320, 1, 4, float)
case 384: LAUNCH_FORWARD_KERNEL(384, 1, 4, float)
case 512: LAUNCH_FORWARD_KERNEL(512, 1, 4, float)
case 640: LAUNCH_FORWARD_KERNEL(640, 1, 4, float)
case 768: LAUNCH_FORWARD_KERNEL(768, 1, 4, float)
case 1024: LAUNCH_FORWARD_KERNEL(1024, 1, 4, float)
case 1280: LAUNCH_FORWARD_KERNEL(1280, 1, 4, float)
case 1536: LAUNCH_FORWARD_KERNEL(1536, 1, 4, float)
case 1792: LAUNCH_FORWARD_KERNEL(1792, 1, 4, float)
case 2048: LAUNCH_FORWARD_KERNEL(2048, 1, 4, float)
case 2560: LAUNCH_FORWARD_KERNEL(2560, 1, 4, float)
case 5120: LAUNCH_FORWARD_KERNEL(5120, 1, 4, float)
}
}
}
......@@ -248,42 +260,54 @@ void cuda_layer_norm_gradient(
case 64: LAUNCH_BACKWARD_KERNEL(64, 2, 4, nv_bfloat16)
case 128: LAUNCH_BACKWARD_KERNEL(128, 2, 4, nv_bfloat16)
case 256: LAUNCH_BACKWARD_KERNEL(256, 2, 4, nv_bfloat16)
case 320: LAUNCH_BACKWARD_KERNEL(320, 2, 4, nv_bfloat16)
case 384: LAUNCH_BACKWARD_KERNEL(384, 2, 4, nv_bfloat16)
case 512: LAUNCH_BACKWARD_KERNEL(512, 2, 4, nv_bfloat16)
case 640: LAUNCH_BACKWARD_KERNEL(640, 2, 4, nv_bfloat16)
case 768: LAUNCH_BACKWARD_KERNEL(768, 2, 4, nv_bfloat16)
case 1024: LAUNCH_BACKWARD_KERNEL(1024, 2, 4, nv_bfloat16)
case 1280: LAUNCH_BACKWARD_KERNEL(1280, 2, 4, nv_bfloat16)
case 1536: LAUNCH_BACKWARD_KERNEL(1536, 2, 4, nv_bfloat16)
case 1792: LAUNCH_BACKWARD_KERNEL(1792, 2, 4, nv_bfloat16)
case 2048: LAUNCH_BACKWARD_KERNEL(2048, 2, 4, nv_bfloat16)
case 2560: LAUNCH_BACKWARD_KERNEL(2560, 2, 4, nv_bfloat16)
case 5120: LAUNCH_BACKWARD_KERNEL(5120, 2, 4, nv_bfloat16)
}
} else if (type == at::ScalarType::Half) {
switch (n2) {
case 64: LAUNCH_BACKWARD_KERNEL(64, 2, 4, half)
case 128: LAUNCH_BACKWARD_KERNEL(128, 2, 4, half)
case 256: LAUNCH_BACKWARD_KERNEL(256, 2, 4, half)
case 320: LAUNCH_BACKWARD_KERNEL(320, 2, 4, half)
case 384: LAUNCH_BACKWARD_KERNEL(384, 2, 4, half)
case 512: LAUNCH_BACKWARD_KERNEL(512, 2, 4, half)
case 640: LAUNCH_BACKWARD_KERNEL(640, 2, 4, half)
case 768: LAUNCH_BACKWARD_KERNEL(768, 2, 4, half)
case 1024: LAUNCH_BACKWARD_KERNEL(1024, 2, 4, half)
case 1280: LAUNCH_BACKWARD_KERNEL(1280, 2, 4, half)
case 1536: LAUNCH_BACKWARD_KERNEL(1536, 2, 4, half)
case 1792: LAUNCH_BACKWARD_KERNEL(1792, 2, 4, half)
case 2048: LAUNCH_BACKWARD_KERNEL(2048, 2, 4, half)
case 2560: LAUNCH_BACKWARD_KERNEL(2560, 2, 4, half)
case 5120: LAUNCH_BACKWARD_KERNEL(5120, 2, 4, half)
}
} else if (type == at::ScalarType::Float) {
switch (n2) {
case 64: LAUNCH_BACKWARD_KERNEL(64, 1, 4, float)
case 128: LAUNCH_BACKWARD_KERNEL(128, 1, 4, float)
case 256: LAUNCH_BACKWARD_KERNEL(256, 1, 4, float)
case 320: LAUNCH_BACKWARD_KERNEL(320, 1, 4, float)
case 384: LAUNCH_BACKWARD_KERNEL(384, 1, 4, float)
case 512: LAUNCH_BACKWARD_KERNEL(512, 1, 4, float)
case 640: LAUNCH_BACKWARD_KERNEL(640, 1, 4, float)
case 768: LAUNCH_BACKWARD_KERNEL(768, 1, 4, float)
case 1024: LAUNCH_BACKWARD_KERNEL(1024, 1, 4, float)
case 1280: LAUNCH_BACKWARD_KERNEL(1280, 1, 4, float)
case 1536: LAUNCH_BACKWARD_KERNEL(1536, 1, 4, float)
case 1792: LAUNCH_BACKWARD_KERNEL(1792, 1, 4, float)
case 2048: LAUNCH_BACKWARD_KERNEL(2048, 1, 4, float)
case 2560: LAUNCH_BACKWARD_KERNEL(2560, 1, 4, float)
case 5120: LAUNCH_BACKWARD_KERNEL(5120, 1, 4, float)
}
}
}
......@@ -45,6 +45,7 @@ class FusedLayerNormFastFunction(torch.autograd.Function):
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
FUSED_LAYER_NORM_SUPPORT_DIM = set([64, 128, 256, 320, 384, 512, 640, 768, 1024, 1280, 1536, 1792, 2048, 2560, 5120])
class LayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
......@@ -57,17 +58,24 @@ class LayerNorm(torch.nn.Module):
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def torch_layer_norm(input):
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
def fused_layer_norm(input):
if input.is_cuda():
return FusedLayerNormFastFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
else:
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
self.func = torch_layer_norm if (not HAS_LAYER_NORM or normalized_shape[0] not in FUSED_LAYER_NORM_SUPPORT_DIM) else fused_layer_norm
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
if not input.is_cuda or not HAS_LAYER_NORM:
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
return FusedLayerNormFastFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
return self.func(input)
def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment