Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8df1b6b8
Commit
8df1b6b8
authored
Apr 15, 2022
by
hubertlu-tw
Browse files
Fix NaN issues in FusedRMSNorm
parent
28c5638d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
7 deletions
+16
-7
csrc/layer_norm_cuda_kernel.cu
csrc/layer_norm_cuda_kernel.cu
+16
-7
No files found.
csrc/layer_norm_cuda_kernel.cu
View file @
8df1b6b8
...
...
@@ -741,8 +741,12 @@ void cuComputeGradInput(
const
U
gamma_idx
=
static_cast
<
U
>
((
idx
<
n2
)
?
gamma
[
idx
]
:
V
(
0
));
const
U
c_h
=
static_cast
<
U
>
((
idx
<
n2
)
?
k_input
[
idx
]
:
T
(
0
));
const
U
c_loss
=
static_cast
<
U
>
((
idx
<
n2
)
?
k_dout
[
idx
]
:
V
(
0
));
if
(
!
rms_only
)
{
sum_loss1
+=
c_loss
*
gamma_idx
;
sum_loss2
+=
c_loss
*
gamma_idx
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
else
{
sum_loss2
+=
c_loss
*
gamma_idx
*
(
c_h
)
*
c_invvar
;
}
}
#endif
}
else
{
...
...
@@ -775,8 +779,12 @@ void cuComputeGradInput(
int
idx
=
l
+
thrx
;
const
U
c_h
=
static_cast
<
U
>
((
idx
<
n2
)
?
k_input
[
idx
]
:
T
(
0
));
const
U
c_loss
=
static_cast
<
U
>
((
idx
<
n2
)
?
k_dout
[
idx
]
:
V
(
0
));
if
(
!
rms_only
)
{
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
else
{
sum_loss2
+=
c_loss
*
(
c_h
)
*
c_invvar
;
}
}
#endif
}
...
...
@@ -895,7 +903,7 @@ void HostApplyLayerNorm(
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
,
warp_size
);
}
//
TODO:
Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
// Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
template
<
typename
T
,
typename
U
,
typename
V
=
T
>
void
HostApplyRMSNorm
(
V
*
output
,
...
...
@@ -1070,7 +1078,7 @@ void HostLayerNormGradient(
grad_input
,
false
);
}
//
TODO:
Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
// Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
template
<
typename
T
,
typename
U
=
float
,
typename
V
=
T
>
void
HostRMSNormGradient
(
const
V
*
dout
,
...
...
@@ -1220,3 +1228,4 @@ void cuda_rms_norm_gradient(
)
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment