Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
06eebf66
Unverified
Commit
06eebf66
authored
Sep 26, 2023
by
Kirthi Shankar Sivamani
Committed by
GitHub
Sep 26, 2023
Browse files
[PyTorch] Mcore DDP support (#446)
Signed-off-by:
Kirthi Shankar Sivamani
<
ksivamani@nvidia.com
>
parent
76669cdd
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
29 deletions
+25
-29
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+1
-3
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+3
-14
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+5
-3
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+10
-6
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+6
-3
No files found.
tests/pytorch/test_sanity.py
View file @
06eebf66
...
@@ -226,9 +226,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_
...
@@ -226,9 +226,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_
if
"layer_norm_weight"
in
name
:
if
"layer_norm_weight"
in
name
:
continue
continue
elif
"weight"
in
name
and
p
.
requires_grad
:
elif
"weight"
in
name
and
p
.
requires_grad
:
assert
(
assert
torch
.
count_nonzero
(
p
.
main_grad
)
>
0
,
"Gradient not accumulated."
p
.
grad
is
None
and
torch
.
count_nonzero
(
p
.
main_grad
)
>
0
),
"Gradient not accumulated."
def
_test_sanity_e2e
(
block
,
bs
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
def
_test_sanity_e2e
(
block
,
bs
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
):
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
06eebf66
...
@@ -45,7 +45,6 @@ def fp8_gemm(
...
@@ -45,7 +45,6 @@ def fp8_gemm(
assert_dim_for_fp8_exec
(
A
)
assert_dim_for_fp8_exec
(
A
)
assert_dim_for_fp8_exec
(
B
)
assert_dim_for_fp8_exec
(
B
)
return_output
=
False
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty
(
out
=
torch
.
empty
(
B
.
shape
[
0
],
B
.
shape
[
0
],
...
@@ -53,7 +52,7 @@ def fp8_gemm(
...
@@ -53,7 +52,7 @@ def fp8_gemm(
dtype
=
out_dtype
,
dtype
=
out_dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
return_output
=
True
# Use bfloat16 as default bias_dtype
# Use bfloat16 as default bias_dtype
bias_dtype
=
torch
.
bfloat16
if
bias
is
None
else
bias
.
dtype
bias_dtype
=
torch
.
bfloat16
if
bias
is
None
else
bias
.
dtype
if
gelu
:
if
gelu
:
...
@@ -110,13 +109,7 @@ def fp8_gemm(
...
@@ -110,13 +109,7 @@ def fp8_gemm(
args
=
tuple
(
args
+
(
True
,
extra_output_tensor
,))
args
=
tuple
(
args
+
(
True
,
extra_output_tensor
,))
_
=
fn
(
*
args
)
_
=
fn
(
*
args
)
if
return_output
:
if
gelu
:
return
out
,
gelu_input
return
out
,
gelu_input
return
out
if
gelu
:
return
gelu_input
return
None
def
gemm
(
def
gemm
(
...
@@ -144,7 +137,6 @@ def gemm(
...
@@ -144,7 +137,6 @@ def gemm(
empty_tensor
=
torch
.
Tensor
()
empty_tensor
=
torch
.
Tensor
()
fp8_index
=
-
1
# dummy index
fp8_index
=
-
1
# dummy index
return_output
=
False
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty
(
out
=
torch
.
empty
(
B
.
shape
[
1
]
if
transb
else
B
.
shape
[
0
],
B
.
shape
[
1
]
if
transb
else
B
.
shape
[
0
],
...
@@ -152,7 +144,6 @@ def gemm(
...
@@ -152,7 +144,6 @@ def gemm(
dtype
=
dtype
,
dtype
=
dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
return_output
=
True
if
gelu
and
not
grad
:
if
gelu
and
not
grad
:
gelu_input
=
torch
.
empty_like
(
out
,
dtype
=
dtype
)
gelu_input
=
torch
.
empty_like
(
out
,
dtype
=
dtype
)
...
@@ -222,6 +213,4 @@ def gemm(
...
@@ -222,6 +213,4 @@ def gemm(
args
=
tuple
(
args
+
(
False
,
extra_output_tensor
,))
args
=
tuple
(
args
+
(
False
,
extra_output_tensor
,))
_
=
fn
(
*
args
)
_
=
fn
(
*
args
)
if
return_output
:
return
out
,
grad_bias
,
gelu_input
return
out
,
grad_bias
,
gelu_input
return
None
,
grad_bias
,
gelu_input
transformer_engine/pytorch/module/layernorm_linear.py
View file @
06eebf66
...
@@ -173,7 +173,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -173,7 +173,7 @@ class _LayerNormLinear(torch.autograd.Function):
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
,
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
,
fp8_dtype_forward
)
fp8_dtype_forward
)
out
=
tex
.
fp8_gemm
(
out
,
_
=
tex
.
fp8_gemm
(
weight_fp8
,
weight_fp8
,
fp8_meta
[
"scaling_fwd"
].
scale_inv
,
fp8_meta
[
"scaling_fwd"
].
scale_inv
,
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
,
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
,
...
@@ -389,7 +389,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -389,7 +389,7 @@ class _LayerNormLinear(torch.autograd.Function):
# WGRAD
# WGRAD
if
not
ctx
.
fp8_meta
[
"recipe"
].
override_linear_precision
.
wgrad
:
if
not
ctx
.
fp8_meta
[
"recipe"
].
override_linear_precision
.
wgrad
:
ln_out_total_t
=
tex
.
fp8_transpose
(
ln_out_total
,
fp8_dtype_forward
)
ln_out_total_t
=
tex
.
fp8_transpose
(
ln_out_total
,
fp8_dtype_forward
)
wgrad
=
tex
.
fp8_gemm
(
wgrad
,
_
=
tex
.
fp8_gemm
(
ln_out_total_t
,
ln_out_total_t
,
fwd_scale_inverses
,
fwd_scale_inverses
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
...
@@ -444,7 +444,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -444,7 +444,6 @@ class _LayerNormLinear(torch.autograd.Function):
ub
=
ub_obj_dgrad
if
ctx
.
ub_bulk_wgrad
else
None
ub
=
ub_obj_dgrad
if
ctx
.
ub_bulk_wgrad
else
None
)
)
if
ctx
.
ub_bulk_wgrad
:
if
ctx
.
ub_bulk_wgrad
:
dgrad
=
ub_obj_dgrad
.
get_ubuf_output
(
0
)
# Reduce-scatter output
dgrad
=
ub_obj_dgrad
.
get_ubuf_output
(
0
)
# Reduce-scatter output
# Column Parallel Linear
# Column Parallel Linear
...
@@ -474,6 +473,9 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -474,6 +473,9 @@ class _LayerNormLinear(torch.autograd.Function):
if
not
ctx
.
use_bias
:
if
not
ctx
.
use_bias
:
grad_bias
=
None
grad_bias
=
None
# Handle custom DDP from mcore.
weight
.
grad_added_to_main_grad
=
ctx
.
fuse_wgrad_accumulation
return
(
return
(
dxmat
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
dxmat
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
dgamma
,
dgamma
,
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
06eebf66
...
@@ -223,7 +223,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -223,7 +223,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward
,
fp8_dtype_forward
,
)
)
fc1_out
=
tex
.
fp8_gemm
(
fc1_out
,
_
=
tex
.
fp8_gemm
(
fc1_weight_fp8
,
fc1_weight_fp8
,
fp8_meta
[
"scaling_fwd"
].
scale_inv
,
fp8_meta
[
"scaling_fwd"
].
scale_inv
,
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
,
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
,
...
@@ -344,7 +344,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -344,7 +344,7 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size
=
list
(
gelu_out
.
size
())
dim_size
=
list
(
gelu_out
.
size
())
dim_size
[
1
]
=
fc2_weight
.
size
(
0
)
dim_size
[
1
]
=
fc2_weight
.
size
(
0
)
fc2_out
=
torch
.
empty
(
dim_size
,
dtype
=
activation_dtype
,
device
=
gelu_out
.
device
)
fc2_out
=
torch
.
empty
(
dim_size
,
dtype
=
activation_dtype
,
device
=
gelu_out
.
device
)
_
,
_
,
_
=
tex
.
gemm
(
_
=
tex
.
gemm
(
fc2_weight
,
fc2_weight
,
gelu_out
,
gelu_out
,
activation_dtype
,
activation_dtype
,
...
@@ -498,7 +498,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -498,7 +498,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
)
# FC2 DGRAD; Unconditional
# FC2 DGRAD; Unconditional
fc2_dgrad
=
tex
.
fp8_gemm
(
fc2_dgrad
,
_
=
tex
.
fp8_gemm
(
fc2_weight_t_fp8
,
fc2_weight_t_fp8
,
fwd_scale_inverses
,
fwd_scale_inverses
,
tex
.
FP8FwdTensors
.
GEMM2_WEIGHT
,
tex
.
FP8FwdTensors
.
GEMM2_WEIGHT
,
...
@@ -519,7 +519,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -519,7 +519,7 @@ class _LayerNormMLP(torch.autograd.Function):
if
not
ctx
.
fp8_meta
[
"recipe"
].
override_linear_precision
.
wgrad
:
if
not
ctx
.
fp8_meta
[
"recipe"
].
override_linear_precision
.
wgrad
:
if
fc2_weight
.
requires_grad
:
if
fc2_weight
.
requires_grad
:
gelu_out_t
=
tex
.
fp8_transpose
(
gelu_out
,
fp8_dtype_forward
)
gelu_out_t
=
tex
.
fp8_transpose
(
gelu_out
,
fp8_dtype_forward
)
fc2_wgrad
=
tex
.
fp8_gemm
(
fc2_wgrad
,
_
=
tex
.
fp8_gemm
(
gelu_out_t
,
gelu_out_t
,
fwd_scale_inverses
,
fwd_scale_inverses
,
tex
.
FP8FwdTensors
.
GEMM2_INPUT
,
tex
.
FP8FwdTensors
.
GEMM2_INPUT
,
...
@@ -675,7 +675,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -675,7 +675,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_size
,
dtype
=
ctx
.
activation_dtype
,
device
=
fc1_weight
.
device
fc1_dgrad_size
,
dtype
=
ctx
.
activation_dtype
,
device
=
fc1_weight
.
device
)
)
# FC1 DGRAD: Unconditional
# FC1 DGRAD: Unconditional
_
,
_
,
_
=
tex
.
gemm
(
_
=
tex
.
gemm
(
fc1_weight
,
fc1_weight
,
dgelu
,
dgelu
,
ctx
.
activation_dtype
,
ctx
.
activation_dtype
,
...
@@ -705,7 +705,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -705,7 +705,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 WGRAD
# FC1 WGRAD
if
not
ctx
.
fp8_meta
[
"recipe"
].
override_linear_precision
.
wgrad
:
if
not
ctx
.
fp8_meta
[
"recipe"
].
override_linear_precision
.
wgrad
:
ln_out_total_t
=
tex
.
fp8_transpose
(
ln_out_total
,
fp8_dtype_forward
)
ln_out_total_t
=
tex
.
fp8_transpose
(
ln_out_total
,
fp8_dtype_forward
)
fc1_wgrad
=
tex
.
fp8_gemm
(
fc1_wgrad
,
_
=
tex
.
fp8_gemm
(
ln_out_total_t
,
ln_out_total_t
,
fwd_scale_inverses
,
fwd_scale_inverses
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
...
@@ -794,6 +794,10 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -794,6 +794,10 @@ class _LayerNormMLP(torch.autograd.Function):
)
)
dbeta
=
None
dbeta
=
None
# Handle custom DDP from mcore.
fc1_weight
.
grad_added_to_main_grad
=
ctx
.
fuse_wgrad_accumulation
fc2_weight
.
grad_added_to_main_grad
=
ctx
.
fuse_wgrad_accumulation
return
(
return
(
dxmat
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
dxmat
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
dgamma
,
dgamma
,
...
...
transformer_engine/pytorch/module/linear.py
View file @
06eebf66
...
@@ -211,7 +211,7 @@ class _Linear(torch.autograd.Function):
...
@@ -211,7 +211,7 @@ class _Linear(torch.autograd.Function):
dim_size
[
1
]
=
weight
.
size
(
0
)
dim_size
[
1
]
=
weight
.
size
(
0
)
out
=
torch
.
empty
(
dim_size
,
dtype
=
activation_dtype
,
device
=
inputmat_total
.
device
)
out
=
torch
.
empty
(
dim_size
,
dtype
=
activation_dtype
,
device
=
inputmat_total
.
device
)
_
,
_
,
_
=
gemm
(
_
=
gemm
(
weight
,
weight
,
inputmat_total
,
inputmat_total
,
activation_dtype
,
activation_dtype
,
...
@@ -325,7 +325,7 @@ class _Linear(torch.autograd.Function):
...
@@ -325,7 +325,7 @@ class _Linear(torch.autograd.Function):
if
ctx
.
requires_dgrad
:
if
ctx
.
requires_dgrad
:
if
ctx
.
fp8
:
if
ctx
.
fp8
:
dgrad
=
fp8_gemm
(
dgrad
,
_
=
fp8_gemm
(
weight_t_fp8
,
weight_t_fp8
,
fwd_scale_inverses
,
fwd_scale_inverses
,
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
,
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
,
...
@@ -368,7 +368,7 @@ class _Linear(torch.autograd.Function):
...
@@ -368,7 +368,7 @@ class _Linear(torch.autograd.Function):
if
not
ctx
.
fp8_meta
[
"recipe"
].
override_linear_precision
.
wgrad
:
if
not
ctx
.
fp8_meta
[
"recipe"
].
override_linear_precision
.
wgrad
:
if
ctx
.
ub_split_ag
:
if
ctx
.
ub_split_ag
:
grad_output_t
=
tex
.
fp8_transpose
(
grad_output_c
,
fp8_dtype_backward
)
grad_output_t
=
tex
.
fp8_transpose
(
grad_output_c
,
fp8_dtype_backward
)
wgrad
=
fp8_gemm
(
wgrad
,
_
=
fp8_gemm
(
inputmat_t_total
,
inputmat_t_total
,
fwd_scale_inverses
,
fwd_scale_inverses
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
...
@@ -415,6 +415,9 @@ class _Linear(torch.autograd.Function):
...
@@ -415,6 +415,9 @@ class _Linear(torch.autograd.Function):
if
not
ctx
.
use_bias
:
if
not
ctx
.
use_bias
:
grad_bias
=
None
grad_bias
=
None
# Handle custom DDP from mcore.
weight
.
grad_added_to_main_grad
=
ctx
.
fuse_wgrad_accumulation
return
(
return
(
wgrad
if
weight
.
requires_grad
else
None
,
wgrad
if
weight
.
requires_grad
else
None
,
None
,
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment