Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Uni-Core
Commits
cf503760
Commit
cf503760
authored
Aug 24, 2022
by
Guolin Ke
Browse files
more layer_norm kernels
parent
e578ae25
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
83 additions
and
11 deletions
+83
-11
.github/workflows/docker_latest.yml
.github/workflows/docker_latest.yml
+40
-0
csrc/layernorm/interface.cpp
csrc/layernorm/interface.cpp
+4
-4
csrc/layernorm/interface_gamma_beta.cpp
csrc/layernorm/interface_gamma_beta.cpp
+2
-2
csrc/layernorm/layernorm.cu
csrc/layernorm/layernorm.cu
+24
-0
unicore/modules/layer_norm.py
unicore/modules/layer_norm.py
+13
-5
No files found.
.github/workflows/docker_latest.yml
0 → 100644
View file @
cf503760
name
:
Build and Publish Docker
on
:
push
:
branches
:
-
main
jobs
:
docker
:
runs-on
:
ubuntu-latest
steps
:
-
name
:
Checkout
uses
:
actions/checkout@v3
-
name
:
Set up QEMU
uses
:
docker/setup-qemu-action@v2
-
name
:
Set up Docker Buildx
uses
:
docker/setup-buildx-action@v2
-
name
:
Login to DockerHub
uses
:
docker/login-action@v2
with
:
username
:
${{ secrets.DOCKERHUB_USERNAME }}
password
:
${{ secrets.DOCKERHUB_TOKEN }}
-
name
:
Build and push cu113
uses
:
docker/build-push-action@v3
with
:
context
:
./docker/cu113/
push
:
true
tags
:
dptechnology/unicore:latest-pytorch1.11.0-cuda11.3
-
name
:
Build and push cu116
uses
:
docker/build-push-action@v3
with
:
context
:
./docker/cu116/
push
:
true
tags
:
dptechnology/unicore:latest-pytorch1.12.1-cuda11.6
csrc/layernorm/interface.cpp
View file @
cf503760
...
@@ -108,8 +108,8 @@ std::vector<at::Tensor> layer_norm(
...
@@ -108,8 +108,8 @@ std::vector<at::Tensor> layer_norm(
CHECK_INPUT
(
beta
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
TORCH_CHECK
(
n2
==
64
||
n2
==
128
||
n2
==
256
||
n2
==
384
||
n2
==
512
||
n2
==
768
||
n2
==
1024
||
n2
==
1280
||
TORCH_CHECK
(
n2
==
64
||
n2
==
128
||
n2
==
256
||
n2
==
320
||
n2
==
384
||
n2
==
512
||
n2
==
640
||
n2
==
768
||
n2
==
1024
||
n2
==
1280
||
n2
==
1536
||
n2
==
1792
||
n2
==
2048
,
"dimension is not supported"
);
n2
==
1536
||
n2
==
1792
||
n2
==
2048
||
n2
==
2560
||
n2
==
5120
,
"dimension is not supported"
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
...
@@ -149,8 +149,8 @@ at::Tensor layer_norm_gradient(
...
@@ -149,8 +149,8 @@ at::Tensor layer_norm_gradient(
CHECK_INPUT
(
beta
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
TORCH_CHECK
(
n2
==
64
||
n2
==
128
||
n2
==
256
||
n2
==
384
||
n2
==
512
||
n2
==
768
||
n2
==
1024
||
n2
==
1280
||
TORCH_CHECK
(
n2
==
64
||
n2
==
128
||
n2
==
256
||
n2
==
320
||
n2
==
384
||
n2
==
512
||
n2
==
640
||
n2
==
768
||
n2
==
1024
||
n2
==
1280
||
n2
==
1536
||
n2
==
1792
||
n2
==
2048
,
"dimension is not supported"
);
n2
==
1536
||
n2
==
1792
||
n2
==
2048
||
n2
==
2560
||
n2
==
5120
,
"dimension is not supported"
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
...
...
csrc/layernorm/interface_gamma_beta.cpp
View file @
cf503760
...
@@ -117,8 +117,8 @@ std::vector<at::Tensor> layer_norm_gradient(
...
@@ -117,8 +117,8 @@ std::vector<at::Tensor> layer_norm_gradient(
CHECK_INPUT
(
beta
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
TORCH_CHECK
(
n2
==
64
||
n2
==
128
||
n2
==
256
||
n2
==
384
||
n2
==
512
||
n2
==
768
||
n2
==
1024
||
n2
==
1280
||
TORCH_CHECK
(
n2
==
64
||
n2
==
128
||
n2
==
256
||
n2
==
320
||
n2
==
384
||
n2
==
512
||
n2
==
640
||
n2
==
768
||
n2
==
1024
||
n2
==
1280
||
n2
==
1536
||
n2
==
1792
||
n2
==
2048
,
"dimension is not supported"
);
n2
==
1536
||
n2
==
1792
||
n2
==
2048
||
n2
==
2560
||
n2
==
5120
,
"dimension is not supported"
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
...
...
csrc/layernorm/layernorm.cu
View file @
cf503760
...
@@ -187,42 +187,54 @@ void cuda_layer_norm(
...
@@ -187,42 +187,54 @@ void cuda_layer_norm(
case
64
:
LAUNCH_FORWARD_KERNEL
(
64
,
2
,
4
,
nv_bfloat16
)
case
64
:
LAUNCH_FORWARD_KERNEL
(
64
,
2
,
4
,
nv_bfloat16
)
case
128
:
LAUNCH_FORWARD_KERNEL
(
128
,
2
,
4
,
nv_bfloat16
)
case
128
:
LAUNCH_FORWARD_KERNEL
(
128
,
2
,
4
,
nv_bfloat16
)
case
256
:
LAUNCH_FORWARD_KERNEL
(
256
,
2
,
4
,
nv_bfloat16
)
case
256
:
LAUNCH_FORWARD_KERNEL
(
256
,
2
,
4
,
nv_bfloat16
)
case
320
:
LAUNCH_FORWARD_KERNEL
(
320
,
2
,
4
,
nv_bfloat16
)
case
384
:
LAUNCH_FORWARD_KERNEL
(
384
,
2
,
4
,
nv_bfloat16
)
case
384
:
LAUNCH_FORWARD_KERNEL
(
384
,
2
,
4
,
nv_bfloat16
)
case
512
:
LAUNCH_FORWARD_KERNEL
(
512
,
2
,
4
,
nv_bfloat16
)
case
512
:
LAUNCH_FORWARD_KERNEL
(
512
,
2
,
4
,
nv_bfloat16
)
case
640
:
LAUNCH_FORWARD_KERNEL
(
640
,
2
,
4
,
nv_bfloat16
)
case
768
:
LAUNCH_FORWARD_KERNEL
(
768
,
2
,
4
,
nv_bfloat16
)
case
768
:
LAUNCH_FORWARD_KERNEL
(
768
,
2
,
4
,
nv_bfloat16
)
case
1024
:
LAUNCH_FORWARD_KERNEL
(
1024
,
2
,
4
,
nv_bfloat16
)
case
1024
:
LAUNCH_FORWARD_KERNEL
(
1024
,
2
,
4
,
nv_bfloat16
)
case
1280
:
LAUNCH_FORWARD_KERNEL
(
1280
,
2
,
4
,
nv_bfloat16
)
case
1280
:
LAUNCH_FORWARD_KERNEL
(
1280
,
2
,
4
,
nv_bfloat16
)
case
1536
:
LAUNCH_FORWARD_KERNEL
(
1536
,
2
,
4
,
nv_bfloat16
)
case
1536
:
LAUNCH_FORWARD_KERNEL
(
1536
,
2
,
4
,
nv_bfloat16
)
case
1792
:
LAUNCH_FORWARD_KERNEL
(
1792
,
2
,
4
,
nv_bfloat16
)
case
1792
:
LAUNCH_FORWARD_KERNEL
(
1792
,
2
,
4
,
nv_bfloat16
)
case
2048
:
LAUNCH_FORWARD_KERNEL
(
2048
,
2
,
4
,
nv_bfloat16
)
case
2048
:
LAUNCH_FORWARD_KERNEL
(
2048
,
2
,
4
,
nv_bfloat16
)
case
2560
:
LAUNCH_FORWARD_KERNEL
(
2560
,
2
,
4
,
nv_bfloat16
)
case
5120
:
LAUNCH_FORWARD_KERNEL
(
5120
,
2
,
4
,
nv_bfloat16
)
}
}
}
else
if
(
type
==
at
::
ScalarType
::
Half
)
{
}
else
if
(
type
==
at
::
ScalarType
::
Half
)
{
switch
(
n2
)
{
switch
(
n2
)
{
case
64
:
LAUNCH_FORWARD_KERNEL
(
64
,
2
,
4
,
half
)
case
64
:
LAUNCH_FORWARD_KERNEL
(
64
,
2
,
4
,
half
)
case
128
:
LAUNCH_FORWARD_KERNEL
(
128
,
2
,
4
,
half
)
case
128
:
LAUNCH_FORWARD_KERNEL
(
128
,
2
,
4
,
half
)
case
256
:
LAUNCH_FORWARD_KERNEL
(
256
,
2
,
4
,
half
)
case
256
:
LAUNCH_FORWARD_KERNEL
(
256
,
2
,
4
,
half
)
case
320
:
LAUNCH_FORWARD_KERNEL
(
320
,
2
,
4
,
half
)
case
384
:
LAUNCH_FORWARD_KERNEL
(
384
,
2
,
4
,
half
)
case
384
:
LAUNCH_FORWARD_KERNEL
(
384
,
2
,
4
,
half
)
case
512
:
LAUNCH_FORWARD_KERNEL
(
512
,
2
,
4
,
half
)
case
512
:
LAUNCH_FORWARD_KERNEL
(
512
,
2
,
4
,
half
)
case
640
:
LAUNCH_FORWARD_KERNEL
(
640
,
2
,
4
,
half
)
case
768
:
LAUNCH_FORWARD_KERNEL
(
768
,
2
,
4
,
half
)
case
768
:
LAUNCH_FORWARD_KERNEL
(
768
,
2
,
4
,
half
)
case
1024
:
LAUNCH_FORWARD_KERNEL
(
1024
,
2
,
4
,
half
)
case
1024
:
LAUNCH_FORWARD_KERNEL
(
1024
,
2
,
4
,
half
)
case
1280
:
LAUNCH_FORWARD_KERNEL
(
1280
,
2
,
4
,
half
)
case
1280
:
LAUNCH_FORWARD_KERNEL
(
1280
,
2
,
4
,
half
)
case
1536
:
LAUNCH_FORWARD_KERNEL
(
1536
,
2
,
4
,
half
)
case
1536
:
LAUNCH_FORWARD_KERNEL
(
1536
,
2
,
4
,
half
)
case
1792
:
LAUNCH_FORWARD_KERNEL
(
1792
,
2
,
4
,
half
)
case
1792
:
LAUNCH_FORWARD_KERNEL
(
1792
,
2
,
4
,
half
)
case
2048
:
LAUNCH_FORWARD_KERNEL
(
2048
,
2
,
4
,
half
)
case
2048
:
LAUNCH_FORWARD_KERNEL
(
2048
,
2
,
4
,
half
)
case
2560
:
LAUNCH_FORWARD_KERNEL
(
2560
,
2
,
4
,
half
)
case
5120
:
LAUNCH_FORWARD_KERNEL
(
5120
,
2
,
4
,
half
)
}
}
}
else
if
(
type
==
at
::
ScalarType
::
Float
)
{
}
else
if
(
type
==
at
::
ScalarType
::
Float
)
{
switch
(
n2
)
{
switch
(
n2
)
{
case
64
:
LAUNCH_FORWARD_KERNEL
(
64
,
1
,
4
,
float
)
case
64
:
LAUNCH_FORWARD_KERNEL
(
64
,
1
,
4
,
float
)
case
128
:
LAUNCH_FORWARD_KERNEL
(
128
,
1
,
4
,
float
)
case
128
:
LAUNCH_FORWARD_KERNEL
(
128
,
1
,
4
,
float
)
case
256
:
LAUNCH_FORWARD_KERNEL
(
256
,
1
,
4
,
float
)
case
256
:
LAUNCH_FORWARD_KERNEL
(
256
,
1
,
4
,
float
)
case
320
:
LAUNCH_FORWARD_KERNEL
(
320
,
1
,
4
,
float
)
case
384
:
LAUNCH_FORWARD_KERNEL
(
384
,
1
,
4
,
float
)
case
384
:
LAUNCH_FORWARD_KERNEL
(
384
,
1
,
4
,
float
)
case
512
:
LAUNCH_FORWARD_KERNEL
(
512
,
1
,
4
,
float
)
case
512
:
LAUNCH_FORWARD_KERNEL
(
512
,
1
,
4
,
float
)
case
640
:
LAUNCH_FORWARD_KERNEL
(
640
,
1
,
4
,
float
)
case
768
:
LAUNCH_FORWARD_KERNEL
(
768
,
1
,
4
,
float
)
case
768
:
LAUNCH_FORWARD_KERNEL
(
768
,
1
,
4
,
float
)
case
1024
:
LAUNCH_FORWARD_KERNEL
(
1024
,
1
,
4
,
float
)
case
1024
:
LAUNCH_FORWARD_KERNEL
(
1024
,
1
,
4
,
float
)
case
1280
:
LAUNCH_FORWARD_KERNEL
(
1280
,
1
,
4
,
float
)
case
1280
:
LAUNCH_FORWARD_KERNEL
(
1280
,
1
,
4
,
float
)
case
1536
:
LAUNCH_FORWARD_KERNEL
(
1536
,
1
,
4
,
float
)
case
1536
:
LAUNCH_FORWARD_KERNEL
(
1536
,
1
,
4
,
float
)
case
1792
:
LAUNCH_FORWARD_KERNEL
(
1792
,
1
,
4
,
float
)
case
1792
:
LAUNCH_FORWARD_KERNEL
(
1792
,
1
,
4
,
float
)
case
2048
:
LAUNCH_FORWARD_KERNEL
(
2048
,
1
,
4
,
float
)
case
2048
:
LAUNCH_FORWARD_KERNEL
(
2048
,
1
,
4
,
float
)
case
2560
:
LAUNCH_FORWARD_KERNEL
(
2560
,
1
,
4
,
float
)
case
5120
:
LAUNCH_FORWARD_KERNEL
(
5120
,
1
,
4
,
float
)
}
}
}
}
}
}
...
@@ -248,42 +260,54 @@ void cuda_layer_norm_gradient(
...
@@ -248,42 +260,54 @@ void cuda_layer_norm_gradient(
case
64
:
LAUNCH_BACKWARD_KERNEL
(
64
,
2
,
4
,
nv_bfloat16
)
case
64
:
LAUNCH_BACKWARD_KERNEL
(
64
,
2
,
4
,
nv_bfloat16
)
case
128
:
LAUNCH_BACKWARD_KERNEL
(
128
,
2
,
4
,
nv_bfloat16
)
case
128
:
LAUNCH_BACKWARD_KERNEL
(
128
,
2
,
4
,
nv_bfloat16
)
case
256
:
LAUNCH_BACKWARD_KERNEL
(
256
,
2
,
4
,
nv_bfloat16
)
case
256
:
LAUNCH_BACKWARD_KERNEL
(
256
,
2
,
4
,
nv_bfloat16
)
case
320
:
LAUNCH_BACKWARD_KERNEL
(
320
,
2
,
4
,
nv_bfloat16
)
case
384
:
LAUNCH_BACKWARD_KERNEL
(
384
,
2
,
4
,
nv_bfloat16
)
case
384
:
LAUNCH_BACKWARD_KERNEL
(
384
,
2
,
4
,
nv_bfloat16
)
case
512
:
LAUNCH_BACKWARD_KERNEL
(
512
,
2
,
4
,
nv_bfloat16
)
case
512
:
LAUNCH_BACKWARD_KERNEL
(
512
,
2
,
4
,
nv_bfloat16
)
case
640
:
LAUNCH_BACKWARD_KERNEL
(
640
,
2
,
4
,
nv_bfloat16
)
case
768
:
LAUNCH_BACKWARD_KERNEL
(
768
,
2
,
4
,
nv_bfloat16
)
case
768
:
LAUNCH_BACKWARD_KERNEL
(
768
,
2
,
4
,
nv_bfloat16
)
case
1024
:
LAUNCH_BACKWARD_KERNEL
(
1024
,
2
,
4
,
nv_bfloat16
)
case
1024
:
LAUNCH_BACKWARD_KERNEL
(
1024
,
2
,
4
,
nv_bfloat16
)
case
1280
:
LAUNCH_BACKWARD_KERNEL
(
1280
,
2
,
4
,
nv_bfloat16
)
case
1280
:
LAUNCH_BACKWARD_KERNEL
(
1280
,
2
,
4
,
nv_bfloat16
)
case
1536
:
LAUNCH_BACKWARD_KERNEL
(
1536
,
2
,
4
,
nv_bfloat16
)
case
1536
:
LAUNCH_BACKWARD_KERNEL
(
1536
,
2
,
4
,
nv_bfloat16
)
case
1792
:
LAUNCH_BACKWARD_KERNEL
(
1792
,
2
,
4
,
nv_bfloat16
)
case
1792
:
LAUNCH_BACKWARD_KERNEL
(
1792
,
2
,
4
,
nv_bfloat16
)
case
2048
:
LAUNCH_BACKWARD_KERNEL
(
2048
,
2
,
4
,
nv_bfloat16
)
case
2048
:
LAUNCH_BACKWARD_KERNEL
(
2048
,
2
,
4
,
nv_bfloat16
)
case
2560
:
LAUNCH_BACKWARD_KERNEL
(
2560
,
2
,
4
,
nv_bfloat16
)
case
5120
:
LAUNCH_BACKWARD_KERNEL
(
5120
,
2
,
4
,
nv_bfloat16
)
}
}
}
else
if
(
type
==
at
::
ScalarType
::
Half
)
{
}
else
if
(
type
==
at
::
ScalarType
::
Half
)
{
switch
(
n2
)
{
switch
(
n2
)
{
case
64
:
LAUNCH_BACKWARD_KERNEL
(
64
,
2
,
4
,
half
)
case
64
:
LAUNCH_BACKWARD_KERNEL
(
64
,
2
,
4
,
half
)
case
128
:
LAUNCH_BACKWARD_KERNEL
(
128
,
2
,
4
,
half
)
case
128
:
LAUNCH_BACKWARD_KERNEL
(
128
,
2
,
4
,
half
)
case
256
:
LAUNCH_BACKWARD_KERNEL
(
256
,
2
,
4
,
half
)
case
256
:
LAUNCH_BACKWARD_KERNEL
(
256
,
2
,
4
,
half
)
case
320
:
LAUNCH_BACKWARD_KERNEL
(
320
,
2
,
4
,
half
)
case
384
:
LAUNCH_BACKWARD_KERNEL
(
384
,
2
,
4
,
half
)
case
384
:
LAUNCH_BACKWARD_KERNEL
(
384
,
2
,
4
,
half
)
case
512
:
LAUNCH_BACKWARD_KERNEL
(
512
,
2
,
4
,
half
)
case
512
:
LAUNCH_BACKWARD_KERNEL
(
512
,
2
,
4
,
half
)
case
640
:
LAUNCH_BACKWARD_KERNEL
(
640
,
2
,
4
,
half
)
case
768
:
LAUNCH_BACKWARD_KERNEL
(
768
,
2
,
4
,
half
)
case
768
:
LAUNCH_BACKWARD_KERNEL
(
768
,
2
,
4
,
half
)
case
1024
:
LAUNCH_BACKWARD_KERNEL
(
1024
,
2
,
4
,
half
)
case
1024
:
LAUNCH_BACKWARD_KERNEL
(
1024
,
2
,
4
,
half
)
case
1280
:
LAUNCH_BACKWARD_KERNEL
(
1280
,
2
,
4
,
half
)
case
1280
:
LAUNCH_BACKWARD_KERNEL
(
1280
,
2
,
4
,
half
)
case
1536
:
LAUNCH_BACKWARD_KERNEL
(
1536
,
2
,
4
,
half
)
case
1536
:
LAUNCH_BACKWARD_KERNEL
(
1536
,
2
,
4
,
half
)
case
1792
:
LAUNCH_BACKWARD_KERNEL
(
1792
,
2
,
4
,
half
)
case
1792
:
LAUNCH_BACKWARD_KERNEL
(
1792
,
2
,
4
,
half
)
case
2048
:
LAUNCH_BACKWARD_KERNEL
(
2048
,
2
,
4
,
half
)
case
2048
:
LAUNCH_BACKWARD_KERNEL
(
2048
,
2
,
4
,
half
)
case
2560
:
LAUNCH_BACKWARD_KERNEL
(
2560
,
2
,
4
,
half
)
case
5120
:
LAUNCH_BACKWARD_KERNEL
(
5120
,
2
,
4
,
half
)
}
}
}
else
if
(
type
==
at
::
ScalarType
::
Float
)
{
}
else
if
(
type
==
at
::
ScalarType
::
Float
)
{
switch
(
n2
)
{
switch
(
n2
)
{
case
64
:
LAUNCH_BACKWARD_KERNEL
(
64
,
1
,
4
,
float
)
case
64
:
LAUNCH_BACKWARD_KERNEL
(
64
,
1
,
4
,
float
)
case
128
:
LAUNCH_BACKWARD_KERNEL
(
128
,
1
,
4
,
float
)
case
128
:
LAUNCH_BACKWARD_KERNEL
(
128
,
1
,
4
,
float
)
case
256
:
LAUNCH_BACKWARD_KERNEL
(
256
,
1
,
4
,
float
)
case
256
:
LAUNCH_BACKWARD_KERNEL
(
256
,
1
,
4
,
float
)
case
320
:
LAUNCH_BACKWARD_KERNEL
(
320
,
1
,
4
,
float
)
case
384
:
LAUNCH_BACKWARD_KERNEL
(
384
,
1
,
4
,
float
)
case
384
:
LAUNCH_BACKWARD_KERNEL
(
384
,
1
,
4
,
float
)
case
512
:
LAUNCH_BACKWARD_KERNEL
(
512
,
1
,
4
,
float
)
case
512
:
LAUNCH_BACKWARD_KERNEL
(
512
,
1
,
4
,
float
)
case
640
:
LAUNCH_BACKWARD_KERNEL
(
640
,
1
,
4
,
float
)
case
768
:
LAUNCH_BACKWARD_KERNEL
(
768
,
1
,
4
,
float
)
case
768
:
LAUNCH_BACKWARD_KERNEL
(
768
,
1
,
4
,
float
)
case
1024
:
LAUNCH_BACKWARD_KERNEL
(
1024
,
1
,
4
,
float
)
case
1024
:
LAUNCH_BACKWARD_KERNEL
(
1024
,
1
,
4
,
float
)
case
1280
:
LAUNCH_BACKWARD_KERNEL
(
1280
,
1
,
4
,
float
)
case
1280
:
LAUNCH_BACKWARD_KERNEL
(
1280
,
1
,
4
,
float
)
case
1536
:
LAUNCH_BACKWARD_KERNEL
(
1536
,
1
,
4
,
float
)
case
1536
:
LAUNCH_BACKWARD_KERNEL
(
1536
,
1
,
4
,
float
)
case
1792
:
LAUNCH_BACKWARD_KERNEL
(
1792
,
1
,
4
,
float
)
case
1792
:
LAUNCH_BACKWARD_KERNEL
(
1792
,
1
,
4
,
float
)
case
2048
:
LAUNCH_BACKWARD_KERNEL
(
2048
,
1
,
4
,
float
)
case
2048
:
LAUNCH_BACKWARD_KERNEL
(
2048
,
1
,
4
,
float
)
case
2560
:
LAUNCH_BACKWARD_KERNEL
(
2560
,
1
,
4
,
float
)
case
5120
:
LAUNCH_BACKWARD_KERNEL
(
5120
,
1
,
4
,
float
)
}
}
}
}
}
}
unicore/modules/layer_norm.py
View file @
cf503760
...
@@ -45,6 +45,7 @@ class FusedLayerNormFastFunction(torch.autograd.Function):
...
@@ -45,6 +45,7 @@ class FusedLayerNormFastFunction(torch.autograd.Function):
weight_
,
bias_
,
ctx
.
eps
)
weight_
,
bias_
,
ctx
.
eps
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
FUSED_LAYER_NORM_SUPPORT_DIM
=
set
([
64
,
128
,
256
,
320
,
384
,
512
,
640
,
768
,
1024
,
1280
,
1536
,
1792
,
2048
,
2560
,
5120
])
class
LayerNorm
(
torch
.
nn
.
Module
):
class
LayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
elementwise_affine
=
True
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
elementwise_affine
=
True
):
...
@@ -57,17 +58,24 @@ class LayerNorm(torch.nn.Module):
...
@@ -57,17 +58,24 @@ class LayerNorm(torch.nn.Module):
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
reset_parameters
()
self
.
reset_parameters
()
def
torch_layer_norm
(
input
):
return
F
.
layer_norm
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
def
fused_layer_norm
(
input
):
if
input
.
is_cuda
():
return
FusedLayerNormFastFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
return
F
.
layer_norm
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
self
.
func
=
torch_layer_norm
if
(
not
HAS_LAYER_NORM
or
normalized_shape
[
0
]
not
in
FUSED_LAYER_NORM_SUPPORT_DIM
)
else
fused_layer_norm
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
if
not
input
.
is_cuda
or
not
HAS_LAYER_NORM
:
return
self
.
func
(
input
)
return
F
.
layer_norm
(
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
return
FusedLayerNormFastFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
def
extra_repr
(
self
):
def
extra_repr
(
self
):
return
'{normalized_shape}, eps={eps}, '
\
return
'{normalized_shape}, eps={eps}, '
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment