Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
035c48c0
Commit
035c48c0
authored
Mar 24, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
ea272d4a
86813893
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
903 additions
and
123 deletions
+903
-123
3rdparty/cudnn-frontend
3rdparty/cudnn-frontend
+1
-1
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+2
-1
qa/L1_pytorch_distributed_unittest/test.sh
qa/L1_pytorch_distributed_unittest/test.sh
+1
-0
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
+399
-0
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
+35
-0
tests/pytorch/references/ref_per_tensor_cs.py
tests/pytorch/references/ref_per_tensor_cs.py
+20
-6
tests/pytorch/test_multi_tensor.py
tests/pytorch/test_multi_tensor.py
+42
-0
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+73
-1
transformer_engine/common/recipe/current_scaling.cu
transformer_engine/common/recipe/current_scaling.cu
+3
-36
transformer_engine/common/recipe/recipe_common.cuh
transformer_engine/common/recipe/recipe_common.cuh
+56
-0
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+6
-0
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu
...src/extensions/multi_tensor/multi_tensor_compute_scale.cu
+66
-0
transformer_engine/pytorch/csrc/extensions/normalization.cpp
transformer_engine/pytorch/csrc/extensions/normalization.cpp
+72
-17
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+3
-0
transformer_engine/pytorch/csrc/extensions/recipe.cpp
transformer_engine/pytorch/csrc/extensions/recipe.cpp
+19
-2
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+29
-1
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+3
-17
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+43
-1
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+11
-14
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+19
-26
No files found.
cudnn-frontend
@
6ed19fd2
Compare
20c28ea7
...
6ed19fd2
Subproject commit
20c28ea798fe99e31d7274e009ee2fbf0e88abfd
Subproject commit
6ed19fd213e33af2d9a1841b1023ccb2f81d45a1
qa/L0_pytorch_unittest/test.sh
View file @
035c48c0
...
...
@@ -26,7 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_recipe.py
||
test_fail
"test_recipe.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_deferred_init.py
||
test_fail
"test_deferred_init.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_numerics.py
||
test_fail
"test_numerics.py"
NVTE_CUDNN_MXFP8_NORM
=
0
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cuda_graphs.py
||
test_fail
"test_cuda_graphs.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cuda_graphs.py
||
test_fail
"test_cuda_graphs.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_jit.py
||
test_fail
"test_jit.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_fused_rope.py
||
test_fail
"test_fused_rope.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_float8tensor.py
||
test_fail
"test_float8tensor.py"
...
...
@@ -39,6 +39,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py ||
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cpu_offloading.py
||
test_fail
"test_cpu_offloading.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_fused_attn.py
||
test_fail
"test_fused_attn.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_paged_attn.py
||
test_fail
"test_paged_attn.py"
if
[
"
$RET
"
-ne
0
]
;
then
echo
"Error in the following test cases:
$FAILED_CASES
"
exit
1
...
...
qa/L1_pytorch_distributed_unittest/test.sh
View file @
035c48c0
...
...
@@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py |
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/distributed/test_comm_gemm_overlap.py
||
test_fail
"test_comm_gemm_overlap.py"
# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
||
test_fail
"test_fused_attn_with_cp.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
||
test_fail
"test_cast_master_weights_to_fp8.py"
if
[
"
$RET
"
-ne
0
]
;
then
echo
"Error in the following test cases:
$FAILED_CASES
"
...
...
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
0 → 100644
View file @
035c48c0
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
argparse
import
datetime
import
os
import
sys
import
torch
from
torch
import
nn
import
torch.distributed
as
dist
from
transformer_engine.common.recipe
import
(
DelayedScaling
,
Float8CurrentScaling
,
Format
,
Recipe
,
)
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
,
cast_master_weights_to_fp8
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
def
_get_raw_data
(
quantized_tensor
):
"""Get the underlying data of a quantized tensor, used in zero-1 optimizer"""
if
isinstance
(
quantized_tensor
,
Float8Tensor
):
assert
hasattr
(
quantized_tensor
,
"_data"
),
"Float8Tensor does not have _data attribute"
assert
quantized_tensor
.
_data
.
dtype
==
torch
.
uint8
,
"Float8Tensor _data must be uint8"
return
quantized_tensor
.
_data
else
:
raise
ValueError
(
f
"Unsupported quantized tensor type:
{
type
(
quantized_tensor
)
}
"
)
class
MiniZero_1
:
"""A mini zero-1 optimizer implementation, just used for this test"""
def
__init__
(
self
,
weights
,
lr
,
dp_group
):
self
.
rank
=
dist
.
get_rank
(
dp_group
)
self
.
world_size
=
dist
.
get_world_size
(
dp_group
)
self
.
weights
=
weights
self
.
lr
=
lr
self
.
dp_group
=
dp_group
# [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer
self
.
offsets
=
[
0
]
for
weight
in
self
.
weights
:
self
.
offsets
.
append
(
self
.
offsets
[
-
1
]
+
weight
.
numel
())
# Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may
# not be the end range of the last weight.
if
self
.
offsets
[
-
1
]
%
self
.
world_size
!=
0
:
self
.
offsets
[
-
1
]
+=
self
.
world_size
-
self
.
offsets
[
-
1
]
%
self
.
world_size
self
.
master_weights
=
[]
# The start offset of the master weight in the weight
self
.
start_offsets
=
[]
# The overlapping area of the weight and this rank's local buffer
self
.
overlapping_areas
=
[]
# The start and end of this rank's local buffer in the global buffer
rank_start
=
self
.
offsets
[
-
1
]
//
self
.
world_size
*
self
.
rank
rank_end
=
rank_start
+
self
.
offsets
[
-
1
]
//
self
.
world_size
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
if
offset
>=
rank_end
or
(
offset
+
weight
.
numel
())
<=
rank_start
:
# This weight is not in this rank's local buffer
master_weight
=
None
start_offset
=
None
overlapping_area
=
None
else
:
overlapping_start
=
max
(
rank_start
,
offset
)
overlapping_end
=
min
(
rank_end
,
offset
+
weight
.
numel
())
length
=
overlapping_end
-
overlapping_start
start_offset
=
overlapping_start
-
offset
if
isinstance
(
weight
,
QuantizedTensor
):
# If weight is a FP8 tensor, we need to use the original high precision version
# to initialize the master weight.
high_precision_init_val
=
weight
.
get_high_precision_init_val
().
view
(
-
1
)
master_weight
=
high_precision_init_val
.
to
(
weight
.
device
).
float
()[
start_offset
:
start_offset
+
length
]
else
:
master_weight
=
(
weight
.
detach
().
view
(
-
1
).
float
()[
start_offset
:
start_offset
+
length
]
)
overlapping_area
=
(
overlapping_start
,
overlapping_end
)
self
.
master_weights
.
append
(
master_weight
)
self
.
start_offsets
.
append
(
start_offset
)
self
.
overlapping_areas
.
append
(
overlapping_area
)
# Create global buffer for grads reduce-scatter
self
.
grad_buffer
=
torch
.
empty
(
[
self
.
offsets
[
-
1
]],
dtype
=
torch
.
float32
,
device
=
weights
[
0
].
device
)
self
.
grad_buffer_slice
=
self
.
grad_buffer
[
rank_start
:
rank_end
]
# Create global buffer for weights all-gather
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
weight_buffer_dtype
=
torch
.
uint8
else
:
weight_buffer_dtype
=
weights
[
0
].
dtype
self
.
weight_buffer
=
torch
.
empty
(
[
self
.
offsets
[
-
1
]],
dtype
=
weight_buffer_dtype
,
device
=
weights
[
0
].
device
)
self
.
weight_buffer_slice
=
self
.
weight_buffer
[
rank_start
:
rank_end
]
def
step
(
self
):
# -----------------------------------------------------------------------------------------
# Step 1: Copy grads to the grad buffer
# -----------------------------------------------------------------------------------------
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
start
=
offset
end
=
offset
+
weight
.
numel
()
self
.
grad_buffer
[
start
:
end
].
copy_
(
weight
.
main_grad
.
view
(
-
1
))
# -----------------------------------------------------------------------------------------
# Step 2: Grads reduce-scatter
# -----------------------------------------------------------------------------------------
# Don't use reduce_scatter directly to explicitly control the reduce order.
# dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG,
# group=self.dp_group)
buffers
=
[
torch
.
empty_like
(
self
.
grad_buffer
)
for
_
in
range
(
self
.
world_size
)]
dist
.
all_gather
(
buffers
,
self
.
grad_buffer
,
group
=
self
.
dp_group
)
for
i
in
range
(
1
,
self
.
world_size
):
buffers
[
0
]
+=
buffers
[
i
]
rank_start
=
self
.
offsets
[
-
1
]
//
self
.
world_size
*
self
.
rank
rank_end
=
rank_start
+
self
.
offsets
[
-
1
]
//
self
.
world_size
self
.
grad_buffer_slice
.
copy_
(
buffers
[
0
][
rank_start
:
rank_end
])
self
.
grad_buffer_slice
/=
self
.
world_size
# -----------------------------------------------------------------------------------------
# Step 3: Update master weights
# -----------------------------------------------------------------------------------------
for
master_weight
,
overlapping_area
in
zip
(
self
.
master_weights
,
self
.
overlapping_areas
):
if
master_weight
is
None
:
# This weight's master weight is in other rank.
continue
grad
=
self
.
grad_buffer
[
overlapping_area
[
0
]
:
overlapping_area
[
1
]]
master_weight
-=
grad
*
self
.
lr
# -----------------------------------------------------------------------------------------
# Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight
# -----------------------------------------------------------------------------------------
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
# FP8 weights case
for
i
in
range
(
1
,
len
(
self
.
weights
)):
assert
isinstance
(
self
.
weights
[
i
],
QuantizedTensor
)
cast_master_weights_to_fp8
(
self
.
weights
,
self
.
master_weights
,
self
.
start_offsets
,
self
.
dp_group
)
else
:
# BF16 weights case
for
weight
,
master_weight
,
start_offset
in
zip
(
self
.
weights
,
self
.
master_weights
,
self
.
start_offsets
):
if
master_weight
is
None
:
continue
start
=
start_offset
end
=
start_offset
+
master_weight
.
numel
()
weight
.
data
.
view
(
-
1
)[
start
:
end
].
copy_
(
master_weight
)
# -----------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -----------------------------------------------------------------------------------------
for
i
in
range
(
len
(
self
.
weights
)):
master_weight
=
self
.
master_weights
[
i
]
if
master_weight
is
None
:
continue
start_offset
=
self
.
start_offsets
[
i
]
if
isinstance
(
self
.
weights
[
i
],
QuantizedTensor
):
weight
=
_get_raw_data
(
self
.
weights
[
i
])
else
:
weight
=
self
.
weights
[
i
]
weight_slice
=
weight
.
view
(
-
1
)[
start_offset
:
start_offset
+
master_weight
.
numel
()]
overlapping_start
,
overlapping_end
=
self
.
overlapping_areas
[
i
]
self
.
weight_buffer
[
overlapping_start
:
overlapping_end
].
copy_
(
weight_slice
)
# -----------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -----------------------------------------------------------------------------------------
dist
.
all_gather_into_tensor
(
self
.
weight_buffer
,
self
.
weight_buffer_slice
,
group
=
self
.
dp_group
)
# -----------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -----------------------------------------------------------------------------------------
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
start
=
offset
end
=
offset
+
weight
.
numel
()
if
isinstance
(
weight
,
QuantizedTensor
):
weight
=
_get_raw_data
(
weight
)
weight
.
view
(
-
1
).
data
.
copy_
(
self
.
weight_buffer
[
start
:
end
])
class
MiniOptimizer
:
def
__init__
(
self
,
weights
,
lr
,
dp_group
):
self
.
world_size
=
dist
.
get_world_size
(
dp_group
)
self
.
weights
=
weights
self
.
lr
=
lr
self
.
dp_group
=
dp_group
master_weights
=
[]
for
weight
in
self
.
weights
:
master_weights
.
append
(
weight
.
detach
().
float
())
self
.
master_weights
=
master_weights
def
step
(
self
):
for
weight
,
master_weight
in
zip
(
self
.
weights
,
self
.
master_weights
):
main_grad
=
weight
.
main_grad
# Don't use all-reduce directly to explicitly control the reduce order.
# dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group)
buffers
=
[
torch
.
empty_like
(
main_grad
)
for
_
in
range
(
self
.
world_size
)]
dist
.
all_gather
(
buffers
,
main_grad
,
group
=
self
.
dp_group
)
for
i
in
range
(
1
,
self
.
world_size
):
buffers
[
0
]
+=
buffers
[
i
]
main_grad
.
copy_
(
buffers
[
0
])
main_grad
/=
self
.
world_size
master_weight
-=
main_grad
*
self
.
lr
weight
.
data
.
copy_
(
master_weight
)
def
_test_zero_1
(
dp_group
):
"""Make sure the implementation of zero-1 optimizer is correct"""
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
torch
.
manual_seed
(
12345
)
torch
.
cuda
.
manual_seed
(
12345
)
weights
=
[
torch
.
randn
(
256
*
256
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
torch
.
randn
(
256
*
256
*
3
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
torch
.
randn
(
256
*
256
*
2
-
1
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
]
weights_1
=
weights
weights_2
=
[
weight
.
clone
()
for
weight
in
weights
]
lr
=
1.0
optimizer_1
=
MiniZero_1
(
weights_1
,
lr
,
dp_group
)
optimizer_2
=
MiniOptimizer
(
weights_2
,
lr
,
dp_group
)
for
_
in
range
(
100
):
for
w1
,
w2
in
zip
(
weights_1
,
weights_2
):
main_grads
=
[
torch
.
randn_like
(
w1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the grads of different ranks are different.
main_grad
=
main_grads
[
rank
]
w1
.
main_grad
=
main_grad
w2
.
main_grad
=
main_grad
optimizer_1
.
step
()
optimizer_2
.
step
()
for
w1
,
w2
in
zip
(
weights_1
,
weights_2
):
torch
.
testing
.
assert_close
(
w1
,
w2
,
atol
=
0
,
rtol
=
0
)
def
quantization_recipe
(
quantization
)
->
Recipe
:
"""Quantization recipe setup"""
if
quantization
==
"fp8"
:
return
DelayedScaling
(
fp8_format
=
Format
.
HYBRID
,
amax_history_len
=
32
,
amax_compute_algo
=
"max"
)
elif
quantization
==
"fp8_cs"
:
return
Float8CurrentScaling
()
else
:
raise
ValueError
(
f
"Unsupported quantization:
{
quantization
}
"
)
def
_test_cast_master_weights_to_fp8
(
quantization
,
dp_group
):
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
torch
.
manual_seed
(
12345
)
torch
.
cuda
.
manual_seed
(
12345
)
mock_groups
=
[
dist
.
new_group
(
ranks
=
[
i
])
for
i
in
range
(
world_size
)]
mock_group
=
mock_groups
[
rank
]
linear_kwargs
=
{
"params_dtype"
:
torch
.
bfloat16
,
"bias"
:
False
,
"fuse_wgrad_accumulation"
:
True
}
# Create model with FP8 weights
with
te
.
fp8
.
fp8_model_init
(
enabled
=
quantization
is
not
None
,
recipe
=
quantization_recipe
(
quantization
),
preserve_high_precision_init_val
=
True
,
):
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Create model with BF16 weights
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
high_precision_init_val
=
w_fp8
.
get_high_precision_init_val
()
w
.
data
.
copy_
(
high_precision_init_val
)
# Allocate main_grads for each weight
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
w_fp8
.
main_grad
=
torch
.
zeros_like
(
w_fp8
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
w
.
main_grad
=
torch
.
zeros_like
(
w
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
optimizer_fp8
=
MiniZero_1
([
w
for
w
in
model_fp8
.
parameters
()],
10.0
,
dp_group
)
optimizer
=
MiniZero_1
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
for
_
in
range
(
100
):
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
w_fp8
.
main_grad
.
zero_
()
w
.
main_grad
.
zero_
()
inputs
=
[
torch
.
randn
(
16
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
with
te
.
fp8
.
fp8_autocast
(
enabled
=
quantization
is
not
None
,
fp8_recipe
=
quantization_recipe
(
quantization
),
fp8_group
=
mock_group
,
):
y_fp8
=
model_fp8
(
x
)
with
te
.
fp8_autocast
(
enabled
=
quantization
is
not
None
,
fp8_recipe
=
quantization_recipe
(
quantization
),
fp8_group
=
mock_group
,
):
y
=
model
(
x
)
targets
=
[
torch
.
randn_like
(
y
)
for
_
in
range
(
world_size
)]
# Choose based on rank to make sure the targets of different ranks are different.
target
=
targets
[
rank
]
loss_fp8
=
nn
.
MSELoss
()(
y_fp8
,
target
)
loss
=
nn
.
MSELoss
()(
y
,
target
)
loss_fp8
.
backward
()
loss
.
backward
()
optimizer_fp8
.
step
()
optimizer
.
step
()
torch
.
testing
.
assert_close
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
)
def
main
(
argv
=
None
,
namespace
=
None
):
WORLD_RANK
=
int
(
os
.
getenv
(
"RANK"
,
"0"
))
WORLD_SIZE
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
LOCAL_RANK
=
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
LOCAL_SIZE
=
int
(
os
.
getenv
(
"LOCAL_WORLD_SIZE"
,
"1"
))
assert
WORLD_SIZE
==
LOCAL_SIZE
# this test supports only 1 node
assert
LOCAL_SIZE
<=
torch
.
cuda
.
device_count
()
dist_init_kwargs
=
{
"backend"
:
"nccl"
,
"rank"
:
WORLD_RANK
,
"world_size"
:
WORLD_SIZE
,
"timeout"
:
datetime
.
timedelta
(
seconds
=
30
),
}
dist_init_kwargs
[
"init_method"
]
=
"env://"
dist_init_kwargs
[
"device_id"
]
=
torch
.
device
(
f
"cuda:
{
LOCAL_RANK
}
"
)
assert
dist
.
is_nccl_available
()
torch
.
cuda
.
set_device
(
LOCAL_RANK
)
dist
.
init_process_group
(
**
dist_init_kwargs
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--quantization"
,
type
=
str
,
default
=
None
,
choices
=
[
"fp8"
,
"fp8_cs"
])
args
=
parser
.
parse_args
(
argv
,
namespace
)
dp_group
=
dist
.
new_group
(
backend
=
"nccl"
)
_test_zero_1
(
dp_group
)
_test_cast_master_weights_to_fp8
(
args
.
quantization
,
dp_group
)
dist
.
destroy_process_group
()
return
0
if
__name__
==
"__main__"
:
sys
.
exit
(
main
())
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
0 → 100644
View file @
035c48c0
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
os
import
subprocess
from
pathlib
import
Path
import
pytest
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
if
torch
.
cuda
.
device_count
()
<
2
:
pytest
.
skip
(
"cast_master_weights_to_fp8 test needs at least 2 GPUs."
)
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
TEST_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
NUM_PROCS
:
int
=
min
(
2
,
torch
.
cuda
.
device_count
())
LAUNCH_CMD
=
[
"torchrun"
,
f
"--nproc_per_node=
{
NUM_PROCS
}
"
]
def
_run_test
(
quantization
):
test_path
=
TEST_ROOT
/
"run_cast_master_weights_to_fp8.py"
test_cmd
=
LAUNCH_CMD
+
[
str
(
test_path
)]
+
[
"--quantization"
,
quantization
]
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
check
=
False
)
assert
result
.
returncode
==
0
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
"fp8"
,
"fp8_cs"
])
def
test_cast_master_weights_to_fp8
(
quantization
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
_run_test
(
quantization
)
tests/pytorch/references/ref_per_tensor_cs.py
View file @
035c48c0
...
...
@@ -8,12 +8,8 @@ import transformer_engine_torch as tex
from
transformer_engine.pytorch.constants
import
TE_DType_To_Torch
# compute amax and scale
def
_ref_compute_amax_scale
(
x
,
quant_dtype
,
eps
,
pow_2_scales
):
x_fp32
=
x
.
to
(
torch
.
float32
)
amax
=
torch
.
amax
(
torch
.
abs
(
x_fp32
)).
view
(
1
)
assert
amax
.
dtype
==
torch
.
float
,
"amax must be a float tensor."
fp8_max
=
torch
.
finfo
(
quant_dtype
).
max
# Compute scale and scale_inv from amax
def
_ref_compute_scale_and_scale_inv_from_amax
(
amax
,
fp8_max
,
eps
,
pow_2_scales
):
# Clamping amax to avoid division by small numbers
amax
=
torch
.
max
(
amax
,
torch
.
tensor
(
eps
))
...
...
@@ -52,6 +48,20 @@ def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
# Compute scale_inv
scale_inv
=
torch
.
reciprocal
(
scale
)
return
scale
,
scale_inv
# compute amax and scale
def
_ref_compute_amax_scale
(
x
,
quant_dtype
,
eps
,
pow_2_scales
):
x_fp32
=
x
.
to
(
torch
.
float32
)
amax
=
torch
.
amax
(
torch
.
abs
(
x_fp32
)).
view
(
1
)
assert
amax
.
dtype
==
torch
.
float
,
"amax must be a float tensor."
fp8_max
=
torch
.
finfo
(
quant_dtype
).
max
scale
,
scale_inv
=
_ref_compute_scale_and_scale_inv_from_amax
(
amax
,
fp8_max
,
eps
,
pow_2_scales
)
# Clamping amax to avoid division by small numbers
amax
=
torch
.
max
(
amax
,
torch
.
tensor
(
eps
))
return
scale
,
scale_inv
,
amax
...
...
@@ -103,3 +113,7 @@ def ref_per_tensor_cs_cast(
qx_t
=
_multi_dim_transpose
(
qx
)
sx_t
=
sx
return
qx
,
sx
,
qx_t
,
sx_t
def
ref_compute_scale_and_scale_inv_from_amax
(
amax
,
fp8_max
,
eps
,
pow_2_scales
):
return
_ref_compute_scale_and_scale_inv_from_amax
(
amax
,
fp8_max
,
eps
,
pow_2_scales
)
tests/pytorch/test_multi_tensor.py
View file @
035c48c0
...
...
@@ -9,6 +9,9 @@ import transformer_engine.pytorch as te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.optimizers
import
MultiTensorApply
from
references.ref_per_tensor_cs
import
ref_compute_scale_and_scale_inv_from_amax
input_size_pairs
=
[
(
7777
*
77
,
555
*
555
),
(
777
,
555
),
...
...
@@ -216,3 +219,42 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type,
if
per_tensor
:
torch
.
testing
.
assert_close
(
norm_per_tensor
,
normab
.
broadcast_to
(
norm_per_tensor
.
shape
))
assert
overflow_buf
.
item
()
==
0
@
pytest
.
mark
.
parametrize
(
"input_size_pair"
,
input_size_pairs
+
[(
1
,
1
)])
@
pytest
.
mark
.
parametrize
(
"applier"
,
appliers
)
@
pytest
.
mark
.
parametrize
(
"repeat"
,
[
1
,
55
])
@
pytest
.
mark
.
parametrize
(
"max_fp8"
,
[
448.0
,
57344.0
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"epsilon"
,
[
0.0
,
100.0
])
def
test_multi_tensor_compute_scale_and_scale_inv
(
input_size_pair
,
applier
,
repeat
,
max_fp8
,
pow_2_scales
,
epsilon
):
sizea
,
sizeb
=
input_size_pair
device
=
torch
.
device
(
"cuda"
)
overflow_buf
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
device
)
a
=
torch
.
randn
([
sizea
],
dtype
=
torch
.
float32
,
device
=
device
).
abs
()
b
=
torch
.
randn
([
sizeb
],
dtype
=
torch
.
float32
,
device
=
device
).
abs
()
amax_list
=
[]
for
i
in
range
(
repeat
):
amax_list
+=
[
a
.
clone
(),
b
.
clone
()]
scale_list
=
[
torch
.
empty_like
(
x
)
for
x
in
amax_list
]
scale_inv_list
=
[
torch
.
empty_like
(
x
)
for
x
in
amax_list
]
applier
(
tex
.
multi_tensor_compute_scale_and_scale_inv
,
overflow_buf
,
[
amax_list
,
scale_list
,
scale_inv_list
],
max_fp8
,
pow_2_scales
,
epsilon
,
)
for
amax
,
scale
,
scale_inv
in
zip
(
amax_list
,
scale_list
,
scale_inv_list
):
scale_ref
,
scale_inv_ref
=
ref_compute_scale_and_scale_inv_from_amax
(
amax
,
max_fp8
,
epsilon
,
pow_2_scales
)
torch
.
testing
.
assert_close
(
scale
,
scale_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
scale_inv
,
scale_inv_ref
,
rtol
=
0
,
atol
=
0
)
tests/pytorch/test_sanity.py
View file @
035c48c0
...
...
@@ -37,7 +37,12 @@ from transformer_engine.common import recipe
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
from
transformer_engine.pytorch.module.base
import
get_workspace
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
test_numerics
import
reset_rng_states
,
dtype_tols
# Only run FP8 tests on supported devices.
...
...
@@ -1207,3 +1212,70 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False
outputs
.
append
(
p
.
grad
)
return
outputs
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_replace_raw_data_for_float8tensor
():
"""Test the functionality of replace_raw_data"""
torch
.
manual_seed
(
12345
)
torch
.
cuda
.
manual_seed
(
12345
)
fp8_quantizer
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
fp8_tensor
=
fp8_quantizer
.
make_empty
([
128
,
128
],
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
random_bf16_data
=
torch
.
randn
(
fp8_tensor
.
shape
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
fp8_quantizer
.
update_quantized
(
random_bf16_data
,
fp8_tensor
)
attrs_to_check
=
[
"_quantizer"
,
"_fp8_dtype"
,
"_scale_inv"
,
"_transpose"
,
"_transpose_invalid"
]
attrs
=
{}
for
attr
in
attrs_to_check
:
attrs
[
attr
]
=
getattr
(
fp8_tensor
,
attr
)
old_data
=
fp8_tensor
.
_data
new_data
=
torch
.
empty_like
(
old_data
)
replace_raw_data
(
fp8_tensor
,
new_data
)
# Make sure the new_data is properly assigned.
assert
fp8_tensor
.
_data
.
data_ptr
()
!=
old_data
.
data_ptr
()
assert
fp8_tensor
.
_data
.
data_ptr
()
==
new_data
.
data_ptr
()
# Make sure the values are not changed.
torch
.
testing
.
assert_close
(
old_data
,
fp8_tensor
.
_data
,
atol
=
0
,
rtol
=
0
)
# Make sure other attributes are not changed (totally identical)
for
attr
in
attrs_to_check
:
assert
id
(
getattr
(
fp8_tensor
,
attr
))
==
id
(
attrs
[
attr
])
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_fp8_model_init_high_precision_init_val
():
"""Test fp8_model_init with preserve_high_precision_init_val=True"""
with
fp8_model_init
(
preserve_high_precision_init_val
=
True
):
model
=
Linear
(
768
,
768
)
weight
=
model
.
weight
assert
isinstance
(
weight
,
QuantizedTensor
),
"Weight should be QuantizedTensor"
assert
hasattr
(
weight
,
"_high_precision_init_val"
),
"_high_precision_init_val not found"
assert
hasattr
(
weight
,
"get_high_precision_init_val"
),
"get_high_precision_init_val() not found"
assert
hasattr
(
weight
,
"clear_high_precision_init_val"
),
"clear_high_precision_init_val() not found"
high_precision
=
weight
.
get_high_precision_init_val
()
assert
high_precision
.
device
.
type
==
"cpu"
,
"high_precision_init_val is not on the CPU"
new_weight
=
weight
.
_get_quantizer
().
make_empty
(
shape
=
weight
.
shape
,
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
weight
.
_get_quantizer
().
update_quantized
(
high_precision
.
to
(
weight
.
device
),
new_weight
)
torch
.
testing
.
assert_close
(
new_weight
.
dequantize
(
dtype
=
weight
.
dtype
),
weight
.
dequantize
(
dtype
=
weight
.
dtype
),
rtol
=
0
,
atol
=
0
,
)
weight
.
clear_high_precision_init_val
()
assert
weight
.
get_high_precision_init_val
()
is
None
,
"clear_high_precision_init_val() not work"
assert
not
hasattr
(
weight
,
"._high_precision_init_val"
),
"clear_high_precision_init_val() not work"
transformer_engine/common/recipe/current_scaling.cu
View file @
035c48c0
...
...
@@ -13,6 +13,7 @@
#include "../common.h"
#include "../util/logging.h"
#include "../util/vectorized_pointwise.h"
#include "recipe_common.cuh"
namespace
transformer_engine
{
namespace
{
...
...
@@ -135,7 +136,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
"Output tensor for amax computation has invalid amax tensor "
"(expected FP32, got dtype="
,
to_string
(
output
.
amax
.
dtype
),
")"
);
CheckOutputTensor
(
output
,
"output_compute_amax"
);
CheckOutputTensor
(
output
,
"output_compute_amax"
,
true
);
// Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
...
...
@@ -151,41 +152,7 @@ namespace {
__global__
void
compute_scale_from_amax_kernel
(
const
float
*
amax_ptr
,
float
*
scale_ptr
,
const
float
max_fp8
,
const
bool
force_pow_2_scales
,
const
float
epsilon
)
{
float
amax
=
*
amax_ptr
;
if
(
amax
<
epsilon
)
{
amax
=
epsilon
;
}
float
scale
=
1.
f
;
if
(
isinf
(
amax
)
||
amax
==
0.
f
)
{
*
scale_ptr
=
scale
;
return
;
}
scale
=
max_fp8
/
amax
;
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
if
(
isinf
(
scale
))
{
// use fp32 max to represent the scale
scale
=
std
::
numeric_limits
<
float
>::
max
();
}
if
(
isnan
(
scale
))
{
scale
=
1.
f
;
}
if
(
force_pow_2_scales
)
{
uint32_t
scale_bits
=
*
reinterpret_cast
<
uint32_t
*>
(
&
scale
);
scale_bits
&=
0xFF800000
;
// If the exponent was zero, we have a logic error.
__builtin_assume
(
scale_bits
!=
0
);
__builtin_assume
(
scale_bits
!=
0x80000000
);
scale
=
*
reinterpret_cast
<
float
*>
(
&
scale_bits
);
}
*
scale_ptr
=
scale
;
*
scale_ptr
=
compute_scale_from_amax
(
*
amax_ptr
,
max_fp8
,
force_pow_2_scales
,
epsilon
);
}
}
// namespace
...
...
transformer_engine/common/recipe/recipe_common.cuh
0 → 100644
View file @
035c48c0
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#include <limits>
namespace
transformer_engine
{
__device__
__forceinline__
float
compute_scale_from_amax
(
float
amax
,
float
max_fp8
,
bool
force_pow_2_scales
,
float
epsilon
)
{
if
(
amax
<
epsilon
)
{
amax
=
epsilon
;
}
float
scale
=
1.
f
;
if
(
isinf
(
amax
)
||
amax
==
0.
f
)
{
return
scale
;
}
// Here we don't use "scale = max_fp8 / amax" because it has different results with/without
// "--use_fast_math".
// "__fdiv_rn" has the same behavior with "max_fp8 / amax" when not using fast math.
scale
=
__fdiv_rn
(
max_fp8
,
amax
);
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
if
(
isinf
(
scale
))
{
// use fp32 max to represent the scale
scale
=
std
::
numeric_limits
<
float
>::
max
();
}
if
(
isnan
(
scale
))
{
scale
=
1.
f
;
}
if
(
force_pow_2_scales
)
{
uint32_t
scale_bits
=
*
reinterpret_cast
<
uint32_t
*>
(
&
scale
);
scale_bits
&=
0xFF800000
;
// If the exponent was zero, we have a logic error.
__builtin_assume
(
scale_bits
!=
0
);
__builtin_assume
(
scale_bits
!=
0x80000000
);
scale
=
*
reinterpret_cast
<
float
*>
(
&
scale_bits
);
}
return
scale
;
}
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
transformer_engine/pytorch/csrc/extensions.h
View file @
035c48c0
...
...
@@ -262,6 +262,8 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads
* FP8 recipe
**************************************************************************************************/
void
compute_amax
(
const
at
::
Tensor
&
tensor
,
at
::
Tensor
&
amax
);
void
fused_amax_and_scale_update_after_reduction
(
const
at
::
Tensor
&
amax_reduction_buffer
,
std
::
vector
<
at
::
Tensor
>
amax_histories
,
std
::
vector
<
at
::
Tensor
>
scales
,
...
...
@@ -369,6 +371,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
float
momentum
,
float
dampening
,
float
lr
,
bool
nesterov
,
bool
first_run
,
bool
wd_after_momentum
,
float
scale
);
void
multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
max_fp8
,
bool
force_pow_2_scales
,
float
epsilon
);
/***************************************************************************************************
* padding
**************************************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu
0 → 100644
View file @
035c48c0
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>
#include "common/recipe/recipe_common.cuh"
#include "common/utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 256
struct
ComputeScaleAndScaleInvFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>
&
tl
,
// NOLINT(*)
float
max_fp8
,
bool
force_pow_2_scales
,
float
epsilon
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
*
amax
=
reinterpret_cast
<
float
*>
(
tl
.
addresses
[
0
][
tensor_loc
]);
amax
+=
chunk_idx
*
chunk_size
;
float
*
scale
=
reinterpret_cast
<
float
*>
(
tl
.
addresses
[
1
][
tensor_loc
]);
scale
+=
chunk_idx
*
chunk_size
;
float
*
scale_inv
=
reinterpret_cast
<
float
*>
(
tl
.
addresses
[
2
][
tensor_loc
]);
scale_inv
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
for
(
int
i_start
=
threadIdx
.
x
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
float
scale_val
=
transformer_engine
::
compute_scale_from_amax
(
amax
[
i_start
],
max_fp8
,
force_pow_2_scales
,
epsilon
);
scale
[
i_start
]
=
scale_val
;
transformer_engine
::
reciprocal
(
scale_inv
+
i_start
,
scale_val
);
}
}
};
void
multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
max_fp8
,
bool
force_pow_2_scales
,
float
epsilon
)
{
using
namespace
at
;
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
ComputeScaleAndScaleInvFunctor
(),
max_fp8
,
force_pow_2_scales
,
epsilon
);
AT_CUDA_CHECK
(
cudaGetLastError
());
}
transformer_engine/pytorch/csrc/extensions/normalization.cpp
View file @
035c48c0
...
...
@@ -6,12 +6,13 @@
#include "common/util/system.h"
#include "extensions.h"
#include "pybind.h"
namespace
transformer_engine
::
pytorch
{
std
::
pair
<
TensorWrapper
,
py
::
object
>
createOutputTensor
(
const
NVTEShape
&
shape
,
DType
dtype
,
py
::
handle
quantizer
)
{
std
::
vector
<
size_t
>
shape_vec
;
for
(
in
t
i
=
0
;
i
<
shape
.
ndim
;
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
shape
.
ndim
;
i
++
)
{
size_t
t
=
shape
.
data
[
i
];
shape_vec
.
push_back
(
t
);
}
...
...
@@ -74,6 +75,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
float
eps
,
py
::
object
out
,
py
::
handle
quantizer
,
DType
out_dtype
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
;
...
...
@@ -107,14 +109,17 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
}
// Determine whether to avoid fused kernel
bool
force_unfused_kernel
=
false
;
if
(
my_quantizer
->
get_scaling_mode
()
==
NVTE_MXFP8_1D_SCALING
)
{
if
(
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_CUDNN_MXFP8_NORM"
,
false
))
{
// TE only supports MXFP8 norm with cuDNN backend
force_unfused_kernel
=
true
;
}
else
if
(
N
%
128
!=
0
||
H
%
128
!=
0
)
{
// cuDNN norm requires full tile for MXFP8
force_unfused_kernel
=
true
;
bool
force_unfused_kernel
=
true
;
if
(
quantizer
.
is_none
())
{
// No need for separate quantization step if output is unquantized
force_unfused_kernel
=
false
;
}
else
if
(
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
// Always used fused kernel for FP8 delayed scaling
force_unfused_kernel
=
false
;
}
else
if
(
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
if
(
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
// cuDNN MXFP8 kernel requires full tile
force_unfused_kernel
=
N
%
128
!=
0
||
H
%
128
!=
0
;
}
}
TensorWrapper
unquantized_out_cu
;
...
...
@@ -145,6 +150,29 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
nvte_compute_amax
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
// check if we need to do amax reudction (depending on model parallel configs)
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
c10
::
intrusive_ptr
<
dist_group_type
>
process_group_ptr
=
my_quantizer_cs
->
amax_reduction_group
;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at
::
Tensor
&
amax_tensor_torch
=
my_quantizer_cs
->
amax
;
std
::
vector
<
at
::
Tensor
>
tensors
=
{
amax_tensor_torch
};
// allreduce amax tensor
c10d
::
AllreduceOptions
allreduce_opts
;
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
nvte_compute_scale_from_amax
(
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_cu
.
defaultShape
);
}
nvte_quantize_noop
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
nullptr
,
at
::
cuda
::
getCurrentCUDAStream
());
}
...
...
@@ -196,6 +224,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
py
::
object
out
,
py
::
handle
quantizer
,
transformer_engine
::
DType
out_dtype
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
;
...
...
@@ -223,14 +252,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
}
// Determine whether to avoid fused kernel
bool
force_unfused_kernel
=
false
;
if
(
my_quantizer
->
get_scaling_mode
()
==
NVTE_MXFP8_1D_SCALING
)
{
if
(
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_CUDNN_MXFP8_NORM"
,
false
))
{
// TE only supports MXFP8 norm with cuDNN backend
force_unfused_kernel
=
true
;
}
else
if
(
N
%
128
!=
0
||
H
%
128
!=
0
)
{
// cuDNN norm requires full tile for MXFP8
force_unfused_kernel
=
true
;
bool
force_unfused_kernel
=
true
;
if
(
quantizer
.
is_none
())
{
// No need for separate quantization step if output is unquantized
force_unfused_kernel
=
false
;
}
else
if
(
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
// Always used fused kernel for FP8 delayed scaling
force_unfused_kernel
=
false
;
}
else
if
(
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
if
(
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
// cuDNN MXFP8 kernel requires full tile
force_unfused_kernel
=
N
%
128
!=
0
||
H
%
128
!=
0
;
}
}
TensorWrapper
unquantized_out_cu
;
...
...
@@ -261,6 +293,29 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
nvte_compute_amax
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
// check if we need to do amax reudction (depending on model parallel configs)
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
c10
::
intrusive_ptr
<
dist_group_type
>
process_group_ptr
=
my_quantizer_cs
->
amax_reduction_group
;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at
::
Tensor
&
amax_tensor_torch
=
my_quantizer_cs
->
amax
;
std
::
vector
<
at
::
Tensor
>
tensors
=
{
amax_tensor_torch
};
// allreduce amax tensor
c10d
::
AllreduceOptions
allreduce_opts
;
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
nvte_compute_scale_from_amax
(
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_cu
.
defaultShape
);
}
nvte_quantize_noop
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
nullptr
,
at
::
cuda
::
getCurrentCUDAStream
());
}
...
...
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
035c48c0
...
...
@@ -181,6 +181,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"get_fused_attn_backend"
,
&
get_fused_attn_backend
,
"Get Fused Attention backend"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"compute_amax"
,
&
compute_amax
,
"Compute amax"
,
py
::
arg
(
"input"
),
py
::
arg
(
"amax"
));
m
.
def
(
"fused_amax_and_scale_update_after_reduction"
,
&
fused_amax_and_scale_update_after_reduction
,
"Update amax history and FP8 scale/scale_inv after reduction"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
...
...
@@ -271,6 +272,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"multi_tensor_sgd"
,
&
multi_tensor_sgd_cuda
,
"Fused SGD optimizer for list of contiguous tensors"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"multi_tensor_compute_scale_and_scale_inv"
,
&
multi_tensor_compute_scale_and_scale_inv_cuda
,
"Fused compute scale and scale_inv from amax"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// Data structures
py
::
class_
<
transformer_engine
::
pytorch
::
FP8TensorMeta
>
(
m
,
"FP8TensorMeta"
)
...
...
transformer_engine/pytorch/csrc/extensions/recipe.cpp
View file @
035c48c0
...
...
@@ -12,10 +12,27 @@
#include "common/common.h"
#include "extensions.h"
void
fused_amax_and_scale_update_after_reduction
(
const
at
::
Tensor
&
amax_reduction_buffer
,
void
compute_amax
(
const
at
::
Tensor
&
tensor
,
at
::
Tensor
&
amax
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
auto
input_tensor
=
tensor
.
contiguous
();
const
TensorWrapper
&
te_input
=
makeTransformerEngineTensor
(
input_tensor
);
TORCH_CHECK
(
amax
.
scalar_type
()
==
at
::
kFloat
,
"amax must be a float tensor"
);
TORCH_CHECK
(
amax
.
numel
()
==
1
,
"amax must have exactly one element"
);
TensorWrapper
fake_te_output
(
nullptr
,
te_input
.
shape
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
// It doesn't matter because we only compute amax.
amax
.
data_ptr
<
float
>
());
nvte_compute_amax
(
te_input
.
data
(),
fake_te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
void
fused_amax_and_scale_update_after_reduction
(
const
at
::
Tensor
&
amax_reduction_buffer
,
std
::
vector
<
at
::
Tensor
>
amax_histories
,
std
::
vector
<
at
::
Tensor
>
scales
,
const
std
::
string
&
amax_compute_algo
,
const
std
::
string
&
amax_compute_algo
,
transformer_engine
::
DType
fp8_dtype
,
float
margin
)
{
using
namespace
transformer_engine
;
...
...
transformer_engine/pytorch/fp8.py
View file @
035c48c0
...
...
@@ -100,6 +100,7 @@ class FP8GlobalStateManager:
FP8_RECIPE
=
None
FP8_DISTRIBUTED_GROUP
=
None
FP8_PARAMETERS
=
False
HIGH_PRECISION_INIT_VAL
=
False
IS_FIRST_FP8_MODULE
=
False
FP8_GRAPH_CAPTURING
=
False
FP8_AUTOCAST_DEPTH
=
0
...
...
@@ -124,6 +125,7 @@ class FP8GlobalStateManager:
cls
.
FP8_RECIPE
=
None
cls
.
FP8_DISTRIBUTED_GROUP
=
None
cls
.
FP8_PARAMETERS
=
False
cls
.
HIGH_PRECISION_INIT_VAL
=
False
cls
.
IS_FIRST_FP8_MODULE
=
False
cls
.
FP8_GRAPH_CAPTURING
=
False
cls
.
FP8_AUTOCAST_DEPTH
=
0
...
...
@@ -274,6 +276,11 @@ class FP8GlobalStateManager:
"""Should the parameters be stored as FP8"""
return
cls
.
FP8_PARAMETERS
@
classmethod
def
with_high_precision_init_val
(
cls
)
->
bool
:
"""Should the high precision initial values be stored with FP8 parameters"""
return
cls
.
HIGH_PRECISION_INIT_VAL
@
classmethod
def
fp8_graph_capturing
(
cls
)
->
bool
:
"""Is CUDA graph capture under way?"""
...
...
@@ -507,7 +514,11 @@ class FP8GlobalStateManager:
@
contextmanager
def
fp8_model_init
(
enabled
:
bool
=
True
,
recipe
:
Optional
[
Recipe
]
=
None
)
->
None
:
def
fp8_model_init
(
enabled
:
bool
=
True
,
recipe
:
Optional
[
Recipe
]
=
None
,
preserve_high_precision_init_val
:
bool
=
False
,
)
->
None
:
"""
Context manager for FP8 initialization of parameters.
...
...
@@ -518,6 +529,12 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non
with fp8_model_init(enabled=True):
model = transformer_engine.pytorch.Linear(768, 768)
# Preserving high precision initial value to initialize master weight
with fp8_model_init(enabled=True, preserve_high_precision_init_val=True):
model = transformer_engine.pytorch.Linear(768, 768)
master_weight = model.weight.get_high_precision_init_val()
model.weight.clear_high_precision_init_val()
Parameters
----------
enabled: bool, default = `True`
...
...
@@ -533,18 +550,29 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non
* LoRA-like fine-tuning, where the main parameters of the model do not change.
recipe: transformer_engine.common.recipe.Recipe, default = `None`
Recipe used to create the parameters. If left to None, it uses the default FP8 recipe.
preserve_high_precision_init_val: bool, default = `False`
when enabled, store the high precision tensor used to initialize FP8 parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high
precision tensor. The purpose is that users can use this high-precision copy
to initialize master weights, avoiding the loss of precision that can occur when
using FP8 parameters directly. Note that after the master weights are initialized,
users should call `clear_high_precision_init_val()` to release this CPU memory.
This functionality is *EXPERIMENTAL*.
"""
_fp8_parameters
=
FP8GlobalStateManager
.
FP8_PARAMETERS
_fp8_recipe
=
FP8GlobalStateManager
.
FP8_RECIPE
_high_precision_init_val
=
FP8GlobalStateManager
.
HIGH_PRECISION_INIT_VAL
FP8GlobalStateManager
.
FP8_PARAMETERS
=
enabled
FP8GlobalStateManager
.
FP8_RECIPE
=
get_default_fp8_recipe
()
if
recipe
is
None
else
recipe
FP8GlobalStateManager
.
HIGH_PRECISION_INIT_VAL
=
preserve_high_precision_init_val
try
:
yield
finally
:
FP8GlobalStateManager
.
FP8_PARAMETERS
=
_fp8_parameters
FP8GlobalStateManager
.
FP8_RECIPE
=
_fp8_recipe
FP8GlobalStateManager
.
HIGH_PRECISION_INIT_VAL
=
_high_precision_init_val
@
contextmanager
...
...
transformer_engine/pytorch/module/_common.py
View file @
035c48c0
...
...
@@ -4,7 +4,6 @@
"""Internal function used by multiple modules."""
import
os
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
,
Callable
from
dataclasses
import
dataclass
from
functools
import
reduce
...
...
@@ -16,9 +15,6 @@ from .. import cpp_extensions as tex
from
..constants
import
TE_DType
from
..utils
import
get_default_init_method
from
..tensor.float8_tensor
import
Float8Tensor
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
_use_cudnn_mxfp8_norm
=
bool
(
int
(
os
.
getenv
(
"NVTE_CUDNN_MXFP8_NORM"
,
"0"
)))
def
_get_normalization_func
(
normalization
:
str
,
forward
:
bool
):
...
...
@@ -86,26 +82,16 @@ def apply_normalization(
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
split_mxfp8_cast
=
False
if
not
_use_cudnn_mxfp8_norm
and
isinstance
(
output_quantizer
,
MXFP8Quantizer
):
split_mxfp8_cast
=
True
output
=
normalization_func
(
return
normalization_func
(
*
inputs
,
eps
,
None
if
split_mxfp8_cast
else
ln_out
,
None
if
split_mxfp8_cast
else
output_quantizer
,
ln_out
,
output_quantizer
,
TE_DType
[
output_dtype
]
if
output_dtype
in
TE_DType
else
output_dtype
,
fwd_ln_sm_margin
,
zero_centered_gamma
,
)
return
(
(
output_quantizer
.
quantize
(
output
[
0
],
out
=
ln_out
),
*
output
[
1
:])
if
split_mxfp8_cast
else
output
)
class
_NoopCatFunc
(
torch
.
autograd
.
Function
):
"""Concatenate tensors, doing a no-op if possible
...
...
transformer_engine/pytorch/module/base.py
View file @
035c48c0
...
...
@@ -10,6 +10,7 @@ import warnings
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
contextlib
import
contextmanager
from
types
import
MethodType
import
torch
import
torch.nn.functional
as
F
...
...
@@ -424,6 +425,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self
.
sequence_parallel
=
False
self
.
param_init_meta
=
{}
self
.
primary_weights_in_fp8
=
FP8GlobalStateManager
.
with_fp8_parameters
()
self
.
preserve_high_precision_init_val
=
FP8GlobalStateManager
.
with_high_precision_init_val
()
self
.
fsdp_wrapped
=
False
self
.
fsdp_group
=
None
self
.
_fp8_workspaces
:
Dict
[
str
,
QuantizedTensor
]
=
{}
...
...
@@ -921,7 +923,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# If primary weights are in fp8, wrap the parameter as FP8Tensor
fp8_meta_index
=
self
.
param_init_meta
[
name
].
fp8_meta_index
high_precision_init_val
=
None
if
self
.
primary_weights_in_fp8
and
fp8_meta_index
is
not
None
:
if
self
.
preserve_high_precision_init_val
:
high_precision_init_val
=
param
.
detach
().
cpu
()
quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
fp8_meta_index
]
assert
(
quantizer
is
not
None
...
...
@@ -933,7 +939,34 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
setattr
(
self
,
name
,
torch
.
nn
.
Parameter
(
param
))
param
=
torch
.
nn
.
Parameter
(
param
)
if
high_precision_init_val
is
not
None
:
# - Master weights are initialized from model weights, if we use fp8 primary
# weights to initialize master weights, the numerical values of master weights
# are not consistent with the numerical values when we initialize them from
# bf16/fp16 weights.
# - So we add a `_high_precision_init_val` attribute to each model weight to store
# the original bf16/fp16 weight on cpu before casting it to fp8. And users can
# use `get_high_precision_init_val` to get this cpu tensor.
# - This cpu tensor is not needed once the master weight is initialized, so users
# should call `clear_high_precision_init_val` to remove it after master weight
# is initialized.
def
get
(
self
):
if
hasattr
(
self
,
"_high_precision_init_val"
):
return
self
.
_high_precision_init_val
return
None
def
clear
(
self
):
if
hasattr
(
self
,
"_high_precision_init_val"
):
del
self
.
_high_precision_init_val
param
.
_high_precision_init_val
=
high_precision_init_val
param
.
get_high_precision_init_val
=
MethodType
(
get
,
param
)
param
.
clear_high_precision_init_val
=
MethodType
(
clear
,
param
)
setattr
(
self
,
name
,
param
)
@
abstractmethod
def
forward
(
self
):
...
...
@@ -972,6 +1005,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FSDP process group that the weights are distributed over.
"""
# FP8 primary weights
if
isinstance
(
tensor
,
QuantizedTensor
):
if
update_workspace
and
quantizer
is
not
None
:
tensor
.
update_usage
(
rowwise_usage
=
quantizer
.
rowwise_usage
,
columnwise_usage
=
quantizer
.
columnwise_usage
,
)
return
tensor
# Try getting workspace from cache
out
=
None
if
cache_name
is
not
None
:
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
035c48c0
...
...
@@ -130,20 +130,17 @@ class _GroupedLinear(torch.autograd.Function):
)
weights_fp8
=
[]
bias_dtype
=
torch
.
bfloat16
if
activation_dtype
==
torch
.
float32
else
activation_dtype
if
not
isinstance
(
weights
[
0
],
QuantizedTensor
):
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
for
i
in
range
(
num_gemms
):
weight_fp8
=
module
.
get_weight_workspace
(
tensor
=
weights
[
i
],
quantizer
=
weight_quantizers
[
i
],
cache_name
=
(
None
if
is_first_microbatch
is
None
else
f
"weight
{
i
}
"
),
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
)
weights_fp8
.
append
(
weight_fp8
)
else
:
weights_fp8
=
weights
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
for
i
in
range
(
num_gemms
):
weight_fp8
=
module
.
get_weight_workspace
(
tensor
=
weights
[
i
],
quantizer
=
weight_quantizers
[
i
],
cache_name
=
(
None
if
is_first_microbatch
is
None
else
f
"weight
{
i
}
"
),
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
)
weights_fp8
.
append
(
weight_fp8
)
else
:
inputmats
=
inputmats_no_fp8
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
035c48c0
...
...
@@ -55,9 +55,9 @@ from ..tensor.quantized_tensor import (
prepare_for_saving
,
restore_from_saved
,
)
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
from
..cpu_offload
import
is_cpu_offload_enabled
,
set_offloading_param
from
..cpp_extensions
import
(
general_gemm
,
...
...
@@ -160,11 +160,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Configure quantizer for normalization output
with_quantized_norm
=
fp8
and
not
return_layernorm_output
# for Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer
# so we need to set with_quantized_norm to False
if
isinstance
(
input_quantizer
,
Float8CurrentScalingQuantizer
):
with_quantized_norm
=
False
if
with_quantized_norm
:
if
with_input_all_gather
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
...
...
@@ -261,28 +256,26 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
# Cast weight to expected dtype
weightmat
=
weight
quantized_weight
=
False
if
not
fp8
:
weightmat
=
cast_if_needed
(
weightmat
,
activation_dtype
)
quantized_weight
=
False
weightmat
=
cast_if_needed
(
weight
,
activation_dtype
)
else
:
if
not
isinstance
(
weight
,
QuantizedTensor
):
quantized_weight
=
True
# Configure quantizer
if
weight_quantizer
is
not
None
:
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
weightmat
=
module
.
get_weight_workspace
(
tensor
=
weight
,
quantizer
=
weight_quantizer
,
cache_name
=
(
None
if
is_first_microbatch
is
None
else
"weight"
),
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
fsdp_group
=
fsdp_group
,
)
quantized_weight
=
not
isinstance
(
weight
,
QuantizedTensor
)
# Configure quantizer
if
weight_quantizer
is
not
None
:
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
weightmat
=
module
.
get_weight_workspace
(
tensor
=
weight
,
quantizer
=
weight_quantizer
,
cache_name
=
(
None
if
is_first_microbatch
is
None
else
"weight"
),
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
fsdp_group
=
fsdp_group
,
)
# Cast bias to expected dtype
bias_dtype
=
activation_dtype
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment