Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
28c5638d
Commit
28c5638d
authored
Apr 15, 2022
by
hubertlu-tw
Browse files
Optimize HostRMSNormGradient and HostApplyRMSNorm for AMD GPUs
parent
d755f1f1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
6 deletions
+10
-6
csrc/layer_norm_cuda_kernel.cu
csrc/layer_norm_cuda_kernel.cu
+10
-6
No files found.
csrc/layer_norm_cuda_kernel.cu
View file @
28c5638d
...
...
@@ -908,9 +908,13 @@ void HostApplyRMSNorm(
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
warp_size
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
warpSize
;
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
dim3
threads
(
warp_size
,
4
,
1
);
#ifdef __HIP_PLATFORM_HCC__
// Optimization for ROCm MI100
threads
.
y
=
2
;
#endif
int
nshared
=
threads
.
y
>
1
?
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
...
...
@@ -1080,10 +1084,10 @@ void HostRMSNormGradient(
V
*
grad_gamma
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int
warp_size
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
warpSize
;
if
(
gamma
!=
NULL
)
{
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
int
part_size
=
warp_size
;
const
dim3
threads2
(
warp_size
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
...
...
@@ -1106,7 +1110,7 @@ void HostRMSNormGradient(
part_grad_gamma
.
DATA_PTR
<
U
>
(),
/* unused */
true
);
const
dim3
threads3
(
32
,
8
,
1
);
const
dim3
threads3
(
warp_size
,
8
,
1
);
const
dim3
blocks3
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
1
,
1
);
const
int
nshared3
=
threads3
.
x
*
threads3
.
y
*
sizeof
(
U
);
cuComputeGradGammaBeta
<<<
blocks3
,
threads3
,
nshared3
,
stream
>>>
(
...
...
@@ -1122,7 +1126,7 @@ void HostRMSNormGradient(
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
const
dim3
threads1
(
warp_size
,
4
,
1
);
int
nshared
=
threads1
.
y
>
1
?
threads1
.
y
*
threads1
.
x
*
sizeof
(
U
)
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment