Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
715c3bb8
Unverified
Commit
715c3bb8
authored
Aug 29, 2025
by
Daniel Stokes
Committed by
GitHub
Aug 29, 2025
Browse files
feat: Add support for multiple quantization modes in the UB communicators (#2043)
parent
f98e3053
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
216 additions
and
85 deletions
+216
-85
docs/api/pytorch.rst
docs/api/pytorch.rst
+4
-1
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
+3
-1
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+65
-12
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
.../pytorch/distributed/test_fusible_ops_with_userbuffers.py
+7
-1
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
...common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
+1
-1
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
...engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
+1
-1
transformer_engine/pytorch/__init__.py
transformer_engine/pytorch/__init__.py
+1
-0
transformer_engine/pytorch/module/__init__.py
transformer_engine/pytorch/module/__init__.py
+1
-1
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+87
-33
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+13
-9
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+14
-10
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+13
-9
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+5
-5
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+1
-1
No files found.
docs/api/pytorch.rst
View file @
715c3bb8
...
@@ -49,7 +49,7 @@ pyTorch
...
@@ -49,7 +49,7 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.moe_permute
.. autoapifunction:: transformer_engine.pytorch.moe_permute
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
...
@@ -62,3 +62,6 @@ pyTorch
...
@@ -62,3 +62,6 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode
:members: FP8, NONE
\ No newline at end of file
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
View file @
715c3bb8
...
@@ -263,7 +263,9 @@ def _train(opts):
...
@@ -263,7 +263,9 @@ def _train(opts):
te
.
module
.
base
.
initialize_ub
(
te
.
module
.
base
.
initialize_ub
(
[
batched_size
,
hidden_size
],
[
batched_size
,
hidden_size
],
tp_size
,
tp_size
,
use_fp8
=
opts
.
fp8
,
quantization_modes
=
[
UserBufferQuantizationMode
.
FP8
if
opts
.
fp8
else
UserBufferQuantizationMode
.
NONE
],
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
bootstrap_backend
=
opts
.
bootstrap_backend
,
bootstrap_backend
=
opts
.
bootstrap_backend
,
)
)
...
...
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
715c3bb8
...
@@ -12,6 +12,8 @@ import argparse
...
@@ -12,6 +12,8 @@ import argparse
import
warnings
import
warnings
import
pprint
import
pprint
import
yaml
import
yaml
from
contextlib
import
nullcontext
from
functools
import
partial
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -35,9 +37,10 @@ class multi_module_model(torch.nn.Module):
...
@@ -35,9 +37,10 @@ class multi_module_model(torch.nn.Module):
self
.
num_layers
=
num_layers
self
.
num_layers
=
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
([
module
(
*
args
,
**
kwargs
)
for
_
in
range
(
num_layers
)])
self
.
layers
=
torch
.
nn
.
ModuleList
([
module
(
*
args
,
**
kwargs
)
for
_
in
range
(
num_layers
)])
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
layer_contexts
):
for
layer
in
self
.
layers
:
for
layer
,
context
in
zip
(
self
.
layers
,
layer_contexts
):
x
=
layer
(
x
)
with
context
():
x
=
layer
(
x
)
return
x
return
x
...
@@ -237,12 +240,46 @@ def _parse_args(argv=None, namespace=None):
...
@@ -237,12 +240,46 @@ def _parse_args(argv=None, namespace=None):
default
=
False
,
default
=
False
,
help
=
"Print out additional debug information."
,
help
=
"Print out additional debug information."
,
)
)
parser
.
add_argument
(
"--first-last-layers-bf16"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use bf16 for first and last N layers."
,
)
parser
.
add_argument
(
"--num-layers-at-start-in-bf16"
,
type
=
int
,
default
=
0
,
help
=
"Number of layers at the start to run in bf16."
,
)
parser
.
add_argument
(
"--num-layers-at-end-in-bf16"
,
type
=
int
,
default
=
0
,
help
=
"Number of layers at the end to run in bf16."
,
)
args
=
parser
.
parse_args
(
argv
,
namespace
)
args
=
parser
.
parse_args
(
argv
,
namespace
)
if
args
.
use_cuda_graphs
and
args
.
layer_type
in
[
te
.
MultiheadAttention
,
te
.
TransformerLayer
]:
if
args
.
use_cuda_graphs
and
args
.
layer_type
in
[
te
.
MultiheadAttention
,
te
.
TransformerLayer
]:
warnings
.
warn
(
f
"
{
args
.
layer_type
.
__name__
}
does not support CUDA Graphs!"
)
warnings
.
warn
(
f
"
{
args
.
layer_type
.
__name__
}
does not support CUDA Graphs!"
)
args
.
use_cuda_graphs
=
False
args
.
use_cuda_graphs
=
False
if
not
args
.
first_last_layers_bf16
and
(
args
.
num_layers_at_start_in_bf16
>
0
or
args
.
num_layers_at_end_in_bf16
>
0
):
warnings
.
warn
(
"num-layers-at-start-in-bf16 and num-layers-at-end-in-bf16 are only supported when"
" first-last-layers-bf16 is enabled!"
)
args
.
num_layers_at_start_in_bf16
=
0
args
.
num_layers_at_end_in_bf16
=
0
if
args
.
num_layers_at_start_in_bf16
+
args
.
num_layers_at_end_in_bf16
>
args
.
num_layers
:
raise
ValueError
(
"num-layers-at-start-in-bf16 + num-layers-at-end-in-bf16 must be less than or equal to"
" num-layers!"
)
return
args
return
args
...
@@ -381,10 +418,17 @@ def _train(opts):
...
@@ -381,10 +418,17 @@ def _train(opts):
"qkv_dgrad"
:
{
"method"
:
"ring_exchange"
},
"qkv_dgrad"
:
{
"method"
:
"ring_exchange"
},
"fc1_dgrad"
:
{
"method"
:
"ring_exchange"
},
"fc1_dgrad"
:
{
"method"
:
"ring_exchange"
},
}
}
quantization_modes
=
[
UserBufferQuantizationMode
.
FP8
if
opts
.
fp8
else
UserBufferQuantizationMode
.
NONE
]
if
opts
.
first_last_layers_bf16
and
opts
.
fp8
:
quantization_modes
.
append
(
UserBufferQuantizationMode
.
NONE
)
te
.
module
.
base
.
initialize_ub
(
te
.
module
.
base
.
initialize_ub
(
[
opts
.
seq_length
*
opts
.
batch_size
,
opts
.
num_heads
*
opts
.
head_dim
],
[
opts
.
seq_length
*
opts
.
batch_size
,
opts
.
num_heads
*
opts
.
head_dim
],
opts
.
tp
,
opts
.
tp
,
use_fp8
=
opts
.
fp8
,
quantization_modes
=
quantization_modes
,
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
bootstrap_backend
=
opts
.
bootstrap_backend
,
bootstrap_backend
=
opts
.
bootstrap_backend
,
ub_cfgs
=
ub_cfgs
if
opts
.
ub_cfg
is
None
else
opts
.
ub_cfg
,
ub_cfgs
=
ub_cfgs
if
opts
.
ub_cfg
is
None
else
opts
.
ub_cfg
,
...
@@ -423,6 +467,16 @@ def _train(opts):
...
@@ -423,6 +467,16 @@ def _train(opts):
elif
opts
.
quantization
==
"mxfp8"
:
elif
opts
.
quantization
==
"mxfp8"
:
fp8_recipe
=
MXFP8BlockScaling
()
fp8_recipe
=
MXFP8BlockScaling
()
layer_contexts
=
[
(
partial
(
te
.
fp8_autocast
,
enabled
=
opts
.
fp8
,
fp8_recipe
=
fp8_recipe
,
fp8_group
=
nccl_world
)
if
opts
.
num_layers_at_start_in_bf16
<=
i
and
i
<
(
opts
.
num_layers
-
opts
.
num_layers_at_end_in_bf16
)
else
nullcontext
)
for
i
in
range
(
opts
.
num_layers
)
]
# Prepare random input tensors
# Prepare random input tensors
test_x
=
torch
.
randn
(
input_shape
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
requires_grad
=
True
)
test_x
=
torch
.
randn
(
input_shape
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
requires_grad
=
True
)
test_x
.
retain_grad
()
test_x
.
retain_grad
()
...
@@ -435,14 +489,13 @@ def _train(opts):
...
@@ -435,14 +489,13 @@ def _train(opts):
# Execute fwd/bwd and collect tensors to test
# Execute fwd/bwd and collect tensors to test
def
run_fwd_bwd
(
model
,
x
):
def
run_fwd_bwd
(
model
,
x
):
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
torch
.
bfloat16
):
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
torch
.
bfloat16
):
with
te
.
fp8_autocast
(
enabled
=
opts
.
fp8
,
fp8_recipe
=
fp8_recipe
,
fp8_group
=
nccl_world
):
y
=
model
(
x
,
layer_contexts
)
y
=
model
(
x
)
if
isinstance
(
y
,
tuple
):
if
isinstance
(
y
,
tuple
):
out
,
*
_
=
y
out
,
*
_
=
y
else
:
else
:
out
=
y
out
=
y
loss
=
out
.
sum
()
loss
=
out
.
sum
()
loss
.
backward
()
loss
.
backward
()
return
out
return
out
torch_rng_state
=
torch
.
get_rng_state
()
torch_rng_state
=
torch
.
get_rng_state
()
...
...
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
View file @
715c3bb8
...
@@ -506,7 +506,13 @@ def main() -> None:
...
@@ -506,7 +506,13 @@ def main() -> None:
model_config
.
num_heads
*
model_config
.
head_dim
,
model_config
.
num_heads
*
model_config
.
head_dim
,
],
],
torch
.
distributed
.
get_world_size
(
group
),
torch
.
distributed
.
get_world_size
(
group
),
use_fp8
=
model_config
.
quantization
is
not
None
,
quantization_modes
=
[
(
UserBufferQuantizationMode
.
FP8
if
model_config
.
quantization
is
not
None
else
UserBufferQuantizationMode
.
NONE
)
],
dtype
=
model_config
.
dtype
,
dtype
=
model_config
.
dtype
,
bootstrap_backend
=
bootstrap_backend
,
bootstrap_backend
=
bootstrap_backend
,
ub_cfgs
=
userbuffer_configs
,
ub_cfgs
=
userbuffer_configs
,
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
View file @
715c3bb8
...
@@ -511,7 +511,7 @@ void destroy_communicator_mpi(communicator *comm) {
...
@@ -511,7 +511,7 @@ void destroy_communicator_mpi(communicator *comm) {
}
}
int
register_user_buffer_collective
(
void
**
gpubuff
,
size_t
bytes
,
communicator
*
comm
,
bool
alloc
)
{
int
register_user_buffer_collective
(
void
**
gpubuff
,
size_t
bytes
,
communicator
*
comm
,
bool
alloc
)
{
if
(
comm
->
free_region
>
NVTE_MAX_REGIONS
)
return
-
1
;
if
(
comm
->
free_region
>
=
NVTE_MAX_REGIONS
)
return
-
1
;
int
hndl
=
comm
->
free_region
;
int
hndl
=
comm
->
free_region
;
comm
->
peer_ptr
[
hndl
]
=
reinterpret_cast
<
void
**>
(
malloc
(
sizeof
(
void
*
)
*
(
comm
->
nvsize
)));
comm
->
peer_ptr
[
hndl
]
=
reinterpret_cast
<
void
**>
(
malloc
(
sizeof
(
void
*
)
*
(
comm
->
nvsize
)));
size_t
aligned_size
=
bytes
;
size_t
aligned_size
=
bytes
;
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
View file @
715c3bb8
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
using
ExtAllgatherOp
=
std
::
function
<
void
(
void
*
,
size_t
,
void
*
,
size_t
,
ExtComm
)
>
;
using
ExtAllgatherOp
=
std
::
function
<
void
(
void
*
,
size_t
,
void
*
,
size_t
,
ExtComm
)
>
;
using
ExtBarrierOp
=
std
::
function
<
void
(
ExtComm
)
>
;
using
ExtBarrierOp
=
std
::
function
<
void
(
ExtComm
)
>
;
#define NVTE_MAX_REGIONS
16
#define NVTE_MAX_REGIONS
32
#define NVTE_MAX_SMS 32
#define NVTE_MAX_SMS 32
#define NVTE_MAX_OPS 32
#define NVTE_MAX_OPS 32
#define NVTE_MAX_PEERS 8192
#define NVTE_MAX_PEERS 8192
...
...
transformer_engine/pytorch/__init__.py
View file @
715c3bb8
...
@@ -33,6 +33,7 @@ from transformer_engine.pytorch.module import GroupedLinear
...
@@ -33,6 +33,7 @@ from transformer_engine.pytorch.module import GroupedLinear
from
transformer_engine.pytorch.module
import
Fp8Padding
,
Fp8Unpadding
from
transformer_engine.pytorch.module
import
Fp8Padding
,
Fp8Unpadding
from
transformer_engine.pytorch.module
import
initialize_ub
from
transformer_engine.pytorch.module
import
initialize_ub
from
transformer_engine.pytorch.module
import
destroy_ub
from
transformer_engine.pytorch.module
import
destroy_ub
from
transformer_engine.pytorch.module
import
UserBufferQuantizationMode
from
transformer_engine.pytorch.attention
import
DotProductAttention
from
transformer_engine.pytorch.attention
import
DotProductAttention
from
transformer_engine.pytorch.attention
import
MultiheadAttention
from
transformer_engine.pytorch.attention
import
MultiheadAttention
from
transformer_engine.pytorch.attention
import
InferenceParams
from
transformer_engine.pytorch.attention
import
InferenceParams
...
...
transformer_engine/pytorch/module/__init__.py
View file @
715c3bb8
...
@@ -11,4 +11,4 @@ from .layernorm import LayerNorm
...
@@ -11,4 +11,4 @@ from .layernorm import LayerNorm
from
.rmsnorm
import
RMSNorm
from
.rmsnorm
import
RMSNorm
from
.fp8_padding
import
Fp8Padding
from
.fp8_padding
import
Fp8Padding
from
.fp8_unpadding
import
Fp8Unpadding
from
.fp8_unpadding
import
Fp8Unpadding
from
.base
import
initialize_ub
,
destroy_ub
from
.base
import
initialize_ub
,
destroy_ub
,
UserBufferQuantizationMode
transformer_engine/pytorch/module/base.py
View file @
715c3bb8
...
@@ -8,6 +8,7 @@ import math
...
@@ -8,6 +8,7 @@ import math
import
os
import
os
import
pickle
import
pickle
import
warnings
import
warnings
from
enum
import
Enum
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
@@ -49,7 +50,7 @@ from ...debug.pytorch.debug_state import TEDebugState
...
@@ -49,7 +50,7 @@ from ...debug.pytorch.debug_state import TEDebugState
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
,
DebugQuantizedTensor
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
,
DebugQuantizedTensor
from
...debug.pytorch.utils
import
next_iter_when_debug_should_be_run
,
any_feature_enabled
from
...debug.pytorch.utils
import
next_iter_when_debug_should_be_run
,
any_feature_enabled
__all__
=
[
"initialize_ub"
,
"destroy_ub"
]
__all__
=
[
"initialize_ub"
,
"destroy_ub"
,
"UserBufferQuantizationMode"
]
_2X_ACC_FPROP
=
False
_2X_ACC_FPROP
=
False
_2X_ACC_DGRAD
=
True
_2X_ACC_DGRAD
=
True
...
@@ -63,6 +64,15 @@ _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
...
@@ -63,6 +64,15 @@ _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange
=
[]
layers_atomic_ring_exchange
=
[]
class
UserBufferQuantizationMode
(
Enum
):
"""
UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer.
"""
NONE
=
"none"
FP8
=
"fp8"
def
get_cublas_workspace_size_bytes
()
->
None
:
def
get_cublas_workspace_size_bytes
()
->
None
:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
>=
9
:
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
>=
9
:
...
@@ -111,8 +121,9 @@ def initialize_ub(
...
@@ -111,8 +121,9 @@ def initialize_ub(
shape
:
list
,
shape
:
list
,
tp_size
:
int
,
tp_size
:
int
,
use_fp8
:
bool
=
False
,
use_fp8
:
bool
=
False
,
quantization_modes
:
List
[
UserBufferQuantizationMode
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
ub_cfgs
:
Optional
[
dict
]
=
None
,
ub_cfgs
:
Optional
[
Union
[
dict
,
List
[
dict
]]
]
=
None
,
bootstrap_backend
:
Union
[
str
,
torch
.
distributed
.
Backend
]
=
None
,
bootstrap_backend
:
Union
[
str
,
torch
.
distributed
.
Backend
]
=
None
,
)
->
None
:
)
->
None
:
r
"""
r
"""
...
@@ -128,7 +139,11 @@ def initialize_ub(
...
@@ -128,7 +139,11 @@ def initialize_ub(
tp_size : int
tp_size : int
number of GPUs in the tensor-parallel process group
number of GPUs in the tensor-parallel process group
use_fp8 : bool = False
use_fp8 : bool = False
allocate the communication buffer for FP8 GEMM inputs/outputs
allocate the communication buffer for FP8 GEMM inputs/outputs.
DEPRECATED: Please use `quantization_modes` instead.
quantization_modes : List[UserBufferQuantizationMode] = None
if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list.
falls back to the legacy `use_fp8` parameter if `None` is provided.
dtype : torch.dtype = torch.bfloat16
dtype : torch.dtype = torch.bfloat16
non-FP8 data type of the communication buffer when `use_fp8 = False`
non-FP8 data type of the communication buffer when `use_fp8 = False`
ub_cfgs: dict = None
ub_cfgs: dict = None
...
@@ -152,6 +167,7 @@ def initialize_ub(
...
@@ -152,6 +167,7 @@ def initialize_ub(
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_wgrad"]`.
"fc2_fprop", "fc2_wgrad"]`.
a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes`
bootstrap_backend : str = None
bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and
`torch.distributed` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are
barrier collectives during Userbuffers initialization. Not all backends are
...
@@ -168,6 +184,28 @@ def initialize_ub(
...
@@ -168,6 +184,28 @@ def initialize_ub(
+
"CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
+
"CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
)
)
if
not
quantization_modes
:
warnings
.
warn
(
"Initializing Userbuffers with use_fp8 is deprecated. Please use quantization_modes"
" instead."
,
DeprecationWarning
,
)
quantization_modes
=
[
UserBufferQuantizationMode
.
FP8
if
use_fp8
else
UserBufferQuantizationMode
.
NONE
]
else
:
assert
isinstance
(
quantization_modes
,
list
),
"quantization_modes must be a list"
assert
all
(
isinstance
(
mode
,
UserBufferQuantizationMode
)
for
mode
in
quantization_modes
),
"quantization_modes must be a list of UserBufferQuantizationMode"
if
isinstance
(
ub_cfgs
,
dict
)
or
ub_cfgs
is
None
:
ub_cfgs
=
[
ub_cfgs
]
*
len
(
quantization_modes
)
else
:
assert
len
(
ub_cfgs
)
==
len
(
quantization_modes
),
"Number of ub_cfgs settings must match number of quantization configurations"
global
_ub_communicators
global
_ub_communicators
assert
_ub_communicators
is
None
,
"UB communicators are already initialized."
assert
_ub_communicators
is
None
,
"UB communicators are already initialized."
_ub_communicators
=
{}
_ub_communicators
=
{}
...
@@ -309,6 +347,7 @@ def initialize_ub(
...
@@ -309,6 +347,7 @@ def initialize_ub(
def
add_ub
(
def
add_ub
(
name
:
str
,
name
:
str
,
quantization_mode
:
UserBufferQuantizationMode
,
method
:
str
,
method
:
str
,
is_reduce_scatter
:
bool
,
is_reduce_scatter
:
bool
,
num_sm
:
int
=
16
,
num_sm
:
int
=
16
,
...
@@ -327,7 +366,9 @@ def initialize_ub(
...
@@ -327,7 +366,9 @@ def initialize_ub(
warnings
.
warn
(
warnings
.
warn
(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
)
)
assert
use_fp8
,
"Atomic GEMM overlap supported only for FP8 GEMM."
assert
(
quantization_mode
==
UserBufferQuantizationMode
.
FP8
),
"Atomic GEMM overlap supported only for FP8 GEMM."
if
method
in
(
"bulk"
,
"external"
):
if
method
in
(
"bulk"
,
"external"
):
warnings
.
warn
(
warnings
.
warn
(
f
"At
{
name
}
, atoimic GEMM not is supported for a bulk overlap."
f
"At
{
name
}
, atoimic GEMM not is supported for a bulk overlap."
...
@@ -367,7 +408,11 @@ def initialize_ub(
...
@@ -367,7 +408,11 @@ def initialize_ub(
f
"
{
external_gemm_to_overlap
[
name
]
}
is not using `ring_exchange` overlap method"
f
"
{
external_gemm_to_overlap
[
name
]
}
is not using `ring_exchange` overlap method"
)
)
buffer_dtype
=
torch
.
uint8
if
(
use_fp8
and
fp8_buf
)
else
dtype
buffer_dtype
=
(
torch
.
uint8
if
(
quantization_mode
==
UserBufferQuantizationMode
.
FP8
and
fp8_buf
)
else
dtype
)
if
method
==
"ring_exchange"
:
if
method
==
"ring_exchange"
:
ub_obj
=
tex
.
CommOverlapP2P
(
ub_obj
=
tex
.
CommOverlapP2P
(
shape
,
# Communication buffer shape
shape
,
# Communication buffer shape
...
@@ -401,38 +446,47 @@ def initialize_ub(
...
@@ -401,38 +446,47 @@ def initialize_ub(
comm_priority
=
comm_priority
,
comm_priority
=
comm_priority
,
rs_overlap_first_gemm
=
pipeline_rs_overlap_first_gemm
,
rs_overlap_first_gemm
=
pipeline_rs_overlap_first_gemm
,
)
)
_ub_communicators
[
name
]
=
ub_obj
_ub_communicators
[(
name
,
quantization_mode
)]
=
ub_obj
if
ub_cfgs
is
not
None
:
for
quantization_mode
,
user_ub_cfg
in
zip
(
quantization_modes
,
ub_cfgs
):
for
name
in
dgrad_reduce_scatter_overlap
:
if
user_ub_cfg
is
not
None
:
if
name
in
ub_cfgs
and
"method"
in
ub_cfgs
[
name
]
and
ub_cfgs
[
name
][
"method"
]
!=
"bulk"
:
for
name
in
dgrad_reduce_scatter_overlap
:
wgrad_name
=
name
.
replace
(
"dgrad"
,
"wgrad"
)
if
(
assert
wgrad_name
not
in
ub_cfgs
name
in
user_ub_cfg
layers_reduce_scatter_overlap
.
remove
(
wgrad_name
)
and
"method"
in
user_ub_cfg
[
name
]
layers_all_gather_overlap
.
remove
(
name
)
and
user_ub_cfg
[
name
][
"method"
]
!=
"bulk"
layers_reduce_scatter_overlap
.
append
(
name
)
):
methods
[
"bulk"
].
remove
(
name
)
wgrad_name
=
name
.
replace
(
"dgrad"
,
"wgrad"
)
new_method
=
ub_cfgs
[
name
][
"method"
]
assert
wgrad_name
not
in
user_ub_cfg
methods
[
new_method
].
append
(
name
)
layers_reduce_scatter_overlap
.
remove
(
wgrad_name
)
layers_all_gather_overlap
.
remove
(
name
)
for
name
in
(
layers_reduce_scatter_overlap
.
append
(
name
)
methods
[
"ring_exchange"
]
+
methods
[
"pipeline"
]
+
methods
[
"bulk"
]
+
methods
[
"external"
]
methods
[
"bulk"
].
remove
(
name
)
):
new_method
=
user_ub_cfg
[
name
][
"method"
]
ub_cfg
=
get_default_config
(
name
)
methods
[
new_method
].
append
(
name
)
if
ub_cfgs
is
not
None
and
name
in
ub_cfgs
:
fp8_buf
=
(
name
in
layers_all_gather_overlap
)
or
(
for
name
in
(
ub_cfgs
[
name
].
get
(
"fp8_buf"
,
False
)
and
name
in
methods
[
"pipeline"
]
methods
[
"ring_exchange"
]
+
methods
[
"pipeline"
]
+
methods
[
"bulk"
]
+
methods
[
"external"
]
)
):
ub_cfg
.
update
(
ub_cfgs
[
name
])
ub_cfg
=
get_default_config
(
name
)
ub_cfg
[
"fp8_buf"
]
=
fp8_buf
if
user_ub_cfg
is
not
None
and
name
in
user_ub_cfg
:
add_ub
(
name
,
**
ub_cfg
)
fp8_buf
=
(
name
in
layers_all_gather_overlap
)
or
(
user_ub_cfg
[
name
].
get
(
"fp8_buf"
,
False
)
and
name
in
methods
[
"pipeline"
]
)
ub_cfg
.
update
(
ub_cfgs
[
name
])
ub_cfg
[
"fp8_buf"
]
=
fp8_buf
add_ub
(
name
,
quantization_mode
,
**
ub_cfg
)
def
get_ub
(
name
:
str
):
def
get_ub
(
name
:
str
,
use_fp8
:
bool
):
"""Get userbuffer communicator corresponding to give key."""
"""Get userbuffer communicator corresponding to give key."""
# For now use `use_fp8` boolean input as it matches the current design in the modules
# So favour simplicity until the correct design becomes clear.
# This is mainly an internal API so we don't need to worry about future changes
key
=
(
name
,
UserBufferQuantizationMode
.
FP8
if
use_fp8
else
UserBufferQuantizationMode
.
NONE
)
assert
_ub_communicators
is
not
None
,
"UB manager is not initialized."
assert
_ub_communicators
is
not
None
,
"UB manager is not initialized."
assert
name
in
_ub_communicators
,
f
"UB for
{
name
}
is not registered."
assert
key
in
_ub_communicators
,
f
"UB for
{
name
}
with use_fp8=
{
use_fp8
}
is not registered."
return
_ub_communicators
[
name
]
return
_ub_communicators
[
key
]
def
destroy_ub
():
def
destroy_ub
():
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
715c3bb8
...
@@ -173,10 +173,10 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -173,10 +173,10 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_ag_fprop
and
is_grad_enabled
and
not
return_layernorm_output
ub_overlap_ag_fprop
and
is_grad_enabled
and
not
return_layernorm_output
)
)
if
ub_overlap_rs_fprop
:
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
RS
ub_type
=
tex
.
CommOverlapType
.
RS
elif
ub_overlap_ag_fprop
:
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
AG
ub_type
=
tex
.
CommOverlapType
.
AG
# Configure quantizer for norm output
# Configure quantizer for norm output
...
@@ -575,23 +575,23 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -575,23 +575,23 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
if
ctx
.
ub_overlap_ag
:
if
ctx
.
ub_overlap_ag
:
# Overlap grad_output all-gather with dgrad compute
# Overlap grad_output all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
elif
ctx
.
ub_overlap_rs_dgrad
:
elif
ctx
.
ub_overlap_rs_dgrad
:
# Overlap dgrad reduce-scatter with dgrad compute
# Overlap dgrad reduce-scatter with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
else
:
else
:
if
ctx
.
ub_bulk_dgrad
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap inputmat all-gather with dgrad compute
# Overlap inputmat all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_bulk_wgrad
:
if
ctx
.
ub_bulk_wgrad
:
# Overlap dgrad reduce-scatter with wgrad compute
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
# --------------------------------------------------
# --------------------------------------------------
...
@@ -769,7 +769,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -769,7 +769,7 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_send_stream
,
dgrad_recv_stream
=
ub_obj_dgrad
.
get_communication_stream
()
dgrad_send_stream
,
dgrad_recv_stream
=
ub_obj_dgrad
.
get_communication_stream
()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
@@ -1492,10 +1492,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1492,10 +1492,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch
=
False
is_first_microbatch
=
False
if
self
.
ub_overlap_rs_fprop
:
if
self
.
ub_overlap_rs_fprop
:
if
get_ub
(
self
.
ub_name
+
"_fprop"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_output
=
True
fp8_output
=
True
if
self
.
ub_overlap_rs_dgrad
:
if
self
.
ub_overlap_rs_dgrad
:
if
get_ub
(
self
.
ub_name
+
"_dgrad"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_dgrad"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_grad
=
True
fp8_grad
=
True
with
torch
.
cuda
.
device
(
with
torch
.
cuda
.
device
(
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
715c3bb8
...
@@ -307,7 +307,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -307,7 +307,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag
:
if
ub_overlap_ag
:
# Copy into Userbuffers buffer
# Copy into Userbuffers buffer
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
)
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
,
fp8
)
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_lnout
,
ub_obj_lnout
,
ln_out
,
ln_out
,
...
@@ -458,7 +458,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -458,7 +458,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc2out
=
None
ub_obj_fc2out
=
None
reduce_scatter_out
=
None
reduce_scatter_out
=
None
if
ub_overlap_rs
:
if
ub_overlap_rs
:
ub_obj_fc2out
=
get_ub
(
"fc2_fprop"
)
ub_obj_fc2out
=
get_ub
(
"fc2_fprop"
,
fp8
)
dim_size
=
list
(
act_out
.
size
())
dim_size
=
list
(
act_out
.
size
())
dim_size
[
0
]
//=
tp_world_size
dim_size
[
0
]
//=
tp_world_size
dim_size
[
-
1
]
=
fc2_weight
.
size
(
0
)
dim_size
[
-
1
]
=
fc2_weight
.
size
(
0
)
...
@@ -740,7 +740,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -740,7 +740,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Cast to expected dtype and perform tensor-parallel communication
# Note: Cast to expected dtype and perform tensor-parallel communication
ub_obj_fc2_dgrad
=
None
ub_obj_fc2_dgrad
=
None
if
ctx
.
ub_overlap_ag
:
if
ctx
.
ub_overlap_ag
:
ub_obj_fc2_dgrad
=
get_ub
(
"fc2_dgrad"
)
ub_obj_fc2_dgrad
=
get_ub
(
"fc2_dgrad"
,
ctx
.
fp8
)
ctx
.
ub_obj_gradout
=
ub_obj_fc2_dgrad
ctx
.
ub_obj_gradout
=
ub_obj_fc2_dgrad
(
(
grad_output
,
grad_output
,
...
@@ -764,7 +764,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -764,7 +764,7 @@ class _LayerNormMLP(torch.autograd.Function):
# wgrad GEMM requires input with column-wise usage
# wgrad GEMM requires input with column-wise usage
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
if
ctx
.
ub_bulk_dgrad
:
if
ctx
.
ub_bulk_dgrad
:
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
,
ctx
.
fp8
)
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_fc1_dgrad
,
ub_obj_fc1_dgrad
,
ln_out
,
ln_out
,
...
@@ -869,7 +869,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -869,7 +869,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc2_dgrad
.
get_communication_stream
()
ub_obj_fc2_dgrad
.
get_communication_stream
()
)
)
ub_obj_fc2_wgrad
=
get_ub
(
"fc2_wgrad"
)
ub_obj_fc2_wgrad
=
get_ub
(
"fc2_wgrad"
,
ctx
.
fp8
)
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
@@ -1036,16 +1036,16 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1036,16 +1036,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_shape
=
[
reduce
(
multiply_op
,
inputmat
.
shape
[:
-
1
]),
inputmat
.
shape
[
-
1
]]
fc1_dgrad_shape
=
[
reduce
(
multiply_op
,
inputmat
.
shape
[:
-
1
]),
inputmat
.
shape
[
-
1
]]
if
ctx
.
ub_overlap_rs_dgrad
:
if
ctx
.
ub_overlap_rs_dgrad
:
# Overlap DGRAD+RS
# Overlap DGRAD+RS
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
,
ctx
.
fp8
)
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
RS
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
RS
else
:
else
:
if
ctx
.
ub_bulk_dgrad
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap ln_out all-gather with DGRAD compute
# Overlap ln_out all-gather with DGRAD compute
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
,
ctx
.
fp8
)
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
AG
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_bulk_wgrad
:
if
ctx
.
ub_bulk_wgrad
:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad
=
get_ub
(
"fc1_wgrad"
)
ub_obj_fc1_wgrad
=
get_ub
(
"fc1_wgrad"
,
ctx
.
fp8
)
ub_type_fc1_wgrad
=
tex
.
CommOverlapType
.
RS
ub_type_fc1_wgrad
=
tex
.
CommOverlapType
.
RS
# --------------------------------------------------
# --------------------------------------------------
...
@@ -1539,7 +1539,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1539,7 +1539,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
gemm_gelu_fusion
=
(
self
.
gemm_gelu_fusion
=
(
bool
(
int
(
os
.
getenv
(
"NVTE_GEMM_GELU_FUSION"
,
"0"
)))
bool
(
int
(
os
.
getenv
(
"NVTE_GEMM_GELU_FUSION"
,
"0"
)))
and
self
.
activation
==
"gelu"
and
self
.
activation
==
"gelu"
and
((
_ub_communicators
is
None
)
or
(
not
get_ub
(
"fc1_fprop"
).
is_atomic_gemm
()))
and
all
(
(
"fc1_fprop"
,
use_fp8
)
not
in
_ub_communicators
or
not
get_ub
(
"fc1_fprop"
,
use_fp8
).
is_atomic_gemm
()
for
use_fp8
in
[
False
,
True
]
)
)
)
self
.
name
=
name
self
.
name
=
name
...
@@ -1757,7 +1761,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1757,7 +1761,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fp8_output
=
False
fp8_output
=
False
if
self
.
ub_overlap_rs
:
if
self
.
ub_overlap_rs
:
if
get_ub
(
"fc2_fprop"
).
is_fp8_ubuf
():
if
get_ub
(
"fc2_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_output
=
True
fp8_output
=
True
with
torch
.
cuda
.
device
(
with
torch
.
cuda
.
device
(
...
...
transformer_engine/pytorch/module/linear.py
View file @
715c3bb8
...
@@ -145,10 +145,10 @@ class _Linear(torch.autograd.Function):
...
@@ -145,10 +145,10 @@ class _Linear(torch.autograd.Function):
ub_obj
=
None
ub_obj
=
None
ub_type
=
None
ub_type
=
None
if
ub_overlap_rs_fprop
:
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
RS
ub_type
=
tex
.
CommOverlapType
.
RS
elif
ub_overlap_ag_fprop
:
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
AG
ub_type
=
tex
.
CommOverlapType
.
AG
# ------------------------------------------------------
# ------------------------------------------------------
...
@@ -520,23 +520,23 @@ class _Linear(torch.autograd.Function):
...
@@ -520,23 +520,23 @@ class _Linear(torch.autograd.Function):
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
if
ctx
.
ub_overlap_ag
:
if
ctx
.
ub_overlap_ag
:
# Overlap grad_output all-gather with dgrad compute
# Overlap grad_output all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
elif
ctx
.
ub_overlap_rs_dgrad
:
elif
ctx
.
ub_overlap_rs_dgrad
:
# Overlap dgrad reduce-scatter with dgrad compute
# Overlap dgrad reduce-scatter with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
else
:
else
:
if
ctx
.
ub_bulk_dgrad
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap inputmat all-gather with dgrad compute
# Overlap inputmat all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_bulk_wgrad
:
if
ctx
.
ub_bulk_wgrad
:
# Overlap dgrad reduce-scatter with wgrad compute
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
# --------------------------------------------------
# --------------------------------------------------
...
@@ -769,7 +769,7 @@ class _Linear(torch.autograd.Function):
...
@@ -769,7 +769,7 @@ class _Linear(torch.autograd.Function):
dgrad_send_stream
,
dgrad_recv_stream
=
ub_obj_dgrad
.
get_communication_stream
()
dgrad_send_stream
,
dgrad_recv_stream
=
ub_obj_dgrad
.
get_communication_stream
()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
@@ -1377,10 +1377,14 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1377,10 +1377,14 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch
=
False
is_first_microbatch
=
False
if
self
.
ub_overlap_rs_fprop
:
if
self
.
ub_overlap_rs_fprop
:
if
get_ub
(
self
.
ub_name
+
"_fprop"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_output
=
True
fp8_output
=
True
if
self
.
ub_overlap_rs_dgrad
:
if
self
.
ub_overlap_rs_dgrad
:
if
get_ub
(
self
.
ub_name
+
"_dgrad"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_dgrad"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_grad
=
True
fp8_grad
=
True
with
torch
.
cuda
.
device
(
with
torch
.
cuda
.
device
(
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
715c3bb8
...
@@ -241,16 +241,16 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -241,16 +241,16 @@ class UserbuffersBackwardLinear(FusedOperation):
with_dgrad_all_gather_x
=
False
with_dgrad_all_gather_x
=
False
with_wgrad_reduce_scatter_dx
=
False
with_wgrad_reduce_scatter_dx
=
False
if
tensor_parallel_mode
==
"row"
:
if
tensor_parallel_mode
==
"row"
:
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
,
with_quantized_compute
)
ub_type_dgrad
=
CommOverlapType
.
AG
ub_type_dgrad
=
CommOverlapType
.
AG
with_dgrad_all_gather_dy
=
True
with_dgrad_all_gather_dy
=
True
elif
tensor_parallel_mode
==
"column"
:
elif
tensor_parallel_mode
==
"column"
:
if
input_requires_grad
and
weight_requires_grad
:
if
input_requires_grad
and
weight_requires_grad
:
with_bulk_overlap
=
True
with_bulk_overlap
=
True
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
,
with_quantized_compute
)
ub_type_dgrad
=
CommOverlapType
.
AG
ub_type_dgrad
=
CommOverlapType
.
AG
with_dgrad_all_gather_x
=
True
with_dgrad_all_gather_x
=
True
ub_comm_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
)
ub_comm_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
,
with_quantized_compute
)
ub_type_wgrad
=
CommOverlapType
.
RS
ub_type_wgrad
=
CommOverlapType
.
RS
with_wgrad_reduce_scatter_dx
=
True
with_wgrad_reduce_scatter_dx
=
True
if
ub_comm_wgrad
.
is_fp8_ubuf
():
if
ub_comm_wgrad
.
is_fp8_ubuf
():
...
@@ -258,7 +258,7 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -258,7 +258,7 @@ class UserbuffersBackwardLinear(FusedOperation):
"Userbuffers reduce-scatter is not supported with FP8 buffers"
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
)
else
:
else
:
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
,
with_quantized_compute
)
ub_type_dgrad
=
CommOverlapType
.
RS
ub_type_dgrad
=
CommOverlapType
.
RS
with_dgrad_reduce_scatter_dx
=
True
with_dgrad_reduce_scatter_dx
=
True
if
ub_comm_dgrad
.
is_fp8_ubuf
():
if
ub_comm_dgrad
.
is_fp8_ubuf
():
...
@@ -409,7 +409,7 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -409,7 +409,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get the communication stream from the dgrad GEMM to use for the AG
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream
,
dgrad_recv_stream
=
ub_comm_dgrad
.
get_communication_stream
()
dgrad_send_stream
,
dgrad_recv_stream
=
ub_comm_dgrad
.
get_communication_stream
()
ub_obj_overlap_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
)
ub_obj_overlap_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
,
with_quantized_compute
)
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
715c3bb8
...
@@ -189,7 +189,7 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -189,7 +189,7 @@ class UserbuffersForwardLinear(FusedOperation):
output_quantizer
=
None
output_quantizer
=
None
# Get Userbuffers communicator
# Get Userbuffers communicator
ub_comm
=
get_ub
(
ub_comm_name
+
"_fprop"
)
ub_comm
=
get_ub
(
ub_comm_name
+
"_fprop"
,
with_quantized_compute
)
with_ub_all_gather
=
tensor_parallel_mode
==
"column"
with_ub_all_gather
=
tensor_parallel_mode
==
"column"
with_ub_reduce_scatter
=
tensor_parallel_mode
==
"row"
with_ub_reduce_scatter
=
tensor_parallel_mode
==
"row"
ub_type
=
CommOverlapType
.
AG
if
with_ub_all_gather
else
CommOverlapType
.
RS
ub_type
=
CommOverlapType
.
AG
if
with_ub_all_gather
else
CommOverlapType
.
RS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment