Unverified Commit 86813893 authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

[PyTorch] Enable fp8_primary_weights for current scaling (#1544)



* Enable fp8_primary_weights for current scaling
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Use different cast_master_weights_to_fp8 functions depending on the type of quantizer
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* All amaxes of model_weights should participate in reduce-max
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Clear _high_precision_init_val automatically in cast_master_weights_to_fp8 function
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Merge all all-reduce on amaxes into one NCCL kernel
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add unit tests for multi_tensor_compute_scale_and_scale_inv and preserve_high_precision_init_val
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Fix conflicts
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add unit test for cast_master_weights_to_fp8
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* use mock group to initialize fp8_autocast to avoid reduction of amax_history by fp8_autocast_exit
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Remove with_computing_amax and with_computing_scale
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Move replace_raw_data from QuantizedTensor to utils.py
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Remove allow_empty_output argument from nvte_compute_amax and set it always be true
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Rename import guard of recipe_common.cuh to be align with other import guards
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Add unit test for replace_raw_data
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add test_replace_raw_data into qa/L0_pytorch_unittest/test.sh
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Minor changes in comments
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Add randomness to the unit test of replace_raw_data
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* (Maybe need revert) Add tex.quantize_to_fragment
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* (Maybe needsto rrevert) Use nvte_quantize_noop in quantize_to_fragment
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix lint error
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Move high_precision_init_val test and replace_raw_data test to test_sanity.py
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove test_fp8_model_init.py and test_replace_raw_data.py
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Remove cast_master_weights_to_fp8 and replace_raw_data from __all__ of tensor.__init__.py
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Move FP8 casting logic back from C++ tex funcs to Python
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unimplemented function from header
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Signed-off-by: default avatarKunlun Li <94586211+kunlunl@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent e80fbd7e
......@@ -39,6 +39,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py ||
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || test_fail "test_paged_attn.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
......
......@@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py |
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
......
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import argparse
import datetime
import os
import sys
import torch
from torch import nn
import torch.distributed as dist
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Format,
Recipe,
)
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
def _get_raw_data(quantized_tensor):
"""Get the underlying data of a quantized tensor, used in zero-1 optimizer"""
if isinstance(quantized_tensor, Float8Tensor):
assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute"
assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8"
return quantized_tensor._data
else:
raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}")
class MiniZero_1:
"""A mini zero-1 optimizer implementation, just used for this test"""
def __init__(self, weights, lr, dp_group):
self.rank = dist.get_rank(dp_group)
self.world_size = dist.get_world_size(dp_group)
self.weights = weights
self.lr = lr
self.dp_group = dp_group
# [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer
self.offsets = [0]
for weight in self.weights:
self.offsets.append(self.offsets[-1] + weight.numel())
# Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may
# not be the end range of the last weight.
if self.offsets[-1] % self.world_size != 0:
self.offsets[-1] += self.world_size - self.offsets[-1] % self.world_size
self.master_weights = []
# The start offset of the master weight in the weight
self.start_offsets = []
# The overlapping area of the weight and this rank's local buffer
self.overlapping_areas = []
# The start and end of this rank's local buffer in the global buffer
rank_start = self.offsets[-1] // self.world_size * self.rank
rank_end = rank_start + self.offsets[-1] // self.world_size
for weight, offset in zip(self.weights, self.offsets[:-1]):
if offset >= rank_end or (offset + weight.numel()) <= rank_start:
# This weight is not in this rank's local buffer
master_weight = None
start_offset = None
overlapping_area = None
else:
overlapping_start = max(rank_start, offset)
overlapping_end = min(rank_end, offset + weight.numel())
length = overlapping_end - overlapping_start
start_offset = overlapping_start - offset
if isinstance(weight, QuantizedTensor):
# If weight is a FP8 tensor, we need to use the original high precision version
# to initialize the master weight.
high_precision_init_val = weight.get_high_precision_init_val().view(-1)
master_weight = high_precision_init_val.to(weight.device).float()[
start_offset : start_offset + length
]
else:
master_weight = (
weight.detach().view(-1).float()[start_offset : start_offset + length]
)
overlapping_area = (overlapping_start, overlapping_end)
self.master_weights.append(master_weight)
self.start_offsets.append(start_offset)
self.overlapping_areas.append(overlapping_area)
# Create global buffer for grads reduce-scatter
self.grad_buffer = torch.empty(
[self.offsets[-1]], dtype=torch.float32, device=weights[0].device
)
self.grad_buffer_slice = self.grad_buffer[rank_start:rank_end]
# Create global buffer for weights all-gather
if isinstance(self.weights[0], QuantizedTensor):
weight_buffer_dtype = torch.uint8
else:
weight_buffer_dtype = weights[0].dtype
self.weight_buffer = torch.empty(
[self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device
)
self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end]
def step(self):
# -----------------------------------------------------------------------------------------
# Step 1: Copy grads to the grad buffer
# -----------------------------------------------------------------------------------------
for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset
end = offset + weight.numel()
self.grad_buffer[start:end].copy_(weight.main_grad.view(-1))
# -----------------------------------------------------------------------------------------
# Step 2: Grads reduce-scatter
# -----------------------------------------------------------------------------------------
# Don't use reduce_scatter directly to explicitly control the reduce order.
# dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG,
# group=self.dp_group)
buffers = [torch.empty_like(self.grad_buffer) for _ in range(self.world_size)]
dist.all_gather(buffers, self.grad_buffer, group=self.dp_group)
for i in range(1, self.world_size):
buffers[0] += buffers[i]
rank_start = self.offsets[-1] // self.world_size * self.rank
rank_end = rank_start + self.offsets[-1] // self.world_size
self.grad_buffer_slice.copy_(buffers[0][rank_start:rank_end])
self.grad_buffer_slice /= self.world_size
# -----------------------------------------------------------------------------------------
# Step 3: Update master weights
# -----------------------------------------------------------------------------------------
for master_weight, overlapping_area in zip(self.master_weights, self.overlapping_areas):
if master_weight is None:
# This weight's master weight is in other rank.
continue
grad = self.grad_buffer[overlapping_area[0] : overlapping_area[1]]
master_weight -= grad * self.lr
# -----------------------------------------------------------------------------------------
# Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight
# -----------------------------------------------------------------------------------------
if isinstance(self.weights[0], QuantizedTensor):
# FP8 weights case
for i in range(1, len(self.weights)):
assert isinstance(self.weights[i], QuantizedTensor)
cast_master_weights_to_fp8(
self.weights, self.master_weights, self.start_offsets, self.dp_group
)
else:
# BF16 weights case
for weight, master_weight, start_offset in zip(
self.weights, self.master_weights, self.start_offsets
):
if master_weight is None:
continue
start = start_offset
end = start_offset + master_weight.numel()
weight.data.view(-1)[start:end].copy_(master_weight)
# -----------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -----------------------------------------------------------------------------------------
for i in range(len(self.weights)):
master_weight = self.master_weights[i]
if master_weight is None:
continue
start_offset = self.start_offsets[i]
if isinstance(self.weights[i], QuantizedTensor):
weight = _get_raw_data(self.weights[i])
else:
weight = self.weights[i]
weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()]
overlapping_start, overlapping_end = self.overlapping_areas[i]
self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice)
# -----------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -----------------------------------------------------------------------------------------
dist.all_gather_into_tensor(
self.weight_buffer, self.weight_buffer_slice, group=self.dp_group
)
# -----------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -----------------------------------------------------------------------------------------
for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset
end = offset + weight.numel()
if isinstance(weight, QuantizedTensor):
weight = _get_raw_data(weight)
weight.view(-1).data.copy_(self.weight_buffer[start:end])
class MiniOptimizer:
def __init__(self, weights, lr, dp_group):
self.world_size = dist.get_world_size(dp_group)
self.weights = weights
self.lr = lr
self.dp_group = dp_group
master_weights = []
for weight in self.weights:
master_weights.append(weight.detach().float())
self.master_weights = master_weights
def step(self):
for weight, master_weight in zip(self.weights, self.master_weights):
main_grad = weight.main_grad
# Don't use all-reduce directly to explicitly control the reduce order.
# dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group)
buffers = [torch.empty_like(main_grad) for _ in range(self.world_size)]
dist.all_gather(buffers, main_grad, group=self.dp_group)
for i in range(1, self.world_size):
buffers[0] += buffers[i]
main_grad.copy_(buffers[0])
main_grad /= self.world_size
master_weight -= main_grad * self.lr
weight.data.copy_(master_weight)
def _test_zero_1(dp_group):
"""Make sure the implementation of zero-1 optimizer is correct"""
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
torch.manual_seed(12345)
torch.cuda.manual_seed(12345)
weights = [
torch.randn(256 * 256, dtype=torch.bfloat16, device="cuda"),
torch.randn(256 * 256 * 3, dtype=torch.bfloat16, device="cuda"),
torch.randn(256 * 256 * 2 - 1, dtype=torch.bfloat16, device="cuda"),
]
weights_1 = weights
weights_2 = [weight.clone() for weight in weights]
lr = 1.0
optimizer_1 = MiniZero_1(weights_1, lr, dp_group)
optimizer_2 = MiniOptimizer(weights_2, lr, dp_group)
for _ in range(100):
for w1, w2 in zip(weights_1, weights_2):
main_grads = [
torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the grads of different ranks are different.
main_grad = main_grads[rank]
w1.main_grad = main_grad
w2.main_grad = main_grad
optimizer_1.step()
optimizer_2.step()
for w1, w2 in zip(weights_1, weights_2):
torch.testing.assert_close(w1, w2, atol=0, rtol=0)
def quantization_recipe(quantization) -> Recipe:
"""Quantization recipe setup"""
if quantization == "fp8":
return DelayedScaling(
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)
elif quantization == "fp8_cs":
return Float8CurrentScaling()
else:
raise ValueError(f"Unsupported quantization: {quantization}")
def _test_cast_master_weights_to_fp8(quantization, dp_group):
rank = dist.get_rank(dp_group)
world_size = dist.get_world_size(dp_group)
torch.manual_seed(12345)
torch.cuda.manual_seed(12345)
mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)]
mock_group = mock_groups[rank]
linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True}
# Create model with FP8 weights
with te.fp8.fp8_model_init(
enabled=quantization is not None,
recipe=quantization_recipe(quantization),
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
high_precision_init_val = w_fp8.get_high_precision_init_val()
w.data.copy_(high_precision_init_val)
# Allocate main_grads for each weight
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
w_fp8.main_grad = torch.zeros_like(w_fp8, dtype=torch.float32, device="cuda")
w.main_grad = torch.zeros_like(w, dtype=torch.float32, device="cuda")
optimizer_fp8 = MiniZero_1([w for w in model_fp8.parameters()], 10.0, dp_group)
optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group)
for _ in range(100):
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
w_fp8.main_grad.zero_()
w.main_grad.zero_()
inputs = [
torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
with te.fp8.fp8_autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
):
y_fp8 = model_fp8(x)
with te.fp8_autocast(
enabled=quantization is not None,
fp8_recipe=quantization_recipe(quantization),
fp8_group=mock_group,
):
y = model(x)
targets = [torch.randn_like(y) for _ in range(world_size)]
# Choose based on rank to make sure the targets of different ranks are different.
target = targets[rank]
loss_fp8 = nn.MSELoss()(y_fp8, target)
loss = nn.MSELoss()(y, target)
loss_fp8.backward()
loss.backward()
optimizer_fp8.step()
optimizer.step()
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
def main(argv=None, namespace=None):
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node
assert LOCAL_SIZE <= torch.cuda.device_count()
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
"timeout": datetime.timedelta(seconds=30),
}
dist_init_kwargs["init_method"] = "env://"
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
torch.cuda.set_device(LOCAL_RANK)
dist.init_process_group(**dist_init_kwargs)
parser = argparse.ArgumentParser()
parser.add_argument("--quantization", type=str, default=None, choices=["fp8", "fp8_cs"])
args = parser.parse_args(argv, namespace)
dp_group = dist.new_group(backend="nccl")
_test_zero_1(dp_group)
_test_cast_master_weights_to_fp8(args.quantization, dp_group)
dist.destroy_process_group()
return 0
if __name__ == "__main__":
sys.exit(main())
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import subprocess
from pathlib import Path
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
if torch.cuda.device_count() < 2:
pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(2, torch.cuda.device_count())
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
def _run_test(quantization):
test_path = TEST_ROOT / "run_cast_master_weights_to_fp8.py"
test_cmd = LAUNCH_CMD + [str(test_path)] + ["--quantization", quantization]
result = subprocess.run(test_cmd, env=os.environ, check=False)
assert result.returncode == 0
@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs"])
def test_cast_master_weights_to_fp8(quantization):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
_run_test(quantization)
......@@ -8,12 +8,8 @@ import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType_To_Torch
# compute amax and scale
def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
x_fp32 = x.to(torch.float32)
amax = torch.amax(torch.abs(x_fp32)).view(1)
assert amax.dtype == torch.float, "amax must be a float tensor."
fp8_max = torch.finfo(quant_dtype).max
# Compute scale and scale_inv from amax
def _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales):
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
......@@ -52,6 +48,20 @@ def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
# Compute scale_inv
scale_inv = torch.reciprocal(scale)
return scale, scale_inv
# compute amax and scale
def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
x_fp32 = x.to(torch.float32)
amax = torch.amax(torch.abs(x_fp32)).view(1)
assert amax.dtype == torch.float, "amax must be a float tensor."
fp8_max = torch.finfo(quant_dtype).max
scale, scale_inv = _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales)
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
return scale, scale_inv, amax
......@@ -103,3 +113,7 @@ def ref_per_tensor_cs_cast(
qx_t = _multi_dim_transpose(qx)
sx_t = sx
return qx, sx, qx_t, sx_t
def ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales):
return _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales)
......@@ -9,6 +9,9 @@ import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.optimizers import MultiTensorApply
from references.ref_per_tensor_cs import ref_compute_scale_and_scale_inv_from_amax
input_size_pairs = [
(7777 * 77, 555 * 555),
(777, 555),
......@@ -216,3 +219,42 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type,
if per_tensor:
torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape))
assert overflow_buf.item() == 0
@pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)])
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("max_fp8", [448.0, 57344.0])
@pytest.mark.parametrize("pow_2_scales", [False, True])
@pytest.mark.parametrize("epsilon", [0.0, 100.0])
def test_multi_tensor_compute_scale_and_scale_inv(
input_size_pair, applier, repeat, max_fp8, pow_2_scales, epsilon
):
sizea, sizeb = input_size_pair
device = torch.device("cuda")
overflow_buf = torch.zeros(1, dtype=torch.int32, device=device)
a = torch.randn([sizea], dtype=torch.float32, device=device).abs()
b = torch.randn([sizeb], dtype=torch.float32, device=device).abs()
amax_list = []
for i in range(repeat):
amax_list += [a.clone(), b.clone()]
scale_list = [torch.empty_like(x) for x in amax_list]
scale_inv_list = [torch.empty_like(x) for x in amax_list]
applier(
tex.multi_tensor_compute_scale_and_scale_inv,
overflow_buf,
[amax_list, scale_list, scale_inv_list],
max_fp8,
pow_2_scales,
epsilon,
)
for amax, scale, scale_inv in zip(amax_list, scale_list, scale_inv_list):
scale_ref, scale_inv_ref = ref_compute_scale_and_scale_inv_from_amax(
amax, max_fp8, epsilon, pow_2_scales
)
torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0)
torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0)
......@@ -36,7 +36,12 @@ from transformer_engine.common import recipe
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from test_numerics import reset_rng_states, dtype_tols
# Only run FP8 tests on supported devices.
......@@ -1196,3 +1201,70 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False
outputs.append(p.grad)
return outputs
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
"""Test the functionality of replace_raw_data"""
torch.manual_seed(12345)
torch.cuda.manual_seed(12345)
fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)
attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
attrs = {}
for attr in attrs_to_check:
attrs[attr] = getattr(fp8_tensor, attr)
old_data = fp8_tensor._data
new_data = torch.empty_like(old_data)
replace_raw_data(fp8_tensor, new_data)
# Make sure the new_data is properly assigned.
assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
# Make sure the values are not changed.
torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
# Make sure other attributes are not changed (totally identical)
for attr in attrs_to_check:
assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_model_init_high_precision_init_val():
"""Test fp8_model_init with preserve_high_precision_init_val=True"""
with fp8_model_init(preserve_high_precision_init_val=True):
model = Linear(768, 768)
weight = model.weight
assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
assert hasattr(
weight, "clear_high_precision_init_val"
), "clear_high_precision_init_val() not found"
high_precision = weight.get_high_precision_init_val()
assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"
new_weight = weight._get_quantizer().make_empty(
shape=weight.shape, dtype=weight.dtype, device=weight.device
)
weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)
torch.testing.assert_close(
new_weight.dequantize(dtype=weight.dtype),
weight.dequantize(dtype=weight.dtype),
rtol=0,
atol=0,
)
weight.clear_high_precision_init_val()
assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
assert not hasattr(
weight, "._high_precision_init_val"
), "clear_high_precision_init_val() not work"
......@@ -13,6 +13,7 @@
#include "../common.h"
#include "../util/logging.h"
#include "../util/vectorized_pointwise.h"
#include "recipe_common.cuh"
namespace transformer_engine {
namespace {
......@@ -135,7 +136,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
"Output tensor for amax computation has invalid amax tensor "
"(expected FP32, got dtype=",
to_string(output.amax.dtype), ")");
CheckOutputTensor(output, "output_compute_amax");
CheckOutputTensor(output, "output_compute_amax", true);
// Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
......@@ -151,41 +152,7 @@ namespace {
__global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr,
const float max_fp8, const bool force_pow_2_scales,
const float epsilon) {
float amax = *amax_ptr;
if (amax < epsilon) {
amax = epsilon;
}
float scale = 1.f;
if (isinf(amax) || amax == 0.f) {
*scale_ptr = scale;
return;
}
scale = max_fp8 / amax;
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
if (isinf(scale)) {
// use fp32 max to represent the scale
scale = std::numeric_limits<float>::max();
}
if (isnan(scale)) {
scale = 1.f;
}
if (force_pow_2_scales) {
uint32_t scale_bits = *reinterpret_cast<uint32_t *>(&scale);
scale_bits &= 0xFF800000;
// If the exponent was zero, we have a logic error.
__builtin_assume(scale_bits != 0);
__builtin_assume(scale_bits != 0x80000000);
scale = *reinterpret_cast<float *>(&scale_bits);
}
*scale_ptr = scale;
*scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon);
}
} // namespace
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#include <limits>
namespace transformer_engine {
__device__ __forceinline__ float compute_scale_from_amax(float amax, float max_fp8,
bool force_pow_2_scales, float epsilon) {
if (amax < epsilon) {
amax = epsilon;
}
float scale = 1.f;
if (isinf(amax) || amax == 0.f) {
return scale;
}
// Here we don't use "scale = max_fp8 / amax" because it has different results with/without
// "--use_fast_math".
// "__fdiv_rn" has the same behavior with "max_fp8 / amax" when not using fast math.
scale = __fdiv_rn(max_fp8, amax);
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
if (isinf(scale)) {
// use fp32 max to represent the scale
scale = std::numeric_limits<float>::max();
}
if (isnan(scale)) {
scale = 1.f;
}
if (force_pow_2_scales) {
uint32_t scale_bits = *reinterpret_cast<uint32_t *>(&scale);
scale_bits &= 0xFF800000;
// If the exponent was zero, we have a logic error.
__builtin_assume(scale_bits != 0);
__builtin_assume(scale_bits != 0x80000000);
scale = *reinterpret_cast<float *>(&scale_bits);
}
return scale;
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
......@@ -252,6 +252,8 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads
* FP8 recipe
**************************************************************************************************/
void compute_amax(const at::Tensor &tensor, at::Tensor &amax);
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales,
......@@ -359,6 +361,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale);
void multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float epsilon);
/***************************************************************************************************
* padding
**************************************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>
#include "common/recipe/recipe_common.cuh"
#include "common/utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 256
struct ComputeScaleAndScaleInvFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
TensorListMetadata<3> &tl, // NOLINT(*)
float max_fp8, bool force_pow_2_scales,
float epsilon) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float *amax = reinterpret_cast<float *>(tl.addresses[0][tensor_loc]);
amax += chunk_idx * chunk_size;
float *scale = reinterpret_cast<float *>(tl.addresses[1][tensor_loc]);
scale += chunk_idx * chunk_size;
float *scale_inv = reinterpret_cast<float *>(tl.addresses[2][tensor_loc]);
scale_inv += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) {
float scale_val = transformer_engine::compute_scale_from_amax(amax[i_start], max_fp8,
force_pow_2_scales, epsilon);
scale[i_start] = scale_val;
transformer_engine::reciprocal(scale_inv + i_start, scale_val);
}
}
};
void multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float epsilon) {
using namespace at;
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ComputeScaleAndScaleInvFunctor(), max_fp8, force_pow_2_scales, epsilon);
AT_CUDA_CHECK(cudaGetLastError());
}
......@@ -178,6 +178,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax"));
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction",
py::call_guard<py::gil_scoped_release>());
......@@ -265,6 +266,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_compute_scale_and_scale_inv", &multi_tensor_compute_scale_and_scale_inv_cuda,
"Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>());
// Data structures
py::class_<transformer_engine::pytorch::FP8TensorMeta>(m, "FP8TensorMeta")
......
......@@ -12,10 +12,27 @@
#include "common/common.h"
#include "extensions.h"
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto input_tensor = tensor.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor");
TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element");
TensorWrapper fake_te_output(
nullptr, te_input.shape(),
transformer_engine::DType::kFloat8E4M3, // It doesn't matter because we only compute amax.
amax.data_ptr<float>());
nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
}
void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reduction_buffer,
std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales,
const std::string &amax_compute_algo,
const std::string& amax_compute_algo,
transformer_engine::DType fp8_dtype,
float margin) {
using namespace transformer_engine;
......
......@@ -93,6 +93,7 @@ class FP8GlobalStateManager:
FP8_RECIPE = None
FP8_DISTRIBUTED_GROUP = None
FP8_PARAMETERS = False
HIGH_PRECISION_INIT_VAL = False
IS_FIRST_FP8_MODULE = False
FP8_GRAPH_CAPTURING = False
FP8_AUTOCAST_DEPTH = 0
......@@ -117,6 +118,7 @@ class FP8GlobalStateManager:
cls.FP8_RECIPE = None
cls.FP8_DISTRIBUTED_GROUP = None
cls.FP8_PARAMETERS = False
cls.HIGH_PRECISION_INIT_VAL = False
cls.IS_FIRST_FP8_MODULE = False
cls.FP8_GRAPH_CAPTURING = False
cls.FP8_AUTOCAST_DEPTH = 0
......@@ -267,6 +269,11 @@ class FP8GlobalStateManager:
"""Should the parameters be stored as FP8"""
return cls.FP8_PARAMETERS
@classmethod
def with_high_precision_init_val(cls) -> bool:
"""Should the high precision initial values be stored with FP8 parameters"""
return cls.HIGH_PRECISION_INIT_VAL
@classmethod
def fp8_graph_capturing(cls) -> bool:
"""Is CUDA graph capture under way?"""
......@@ -500,7 +507,11 @@ class FP8GlobalStateManager:
@contextmanager
def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> None:
def fp8_model_init(
enabled: bool = True,
recipe: Optional[Recipe] = None,
preserve_high_precision_init_val: bool = False,
) -> None:
"""
Context manager for FP8 initialization of parameters.
......@@ -511,6 +522,12 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non
with fp8_model_init(enabled=True):
model = transformer_engine.pytorch.Linear(768, 768)
# Preserving high precision initial value to initialize master weight
with fp8_model_init(enabled=True, preserve_high_precision_init_val=True):
model = transformer_engine.pytorch.Linear(768, 768)
master_weight = model.weight.get_high_precision_init_val()
model.weight.clear_high_precision_init_val()
Parameters
----------
enabled: bool, default = `True`
......@@ -526,18 +543,29 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non
* LoRA-like fine-tuning, where the main parameters of the model do not change.
recipe: transformer_engine.common.recipe.Recipe, default = `None`
Recipe used to create the parameters. If left to None, it uses the default FP8 recipe.
preserve_high_precision_init_val: bool, default = `False`
when enabled, store the high precision tensor used to initialize FP8 parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high
precision tensor. The purpose is that users can use this high-precision copy
to initialize master weights, avoiding the loss of precision that can occur when
using FP8 parameters directly. Note that after the master weights are initialized,
users should call `clear_high_precision_init_val()` to release this CPU memory.
This functionality is *EXPERIMENTAL*.
"""
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
_fp8_recipe = FP8GlobalStateManager.FP8_RECIPE
_high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL
FP8GlobalStateManager.FP8_PARAMETERS = enabled
FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe
FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val
try:
yield
finally:
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe
FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val
@contextmanager
......
......@@ -10,6 +10,7 @@ import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager
from types import MethodType
import torch
import torch.nn.functional as F
......@@ -405,6 +406,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.sequence_parallel = False
self.param_init_meta = {}
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
self.fsdp_wrapped = False
self.fsdp_group = None
self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
......@@ -902,7 +904,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# If primary weights are in fp8, wrap the parameter as FP8Tensor
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
high_precision_init_val = None
if self.primary_weights_in_fp8 and fp8_meta_index is not None:
if self.preserve_high_precision_init_val:
high_precision_init_val = param.detach().cpu()
quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
assert (
quantizer is not None
......@@ -914,7 +920,34 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
setattr(self, name, torch.nn.Parameter(param))
param = torch.nn.Parameter(param)
if high_precision_init_val is not None:
# - Master weights are initialized from model weights, if we use fp8 primary
# weights to initialize master weights, the numerical values of master weights
# are not consistent with the numerical values when we initialize them from
# bf16/fp16 weights.
# - So we add a `_high_precision_init_val` attribute to each model weight to store
# the original bf16/fp16 weight on cpu before casting it to fp8. And users can
# use `get_high_precision_init_val` to get this cpu tensor.
# - This cpu tensor is not needed once the master weight is initialized, so users
# should call `clear_high_precision_init_val` to remove it after master weight
# is initialized.
def get(self):
if hasattr(self, "_high_precision_init_val"):
return self._high_precision_init_val
return None
def clear(self):
if hasattr(self, "_high_precision_init_val"):
del self._high_precision_init_val
param._high_precision_init_val = high_precision_init_val
param.get_high_precision_init_val = MethodType(get, param)
param.clear_high_precision_init_val = MethodType(clear, param)
setattr(self, name, param)
@abstractmethod
def forward(self):
......@@ -953,6 +986,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FSDP process group that the weights are distributed over.
"""
# FP8 primary weights
if isinstance(tensor, QuantizedTensor):
if update_workspace and quantizer is not None:
tensor.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
return tensor
# Try getting workspace from cache
out = None
if cache_name is not None:
......
......@@ -130,7 +130,6 @@ class _GroupedLinear(torch.autograd.Function):
)
weights_fp8 = []
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
if not isinstance(weights[0], QuantizedTensor):
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
for i in range(num_gemms):
......@@ -142,8 +141,6 @@ class _GroupedLinear(torch.autograd.Function):
skip_update_flag=skip_fp8_weight_update,
)
weights_fp8.append(weight_fp8)
else:
weights_fp8 = weights
else:
inputmats = inputmats_no_fp8
......
......@@ -256,13 +256,11 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
# Cast weight to expected dtype
weightmat = weight
quantized_weight = False
if not fp8:
weightmat = cast_if_needed(weightmat, activation_dtype)
quantized_weight = False
weightmat = cast_if_needed(weight, activation_dtype)
else:
if not isinstance(weight, QuantizedTensor):
quantized_weight = True
quantized_weight = not isinstance(weight, QuantizedTensor)
# Configure quantizer
if weight_quantizer is not None:
......
......@@ -315,15 +315,12 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total = ln_out
# Cast weights to expected dtype
fc1_weight_final = fc1_weight
fc2_weight_final = fc2_weight
if not fp8:
fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype)
fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype)
fc1_weight_final = cast_if_needed(fc1_weight, activation_dtype)
fc2_weight_final = cast_if_needed(fc2_weight, activation_dtype)
else:
# If weights are not quantized, we call get_weight_workspace,
# which handles weight caching etc.
if not isinstance(fc1_weight, QuantizedTensor):
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
fc1_weight_final = module.get_weight_workspace(
......@@ -334,7 +331,6 @@ class _LayerNormMLP(torch.autograd.Function):
skip_update_flag=skip_fp8_weight_update,
fsdp_group=fsdp_group,
)
if not isinstance(fc2_weight, QuantizedTensor):
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_final = module.get_weight_workspace(
tensor=fc2_weight,
......
......@@ -176,11 +176,9 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# Cast weight to expected dtype
weightmat = weight
if not fp8:
weightmat = cast_if_needed(weightmat, activation_dtype)
weightmat = cast_if_needed(weight, activation_dtype)
else:
if not isinstance(weight, QuantizedTensor):
# Configure quantizer
if weight_quantizer is not None:
columnwise_usage = is_grad_enabled and inp.requires_grad
......
......@@ -7,6 +7,7 @@
import torch
from .quantized_tensor import QuantizedTensor, Quantizer
from .utils import cast_master_weights_to_fp8, replace_raw_data
__all__ = [
"QuantizedTensor",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment