Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
035c48c0
Commit
035c48c0
authored
Mar 24, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
ea272d4a
86813893
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
328 additions
and
52 deletions
+328
-52
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+21
-27
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+21
-23
transformer_engine/pytorch/tensor/__init__.py
transformer_engine/pytorch/tensor/__init__.py
+1
-0
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+2
-2
transformer_engine/pytorch/tensor/utils.py
transformer_engine/pytorch/tensor/utils.py
+283
-0
No files found.
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
035c48c0
...
@@ -213,8 +213,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -213,8 +213,6 @@ class _LayerNormMLP(torch.autograd.Function):
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned
# high precision layernorm output and output of the linear are returned
with_quantized_norm
=
fp8
and
not
return_layernorm_output
with_quantized_norm
=
fp8
and
not
return_layernorm_output
if
isinstance
(
fc1_input_quantizer
,
Float8CurrentScalingQuantizer
):
with_quantized_norm
=
False
tp_world_size
=
get_distributed_world_size
(
tp_group
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
ub_overlap_ag
=
ub_overlap_ag
and
is_grad_enabled
and
not
return_layernorm_output
ub_overlap_ag
=
ub_overlap_ag
and
is_grad_enabled
and
not
return_layernorm_output
...
@@ -318,35 +316,31 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -318,35 +316,31 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total
=
ln_out
ln_out_total
=
ln_out
# Cast weights to expected dtype
# Cast weights to expected dtype
fc1_weight_final
=
fc1_weight
fc2_weight_final
=
fc2_weight
if
not
fp8
:
if
not
fp8
:
fc1_weight_final
=
cast_if_needed
(
fc1_weight
_final
,
activation_dtype
)
fc1_weight_final
=
cast_if_needed
(
fc1_weight
,
activation_dtype
)
fc2_weight_final
=
cast_if_needed
(
fc2_weight
_final
,
activation_dtype
)
fc2_weight_final
=
cast_if_needed
(
fc2_weight
,
activation_dtype
)
else
:
else
:
# If weights are not quantized, we call get_weight_workspace,
# If weights are not quantized, we call get_weight_workspace,
# which handles weight caching etc.
# which handles weight caching etc.
if
not
isinstance
(
fc1_weight
,
QuantizedTensor
):
# FP8 cast to workspace buffer
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
fc1_weight_final
=
module
.
get_weight_workspace
(
fc1_weight_final
=
module
.
get_weight_workspace
(
tensor
=
fc1_weight
,
tensor
=
fc1_weight
,
quantizer
=
fc1_weight_quantizer
,
quantizer
=
fc1_weight_quantizer
,
cache_name
=
(
None
if
is_first_microbatch
is
None
else
"fc1_weight"
),
cache_name
=
(
None
if
is_first_microbatch
is
None
else
"fc1_weight"
),
update_workspace
=
update_workspace
,
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
skip_update_flag
=
skip_fp8_weight_update
,
fsdp_group
=
fsdp_group
,
fsdp_group
=
fsdp_group
,
)
)
fc2_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
if
not
isinstance
(
fc2_weight
,
QuantizedTensor
):
fc2_weight_final
=
module
.
get_weight_workspace
(
fc2_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
tensor
=
fc2_weight
,
fc2_weight_final
=
module
.
get_weight_workspace
(
quantizer
=
fc2_weight_quantizer
,
tensor
=
fc2_weight
,
cache_name
=
(
None
if
is_first_microbatch
is
None
else
"fc2_weight"
),
quantizer
=
fc2_weight_quantizer
,
update_workspace
=
update_workspace
,
cache_name
=
(
None
if
is_first_microbatch
is
None
else
"fc2_weight"
),
skip_update_flag
=
skip_fp8_weight_update
,
update_workspace
=
update_workspace
,
fsdp_group
=
fsdp_group
,
skip_update_flag
=
skip_fp8_weight_update
,
)
fsdp_group
=
fsdp_group
,
)
# Cast biases to expected dtype
# Cast biases to expected dtype
bias_dtype
=
activation_dtype
bias_dtype
=
activation_dtype
...
...
transformer_engine/pytorch/module/linear.py
View file @
035c48c0
...
@@ -176,31 +176,29 @@ class _Linear(torch.autograd.Function):
...
@@ -176,31 +176,29 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop
(
f
"
{
nvtx_label
}
.input_cast_comm"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.input_cast_comm"
)
# Cast weight to expected dtype
# Cast weight to expected dtype
weightmat
=
weight
if
not
fp8
:
if
not
fp8
:
weightmat
=
cast_if_needed
(
weight
mat
,
activation_dtype
)
weightmat
=
cast_if_needed
(
weight
,
activation_dtype
)
else
:
else
:
if
not
isinstance
(
weight
,
QuantizedTensor
):
# Configure quantizer
# Configure quantizer
if
weight_quantizer
is
not
None
:
if
weight_quantizer
is
not
None
:
columnwise_usage
=
is_grad_enabled
and
inp
.
requires_grad
columnwise_usage
=
is_grad_enabled
and
inp
.
requires_grad
if
not
columnwise_usage
:
if
not
columnwise_usage
:
columnwise_usage
=
(
columnwise_usage
=
(
is_fp8_activation_recompute_enabled
()
is_fp8_activation_recompute_enabled
()
and
not
in_fp8_activation_recompute_phase
()
and
not
in_fp8_activation_recompute_phase
()
)
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
# FP8 cast to workspace buffer
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
weightmat
=
module
.
get_weight_workspace
(
weightmat
=
module
.
get_weight_workspace
(
tensor
=
weight
,
tensor
=
weight
,
quantizer
=
weight_quantizer
,
quantizer
=
weight_quantizer
,
cache_name
=
(
None
if
is_first_microbatch
is
None
else
"weight"
),
cache_name
=
(
None
if
is_first_microbatch
is
None
else
"weight"
),
update_workspace
=
update_workspace
,
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
skip_update_flag
=
skip_fp8_weight_update
,
fsdp_group
=
fsdp_group
,
fsdp_group
=
fsdp_group
,
)
)
# Cast bias to expected dtype
# Cast bias to expected dtype
bias_dtype
=
activation_dtype
bias_dtype
=
activation_dtype
...
...
transformer_engine/pytorch/tensor/__init__.py
View file @
035c48c0
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
import
torch
import
torch
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
from
.utils
import
cast_master_weights_to_fp8
,
replace_raw_data
__all__
=
[
__all__
=
[
"QuantizedTensor"
,
"QuantizedTensor"
,
...
...
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
035c48c0
...
@@ -185,9 +185,9 @@ class Float8CurrentScalingQuantizer(Quantizer):
...
@@ -185,9 +185,9 @@ class Float8CurrentScalingQuantizer(Quantizer):
"""
"""
"""
Scaling factor to multiply when quantizing to FP8
"""
"""
Workspace buffer for FP8 scaling factor
"""
scale
:
torch
.
Tensor
scale
:
torch
.
Tensor
"""
Max-abs value from last FP8 cast
"""
"""
Workspace buffer for max-abs value
"""
amax
:
torch
.
Tensor
amax
:
torch
.
Tensor
"""FP8 datatype"""
"""FP8 datatype"""
dtype
:
TE_DType
dtype
:
TE_DType
...
...
transformer_engine/pytorch/tensor/utils.py
0 → 100644
View file @
035c48c0
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Helper functions for using fp8 tensors as weights"""
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
multi_tensor_scale
,
multi_tensor_compute_scale_and_scale_inv
from
.quantized_tensor
import
QuantizedTensor
from
.float8_tensor
import
Float8Tensor
,
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
.mxfp8_tensor
import
MXFP8Tensor
,
MXFP8Quantizer
from
..optimizers.multi_tensor_apply
import
multi_tensor_applier
def
replace_raw_data
(
tensor
:
QuantizedTensor
,
new_raw_data
:
torch
.
Tensor
):
r
"""Change a quantized tensor's data buffer while preserving values
This function modifies only the address space of the underlying
raw data and does not alter any other tensor attributes or values.
This may be used for custom buffer allocations, e.g. packing
multiple parameter tensors together into a single contiguous
buffer for ZeRO-2.
"""
if
isinstance
(
tensor
,
Float8Tensor
):
old_raw_data
=
tensor
.
_data
assert
old_raw_data
.
dtype
==
new_raw_data
.
dtype
,
"The data types of raw data don't match"
new_raw_data
.
detach
().
copy_
(
old_raw_data
)
tensor
.
_data
=
new_raw_data
del
old_raw_data
elif
isinstance
(
tensor
,
MXFP8Tensor
):
raise
NotImplementedError
(
"replace_raw_data for MXFP8Tensor is not supported yet"
)
else
:
raise
ValueError
(
f
"replace_raw_data for
{
type
(
tensor
)
}
is not supported yet"
)
def
cast_master_weights_to_fp8
(
model_weights
,
master_weights
,
start_offsets
,
group
):
r
"""Helper function to cast master weights to FP8 primary weights.
This is intended for use with ZeRO/FSDP. Each rank has a shard of
the master weights (possibly empty) and a full copy of the model
weights.
Parameters
----------
model_weights : list of FP8 weights.
master_weights : list of master weights. Typically they are FP32 weights.
start_offsets : list of integers, the starting index of the master weight in the model weight.
master_weight may be smaller than model_weight because it could be distributed
across multiple ranks. These offsets indicate which part of the model_weight
should be updated.
group : The distributed group to do amax reduction. Typically it's the data parallel
group.
"""
delayed_scaling_params
=
[]
current_scaling_params
=
[]
for
model_weight
,
master_weight
,
start_offset
in
zip
(
model_weights
,
master_weights
,
start_offsets
):
# Clear `_high_precision_init_val` of model_weight automatically.
# - Master weights are initialized from model weights, if we use fp8 primary weights to
# initialize master weights, the numerical values of master weights are not consistent
# with the numerical values when we initialize them from bf16/fp16 weights.
# - So we add a `_high_precision_init_val` attribute to each model weight to store the
# original bf16/fp16 weight on cpu before casting it to fp8. And users can use
# `get_high_precision_init_val` to get this cpu tensor.
# - This cpu tensor is not needed once the master weight is initialized, so users should
# call `clear_high_precision_init_val` to remove it after master weight is initialized.
# - In case users don't call `clear_high_precision_init_val`, we will clear it automatically
# here. It's safe to clear the `_high_precision_init_val` at this time because this
# function is supposed to be called after the master weights are initialized and updated.
if
hasattr
(
model_weight
,
"clear_high_precision_init_val"
):
model_weight
.
clear_high_precision_init_val
()
if
master_weight
is
not
None
:
# When not using fp8_primary_weights, the master_weight (fp32) is first cast to
# bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when
# fp8_primary_weights is enabled, we still keep this logic to keep numerical
# consistency. So here we cast the master_weight to model_weight.dtype.
master_weight
=
master_weight
.
to
(
model_weight
.
dtype
)
quantizer
=
model_weight
.
_get_quantizer
()
if
isinstance
(
quantizer
,
Float8Quantizer
):
delayed_scaling_params
.
append
((
model_weight
,
master_weight
,
start_offset
))
elif
isinstance
(
quantizer
,
Float8CurrentScalingQuantizer
):
current_scaling_params
.
append
((
model_weight
,
master_weight
,
start_offset
))
elif
isinstance
(
quantizer
,
MXFP8Quantizer
):
raise
NotImplementedError
(
"cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet"
)
else
:
raise
ValueError
(
f
"cast_master_weights_to_fp8 for
{
type
(
quantizer
)
}
is not supported yet"
)
if
len
(
delayed_scaling_params
)
>
0
:
_cast_master_weights_to_fp8_delayed_scaling
(
delayed_scaling_params
,
group
)
if
len
(
current_scaling_params
)
>
0
:
_cast_master_weights_to_fp8_current_scaling
(
current_scaling_params
,
group
)
def
_cast_master_weights_to_fp8_delayed_scaling
(
params
,
group
):
r
"""Helper function to cast master weights to FP8 primary weights for delayed scaling.
Parameters
----------
params : List of tuple, each tuple contains a model weight, a master weight, and an offset
indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel
group.
"""
# Collect amaxes to do reduce-max among dp group.
# Collect scales and scale_invs to update scale_invs of the fp8 weights.
amaxes
,
scales
,
scale_invs
=
[],
[],
[]
for
model_weight
,
master_weight
,
start_offset
in
params
:
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# currently.
model_weight
.
_reset_caches
()
quantizer
=
model_weight
.
_get_quantizer
()
amaxes
.
append
(
quantizer
.
amax
.
view
(
1
))
scales
.
append
(
quantizer
.
scale
.
view
(
1
))
scale_invs
.
append
(
model_weight
.
_scale_inv
.
view
(
1
))
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
if
master_weight
is
None
:
continue
# If master weight is not None, start_offset must be a valid value.
assert
start_offset
is
not
None
assert
start_offset
>=
0
end_offset
=
start_offset
+
master_weight
.
numel
()
assert
end_offset
<=
model_weight
.
numel
()
# master_weight may be smaller than model_weight because it could be distributed across
# multiple ranks. So we need to create a dummy weight using the raw data from model_weight.
shard_model_weight_raw
=
model_weight
.
_data
.
view
(
-
1
)[
start_offset
:
end_offset
]
shard_model_weight_fp8
=
quantizer
.
create_tensor_from_data
(
shard_model_weight_raw
.
view
(
1
,
-
1
),
model_weight
.
dtype
,
)
# Cast master weight to fp8.
quantizer
.
update_quantized
(
master_weight
.
view
(
1
,
-
1
),
shard_model_weight_fp8
)
if
len
(
amaxes
)
>
0
:
dummy_overflow_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
amaxes
[
0
].
device
)
# Reduce amaxes.
packed_amaxes
=
torch
.
empty
(
len
(
amaxes
),
dtype
=
torch
.
float32
,
device
=
amaxes
[
0
].
device
)
packed_amax_views
=
[
packed_amaxes
[
i
].
view
(
1
)
for
i
in
range
(
len
(
amaxes
))]
multi_tensor_applier
(
multi_tensor_scale
,
dummy_overflow_buf
,
[
amaxes
,
packed_amax_views
],
1.0
)
torch
.
distributed
.
all_reduce
(
packed_amaxes
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
group
,
)
multi_tensor_applier
(
multi_tensor_scale
,
dummy_overflow_buf
,
[
packed_amax_views
,
amaxes
],
1.0
)
# Update scale_invs.
packed_scales
=
torch
.
empty
(
len
(
scales
),
dtype
=
torch
.
float32
,
device
=
scales
[
0
].
device
)
packed_scale_views
=
[
packed_scales
[
i
].
view
(
1
)
for
i
in
range
(
len
(
scales
))]
multi_tensor_applier
(
multi_tensor_scale
,
dummy_overflow_buf
,
[
scales
,
packed_scale_views
],
1.0
)
torch
.
reciprocal
(
packed_scales
,
out
=
packed_scales
)
multi_tensor_applier
(
multi_tensor_scale
,
dummy_overflow_buf
,
[
packed_scale_views
,
scale_invs
],
1.0
)
def
_cast_master_weights_to_fp8_current_scaling
(
params
,
group
):
r
"""Helper function to cast master weights to FP8 primary weights for current scaling.
Parameters
----------
params : List of tuple, each tuple contains a model weight, a master weight, and an offset
indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel
group.
"""
# Parameter attributes
device
=
params
[
0
][
0
].
device
fp8_dtype
=
params
[
0
][
0
].
_get_quantizer
().
dtype
force_pow_2_scales
=
params
[
0
][
0
].
_get_quantizer
().
force_pow_2_scales
amax_epsilon
=
params
[
0
][
0
].
_get_quantizer
().
amax_epsilon
# Create a dummy overflow buffer, it's needed by multi_tensor_applier.
dummy_overflow_buf
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
,
device
=
device
)
# Create a contiguous buffer to store amaxes temporarily, so we can perform all all-reduce
# NCCL kernels at once.
packed_amaxes
=
torch
.
zeros
(
len
(
params
),
dtype
=
torch
.
float32
,
device
=
device
)
amaxes
=
[
packed_amaxes
[
i
:
i
+
1
]
for
i
in
range
(
len
(
params
))]
# Collect scales and scale_invs to update them after amax reduction.
scales
,
scale_invs
=
[],
[]
# ---------------------------------------------------------------------------------------------
# Step 1: Iterate through all the none empty master weights and compute amax of them. Store the
# amaxes in a contiguous buffer. If the master weight is None, the corresponding amax
# will be set to 0.
# ---------------------------------------------------------------------------------------------
for
(
model_weight
,
master_weight
,
_
),
amax
in
zip
(
params
,
amaxes
):
# Make sure all the model weights have the same numerical options.
quantizer
=
model_weight
.
_get_quantizer
()
assert
quantizer
.
dtype
==
fp8_dtype
assert
quantizer
.
force_pow_2_scales
==
force_pow_2_scales
assert
quantizer
.
amax_epsilon
==
amax_epsilon
scales
.
append
(
quantizer
.
scale
.
view
(
1
))
scale_invs
.
append
(
model_weight
.
_scale_inv
.
view
(
1
))
# Compute amax of the master weight and store it in packed_amaxes.
if
master_weight
is
not
None
:
tex
.
compute_amax
(
master_weight
,
amax
)
# ---------------------------------------------------------------------------------------------
# Step 2: Perform all-reduce on packed_amaxes to get the global amax.
# ---------------------------------------------------------------------------------------------
torch
.
distributed
.
all_reduce
(
packed_amaxes
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
group
)
# ---------------------------------------------------------------------------------------------
# Step 3: Update scales and scale_invs.
# ---------------------------------------------------------------------------------------------
if
fp8_dtype
==
tex
.
DType
.
kFloat8E4M3
:
max_fp8
=
448.0
elif
fp8_dtype
==
tex
.
DType
.
kFloat8E5M2
:
max_fp8
=
57344.0
else
:
raise
ValueError
(
f
"Unsupported FP8 dtype:
{
fp8_dtype
}
"
)
multi_tensor_applier
(
multi_tensor_compute_scale_and_scale_inv
,
dummy_overflow_buf
,
[
amaxes
,
scales
,
scale_invs
],
max_fp8
,
force_pow_2_scales
,
amax_epsilon
,
)
# ---------------------------------------------------------------------------------------------
# Step 4: Cast master weights to FP8.
# ---------------------------------------------------------------------------------------------
for
(
model_weight
,
master_weight
,
start_offset
),
scale
in
zip
(
params
,
scales
):
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# currently.
model_weight
.
_reset_caches
()
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
if
master_weight
is
None
:
continue
# Cast master weight to FP8
end_offset
=
start_offset
+
master_weight
.
numel
()
model_weight_fragment
=
model_weight
.
reshape
(
-
1
)[
start_offset
:
end_offset
]
quantizer
=
Float8Quantizer
(
scale
=
scale
,
amax
=
torch
.
Tensor
(),
fp8_dtype
=
model_weight
.
_fp8_dtype
,
)
quantizer
.
update_quantized
(
master_weight
,
model_weight_fragment
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment