Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d755f1f1
Commit
d755f1f1
authored
Apr 15, 2022
by
hubertlu-tw
Browse files
Fix some bugs
parent
47921708
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
7 deletions
+23
-7
csrc/layer_norm_cuda_kernel.cu
csrc/layer_norm_cuda_kernel.cu
+22
-6
tests/L0/run_fused_layer_norm/test_fused_layer_norm.py
tests/L0/run_fused_layer_norm/test_fused_layer_norm.py
+1
-1
No files found.
csrc/layer_norm_cuda_kernel.cu
View file @
d755f1f1
...
...
@@ -75,7 +75,7 @@ void cuWelfordMuSigma2(
U
&
mu
,
U
&
sigma2
,
U
*
buf
,
const
int
GPU_WARP_SIZE
)
const
int
GPU_WARP_SIZE
,
bool
rms_only
)
{
// Assumptions:
...
...
@@ -185,7 +185,7 @@ void cuWelfordMuSigma2(
float
&
mu
,
float
&
sigma2
,
float
*
buf
,
const
int
GPU_WARP_SIZE
)
const
int
GPU_WARP_SIZE
,
bool
rms_only
)
{
// Assumptions:
...
...
@@ -369,9 +369,8 @@ void cuApplyLayerNorm_(
const
U
epsilon
,
const
V
*
__restrict__
gamma
,
const
V
*
__restrict__
beta
,
const
int
GPU_WARP_SIZE
bool
rms_only
)
const
int
GPU_WARP_SIZE
,
bool
rms_only
)
{
// Assumptions:
// 1) blockDim.x == warpSize
...
...
@@ -433,6 +432,20 @@ void cuApplyLayerNorm(
cuApplyLayerNorm_
<
T
,
U
,
V
>
(
output_vals
,
mean
,
invvar
,
vals
,
n1
,
n2
,
epsilon
,
gamma
,
beta
,
warp_size
,
false
);
}
template
<
typename
T
,
typename
U
,
typename
V
=
T
>
__global__
void
cuApplyRMSNorm
(
V
*
__restrict__
output_vals
,
U
*
__restrict__
invvar
,
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
U
epsilon
,
const
V
*
__restrict__
gamma
,
const
int
warp_size
)
{
cuApplyLayerNorm_
<
T
,
U
,
V
>
(
output_vals
,
NULL
,
invvar
,
vals
,
n1
,
n2
,
epsilon
,
gamma
,
NULL
,
warp_size
,
true
);
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadWriteStridedInputs
(
const
int
i1_block
,
...
...
@@ -882,6 +895,7 @@ void HostApplyLayerNorm(
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
,
warp_size
);
}
// TODO: Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
template
<
typename
T
,
typename
U
,
typename
V
=
T
>
void
HostApplyRMSNorm
(
V
*
output
,
...
...
@@ -893,6 +907,7 @@ void HostApplyRMSNorm(
const
V
*
gamma
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
warp_size
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
warpSize
;
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
...
...
@@ -901,7 +916,7 @@ void HostApplyRMSNorm(
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
0
;
cuApplyRMSNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
output
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
);
output
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
warp_size
);
}
void
cuda_layer_norm
(
...
...
@@ -1200,3 +1215,4 @@ void cuda_rms_norm_gradient(
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);
)
}
tests/L0/run_fused_layer_norm/test_fused_layer_norm.py
View file @
d755f1f1
import
unittest
import
os
import
random
import
itertools
import
torch
import
apex
from
torch.autograd
import
Variable
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment