Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
d1d00b3e
Unverified
Commit
d1d00b3e
authored
Mar 16, 2023
by
Kirthi Shankar Sivamani
Committed by
GitHub
Mar 16, 2023
Browse files
Relax dimension checks for fp8 exec (#106)
Signed-off-by:
Kirthi Shankar Sivamani
<
ksivamani@nvidia.com
>
parent
44d64abc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
10 deletions
+12
-10
transformer_engine/pytorch/module.py
transformer_engine/pytorch/module.py
+7
-7
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+5
-3
No files found.
transformer_engine/pytorch/module.py
View file @
d1d00b3e
...
...
@@ -51,7 +51,7 @@ from .utils import (
divide
,
get_default_init_method
,
cast_if_needed
,
check_
modulo_16
,
check_
dim_for_fp8_forward_exec
,
)
from
.distributed
import
(
set_tensor_model_parallel_attributes
,
...
...
@@ -666,8 +666,8 @@ class _LayerNormLinear(torch.autograd.Function):
assert
inp
.
shape
[
-
1
]
==
in_features
,
"GEMM not possible"
inputmat
=
inp
.
view
((
-
1
,
in_features
))
assert
(
not
fp8
or
check_
modulo_16
(
inputmat
,
weight
)
),
"Input
s
and weight
s must be divisible by 16
for FP8 execution."
not
fp8
or
check_
dim_for_fp8_forward_exec
(
inputmat
,
weight
)
),
"Input and weight
dimensions are not compatible
for FP8 execution."
update_fp8_weights
=
is_first_microbatch
is
None
or
is_first_microbatch
...
...
@@ -1396,8 +1396,8 @@ class _Linear(torch.autograd.Function):
assert
inp
.
shape
[
-
1
]
==
in_features
,
"GEMM not possible"
inputmat
=
inp
.
view
((
-
1
,
in_features
))
assert
(
not
fp8
or
check_
modulo_16
(
inputmat
,
weight
)
),
"Input
s
and weight
s must be divisible by 16
for FP8 execution."
not
fp8
or
check_
dim_for_fp8_forward_exec
(
inputmat
,
weight
)
),
"Input and weight
dimensions are not compatible
for FP8 execution."
update_fp8_weights
=
is_first_microbatch
is
None
or
is_first_microbatch
...
...
@@ -2012,8 +2012,8 @@ class _LayerNormMLP(torch.autograd.Function):
assert
inp
.
shape
[
-
1
]
==
in_features
,
"GEMM not possible"
inputmat
=
inp
.
view
((
-
1
,
in_features
))
assert
(
not
fp8
or
check_
modulo_16
(
inputmat
,
fc1_weight
,
fc2_weight
)
),
"Input
s
and weight
s must be divisible by 16
for FP8 execution."
not
fp8
or
check_
dim_for_fp8_forward_exec
(
inputmat
,
fc1_weight
,
fc2_weight
)
),
"Input and weight
dimensions are not compatible
for FP8 execution."
update_fp8_weights
=
is_first_microbatch
is
None
or
is_first_microbatch
...
...
transformer_engine/pytorch/utils.py
View file @
d1d00b3e
...
...
@@ -179,6 +179,8 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return
tensor
if
tensor
is
None
or
tensor
.
dtype
==
dtype
else
tensor
.
to
(
dtype
)
def
check_modulo_16
(
*
tensors
:
Tuple
[
torch
.
Tensor
,
...])
->
bool
:
"""Check if each dimension of given tensors is divisible by 16."""
return
all
(
all
(
n
%
16
==
0
for
n
in
t
.
shape
)
for
t
in
tensors
)
def
check_dim_for_fp8_forward_exec
(
*
tensors
:
Tuple
[
torch
.
Tensor
,
...])
->
bool
:
"""For fp8 fprop (TN layout), inputs and weights must be such
that dim0 is divisible by 8 and dim1 is divisible by 16.
"""
return
all
(
not
t
.
shape
[
0
]
%
8
and
not
t
.
shape
[
1
]
%
16
for
t
in
tensors
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment