Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8df1b6b8
"vscode:/vscode.git/clone" did not exist on "cfb3b75d63bf097fec2efb0835d7bf62a0bd3492"
Commit
8df1b6b8
authored
Apr 15, 2022
by
hubertlu-tw
Browse files
Fix NaN issues in FusedRMSNorm
parent
28c5638d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
7 deletions
+16
-7
csrc/layer_norm_cuda_kernel.cu
csrc/layer_norm_cuda_kernel.cu
+16
-7
No files found.
csrc/layer_norm_cuda_kernel.cu
View file @
8df1b6b8
...
@@ -712,7 +712,7 @@ void cuComputeGradInput(
...
@@ -712,7 +712,7 @@ void cuComputeGradInput(
#ifndef __HIP_PLATFORM_HCC__
#ifndef __HIP_PLATFORM_HCC__
int
l
=
4
*
thrx
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
if
(
!
rms_only
)
{
if
(
!
rms_only
)
{
...
@@ -741,8 +741,12 @@ void cuComputeGradInput(
...
@@ -741,8 +741,12 @@ void cuComputeGradInput(
const
U
gamma_idx
=
static_cast
<
U
>
((
idx
<
n2
)
?
gamma
[
idx
]
:
V
(
0
));
const
U
gamma_idx
=
static_cast
<
U
>
((
idx
<
n2
)
?
gamma
[
idx
]
:
V
(
0
));
const
U
c_h
=
static_cast
<
U
>
((
idx
<
n2
)
?
k_input
[
idx
]
:
T
(
0
));
const
U
c_h
=
static_cast
<
U
>
((
idx
<
n2
)
?
k_input
[
idx
]
:
T
(
0
));
const
U
c_loss
=
static_cast
<
U
>
((
idx
<
n2
)
?
k_dout
[
idx
]
:
V
(
0
));
const
U
c_loss
=
static_cast
<
U
>
((
idx
<
n2
)
?
k_dout
[
idx
]
:
V
(
0
));
sum_loss1
+=
c_loss
*
gamma_idx
;
if
(
!
rms_only
)
{
sum_loss2
+=
c_loss
*
gamma_idx
*
(
c_h
-
c_mean
)
*
c_invvar
;
sum_loss1
+=
c_loss
*
gamma_idx
;
sum_loss2
+=
c_loss
*
gamma_idx
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
else
{
sum_loss2
+=
c_loss
*
gamma_idx
*
(
c_h
)
*
c_invvar
;
}
}
}
#endif
#endif
}
else
{
}
else
{
...
@@ -775,8 +779,12 @@ void cuComputeGradInput(
...
@@ -775,8 +779,12 @@ void cuComputeGradInput(
int
idx
=
l
+
thrx
;
int
idx
=
l
+
thrx
;
const
U
c_h
=
static_cast
<
U
>
((
idx
<
n2
)
?
k_input
[
idx
]
:
T
(
0
));
const
U
c_h
=
static_cast
<
U
>
((
idx
<
n2
)
?
k_input
[
idx
]
:
T
(
0
));
const
U
c_loss
=
static_cast
<
U
>
((
idx
<
n2
)
?
k_dout
[
idx
]
:
V
(
0
));
const
U
c_loss
=
static_cast
<
U
>
((
idx
<
n2
)
?
k_dout
[
idx
]
:
V
(
0
));
sum_loss1
+=
c_loss
;
if
(
!
rms_only
)
{
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
else
{
sum_loss2
+=
c_loss
*
(
c_h
)
*
c_invvar
;
}
}
}
#endif
#endif
}
}
...
@@ -895,7 +903,7 @@ void HostApplyLayerNorm(
...
@@ -895,7 +903,7 @@ void HostApplyLayerNorm(
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
,
warp_size
);
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
,
warp_size
);
}
}
//
TODO:
Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
// Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
template
<
typename
T
,
typename
U
,
typename
V
=
T
>
template
<
typename
T
,
typename
U
,
typename
V
=
T
>
void
HostApplyRMSNorm
(
void
HostApplyRMSNorm
(
V
*
output
,
V
*
output
,
...
@@ -1070,7 +1078,7 @@ void HostLayerNormGradient(
...
@@ -1070,7 +1078,7 @@ void HostLayerNormGradient(
grad_input
,
grad_input
,
false
);
false
);
}
}
//
TODO:
Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
// Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
template
<
typename
T
,
typename
U
=
float
,
typename
V
=
T
>
template
<
typename
T
,
typename
U
=
float
,
typename
V
=
T
>
void
HostRMSNormGradient
(
void
HostRMSNormGradient
(
const
V
*
dout
,
const
V
*
dout
,
...
@@ -1220,3 +1228,4 @@ void cuda_rms_norm_gradient(
...
@@ -1220,3 +1228,4 @@ void cuda_rms_norm_gradient(
)
)
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment