Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
2d0f9cf2
"docs/vscode:/vscode.git/clone" did not exist on "32b85dfa8d4a5fa54469ddc72be89d827c1ee9d6"
Unverified
Commit
2d0f9cf2
authored
May 07, 2020
by
Chaitanya Sri Krishna Lolla
Committed by
GitHub
May 07, 2020
Browse files
Enable fusedlayernorm extension (#3)
parent
3ccdd63d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
5 deletions
+17
-5
csrc/layer_norm_cuda_kernel.cu
csrc/layer_norm_cuda_kernel.cu
+10
-4
setup.py
setup.py
+7
-1
No files found.
csrc/layer_norm_cuda_kernel.cu
View file @
2d0f9cf2
...
...
@@ -172,8 +172,8 @@ void cuWelfordMuSigma2(
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
<
float
>
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
<
float
>
(
curr
.
y
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
...
...
@@ -230,9 +230,15 @@ void cuWelfordMuSigma2(
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
return
U
(
1
)
/
sqrt
(
v
);
}
#if defined __HIP_PLATFORM_HCC__
__device__
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
#else
template
<
>
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
#endif
template
<
>
double
rsqrt
(
double
v
)
{
return
rsqrt
(
v
);
}
...
...
@@ -293,7 +299,7 @@ void cuApplyLayerNorm(
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
...
...
@@ -531,7 +537,7 @@ void cuComputeGradInput(
const
T
*
gamma
,
T
*
grad_input
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
...
...
setup.py
View file @
2d0f9cf2
...
...
@@ -177,7 +177,13 @@ if "--cuda_ext" in sys.argv:
'-O3'
,
'--use_fast_math'
]
+
version_dependent_macros
}))
else
:
print
(
"INFO: Skipping FusedLayerNorm extension."
)
print
(
"INFO: Building FusedLayerNorm extension."
)
ext_modules
.
append
(
CUDAExtension
(
name
=
'fused_layer_norm_cuda'
,
sources
=
[
'csrc/layer_norm_cuda.cpp'
,
'csrc/hip/layer_norm_hip_kernel.hip'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:
[]}))
if
not
is_rocm_pytorch
:
ext_modules
.
append
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment