Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
715c3bb8
Unverified
Commit
715c3bb8
authored
Aug 29, 2025
by
Daniel Stokes
Committed by
GitHub
Aug 29, 2025
Browse files
feat: Add support for multiple quantization modes in the UB communicators (#2043)
parent
f98e3053
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
216 additions
and
85 deletions
+216
-85
docs/api/pytorch.rst
docs/api/pytorch.rst
+4
-1
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
+3
-1
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+65
-12
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
.../pytorch/distributed/test_fusible_ops_with_userbuffers.py
+7
-1
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
...common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
+1
-1
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
...engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
+1
-1
transformer_engine/pytorch/__init__.py
transformer_engine/pytorch/__init__.py
+1
-0
transformer_engine/pytorch/module/__init__.py
transformer_engine/pytorch/module/__init__.py
+1
-1
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+87
-33
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+13
-9
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+14
-10
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+13
-9
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+5
-5
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+1
-1
No files found.
docs/api/pytorch.rst
View file @
715c3bb8
...
...
@@ -62,3 +62,6 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode
:members: FP8, NONE
\ No newline at end of file
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
View file @
715c3bb8
...
...
@@ -263,7 +263,9 @@ def _train(opts):
te
.
module
.
base
.
initialize_ub
(
[
batched_size
,
hidden_size
],
tp_size
,
use_fp8
=
opts
.
fp8
,
quantization_modes
=
[
UserBufferQuantizationMode
.
FP8
if
opts
.
fp8
else
UserBufferQuantizationMode
.
NONE
],
dtype
=
torch
.
bfloat16
,
bootstrap_backend
=
opts
.
bootstrap_backend
,
)
...
...
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
715c3bb8
...
...
@@ -12,6 +12,8 @@ import argparse
import
warnings
import
pprint
import
yaml
from
contextlib
import
nullcontext
from
functools
import
partial
import
torch
import
torch.distributed
as
dist
...
...
@@ -35,8 +37,9 @@ class multi_module_model(torch.nn.Module):
self
.
num_layers
=
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
([
module
(
*
args
,
**
kwargs
)
for
_
in
range
(
num_layers
)])
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
def
forward
(
self
,
x
,
layer_contexts
):
for
layer
,
context
in
zip
(
self
.
layers
,
layer_contexts
):
with
context
():
x
=
layer
(
x
)
return
x
...
...
@@ -237,12 +240,46 @@ def _parse_args(argv=None, namespace=None):
default
=
False
,
help
=
"Print out additional debug information."
,
)
parser
.
add_argument
(
"--first-last-layers-bf16"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use bf16 for first and last N layers."
,
)
parser
.
add_argument
(
"--num-layers-at-start-in-bf16"
,
type
=
int
,
default
=
0
,
help
=
"Number of layers at the start to run in bf16."
,
)
parser
.
add_argument
(
"--num-layers-at-end-in-bf16"
,
type
=
int
,
default
=
0
,
help
=
"Number of layers at the end to run in bf16."
,
)
args
=
parser
.
parse_args
(
argv
,
namespace
)
if
args
.
use_cuda_graphs
and
args
.
layer_type
in
[
te
.
MultiheadAttention
,
te
.
TransformerLayer
]:
warnings
.
warn
(
f
"
{
args
.
layer_type
.
__name__
}
does not support CUDA Graphs!"
)
args
.
use_cuda_graphs
=
False
if
not
args
.
first_last_layers_bf16
and
(
args
.
num_layers_at_start_in_bf16
>
0
or
args
.
num_layers_at_end_in_bf16
>
0
):
warnings
.
warn
(
"num-layers-at-start-in-bf16 and num-layers-at-end-in-bf16 are only supported when"
" first-last-layers-bf16 is enabled!"
)
args
.
num_layers_at_start_in_bf16
=
0
args
.
num_layers_at_end_in_bf16
=
0
if
args
.
num_layers_at_start_in_bf16
+
args
.
num_layers_at_end_in_bf16
>
args
.
num_layers
:
raise
ValueError
(
"num-layers-at-start-in-bf16 + num-layers-at-end-in-bf16 must be less than or equal to"
" num-layers!"
)
return
args
...
...
@@ -381,10 +418,17 @@ def _train(opts):
"qkv_dgrad"
:
{
"method"
:
"ring_exchange"
},
"fc1_dgrad"
:
{
"method"
:
"ring_exchange"
},
}
quantization_modes
=
[
UserBufferQuantizationMode
.
FP8
if
opts
.
fp8
else
UserBufferQuantizationMode
.
NONE
]
if
opts
.
first_last_layers_bf16
and
opts
.
fp8
:
quantization_modes
.
append
(
UserBufferQuantizationMode
.
NONE
)
te
.
module
.
base
.
initialize_ub
(
[
opts
.
seq_length
*
opts
.
batch_size
,
opts
.
num_heads
*
opts
.
head_dim
],
opts
.
tp
,
use_fp8
=
opts
.
fp8
,
quantization_modes
=
quantization_modes
,
dtype
=
torch
.
bfloat16
,
bootstrap_backend
=
opts
.
bootstrap_backend
,
ub_cfgs
=
ub_cfgs
if
opts
.
ub_cfg
is
None
else
opts
.
ub_cfg
,
...
...
@@ -423,6 +467,16 @@ def _train(opts):
elif
opts
.
quantization
==
"mxfp8"
:
fp8_recipe
=
MXFP8BlockScaling
()
layer_contexts
=
[
(
partial
(
te
.
fp8_autocast
,
enabled
=
opts
.
fp8
,
fp8_recipe
=
fp8_recipe
,
fp8_group
=
nccl_world
)
if
opts
.
num_layers_at_start_in_bf16
<=
i
and
i
<
(
opts
.
num_layers
-
opts
.
num_layers_at_end_in_bf16
)
else
nullcontext
)
for
i
in
range
(
opts
.
num_layers
)
]
# Prepare random input tensors
test_x
=
torch
.
randn
(
input_shape
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
requires_grad
=
True
)
test_x
.
retain_grad
()
...
...
@@ -435,8 +489,7 @@ def _train(opts):
# Execute fwd/bwd and collect tensors to test
def
run_fwd_bwd
(
model
,
x
):
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
torch
.
bfloat16
):
with
te
.
fp8_autocast
(
enabled
=
opts
.
fp8
,
fp8_recipe
=
fp8_recipe
,
fp8_group
=
nccl_world
):
y
=
model
(
x
)
y
=
model
(
x
,
layer_contexts
)
if
isinstance
(
y
,
tuple
):
out
,
*
_
=
y
else
:
...
...
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
View file @
715c3bb8
...
...
@@ -506,7 +506,13 @@ def main() -> None:
model_config
.
num_heads
*
model_config
.
head_dim
,
],
torch
.
distributed
.
get_world_size
(
group
),
use_fp8
=
model_config
.
quantization
is
not
None
,
quantization_modes
=
[
(
UserBufferQuantizationMode
.
FP8
if
model_config
.
quantization
is
not
None
else
UserBufferQuantizationMode
.
NONE
)
],
dtype
=
model_config
.
dtype
,
bootstrap_backend
=
bootstrap_backend
,
ub_cfgs
=
userbuffer_configs
,
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
View file @
715c3bb8
...
...
@@ -511,7 +511,7 @@ void destroy_communicator_mpi(communicator *comm) {
}
int
register_user_buffer_collective
(
void
**
gpubuff
,
size_t
bytes
,
communicator
*
comm
,
bool
alloc
)
{
if
(
comm
->
free_region
>
NVTE_MAX_REGIONS
)
return
-
1
;
if
(
comm
->
free_region
>
=
NVTE_MAX_REGIONS
)
return
-
1
;
int
hndl
=
comm
->
free_region
;
comm
->
peer_ptr
[
hndl
]
=
reinterpret_cast
<
void
**>
(
malloc
(
sizeof
(
void
*
)
*
(
comm
->
nvsize
)));
size_t
aligned_size
=
bytes
;
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h
View file @
715c3bb8
...
...
@@ -27,7 +27,7 @@
using
ExtAllgatherOp
=
std
::
function
<
void
(
void
*
,
size_t
,
void
*
,
size_t
,
ExtComm
)
>
;
using
ExtBarrierOp
=
std
::
function
<
void
(
ExtComm
)
>
;
#define NVTE_MAX_REGIONS
16
#define NVTE_MAX_REGIONS
32
#define NVTE_MAX_SMS 32
#define NVTE_MAX_OPS 32
#define NVTE_MAX_PEERS 8192
...
...
transformer_engine/pytorch/__init__.py
View file @
715c3bb8
...
...
@@ -33,6 +33,7 @@ from transformer_engine.pytorch.module import GroupedLinear
from
transformer_engine.pytorch.module
import
Fp8Padding
,
Fp8Unpadding
from
transformer_engine.pytorch.module
import
initialize_ub
from
transformer_engine.pytorch.module
import
destroy_ub
from
transformer_engine.pytorch.module
import
UserBufferQuantizationMode
from
transformer_engine.pytorch.attention
import
DotProductAttention
from
transformer_engine.pytorch.attention
import
MultiheadAttention
from
transformer_engine.pytorch.attention
import
InferenceParams
...
...
transformer_engine/pytorch/module/__init__.py
View file @
715c3bb8
...
...
@@ -11,4 +11,4 @@ from .layernorm import LayerNorm
from
.rmsnorm
import
RMSNorm
from
.fp8_padding
import
Fp8Padding
from
.fp8_unpadding
import
Fp8Unpadding
from
.base
import
initialize_ub
,
destroy_ub
from
.base
import
initialize_ub
,
destroy_ub
,
UserBufferQuantizationMode
transformer_engine/pytorch/module/base.py
View file @
715c3bb8
...
...
@@ -8,6 +8,7 @@ import math
import
os
import
pickle
import
warnings
from
enum
import
Enum
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
contextlib
import
contextmanager
...
...
@@ -49,7 +50,7 @@ from ...debug.pytorch.debug_state import TEDebugState
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
,
DebugQuantizedTensor
from
...debug.pytorch.utils
import
next_iter_when_debug_should_be_run
,
any_feature_enabled
__all__
=
[
"initialize_ub"
,
"destroy_ub"
]
__all__
=
[
"initialize_ub"
,
"destroy_ub"
,
"UserBufferQuantizationMode"
]
_2X_ACC_FPROP
=
False
_2X_ACC_DGRAD
=
True
...
...
@@ -63,6 +64,15 @@ _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange
=
[]
class
UserBufferQuantizationMode
(
Enum
):
"""
UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer.
"""
NONE
=
"none"
FP8
=
"fp8"
def
get_cublas_workspace_size_bytes
()
->
None
:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
>=
9
:
...
...
@@ -111,8 +121,9 @@ def initialize_ub(
shape
:
list
,
tp_size
:
int
,
use_fp8
:
bool
=
False
,
quantization_modes
:
List
[
UserBufferQuantizationMode
]
=
None
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
ub_cfgs
:
Optional
[
dict
]
=
None
,
ub_cfgs
:
Optional
[
Union
[
dict
,
List
[
dict
]]
]
=
None
,
bootstrap_backend
:
Union
[
str
,
torch
.
distributed
.
Backend
]
=
None
,
)
->
None
:
r
"""
...
...
@@ -128,7 +139,11 @@ def initialize_ub(
tp_size : int
number of GPUs in the tensor-parallel process group
use_fp8 : bool = False
allocate the communication buffer for FP8 GEMM inputs/outputs
allocate the communication buffer for FP8 GEMM inputs/outputs.
DEPRECATED: Please use `quantization_modes` instead.
quantization_modes : List[UserBufferQuantizationMode] = None
if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list.
falls back to the legacy `use_fp8` parameter if `None` is provided.
dtype : torch.dtype = torch.bfloat16
non-FP8 data type of the communication buffer when `use_fp8 = False`
ub_cfgs: dict = None
...
...
@@ -152,6 +167,7 @@ def initialize_ub(
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_wgrad"]`.
a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes`
bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are
...
...
@@ -168,6 +184,28 @@ def initialize_ub(
+
"CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
)
if
not
quantization_modes
:
warnings
.
warn
(
"Initializing Userbuffers with use_fp8 is deprecated. Please use quantization_modes"
" instead."
,
DeprecationWarning
,
)
quantization_modes
=
[
UserBufferQuantizationMode
.
FP8
if
use_fp8
else
UserBufferQuantizationMode
.
NONE
]
else
:
assert
isinstance
(
quantization_modes
,
list
),
"quantization_modes must be a list"
assert
all
(
isinstance
(
mode
,
UserBufferQuantizationMode
)
for
mode
in
quantization_modes
),
"quantization_modes must be a list of UserBufferQuantizationMode"
if
isinstance
(
ub_cfgs
,
dict
)
or
ub_cfgs
is
None
:
ub_cfgs
=
[
ub_cfgs
]
*
len
(
quantization_modes
)
else
:
assert
len
(
ub_cfgs
)
==
len
(
quantization_modes
),
"Number of ub_cfgs settings must match number of quantization configurations"
global
_ub_communicators
assert
_ub_communicators
is
None
,
"UB communicators are already initialized."
_ub_communicators
=
{}
...
...
@@ -309,6 +347,7 @@ def initialize_ub(
def
add_ub
(
name
:
str
,
quantization_mode
:
UserBufferQuantizationMode
,
method
:
str
,
is_reduce_scatter
:
bool
,
num_sm
:
int
=
16
,
...
...
@@ -327,7 +366,9 @@ def initialize_ub(
warnings
.
warn
(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
)
assert
use_fp8
,
"Atomic GEMM overlap supported only for FP8 GEMM."
assert
(
quantization_mode
==
UserBufferQuantizationMode
.
FP8
),
"Atomic GEMM overlap supported only for FP8 GEMM."
if
method
in
(
"bulk"
,
"external"
):
warnings
.
warn
(
f
"At
{
name
}
, atoimic GEMM not is supported for a bulk overlap."
...
...
@@ -367,7 +408,11 @@ def initialize_ub(
f
"
{
external_gemm_to_overlap
[
name
]
}
is not using `ring_exchange` overlap method"
)
buffer_dtype
=
torch
.
uint8
if
(
use_fp8
and
fp8_buf
)
else
dtype
buffer_dtype
=
(
torch
.
uint8
if
(
quantization_mode
==
UserBufferQuantizationMode
.
FP8
and
fp8_buf
)
else
dtype
)
if
method
==
"ring_exchange"
:
ub_obj
=
tex
.
CommOverlapP2P
(
shape
,
# Communication buffer shape
...
...
@@ -401,38 +446,47 @@ def initialize_ub(
comm_priority
=
comm_priority
,
rs_overlap_first_gemm
=
pipeline_rs_overlap_first_gemm
,
)
_ub_communicators
[
name
]
=
ub_obj
_ub_communicators
[
(
name
,
quantization_mode
)
]
=
ub_obj
if
ub_cfgs
is
not
None
:
for
quantization_mode
,
user_ub_cfg
in
zip
(
quantization_modes
,
ub_cfgs
):
if
user_ub_cfg
is
not
None
:
for
name
in
dgrad_reduce_scatter_overlap
:
if
name
in
ub_cfgs
and
"method"
in
ub_cfgs
[
name
]
and
ub_cfgs
[
name
][
"method"
]
!=
"bulk"
:
if
(
name
in
user_ub_cfg
and
"method"
in
user_ub_cfg
[
name
]
and
user_ub_cfg
[
name
][
"method"
]
!=
"bulk"
):
wgrad_name
=
name
.
replace
(
"dgrad"
,
"wgrad"
)
assert
wgrad_name
not
in
ub_cfg
s
assert
wgrad_name
not
in
user_
ub_cfg
layers_reduce_scatter_overlap
.
remove
(
wgrad_name
)
layers_all_gather_overlap
.
remove
(
name
)
layers_reduce_scatter_overlap
.
append
(
name
)
methods
[
"bulk"
].
remove
(
name
)
new_method
=
ub_cfg
s
[
name
][
"method"
]
new_method
=
user_
ub_cfg
[
name
][
"method"
]
methods
[
new_method
].
append
(
name
)
for
name
in
(
methods
[
"ring_exchange"
]
+
methods
[
"pipeline"
]
+
methods
[
"bulk"
]
+
methods
[
"external"
]
):
ub_cfg
=
get_default_config
(
name
)
if
ub_cfg
s
is
not
None
and
name
in
ub_cfg
s
:
if
user_
ub_cfg
is
not
None
and
name
in
user_
ub_cfg
:
fp8_buf
=
(
name
in
layers_all_gather_overlap
)
or
(
ub_cfg
s
[
name
].
get
(
"fp8_buf"
,
False
)
and
name
in
methods
[
"pipeline"
]
user_
ub_cfg
[
name
].
get
(
"fp8_buf"
,
False
)
and
name
in
methods
[
"pipeline"
]
)
ub_cfg
.
update
(
ub_cfgs
[
name
])
ub_cfg
[
"fp8_buf"
]
=
fp8_buf
add_ub
(
name
,
**
ub_cfg
)
add_ub
(
name
,
quantization_mode
,
**
ub_cfg
)
def
get_ub
(
name
:
str
):
def
get_ub
(
name
:
str
,
use_fp8
:
bool
):
"""Get userbuffer communicator corresponding to give key."""
# For now use `use_fp8` boolean input as it matches the current design in the modules
# So favour simplicity until the correct design becomes clear.
# This is mainly an internal API so we don't need to worry about future changes
key
=
(
name
,
UserBufferQuantizationMode
.
FP8
if
use_fp8
else
UserBufferQuantizationMode
.
NONE
)
assert
_ub_communicators
is
not
None
,
"UB manager is not initialized."
assert
name
in
_ub_communicators
,
f
"UB for
{
name
}
is not registered."
return
_ub_communicators
[
name
]
assert
key
in
_ub_communicators
,
f
"UB for
{
name
}
with use_fp8=
{
use_fp8
}
is not registered."
return
_ub_communicators
[
key
]
def
destroy_ub
():
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
715c3bb8
...
...
@@ -173,10 +173,10 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_ag_fprop
and
is_grad_enabled
and
not
return_layernorm_output
)
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
RS
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
AG
# Configure quantizer for norm output
...
...
@@ -575,23 +575,23 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
if
ctx
.
ub_overlap_ag
:
# Overlap grad_output all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
elif
ctx
.
ub_overlap_rs_dgrad
:
# Overlap dgrad reduce-scatter with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
else
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap inputmat all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_bulk_wgrad
:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
# --------------------------------------------------
...
...
@@ -769,7 +769,7 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_send_stream
,
dgrad_recv_stream
=
ub_obj_dgrad
.
get_communication_stream
()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
@@ -1492,10 +1492,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch
=
False
if
self
.
ub_overlap_rs_fprop
:
if
get_ub
(
self
.
ub_name
+
"_fprop"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_output
=
True
if
self
.
ub_overlap_rs_dgrad
:
if
get_ub
(
self
.
ub_name
+
"_dgrad"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_dgrad"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_grad
=
True
with
torch
.
cuda
.
device
(
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
715c3bb8
...
...
@@ -307,7 +307,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag
:
# Copy into Userbuffers buffer
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
)
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
,
fp8
)
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_lnout
,
ln_out
,
...
...
@@ -458,7 +458,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc2out
=
None
reduce_scatter_out
=
None
if
ub_overlap_rs
:
ub_obj_fc2out
=
get_ub
(
"fc2_fprop"
)
ub_obj_fc2out
=
get_ub
(
"fc2_fprop"
,
fp8
)
dim_size
=
list
(
act_out
.
size
())
dim_size
[
0
]
//=
tp_world_size
dim_size
[
-
1
]
=
fc2_weight
.
size
(
0
)
...
...
@@ -740,7 +740,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Cast to expected dtype and perform tensor-parallel communication
ub_obj_fc2_dgrad
=
None
if
ctx
.
ub_overlap_ag
:
ub_obj_fc2_dgrad
=
get_ub
(
"fc2_dgrad"
)
ub_obj_fc2_dgrad
=
get_ub
(
"fc2_dgrad"
,
ctx
.
fp8
)
ctx
.
ub_obj_gradout
=
ub_obj_fc2_dgrad
(
grad_output
,
...
...
@@ -764,7 +764,7 @@ class _LayerNormMLP(torch.autograd.Function):
# wgrad GEMM requires input with column-wise usage
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
if
ctx
.
ub_bulk_dgrad
:
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
,
ctx
.
fp8
)
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_fc1_dgrad
,
ln_out
,
...
...
@@ -869,7 +869,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc2_dgrad
.
get_communication_stream
()
)
ub_obj_fc2_wgrad
=
get_ub
(
"fc2_wgrad"
)
ub_obj_fc2_wgrad
=
get_ub
(
"fc2_wgrad"
,
ctx
.
fp8
)
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
@@ -1036,16 +1036,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_shape
=
[
reduce
(
multiply_op
,
inputmat
.
shape
[:
-
1
]),
inputmat
.
shape
[
-
1
]]
if
ctx
.
ub_overlap_rs_dgrad
:
# Overlap DGRAD+RS
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
,
ctx
.
fp8
)
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
RS
else
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap ln_out all-gather with DGRAD compute
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
,
ctx
.
fp8
)
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_bulk_wgrad
:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad
=
get_ub
(
"fc1_wgrad"
)
ub_obj_fc1_wgrad
=
get_ub
(
"fc1_wgrad"
,
ctx
.
fp8
)
ub_type_fc1_wgrad
=
tex
.
CommOverlapType
.
RS
# --------------------------------------------------
...
...
@@ -1539,7 +1539,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
gemm_gelu_fusion
=
(
bool
(
int
(
os
.
getenv
(
"NVTE_GEMM_GELU_FUSION"
,
"0"
)))
and
self
.
activation
==
"gelu"
and
((
_ub_communicators
is
None
)
or
(
not
get_ub
(
"fc1_fprop"
).
is_atomic_gemm
()))
and
all
(
(
"fc1_fprop"
,
use_fp8
)
not
in
_ub_communicators
or
not
get_ub
(
"fc1_fprop"
,
use_fp8
).
is_atomic_gemm
()
for
use_fp8
in
[
False
,
True
]
)
)
self
.
name
=
name
...
...
@@ -1757,7 +1761,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fp8_output
=
False
if
self
.
ub_overlap_rs
:
if
get_ub
(
"fc2_fprop"
).
is_fp8_ubuf
():
if
get_ub
(
"fc2_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_output
=
True
with
torch
.
cuda
.
device
(
...
...
transformer_engine/pytorch/module/linear.py
View file @
715c3bb8
...
...
@@ -145,10 +145,10 @@ class _Linear(torch.autograd.Function):
ub_obj
=
None
ub_type
=
None
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
RS
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
AG
# ------------------------------------------------------
...
...
@@ -520,23 +520,23 @@ class _Linear(torch.autograd.Function):
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
if
ctx
.
ub_overlap_ag
:
# Overlap grad_output all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
elif
ctx
.
ub_overlap_rs_dgrad
:
# Overlap dgrad reduce-scatter with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
else
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap inputmat all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_bulk_wgrad
:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
# --------------------------------------------------
...
...
@@ -769,7 +769,7 @@ class _Linear(torch.autograd.Function):
dgrad_send_stream
,
dgrad_recv_stream
=
ub_obj_dgrad
.
get_communication_stream
()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
@@ -1377,10 +1377,14 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch
=
False
if
self
.
ub_overlap_rs_fprop
:
if
get_ub
(
self
.
ub_name
+
"_fprop"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_output
=
True
if
self
.
ub_overlap_rs_dgrad
:
if
get_ub
(
self
.
ub_name
+
"_dgrad"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_dgrad"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_grad
=
True
with
torch
.
cuda
.
device
(
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
715c3bb8
...
...
@@ -241,16 +241,16 @@ class UserbuffersBackwardLinear(FusedOperation):
with_dgrad_all_gather_x
=
False
with_wgrad_reduce_scatter_dx
=
False
if
tensor_parallel_mode
==
"row"
:
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
,
with_quantized_compute
)
ub_type_dgrad
=
CommOverlapType
.
AG
with_dgrad_all_gather_dy
=
True
elif
tensor_parallel_mode
==
"column"
:
if
input_requires_grad
and
weight_requires_grad
:
with_bulk_overlap
=
True
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
,
with_quantized_compute
)
ub_type_dgrad
=
CommOverlapType
.
AG
with_dgrad_all_gather_x
=
True
ub_comm_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
)
ub_comm_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
,
with_quantized_compute
)
ub_type_wgrad
=
CommOverlapType
.
RS
with_wgrad_reduce_scatter_dx
=
True
if
ub_comm_wgrad
.
is_fp8_ubuf
():
...
...
@@ -258,7 +258,7 @@ class UserbuffersBackwardLinear(FusedOperation):
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
else
:
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
,
with_quantized_compute
)
ub_type_dgrad
=
CommOverlapType
.
RS
with_dgrad_reduce_scatter_dx
=
True
if
ub_comm_dgrad
.
is_fp8_ubuf
():
...
...
@@ -409,7 +409,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream
,
dgrad_recv_stream
=
ub_comm_dgrad
.
get_communication_stream
()
ub_obj_overlap_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
)
ub_obj_overlap_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
,
with_quantized_compute
)
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
715c3bb8
...
...
@@ -189,7 +189,7 @@ class UserbuffersForwardLinear(FusedOperation):
output_quantizer
=
None
# Get Userbuffers communicator
ub_comm
=
get_ub
(
ub_comm_name
+
"_fprop"
)
ub_comm
=
get_ub
(
ub_comm_name
+
"_fprop"
,
with_quantized_compute
)
with_ub_all_gather
=
tensor_parallel_mode
==
"column"
with_ub_reduce_scatter
=
tensor_parallel_mode
==
"row"
ub_type
=
CommOverlapType
.
AG
if
with_ub_all_gather
else
CommOverlapType
.
RS
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment