Unverified Commit ac81c85b authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Python `GroupedTensor` (#2654)



* PyTorch-Python GroupedTensor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/tensor/storage/grouped_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove mxfp8 gq test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix recipe tests and FP8 weights
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix device test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Disable grouped weights for unsupported recipes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent 8ebb47e5
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for GroupedTensor class"""
from typing import List, Tuple
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from transformer_engine.pytorch import (
Quantizer,
Float8Quantizer,
Float8CurrentScalingQuantizer,
Float8BlockQuantizer,
MXFP8Quantizer,
NVFP4Quantizer,
)
from transformer_engine.pytorch.constants import TE_DType_To_Torch
import transformer_engine_torch as tex
# Check available recipes
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True
)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
_quantization_params = [
pytest.param(
"fp8_delayed_scaling",
marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8),
),
pytest.param(
"fp8_current_scaling",
marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8),
),
pytest.param(
"fp8_blockwise",
marks=pytest.mark.skipif(
not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling
),
),
pytest.param(
"mxfp8",
marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8),
),
pytest.param(
"nvfp4",
marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4),
),
]
def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer:
"""Create quantizers for given quantization scheme"""
if quantization == "fp8_delayed_scaling":
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device="cuda"),
amax=torch.zeros(1, dtype=torch.float32, device="cuda"),
fp8_dtype=tex.DType.kFloat8E4M3,
)
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device="cuda",
)
quantizer.set_usage(rowwise=True, columnwise=False)
elif quantization == "fp8_blockwise":
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=False,
force_pow_2_scales=True,
amax_epsilon=0.0,
block_scaling_dim=1,
)
elif quantization == "mxfp8":
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
elif quantization == "nvfp4":
quantizer = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)
else:
raise ValueError(f"Unknown quantization scheme: {quantization}")
quantizer.internal = False
return quantizer
def _get_rowwise_data_tensor(qtensor, quantization: str) -> torch.Tensor:
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling"):
return qtensor._data
if quantization in ("fp8_blockwise", "mxfp8", "nvfp4"):
return qtensor._rowwise_data
raise ValueError(f"Unknown quantization scheme: {quantization}")
def _rowwise_offset_bytes(numel: int, quantization: str) -> int:
if quantization == "nvfp4":
return numel // 2
return numel
class TestGroupedTensor:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def test_basic_construction_all_same_shape(self) -> None:
"""Test GroupedTensor construction with all tensors having same shape"""
num_tensors = 4
shape = [(256, 512) for _ in range(num_tensors)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
assert grouped_tensor.num_tensors == num_tensors
assert grouped_tensor.all_same_shape()
assert grouped_tensor.all_same_first_dim()
assert grouped_tensor.all_same_last_dim()
assert grouped_tensor.logical_shape == (num_tensors * 256, 512)
assert grouped_tensor.get_common_first_dim() == 256
assert grouped_tensor.get_common_last_dim() == 512
assert grouped_tensor.has_data()
def test_basic_construction_varying_first_dim(self) -> None:
"""Test GroupedTensor construction with varying first dimension"""
num_tensors = 3
shape = [(128, 512), (256, 512), (384, 512)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
assert grouped_tensor.num_tensors == num_tensors
assert not grouped_tensor.all_same_shape()
assert not grouped_tensor.all_same_first_dim()
assert grouped_tensor.all_same_last_dim()
assert grouped_tensor.get_common_last_dim() == shape[0][1]
assert grouped_tensor.logical_shape == (
sum(v for v, _ in shape),
shape[0][1],
) # sum of first dims
def test_split_into_quantized_tensors_no_quantization(self) -> None:
"""Test split_into_quantized_tensors for unquantized tensors"""
num_tensors = 3
shape = [(256, 512) for _ in range(num_tensors)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
# Get the original data pointer
original_data_ptr = grouped_tensor.data.data_ptr()
# Split into tensors
tensors = grouped_tensor.split_into_quantized_tensors()
assert len(tensors) == num_tensors
# Verify each tensor has correct shape and shares storage
for i, tensor in enumerate(tensors):
assert tensor.shape == shape[i]
assert isinstance(tensor, torch.Tensor)
assert not hasattr(tensor, "_data") # Not a quantized tensor
# Verify data pointer is within the original grouped tensor storage
# The tensor should be a view of the original data
assert tensor.data_ptr() >= original_data_ptr
# Calculate expected offset
expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size()
assert tensor.data_ptr() == original_data_ptr + expected_offset
@pytest.mark.parametrize("quantization", _quantization_params)
def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None:
"""Test split_into_quantized_tensors for quantized tensors"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
device="cuda",
)
# Get the original data pointer
original_data_ptr = grouped_tensor.data.data_ptr()
# Split into tensors
tensors = grouped_tensor.split_into_quantized_tensors()
assert len(tensors) == num_tensors
# Verify each tensor shares storage with the grouped tensor
for i, tensor in enumerate(tensors):
rowwise_data = _get_rowwise_data_tensor(tensor, quantization)
assert rowwise_data is not None
assert rowwise_data.data_ptr() >= original_data_ptr
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
def test_split_varying_shapes(self) -> None:
"""Test split_into_quantized_tensors with varying shapes"""
num_tensors = 3
shape = [(128, 512), (256, 512), (384, 512)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
original_data_ptr = grouped_tensor.data.data_ptr()
tensors = grouped_tensor.split_into_quantized_tensors()
assert len(tensors) == num_tensors
# Verify shapes and storage
cumulative_offset = 0
for i, tensor in enumerate(tensors):
assert tensor.shape == shape[i]
expected_offset = cumulative_offset * tensor.element_size()
assert tensor.data_ptr() == original_data_ptr + expected_offset
cumulative_offset += shape[i][0] * shape[i][1]
@pytest.mark.parametrize("quantization", _quantization_params)
def test_quantize_inplace(self, quantization: str) -> None:
"""Test that quantize is done in-place for all recipes"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
device="cuda",
)
# Get original data pointers before quantization
original_data_ptr = grouped_tensor.data.data_ptr()
original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr()
original_scale_ptr = (
grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None
)
# Create input tensors
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
# Quantize in place
quantized_tensors = grouped_tensor.quantize(input_tensors)
# Verify data pointers haven't changed (in-place operation)
assert grouped_tensor.data.data_ptr() == original_data_ptr
assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr
if original_scale_ptr is not None:
assert grouped_tensor.scale.data_ptr() == original_scale_ptr
# Verify returned tensors point to the same storage
for i, qtensor in enumerate(quantized_tensors):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
@pytest.mark.parametrize("quantization", _quantization_params)
def test_quantize_varying_shapes(self, quantization: str) -> None:
"""Test quantize with varying shapes"""
num_tensors = 3
shape = [(256, 512), (512, 512), (768, 512)]
quantizers = make_quantizer(quantization, num_tensors, shape)
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
device="cuda",
)
# Get original data pointers
original_data_ptr = grouped_tensor.data.data_ptr()
# Create input tensors with varying shapes
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
# Quantize in place
quantized_tensors = grouped_tensor.quantize(input_tensors)
# Verify data pointer hasn't changed
assert grouped_tensor.data.data_ptr() == original_data_ptr
# Verify each tensor points to correct location
cumulative_numel = 0
for qtensor, tensor_shape in zip(quantized_tensors, shape):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
cumulative_numel += tensor_shape[0] * tensor_shape[1]
@pytest.mark.parametrize("quantization", _quantization_params)
def test_static_quantize_method(self, quantization: str) -> None:
"""Test the static quantize method"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
# Create input tensors
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
# Use static quantize method
grouped_tensor = GroupedTensor.create_and_quantize(
tensors=input_tensors,
quantizer=quantizers,
device="cuda",
)
# Verify the grouped tensor was created correctly
assert grouped_tensor.num_tensors == num_tensors
assert grouped_tensor.has_data()
# Verify quantized_tensors were created and point to same storage
assert grouped_tensor.quantized_tensors is not None
assert len(grouped_tensor.quantized_tensors) == num_tensors
original_data_ptr = grouped_tensor.data.data_ptr()
for i, qtensor in enumerate(grouped_tensor.quantized_tensors):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
def test_clear(self) -> None:
"""Test clear method"""
num_tensors = 3
shape = [(256, 512) for _ in range(num_tensors)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
assert grouped_tensor.has_data()
assert grouped_tensor.num_tensors == num_tensors
grouped_tensor.clear()
assert not grouped_tensor.has_data()
assert grouped_tensor.num_tensors == 0
assert grouped_tensor.data is None
assert grouped_tensor.logical_shape == (0, 0)
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
from typing import Optional
from typing import Optional, List
import torch
import pytest
......@@ -137,6 +137,117 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset()
def check_grouped_tensor_pointers_helper(tensors, num_elems_in_byte=1, tensor_name="tensor"):
"""
Verify that tensors are stored in contiguous memory.
Args:
tensors: List or iterable of tensors to check
num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4)
tensor_name: Name to use in error messages
"""
tensor_list = list(tensors)
if len(tensor_list) < 2:
return # Nothing to check
for i in range(1, len(tensor_list)):
prev_tensor = tensor_list[i - 1]
curr_tensor = tensor_list[i]
# Calculate expected offset based on previous tensor size
prev_numel = prev_tensor.numel()
expected_offset = (prev_numel // num_elems_in_byte) * prev_tensor.element_size()
# Verify current tensor's data pointer is correctly offset
expected_ptr = prev_tensor.data_ptr() + expected_offset
actual_ptr = curr_tensor.data_ptr()
assert (
actual_ptr == expected_ptr
), f"{tensor_name} {i} data pointer mismatch: expected {expected_ptr}, got {actual_ptr}"
def check_grouped_tensor_pointers(
weights: List[torch.Tensor], fp8_recipe: Optional[recipe.Recipe] = None
):
"""
Verify that the pointers of the weights are in contiguous memory for GroupedTensor.
TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach.
"""
num_elems_in_a_data_byte = 1 if fp8_recipe is None else 2 if fp8_recipe.nvfp4() else 1
# Check data.
if hasattr(weights[0], "_data") and weights[0]._data is not None:
data_tensors = [w._data for w in weights]
check_grouped_tensor_pointers_helper(data_tensors, num_elems_in_byte=1, tensor_name="data")
# Check transpose.
if hasattr(weights[0], "_transpose") and weights[0]._transpose is not None:
transpose_tensors = [w._transpose for w in weights]
check_grouped_tensor_pointers_helper(
transpose_tensors, num_elems_in_byte=1, tensor_name="transpose"
)
# Check scale_inv.
if hasattr(weights[0], "_scale_inv") and weights[0]._scale_inv is not None:
scale_inv_tensors = [w._scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
scale_inv_tensors, num_elems_in_byte=1, tensor_name="scale_inv"
)
# Check rowwise scale_inv.
if hasattr(weights[0], "_rowwise_scale_inv") and weights[0]._rowwise_scale_inv is not None:
scale_inv_tensors = [w._rowwise_scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
scale_inv_tensors, num_elems_in_byte=1, tensor_name="rowwise_scale_inv"
)
# Check columnwise scale_inv.
if (
hasattr(weights[0], "_columnwise_scale_inv")
and weights[0]._columnwise_scale_inv is not None
):
columnwise_scale_inv_tensors = [w._columnwise_scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_scale_inv_tensors,
num_elems_in_byte=1,
tensor_name="columnwise scale_inv",
)
# Check rowwise amax.
if hasattr(weights[0], "_rowwise_amax") and weights[0]._rowwise_amax is not None:
rowwise_amax_tensors = [w._rowwise_amax for w in weights]
check_grouped_tensor_pointers_helper(
rowwise_amax_tensors, num_elems_in_byte=1, tensor_name="rowwise amax"
)
# Check columnwise amax.
if hasattr(weights[0], "_columnwise_amax") and weights[0]._columnwise_amax is not None:
columnwise_amax_tensors = [w._columnwise_amax for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_amax_tensors, num_elems_in_byte=1, tensor_name="columnwise amax"
)
# Check rowwise data.
if hasattr(weights[0], "_rowwise_data") and weights[0]._rowwise_data is not None:
rowwise_data_tensors = [w._rowwise_data for w in weights]
check_grouped_tensor_pointers_helper(
rowwise_data_tensors,
num_elems_in_byte=num_elems_in_a_data_byte,
tensor_name="rowwise data",
)
# Check columnwise data.
if hasattr(weights[0], "_columnwise_data") and weights[0]._columnwise_data is not None:
columnwise_data_tensors = [w._columnwise_data for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_data_tensors,
num_elems_in_byte=num_elems_in_a_data_byte,
tensor_name="columnwise data",
)
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
......@@ -495,9 +606,18 @@ def test_sanity_grouped_linear(
use_fp8 = fp8_recipe is not None
with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
te_grouped_linear = GroupedLinear(
num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
num_gemms,
config.hidden_size,
ffn_hidden_size,
bias=use_bias,
params_dtype=dtype,
).cuda()
# Verify that weights are stored in contiguous GroupedTensor storage.
weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)]
if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()):
check_grouped_tensor_pointers(weights, fp8_recipe)
inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
......@@ -956,7 +1076,13 @@ def test_replace_raw_data_for_float8tensor():
random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)
attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
attrs_to_check = [
"_quantizer",
"_fp8_dtype",
"_scale_inv",
"_transpose",
"_transpose_invalid",
]
attrs = {}
for attr in attrs_to_check:
attrs[attr] = getattr(fp8_tensor, attr)
......
......@@ -88,33 +88,40 @@ class Recipe:
Base recipe class.
"""
def nvfp4(self):
@classmethod
def nvfp4(cls):
"""Whether the given recipe is NVFP4 1D block scaling."""
return isinstance(self, NVFP4BlockScaling)
return issubclass(cls, NVFP4BlockScaling)
def mxfp8(self):
@classmethod
def mxfp8(cls):
"""Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling)
return issubclass(cls, MXFP8BlockScaling)
def delayed(self):
@classmethod
def delayed(cls):
"""Whether the given recipe is delayed scaling."""
return isinstance(self, DelayedScaling)
return issubclass(cls, DelayedScaling)
def float8_current_scaling(self):
@classmethod
def float8_current_scaling(cls):
"""Whether the given recipe is (per-tensor) current scaling."""
return isinstance(self, Float8CurrentScaling)
return issubclass(cls, Float8CurrentScaling)
def float8_per_tensor_scaling(self):
@classmethod
def float8_per_tensor_scaling(cls):
"""Whether the given recipe is per-tensor scaling."""
return isinstance(self, (DelayedScaling, Float8CurrentScaling))
return issubclass(cls, (DelayedScaling, Float8CurrentScaling))
def float8_block_scaling(self):
@classmethod
def float8_block_scaling(cls):
"""Whether the given recipe is float8 blockwise scaling."""
return isinstance(self, Float8BlockScaling)
return issubclass(cls, Float8BlockScaling)
def custom(self):
@classmethod
def custom(cls):
"""Whether the given recipe is custom."""
return isinstance(self, CustomRecipe)
return issubclass(cls, CustomRecipe)
@dataclass()
......
......@@ -13,6 +13,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from .base import (
get_dummy_wgrad,
TransformerEngineBaseModule,
......@@ -147,7 +148,10 @@ class _GroupedLinear(torch.autograd.Function):
# tensors (like scales), but bulk allocation shares storage across all tensors,
# so if scales can't be offloaded, nothing in the group can be offloaded.
inputmats = tex.split_quantize(
inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading
inp_view,
m_splits,
input_quantizers,
disable_bulk_allocation=cpu_offloading,
)
elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
......@@ -365,7 +369,10 @@ class _GroupedLinear(torch.autograd.Function):
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output = DebugQuantizer.multi_tensor_quantize(
grad_output_view, ctx.grad_output_quantizers, ctx.m_splits, ctx.activation_dtype
grad_output_view,
ctx.grad_output_quantizers,
ctx.m_splits,
ctx.activation_dtype,
)
else:
# Only split grad output. Grad bias is fused with
......@@ -436,7 +443,8 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.input_quantizers[0] is not None:
for input_quantizer in ctx.input_quantizers:
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
input_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
input_quantizer.set_usage(rowwise=True, columnwise=True)
else:
......@@ -446,7 +454,10 @@ class _GroupedLinear(torch.autograd.Function):
inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers)
elif ctx.debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype
inp_view,
ctx.input_quantizers,
ctx.m_splits,
ctx.activation_dtype,
)
else:
inputmats = torch.split(
......@@ -616,7 +627,7 @@ class GroupedLinear(TransformerEngineBaseModule):
) -> None:
super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.num_gemms = num_gemms
self.in_features = in_features
self.out_features = out_features
......@@ -631,12 +642,19 @@ class GroupedLinear(TransformerEngineBaseModule):
assert (
not ub_overlap_rs and not ub_overlap_ag
), "GroupedLinear doesn't support Userbuffer overlap."
self.init_method = init_method
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self.wgrad_store = WeightGradStore(delay_wgrad_compute)
self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1}
self._offsets = {
"input": 0,
"weight": 1,
"output": 2,
"grad_output": 0,
"grad_input": 1,
}
self._num_fp8_tensors_per_gemm = {
"fwd": 3,
"bwd": 2,
......@@ -678,7 +696,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self.out_features,
self.in_features,
device=device,
dtype=params_dtype,
dtype=self.params_dtype,
),
),
init_fn=init_method,
......@@ -694,13 +712,13 @@ class GroupedLinear(TransformerEngineBaseModule):
torch.empty(
self.out_features,
device=device,
dtype=params_dtype,
dtype=self.params_dtype,
),
),
init_fn=init_method_constant(0.0),
)
else:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
bias = torch.Tensor().to(dtype=self.params_dtype, device=device)
setattr(self, f"bias{i}", bias)
if self.primary_weights_in_fp8:
......@@ -724,8 +742,61 @@ class GroupedLinear(TransformerEngineBaseModule):
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
def make_grouped_weights(self, defer_init=False) -> None:
"""
Convert parameters into a GroupedTensor and re-register them as parameters.
"""
if defer_init:
return
weight_quantizers = self._get_weight_quantizers()
recipe = (
weight_quantizers[0]._get_compatible_recipe()
if weight_quantizers and weight_quantizers[0] is not None
else None
)
if recipe is not None and (recipe.delayed() or recipe.float8_current_scaling()):
self.set_tensor_parallel_attributes(defer_init=defer_init)
return
weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
# Create the weight storage.
grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=self.num_gemms,
shape=[(self.out_features, self.in_features)] * self.num_gemms,
quantizer=weight_quantizers[0],
dtype=self.params_dtype,
device=weights[0].device,
)
# Copy existing params into storage.
with torch.no_grad():
for i in range(self.num_gemms):
if self.primary_weights_in_fp8:
grouped_weights.quantized_tensors[i].copy_from_storage(weights[i])
else:
grouped_weights.quantized_tensors[i].copy_(weights[i])
# Re-register the grouped weights as parameters.
for i in range(self.num_gemms):
self.register_parameter(
f"weight{i}",
torch.nn.Parameter(grouped_weights.quantized_tensors[i]),
init_fn=self.init_method,
get_rng_state_tracker=self.get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"],
)
self.set_tensor_parallel_attributes(defer_init=defer_init)
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
self.make_grouped_weights(defer_init=defer_init)
def set_tensor_parallel_attributes(self, defer_init=False) -> None:
"""Set attributes needed for TP"""
if not defer_init:
# Set parallelism attributes for linear weights
......@@ -925,7 +996,7 @@ class GroupedLinear(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8 and not self.fp8_calibration:
if not self.fp8 and not self.fp8_calibration and not self.primary_weights_in_fp8:
return [None] * self.num_gemms
weight_quantizers = [
self.quantizers["scaling_fwd"][
......@@ -934,7 +1005,7 @@ class GroupedLinear(TransformerEngineBaseModule):
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
weight_quantizers[i].internal = not self.primary_weights_in_fp8
return weight_quantizers
def _get_quantizers(self):
......
......@@ -69,7 +69,9 @@ class QuantizedTensorStorage:
f"{self.__class__.__name__} class does not implement get_usages function"
)
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement prepare_for_saving function"
......@@ -115,11 +117,18 @@ class QuantizedTensorStorage:
warnings.warn("Quantizer is being updated, this may affect model behavior")
self._quantizer = quantizer
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data from another QuantizedTensorStorage."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement copy_from_storage function"
)
def prepare_for_saving(
*tensors: Union[torch.Tensor, QuantizedTensorStorage],
) -> Tuple[
list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorStorage]]
list[Optional[Union[torch.Tensor, torch.nn.Parameter]]],
list[Optional[QuantizedTensorStorage]],
]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
......@@ -144,7 +153,10 @@ def restore_from_saved(
return_saved_tensors: bool = False,
) -> (
list[Optional[torch.Tensor | QuantizedTensorStorage]]
| tuple[list[Optional[torch.Tensor | QuantizedTensorStorage]], list[Optional[torch.Tensor]]]
| tuple[
list[Optional[torch.Tensor | QuantizedTensorStorage]],
list[Optional[torch.Tensor]],
]
):
"""Recombine the tensor data and metadata during backward pass."""
tensor_objects = []
......
......@@ -11,7 +11,11 @@ from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Recipe,
)
from ..utils import canonicalize_process_group, devices_match
from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func
from ..quantized_tensor import QuantizedTensor, Quantizer
......@@ -154,6 +158,10 @@ class Float8Quantizer(Quantizer):
amin, amax = tensor.aminmax()
self.amax.copy_(torch.max(-amin, amax))
def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of the columnwise data for Float8 1D blockwise quantization."""
return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1])
def create_tensor_from_data(
self,
data: torch.Tensor,
......@@ -408,6 +416,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
quantizer=self,
)
def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of the columnwise data for Float8 1D blockwise quantization."""
return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1])
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Function using primitives with ONNX defined translations."""
if tensor.dtype != torch.float32:
......@@ -769,7 +781,10 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
kwargs,
)
return Float8Tensor.make_like(
tensor, data=func_out, data_transpose=func_transposed_out, shape=func_out.shape
tensor,
data=func_out,
data_transpose=func_transposed_out,
shape=func_out.shape,
)
if func == torch.ops.aten.detach.default:
......
......@@ -164,6 +164,49 @@ class MXFP8Quantizer(Quantizer):
# TODO(ksivamani): No calibration needed for mxfp8?
pass
def get_scale_shape(
self,
shape: Iterable[int],
columnwise: bool,
) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization.
This method determines the shape of the scaling tensor needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False)
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For MXFP8 1D blockwise quantization, blocksize is 32
Swizzle kernel will be performed before GEMM to suit the need of CuBLAS.
CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
"""
if columnwise:
# Columnwise: scale_inv shape is [prod(shape[:-1]) // BLOCK_SIZE, shape[-1]]
# with padding to multiples of [4, 128]
return (
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128),
)
# Rowwise: scale_inv shape is [prod(shape[:-1]), shape[-1] // BLOCK_SIZE]
# with padding to multiples of [128, 4]
return (
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
)
def get_columnwise_shape(self, rowwise_data_shape: Tuple[int, ...]) -> Tuple[int, ...]:
"""Calculate the shape of the columnwise data for MXFP8 1D blockwise quantization."""
return rowwise_data_shape
def create_tensor_from_data(
self,
data: torch.Tensor,
......@@ -704,7 +747,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=fp8_dtype,
dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
shape=(rowwise_data.shape if rowwise_data is not None else columnwise_data.shape),
quantizer=self._quantizer,
with_gemm_swizzled_scales=False,
)
......
......@@ -341,7 +341,10 @@ class NVFP4Quantizer(Quantizer):
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
columnwise_scale_shape,
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
amax_columnwise = torch.zeros(
1, dtype=torch.float32, device=device, pin_memory=pin_memory
......
......@@ -7,3 +7,4 @@ from .float8_tensor_storage import Float8TensorStorage # noqa: F401
from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401
from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401
from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401
from .grouped_tensor import GroupedTensor # noqa: F401
......@@ -73,6 +73,24 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
if t is not None:
t.data = _empty_tensor()
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data buffers from another Float8BlockwiseQTensorStorage."""
if not isinstance(src, Float8BlockwiseQTensorStorage):
raise TypeError("copy_from_storage expects Float8BlockwiseQTensorStorage")
if self._fp8_dtype != src._fp8_dtype:
raise RuntimeError("FP8 dtype mismatch in copy_from_storage")
if self._is_2D_scaled != src._is_2D_scaled:
raise RuntimeError("Scale layout mismatch in copy_from_storage")
def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]):
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
_copy_optional(self._rowwise_data, src._rowwise_data)
_copy_optional(self._columnwise_data, src._columnwise_data)
_copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv)
_copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv)
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
......
......@@ -104,6 +104,24 @@ class Float8TensorStorage(QuantizedTensorStorage):
t.data = _empty_tensor()
self._transpose_invalid = True
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data buffers from another Float8TensorStorage."""
if not isinstance(src, Float8TensorStorage):
raise TypeError("copy_from_storage expects Float8TensorStorage")
if self._fp8_dtype != src._fp8_dtype:
raise RuntimeError("FP8 dtype mismatch in copy_from_storage")
def _copy_optional(
dst: Optional[torch.Tensor],
src_tensor: Optional[torch.Tensor],
):
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
_copy_optional(self._data, src._data)
_copy_optional(self._transpose, src._transpose)
_copy_optional(self._scale_inv, src._scale_inv)
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
......
This diff is collapsed.
......@@ -111,6 +111,24 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
if t is not None:
t.data = _empty_tensor()
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data buffers from another MXFP8TensorStorage."""
if not isinstance(src, MXFP8TensorStorage):
raise TypeError("copy_from_storage expects MXFP8TensorStorage")
if self._fp8_dtype != src._fp8_dtype:
raise RuntimeError("FP8 dtype mismatch in copy_from_storage")
if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales:
raise RuntimeError("Scale layout mismatch in copy_from_storage")
def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]):
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
_copy_optional(self._rowwise_data, src._rowwise_data)
_copy_optional(self._columnwise_data, src._columnwise_data)
_copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv)
_copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv)
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
......
......@@ -136,6 +136,26 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
if t is not None:
t.data = _empty_tensor()
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data buffers from another NVFP4TensorStorage."""
if not isinstance(src, NVFP4TensorStorage):
raise TypeError("copy_from_storage expects NVFP4TensorStorage")
if self._fp4_dtype != src._fp4_dtype:
raise RuntimeError("FP4 dtype mismatch in copy_from_storage")
if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales:
raise RuntimeError("Scale layout mismatch in copy_from_storage")
def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]):
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
_copy_optional(self._rowwise_data, src._rowwise_data)
_copy_optional(self._columnwise_data, src._columnwise_data)
_copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv)
_copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv)
_copy_optional(self._amax_rowwise, src._amax_rowwise)
_copy_optional(self._amax_columnwise, src._amax_columnwise)
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment