Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
905d94f4
Unverified
Commit
905d94f4
authored
Jun 13, 2024
by
Tim Moon
Committed by
GitHub
Jun 13, 2024
Browse files
Use unoptimized RMSNorm kernel if pointers are not aligned (#886)
Signed-off-by:
Tim Moon
<
tmoon@nvidia.com
>
parent
e706e5fa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
17 deletions
+48
-17
transformer_engine/common/rmsnorm/rmsnorm_api.cpp
transformer_engine/common/rmsnorm/rmsnorm_api.cpp
+48
-17
No files found.
transformer_engine/common/rmsnorm/rmsnorm_api.cpp
View file @
905d94f4
...
...
@@ -4,11 +4,14 @@
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/rmsnorm.h"
#include <cstdint>
#include <numeric>
#include <vector>
#include "../common.h"
#include "rmsnorm.h"
#include "
transformer_engine/rmsnorm
.h"
#include "
../common
.h"
/*
...
...
@@ -46,11 +49,23 @@ BwdTunedRegistry BWD_TUNED_FUNCS;
FwdGeneralRegistry
FWD_GENERAL_FUNCS
;
BwdGeneralRegistry
BWD_GENERAL_FUNCS
;
FwdFunction
&
get_fwd_launcher
(
DType
wtype
,
DType
itype
,
DType
otype
,
DType
ctype
,
uint32_t
hidden_size
,
uint32_t
batch_size
)
{
FwdFunction
&
get_fwd_launcher
(
DType
wtype
,
DType
itype
,
DType
otype
,
DType
ctype
,
const
layer_norm
::
FwdParams
&
params
)
{
// Look for tuned kernel
auto
tuned_key
=
layer_norm
::
get_key
(
wtype
,
itype
,
otype
,
ctype
,
hidden_size
);
if
(
batch_size
%
4
==
0
&&
FWD_TUNED_FUNCS
.
count
(
tuned_key
)
>
0
)
{
auto
tuned_key
=
layer_norm
::
get_key
(
wtype
,
itype
,
otype
,
ctype
,
params
.
cols
);
auto
is_aligned
=
[](
const
void
*
ptr
)
->
bool
{
// Assume vectorized memory accesses are <=16B
return
reinterpret_cast
<
uintptr_t
>
(
ptr
)
%
16
==
0
;
};
if
(
params
.
rows
%
4
==
0
&&
is_aligned
(
params
.
x
)
&&
is_aligned
(
params
.
rs
)
&&
is_aligned
(
params
.
gamma
)
&&
is_aligned
(
params
.
z
)
&&
FWD_TUNED_FUNCS
.
count
(
tuned_key
)
>
0
)
{
return
FWD_TUNED_FUNCS
.
at
(
tuned_key
);
}
...
...
@@ -60,7 +75,7 @@ FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype
NVTE_ERROR
(
"FWD: Unsupported types."
);
}
auto
&
general_func_map
=
FWD_GENERAL_FUNCS
.
at
(
general_key
);
auto
func_iter
=
general_func_map
.
lower_bound
(
hidden_size
);
auto
func_iter
=
general_func_map
.
lower_bound
(
params
.
cols
);
if
(
func_iter
==
general_func_map
.
end
())
{
// Hidden size is too big, need to use multi-CTA
return
general_func_map
.
rbegin
()
->
second
;
...
...
@@ -71,11 +86,26 @@ FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype
////////////////////////////////////////////////////////////////////////////////////////////////////
BwdFunction
&
get_bwd_launcher
(
DType
wtype
,
DType
itype
,
DType
otype
,
DType
ctype
,
uint32_t
hidden_size
,
uint32_t
batch_size
)
{
BwdFunction
&
get_bwd_launcher
(
DType
wtype
,
DType
itype
,
DType
otype
,
DType
ctype
,
const
layer_norm
::
BwdParams
&
params
)
{
// Look for tuned kernel
auto
tuned_key
=
layer_norm
::
get_key
(
wtype
,
itype
,
otype
,
ctype
,
hidden_size
);
if
(
batch_size
%
4
==
0
&&
BWD_TUNED_FUNCS
.
count
(
tuned_key
)
>
0
)
{
auto
tuned_key
=
layer_norm
::
get_key
(
wtype
,
itype
,
otype
,
ctype
,
params
.
cols
);
auto
is_aligned
=
[](
const
void
*
ptr
)
->
bool
{
// Assume vectorized memory accesses are <=16B
return
reinterpret_cast
<
uintptr_t
>
(
ptr
)
%
16
==
0
;
};
if
(
params
.
rows
%
4
==
0
&&
is_aligned
(
params
.
x
)
&&
is_aligned
(
params
.
rs
)
&&
is_aligned
(
params
.
gamma
)
&&
is_aligned
(
params
.
dz
)
&&
is_aligned
(
params
.
dx
)
&&
is_aligned
(
params
.
dgamma
)
&&
is_aligned
(
params
.
dgamma_part
)
&&
layer_norm
::
BWD_TUNED_FUNCS
.
count
(
tuned_key
)
>
0
)
{
return
BWD_TUNED_FUNCS
.
at
(
tuned_key
);
}
...
...
@@ -85,7 +115,7 @@ BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype
NVTE_ERROR
(
"BWD: Unsupported types."
);
}
auto
&
general_func_map
=
BWD_GENERAL_FUNCS
.
at
(
general_key
);
auto
func_iter
=
general_func_map
.
lower_bound
(
hidden_size
);
auto
func_iter
=
general_func_map
.
lower_bound
(
params
.
cols
);
if
(
func_iter
==
general_func_map
.
end
())
{
// Hidden size is too big, need to use multi-CTA
return
general_func_map
.
rbegin
()
->
second
;
...
...
@@ -132,9 +162,6 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
launch_params
.
multiprocessorCount
=
multiprocessorCount
;
launch_params
.
stream
=
stream
;
// Request the kernel launcher.
auto
launcher
=
rmsnorm
::
get_fwd_launcher
(
wtype
,
itype
,
otype
,
ctype
,
hidden_size
,
rows
);
// Set the kernel runtime parameters.
rmsnorm
::
FwdParams
&
params
=
launch_params
.
params
;
params
.
rows
=
rows
;
...
...
@@ -151,6 +178,9 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
params
.
fp8_out
=
fp8_out
;
params
.
zero_centered_gamma
=
zero_centered_gamma
;
// Request the kernel launcher.
auto
launcher
=
rmsnorm
::
get_fwd_launcher
(
wtype
,
itype
,
otype
,
ctype
,
params
);
// Query the kernel-specific launch parameters.
launcher
(
launch_params
,
true
);
if
(
launch_params
.
workspace_bytes
==
0
)
{
...
...
@@ -242,8 +272,6 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
launch_params
.
stream
=
stream
;
launch_params
.
multiprocessorCount
=
multiprocessorCount
;
auto
launcher
=
rmsnorm
::
get_bwd_launcher
(
wtype
,
itype
,
otype
,
ctype
,
hidden_size
,
rows
);
// Set the kernel runtime parameters.
rmsnorm
::
BwdParams
&
params
=
launch_params
.
params
;
params
.
rows
=
rows
;
...
...
@@ -260,6 +288,9 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
params
.
dgamma_part
=
dgamma_part
->
data
.
dptr
;
params
.
zero_centered_gamma
=
zero_centered_gamma
;
// Request the kernel launcher.
auto
launcher
=
rmsnorm
::
get_bwd_launcher
(
wtype
,
itype
,
otype
,
ctype
,
params
);
// Query the kernel-specific launch parameters.
launcher
(
launch_params
,
true
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment