Unverified Commit ac81c85b authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Python `GroupedTensor` (#2654)



* PyTorch-Python GroupedTensor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/tensor/storage/grouped_tensor.py
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove mxfp8 gq test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix recipe tests and FP8 weights
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix device test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Disable grouped weights for unsupported recipes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent 8ebb47e5
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for GroupedTensor class"""
from typing import List, Tuple
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from transformer_engine.pytorch import (
Quantizer,
Float8Quantizer,
Float8CurrentScalingQuantizer,
Float8BlockQuantizer,
MXFP8Quantizer,
NVFP4Quantizer,
)
from transformer_engine.pytorch.constants import TE_DType_To_Torch
import transformer_engine_torch as tex
# Check available recipes
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
return_reason=True
)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
_quantization_params = [
pytest.param(
"fp8_delayed_scaling",
marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8),
),
pytest.param(
"fp8_current_scaling",
marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8),
),
pytest.param(
"fp8_blockwise",
marks=pytest.mark.skipif(
not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling
),
),
pytest.param(
"mxfp8",
marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8),
),
pytest.param(
"nvfp4",
marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4),
),
]
def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer:
"""Create quantizers for given quantization scheme"""
if quantization == "fp8_delayed_scaling":
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device="cuda"),
amax=torch.zeros(1, dtype=torch.float32, device="cuda"),
fp8_dtype=tex.DType.kFloat8E4M3,
)
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device="cuda",
)
quantizer.set_usage(rowwise=True, columnwise=False)
elif quantization == "fp8_blockwise":
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=False,
force_pow_2_scales=True,
amax_epsilon=0.0,
block_scaling_dim=1,
)
elif quantization == "mxfp8":
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
elif quantization == "nvfp4":
quantizer = NVFP4Quantizer(
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=False,
stochastic_rounding=False,
with_random_sign_mask=False,
)
else:
raise ValueError(f"Unknown quantization scheme: {quantization}")
quantizer.internal = False
return quantizer
def _get_rowwise_data_tensor(qtensor, quantization: str) -> torch.Tensor:
if quantization in ("fp8_delayed_scaling", "fp8_current_scaling"):
return qtensor._data
if quantization in ("fp8_blockwise", "mxfp8", "nvfp4"):
return qtensor._rowwise_data
raise ValueError(f"Unknown quantization scheme: {quantization}")
def _rowwise_offset_bytes(numel: int, quantization: str) -> int:
if quantization == "nvfp4":
return numel // 2
return numel
class TestGroupedTensor:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def test_basic_construction_all_same_shape(self) -> None:
"""Test GroupedTensor construction with all tensors having same shape"""
num_tensors = 4
shape = [(256, 512) for _ in range(num_tensors)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
assert grouped_tensor.num_tensors == num_tensors
assert grouped_tensor.all_same_shape()
assert grouped_tensor.all_same_first_dim()
assert grouped_tensor.all_same_last_dim()
assert grouped_tensor.logical_shape == (num_tensors * 256, 512)
assert grouped_tensor.get_common_first_dim() == 256
assert grouped_tensor.get_common_last_dim() == 512
assert grouped_tensor.has_data()
def test_basic_construction_varying_first_dim(self) -> None:
"""Test GroupedTensor construction with varying first dimension"""
num_tensors = 3
shape = [(128, 512), (256, 512), (384, 512)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
assert grouped_tensor.num_tensors == num_tensors
assert not grouped_tensor.all_same_shape()
assert not grouped_tensor.all_same_first_dim()
assert grouped_tensor.all_same_last_dim()
assert grouped_tensor.get_common_last_dim() == shape[0][1]
assert grouped_tensor.logical_shape == (
sum(v for v, _ in shape),
shape[0][1],
) # sum of first dims
def test_split_into_quantized_tensors_no_quantization(self) -> None:
"""Test split_into_quantized_tensors for unquantized tensors"""
num_tensors = 3
shape = [(256, 512) for _ in range(num_tensors)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
# Get the original data pointer
original_data_ptr = grouped_tensor.data.data_ptr()
# Split into tensors
tensors = grouped_tensor.split_into_quantized_tensors()
assert len(tensors) == num_tensors
# Verify each tensor has correct shape and shares storage
for i, tensor in enumerate(tensors):
assert tensor.shape == shape[i]
assert isinstance(tensor, torch.Tensor)
assert not hasattr(tensor, "_data") # Not a quantized tensor
# Verify data pointer is within the original grouped tensor storage
# The tensor should be a view of the original data
assert tensor.data_ptr() >= original_data_ptr
# Calculate expected offset
expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size()
assert tensor.data_ptr() == original_data_ptr + expected_offset
@pytest.mark.parametrize("quantization", _quantization_params)
def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None:
"""Test split_into_quantized_tensors for quantized tensors"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
device="cuda",
)
# Get the original data pointer
original_data_ptr = grouped_tensor.data.data_ptr()
# Split into tensors
tensors = grouped_tensor.split_into_quantized_tensors()
assert len(tensors) == num_tensors
# Verify each tensor shares storage with the grouped tensor
for i, tensor in enumerate(tensors):
rowwise_data = _get_rowwise_data_tensor(tensor, quantization)
assert rowwise_data is not None
assert rowwise_data.data_ptr() >= original_data_ptr
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
def test_split_varying_shapes(self) -> None:
"""Test split_into_quantized_tensors with varying shapes"""
num_tensors = 3
shape = [(128, 512), (256, 512), (384, 512)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
original_data_ptr = grouped_tensor.data.data_ptr()
tensors = grouped_tensor.split_into_quantized_tensors()
assert len(tensors) == num_tensors
# Verify shapes and storage
cumulative_offset = 0
for i, tensor in enumerate(tensors):
assert tensor.shape == shape[i]
expected_offset = cumulative_offset * tensor.element_size()
assert tensor.data_ptr() == original_data_ptr + expected_offset
cumulative_offset += shape[i][0] * shape[i][1]
@pytest.mark.parametrize("quantization", _quantization_params)
def test_quantize_inplace(self, quantization: str) -> None:
"""Test that quantize is done in-place for all recipes"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
device="cuda",
)
# Get original data pointers before quantization
original_data_ptr = grouped_tensor.data.data_ptr()
original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr()
original_scale_ptr = (
grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None
)
# Create input tensors
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
# Quantize in place
quantized_tensors = grouped_tensor.quantize(input_tensors)
# Verify data pointers haven't changed (in-place operation)
assert grouped_tensor.data.data_ptr() == original_data_ptr
assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr
if original_scale_ptr is not None:
assert grouped_tensor.scale.data_ptr() == original_scale_ptr
# Verify returned tensors point to the same storage
for i, qtensor in enumerate(quantized_tensors):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
@pytest.mark.parametrize("quantization", _quantization_params)
def test_quantize_varying_shapes(self, quantization: str) -> None:
"""Test quantize with varying shapes"""
num_tensors = 3
shape = [(256, 512), (512, 512), (768, 512)]
quantizers = make_quantizer(quantization, num_tensors, shape)
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
device="cuda",
)
# Get original data pointers
original_data_ptr = grouped_tensor.data.data_ptr()
# Create input tensors with varying shapes
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
# Quantize in place
quantized_tensors = grouped_tensor.quantize(input_tensors)
# Verify data pointer hasn't changed
assert grouped_tensor.data.data_ptr() == original_data_ptr
# Verify each tensor points to correct location
cumulative_numel = 0
for qtensor, tensor_shape in zip(quantized_tensors, shape):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
cumulative_numel += tensor_shape[0] * tensor_shape[1]
@pytest.mark.parametrize("quantization", _quantization_params)
def test_static_quantize_method(self, quantization: str) -> None:
"""Test the static quantize method"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
# Create input tensors
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]
# Use static quantize method
grouped_tensor = GroupedTensor.create_and_quantize(
tensors=input_tensors,
quantizer=quantizers,
device="cuda",
)
# Verify the grouped tensor was created correctly
assert grouped_tensor.num_tensors == num_tensors
assert grouped_tensor.has_data()
# Verify quantized_tensors were created and point to same storage
assert grouped_tensor.quantized_tensors is not None
assert len(grouped_tensor.quantized_tensors) == num_tensors
original_data_ptr = grouped_tensor.data.data_ptr()
for i, qtensor in enumerate(grouped_tensor.quantized_tensors):
rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
numel = shape[i][0] * shape[i][1]
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
def test_clear(self) -> None:
"""Test clear method"""
num_tensors = 3
shape = [(256, 512) for _ in range(num_tensors)]
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=None,
device="cuda",
dtype=torch.float32,
)
assert grouped_tensor.has_data()
assert grouped_tensor.num_tensors == num_tensors
grouped_tensor.clear()
assert not grouped_tensor.has_data()
assert grouped_tensor.num_tensors == 0
assert grouped_tensor.data is None
assert grouped_tensor.logical_shape == (0, 0)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import Optional from typing import Optional, List
import torch import torch
import pytest import pytest
...@@ -137,6 +137,117 @@ def reset_global_fp8_state(): ...@@ -137,6 +137,117 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
def check_grouped_tensor_pointers_helper(tensors, num_elems_in_byte=1, tensor_name="tensor"):
"""
Verify that tensors are stored in contiguous memory.
Args:
tensors: List or iterable of tensors to check
num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4)
tensor_name: Name to use in error messages
"""
tensor_list = list(tensors)
if len(tensor_list) < 2:
return # Nothing to check
for i in range(1, len(tensor_list)):
prev_tensor = tensor_list[i - 1]
curr_tensor = tensor_list[i]
# Calculate expected offset based on previous tensor size
prev_numel = prev_tensor.numel()
expected_offset = (prev_numel // num_elems_in_byte) * prev_tensor.element_size()
# Verify current tensor's data pointer is correctly offset
expected_ptr = prev_tensor.data_ptr() + expected_offset
actual_ptr = curr_tensor.data_ptr()
assert (
actual_ptr == expected_ptr
), f"{tensor_name} {i} data pointer mismatch: expected {expected_ptr}, got {actual_ptr}"
def check_grouped_tensor_pointers(
weights: List[torch.Tensor], fp8_recipe: Optional[recipe.Recipe] = None
):
"""
Verify that the pointers of the weights are in contiguous memory for GroupedTensor.
TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach.
"""
num_elems_in_a_data_byte = 1 if fp8_recipe is None else 2 if fp8_recipe.nvfp4() else 1
# Check data.
if hasattr(weights[0], "_data") and weights[0]._data is not None:
data_tensors = [w._data for w in weights]
check_grouped_tensor_pointers_helper(data_tensors, num_elems_in_byte=1, tensor_name="data")
# Check transpose.
if hasattr(weights[0], "_transpose") and weights[0]._transpose is not None:
transpose_tensors = [w._transpose for w in weights]
check_grouped_tensor_pointers_helper(
transpose_tensors, num_elems_in_byte=1, tensor_name="transpose"
)
# Check scale_inv.
if hasattr(weights[0], "_scale_inv") and weights[0]._scale_inv is not None:
scale_inv_tensors = [w._scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
scale_inv_tensors, num_elems_in_byte=1, tensor_name="scale_inv"
)
# Check rowwise scale_inv.
if hasattr(weights[0], "_rowwise_scale_inv") and weights[0]._rowwise_scale_inv is not None:
scale_inv_tensors = [w._rowwise_scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
scale_inv_tensors, num_elems_in_byte=1, tensor_name="rowwise_scale_inv"
)
# Check columnwise scale_inv.
if (
hasattr(weights[0], "_columnwise_scale_inv")
and weights[0]._columnwise_scale_inv is not None
):
columnwise_scale_inv_tensors = [w._columnwise_scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_scale_inv_tensors,
num_elems_in_byte=1,
tensor_name="columnwise scale_inv",
)
# Check rowwise amax.
if hasattr(weights[0], "_rowwise_amax") and weights[0]._rowwise_amax is not None:
rowwise_amax_tensors = [w._rowwise_amax for w in weights]
check_grouped_tensor_pointers_helper(
rowwise_amax_tensors, num_elems_in_byte=1, tensor_name="rowwise amax"
)
# Check columnwise amax.
if hasattr(weights[0], "_columnwise_amax") and weights[0]._columnwise_amax is not None:
columnwise_amax_tensors = [w._columnwise_amax for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_amax_tensors, num_elems_in_byte=1, tensor_name="columnwise amax"
)
# Check rowwise data.
if hasattr(weights[0], "_rowwise_data") and weights[0]._rowwise_data is not None:
rowwise_data_tensors = [w._rowwise_data for w in weights]
check_grouped_tensor_pointers_helper(
rowwise_data_tensors,
num_elems_in_byte=num_elems_in_a_data_byte,
tensor_name="rowwise data",
)
# Check columnwise data.
if hasattr(weights[0], "_columnwise_data") and weights[0]._columnwise_data is not None:
columnwise_data_tensors = [w._columnwise_data for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_data_tensors,
num_elems_in_byte=num_elems_in_a_data_byte,
tensor_name="columnwise data",
)
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size), (config.max_seqlen_q, config.batch_size, config.hidden_size),
...@@ -495,9 +606,18 @@ def test_sanity_grouped_linear( ...@@ -495,9 +606,18 @@ def test_sanity_grouped_linear(
use_fp8 = fp8_recipe is not None use_fp8 = fp8_recipe is not None
with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
te_grouped_linear = GroupedLinear( te_grouped_linear = GroupedLinear(
num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype num_gemms,
config.hidden_size,
ffn_hidden_size,
bias=use_bias,
params_dtype=dtype,
).cuda() ).cuda()
# Verify that weights are stored in contiguous GroupedTensor storage.
weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)]
if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()):
check_grouped_tensor_pointers(weights, fp8_recipe)
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
...@@ -956,7 +1076,13 @@ def test_replace_raw_data_for_float8tensor(): ...@@ -956,7 +1076,13 @@ def test_replace_raw_data_for_float8tensor():
random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda") random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor) fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)
attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"] attrs_to_check = [
"_quantizer",
"_fp8_dtype",
"_scale_inv",
"_transpose",
"_transpose_invalid",
]
attrs = {} attrs = {}
for attr in attrs_to_check: for attr in attrs_to_check:
attrs[attr] = getattr(fp8_tensor, attr) attrs[attr] = getattr(fp8_tensor, attr)
......
...@@ -88,33 +88,40 @@ class Recipe: ...@@ -88,33 +88,40 @@ class Recipe:
Base recipe class. Base recipe class.
""" """
def nvfp4(self): @classmethod
def nvfp4(cls):
"""Whether the given recipe is NVFP4 1D block scaling.""" """Whether the given recipe is NVFP4 1D block scaling."""
return isinstance(self, NVFP4BlockScaling) return issubclass(cls, NVFP4BlockScaling)
def mxfp8(self): @classmethod
def mxfp8(cls):
"""Whether the given recipe is MXFP8 block scaling.""" """Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling) return issubclass(cls, MXFP8BlockScaling)
def delayed(self): @classmethod
def delayed(cls):
"""Whether the given recipe is delayed scaling.""" """Whether the given recipe is delayed scaling."""
return isinstance(self, DelayedScaling) return issubclass(cls, DelayedScaling)
def float8_current_scaling(self): @classmethod
def float8_current_scaling(cls):
"""Whether the given recipe is (per-tensor) current scaling.""" """Whether the given recipe is (per-tensor) current scaling."""
return isinstance(self, Float8CurrentScaling) return issubclass(cls, Float8CurrentScaling)
def float8_per_tensor_scaling(self): @classmethod
def float8_per_tensor_scaling(cls):
"""Whether the given recipe is per-tensor scaling.""" """Whether the given recipe is per-tensor scaling."""
return isinstance(self, (DelayedScaling, Float8CurrentScaling)) return issubclass(cls, (DelayedScaling, Float8CurrentScaling))
def float8_block_scaling(self): @classmethod
def float8_block_scaling(cls):
"""Whether the given recipe is float8 blockwise scaling.""" """Whether the given recipe is float8 blockwise scaling."""
return isinstance(self, Float8BlockScaling) return issubclass(cls, Float8BlockScaling)
def custom(self): @classmethod
def custom(cls):
"""Whether the given recipe is custom.""" """Whether the given recipe is custom."""
return isinstance(self, CustomRecipe) return issubclass(cls, CustomRecipe)
@dataclass() @dataclass()
......
...@@ -13,6 +13,7 @@ import torch ...@@ -13,6 +13,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from .base import ( from .base import (
get_dummy_wgrad, get_dummy_wgrad,
TransformerEngineBaseModule, TransformerEngineBaseModule,
...@@ -147,7 +148,10 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -147,7 +148,10 @@ class _GroupedLinear(torch.autograd.Function):
# tensors (like scales), but bulk allocation shares storage across all tensors, # tensors (like scales), but bulk allocation shares storage across all tensors,
# so if scales can't be offloaded, nothing in the group can be offloaded. # so if scales can't be offloaded, nothing in the group can be offloaded.
inputmats = tex.split_quantize( inputmats = tex.split_quantize(
inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading inp_view,
m_splits,
input_quantizers,
disable_bulk_allocation=cpu_offloading,
) )
elif debug: elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize( inputmats = DebugQuantizer.multi_tensor_quantize(
...@@ -365,7 +369,10 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -365,7 +369,10 @@ class _GroupedLinear(torch.autograd.Function):
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0) grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output = DebugQuantizer.multi_tensor_quantize( grad_output = DebugQuantizer.multi_tensor_quantize(
grad_output_view, ctx.grad_output_quantizers, ctx.m_splits, ctx.activation_dtype grad_output_view,
ctx.grad_output_quantizers,
ctx.m_splits,
ctx.activation_dtype,
) )
else: else:
# Only split grad output. Grad bias is fused with # Only split grad output. Grad bias is fused with
...@@ -436,7 +443,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -436,7 +443,8 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.input_quantizers[0] is not None: if ctx.input_quantizers[0] is not None:
for input_quantizer in ctx.input_quantizers: for input_quantizer in ctx.input_quantizers:
if isinstance( if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) input_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
): ):
input_quantizer.set_usage(rowwise=True, columnwise=True) input_quantizer.set_usage(rowwise=True, columnwise=True)
else: else:
...@@ -446,7 +454,10 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -446,7 +454,10 @@ class _GroupedLinear(torch.autograd.Function):
inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers)
elif ctx.debug: elif ctx.debug:
inputmats = DebugQuantizer.multi_tensor_quantize( inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype inp_view,
ctx.input_quantizers,
ctx.m_splits,
ctx.activation_dtype,
) )
else: else:
inputmats = torch.split( inputmats = torch.split(
...@@ -616,7 +627,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -616,7 +627,7 @@ class GroupedLinear(TransformerEngineBaseModule):
) -> None: ) -> None:
super().__init__(name) super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.num_gemms = num_gemms self.num_gemms = num_gemms
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
...@@ -631,12 +642,19 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -631,12 +642,19 @@ class GroupedLinear(TransformerEngineBaseModule):
assert ( assert (
not ub_overlap_rs and not ub_overlap_ag not ub_overlap_rs and not ub_overlap_ag
), "GroupedLinear doesn't support Userbuffer overlap." ), "GroupedLinear doesn't support Userbuffer overlap."
self.init_method = init_method
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.wgrad_store = WeightGradStore(delay_wgrad_compute) self.wgrad_store = WeightGradStore(delay_wgrad_compute)
self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} self._offsets = {
"input": 0,
"weight": 1,
"output": 2,
"grad_output": 0,
"grad_input": 1,
}
self._num_fp8_tensors_per_gemm = { self._num_fp8_tensors_per_gemm = {
"fwd": 3, "fwd": 3,
"bwd": 2, "bwd": 2,
...@@ -678,7 +696,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -678,7 +696,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self.out_features, self.out_features,
self.in_features, self.in_features,
device=device, device=device,
dtype=params_dtype, dtype=self.params_dtype,
), ),
), ),
init_fn=init_method, init_fn=init_method,
...@@ -694,13 +712,13 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -694,13 +712,13 @@ class GroupedLinear(TransformerEngineBaseModule):
torch.empty( torch.empty(
self.out_features, self.out_features,
device=device, device=device,
dtype=params_dtype, dtype=self.params_dtype,
), ),
), ),
init_fn=init_method_constant(0.0), init_fn=init_method_constant(0.0),
) )
else: else:
bias = torch.Tensor().to(dtype=params_dtype, device=device) bias = torch.Tensor().to(dtype=self.params_dtype, device=device)
setattr(self, f"bias{i}", bias) setattr(self, f"bias{i}", bias)
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
...@@ -724,8 +742,61 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -724,8 +742,61 @@ class GroupedLinear(TransformerEngineBaseModule):
if recipe.float8_current_scaling(): if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
def make_grouped_weights(self, defer_init=False) -> None:
"""
Convert parameters into a GroupedTensor and re-register them as parameters.
"""
if defer_init:
return
weight_quantizers = self._get_weight_quantizers()
recipe = (
weight_quantizers[0]._get_compatible_recipe()
if weight_quantizers and weight_quantizers[0] is not None
else None
)
if recipe is not None and (recipe.delayed() or recipe.float8_current_scaling()):
self.set_tensor_parallel_attributes(defer_init=defer_init)
return
weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
# Create the weight storage.
grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=self.num_gemms,
shape=[(self.out_features, self.in_features)] * self.num_gemms,
quantizer=weight_quantizers[0],
dtype=self.params_dtype,
device=weights[0].device,
)
# Copy existing params into storage.
with torch.no_grad():
for i in range(self.num_gemms):
if self.primary_weights_in_fp8:
grouped_weights.quantized_tensors[i].copy_from_storage(weights[i])
else:
grouped_weights.quantized_tensors[i].copy_(weights[i])
# Re-register the grouped weights as parameters.
for i in range(self.num_gemms):
self.register_parameter(
f"weight{i}",
torch.nn.Parameter(grouped_weights.quantized_tensors[i]),
init_fn=self.init_method,
get_rng_state_tracker=self.get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"],
)
self.set_tensor_parallel_attributes(defer_init=defer_init)
def reset_parameters(self, defer_init=False): def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init) super().reset_parameters(defer_init=defer_init)
self.make_grouped_weights(defer_init=defer_init)
def set_tensor_parallel_attributes(self, defer_init=False) -> None:
"""Set attributes needed for TP"""
if not defer_init: if not defer_init:
# Set parallelism attributes for linear weights # Set parallelism attributes for linear weights
...@@ -925,7 +996,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -925,7 +996,7 @@ class GroupedLinear(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]: def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module.""" """Get the weight quantizers of the module."""
if not self.fp8 and not self.fp8_calibration: if not self.fp8 and not self.fp8_calibration and not self.primary_weights_in_fp8:
return [None] * self.num_gemms return [None] * self.num_gemms
weight_quantizers = [ weight_quantizers = [
self.quantizers["scaling_fwd"][ self.quantizers["scaling_fwd"][
...@@ -934,7 +1005,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -934,7 +1005,7 @@ class GroupedLinear(TransformerEngineBaseModule):
for i in range(self.num_gemms) for i in range(self.num_gemms)
] ]
for i in range(self.num_gemms): for i in range(self.num_gemms):
weight_quantizers[i].internal = True weight_quantizers[i].internal = not self.primary_weights_in_fp8
return weight_quantizers return weight_quantizers
def _get_quantizers(self): def _get_quantizers(self):
......
...@@ -69,7 +69,9 @@ class QuantizedTensorStorage: ...@@ -69,7 +69,9 @@ class QuantizedTensorStorage:
f"{self.__class__.__name__} class does not implement get_usages function" f"{self.__class__.__name__} class does not implement get_usages function"
) )
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward""" """Prepare the tensor base for saving for backward"""
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__.__name__} class does not implement prepare_for_saving function" f"{self.__class__.__name__} class does not implement prepare_for_saving function"
...@@ -115,11 +117,18 @@ class QuantizedTensorStorage: ...@@ -115,11 +117,18 @@ class QuantizedTensorStorage:
warnings.warn("Quantizer is being updated, this may affect model behavior") warnings.warn("Quantizer is being updated, this may affect model behavior")
self._quantizer = quantizer self._quantizer = quantizer
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data from another QuantizedTensorStorage."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement copy_from_storage function"
)
def prepare_for_saving( def prepare_for_saving(
*tensors: Union[torch.Tensor, QuantizedTensorStorage], *tensors: Union[torch.Tensor, QuantizedTensorStorage],
) -> Tuple[ ) -> Tuple[
list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorStorage]] list[Optional[Union[torch.Tensor, torch.nn.Parameter]]],
list[Optional[QuantizedTensorStorage]],
]: ]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only """Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save torch.Tensor/torch.nn.Parameter types, while we want to be able to save
...@@ -144,7 +153,10 @@ def restore_from_saved( ...@@ -144,7 +153,10 @@ def restore_from_saved(
return_saved_tensors: bool = False, return_saved_tensors: bool = False,
) -> ( ) -> (
list[Optional[torch.Tensor | QuantizedTensorStorage]] list[Optional[torch.Tensor | QuantizedTensorStorage]]
| tuple[list[Optional[torch.Tensor | QuantizedTensorStorage]], list[Optional[torch.Tensor]]] | tuple[
list[Optional[torch.Tensor | QuantizedTensorStorage]],
list[Optional[torch.Tensor]],
]
): ):
"""Recombine the tensor data and metadata during backward pass.""" """Recombine the tensor data and metadata during backward pass."""
tensor_objects = [] tensor_objects = []
......
...@@ -11,7 +11,11 @@ from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState ...@@ -11,7 +11,11 @@ from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Recipe,
)
from ..utils import canonicalize_process_group, devices_match from ..utils import canonicalize_process_group, devices_match
from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func
from ..quantized_tensor import QuantizedTensor, Quantizer from ..quantized_tensor import QuantizedTensor, Quantizer
...@@ -154,6 +158,10 @@ class Float8Quantizer(Quantizer): ...@@ -154,6 +158,10 @@ class Float8Quantizer(Quantizer):
amin, amax = tensor.aminmax() amin, amax = tensor.aminmax()
self.amax.copy_(torch.max(-amin, amax)) self.amax.copy_(torch.max(-amin, amax))
def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of the columnwise data for Float8 1D blockwise quantization."""
return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1])
def create_tensor_from_data( def create_tensor_from_data(
self, self,
data: torch.Tensor, data: torch.Tensor,
...@@ -408,6 +416,10 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -408,6 +416,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
quantizer=self, quantizer=self,
) )
def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of the columnwise data for Float8 1D blockwise quantization."""
return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1])
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Function using primitives with ONNX defined translations.""" """Function using primitives with ONNX defined translations."""
if tensor.dtype != torch.float32: if tensor.dtype != torch.float32:
...@@ -769,7 +781,10 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -769,7 +781,10 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
kwargs, kwargs,
) )
return Float8Tensor.make_like( return Float8Tensor.make_like(
tensor, data=func_out, data_transpose=func_transposed_out, shape=func_out.shape tensor,
data=func_out,
data_transpose=func_transposed_out,
shape=func_out.shape,
) )
if func == torch.ops.aten.detach.default: if func == torch.ops.aten.detach.default:
......
...@@ -164,6 +164,49 @@ class MXFP8Quantizer(Quantizer): ...@@ -164,6 +164,49 @@ class MXFP8Quantizer(Quantizer):
# TODO(ksivamani): No calibration needed for mxfp8? # TODO(ksivamani): No calibration needed for mxfp8?
pass pass
def get_scale_shape(
self,
shape: Iterable[int],
columnwise: bool,
) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization.
This method determines the shape of the scaling tensor needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False)
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For MXFP8 1D blockwise quantization, blocksize is 32
Swizzle kernel will be performed before GEMM to suit the need of CuBLAS.
CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
"""
if columnwise:
# Columnwise: scale_inv shape is [prod(shape[:-1]) // BLOCK_SIZE, shape[-1]]
# with padding to multiples of [4, 128]
return (
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128),
)
# Rowwise: scale_inv shape is [prod(shape[:-1]), shape[-1] // BLOCK_SIZE]
# with padding to multiples of [128, 4]
return (
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
)
def get_columnwise_shape(self, rowwise_data_shape: Tuple[int, ...]) -> Tuple[int, ...]:
"""Calculate the shape of the columnwise data for MXFP8 1D blockwise quantization."""
return rowwise_data_shape
def create_tensor_from_data( def create_tensor_from_data(
self, self,
data: torch.Tensor, data: torch.Tensor,
...@@ -704,7 +747,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -704,7 +747,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv=columnwise_scale_inv, columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
dtype=param_dtype, dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, shape=(rowwise_data.shape if rowwise_data is not None else columnwise_data.shape),
quantizer=self._quantizer, quantizer=self._quantizer,
with_gemm_swizzled_scales=False, with_gemm_swizzled_scales=False,
) )
......
...@@ -341,7 +341,10 @@ class NVFP4Quantizer(Quantizer): ...@@ -341,7 +341,10 @@ class NVFP4Quantizer(Quantizer):
) )
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty( columnwise_scale_inv = torch.empty(
columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory columnwise_scale_shape,
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
) )
amax_columnwise = torch.zeros( amax_columnwise = torch.zeros(
1, dtype=torch.float32, device=device, pin_memory=pin_memory 1, dtype=torch.float32, device=device, pin_memory=pin_memory
......
...@@ -7,3 +7,4 @@ from .float8_tensor_storage import Float8TensorStorage # noqa: F401 ...@@ -7,3 +7,4 @@ from .float8_tensor_storage import Float8TensorStorage # noqa: F401
from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401 from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401
from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401 from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401
from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401 from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401
from .grouped_tensor import GroupedTensor # noqa: F401
...@@ -73,6 +73,24 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -73,6 +73,24 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
if t is not None: if t is not None:
t.data = _empty_tensor() t.data = _empty_tensor()
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data buffers from another Float8BlockwiseQTensorStorage."""
if not isinstance(src, Float8BlockwiseQTensorStorage):
raise TypeError("copy_from_storage expects Float8BlockwiseQTensorStorage")
if self._fp8_dtype != src._fp8_dtype:
raise RuntimeError("FP8 dtype mismatch in copy_from_storage")
if self._is_2D_scaled != src._is_2D_scaled:
raise RuntimeError("Scale layout mismatch in copy_from_storage")
def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]):
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
_copy_optional(self._rowwise_data, src._rowwise_data)
_copy_optional(self._columnwise_data, src._columnwise_data)
_copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv)
_copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv)
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata.""" """Get this tensor's metadata."""
return { return {
......
...@@ -104,6 +104,24 @@ class Float8TensorStorage(QuantizedTensorStorage): ...@@ -104,6 +104,24 @@ class Float8TensorStorage(QuantizedTensorStorage):
t.data = _empty_tensor() t.data = _empty_tensor()
self._transpose_invalid = True self._transpose_invalid = True
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data buffers from another Float8TensorStorage."""
if not isinstance(src, Float8TensorStorage):
raise TypeError("copy_from_storage expects Float8TensorStorage")
if self._fp8_dtype != src._fp8_dtype:
raise RuntimeError("FP8 dtype mismatch in copy_from_storage")
def _copy_optional(
dst: Optional[torch.Tensor],
src_tensor: Optional[torch.Tensor],
):
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
_copy_optional(self._data, src._data)
_copy_optional(self._transpose, src._transpose)
_copy_optional(self._scale_inv, src._scale_inv)
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata.""" """Get this tensor's metadata."""
return { return {
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Grouped tensor class for handling collections of tensors with different shapes"""
from __future__ import annotations
from typing import Optional, Tuple, List, Union
import math
import torch
from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from ..mxfp8_tensor import MXFP8Tensor
from ..nvfp4_tensor import NVFP4Tensor
from ..float8_tensor import Float8Tensor
from ..float8_blockwise_tensor import Float8BlockwiseQTensor
from .float8_tensor_storage import Float8TensorStorage
from .mxfp8_tensor_storage import MXFP8TensorStorage
from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .nvfp4_tensor_storage import NVFP4TensorStorage
class GroupedTensor:
"""
EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE.
Grouped tensor is a collection of tensors with different shapes but the same dtype and scaling mode.
Shape Representation:
- logical_shape: 2D shape representing the conceptual layout, i.e. the shape when member tensors
are flattened to 2D and stacked together (REQUIRED)
+ When all_same_shape(): [num_tensors * M, N] where each tensor is (M, N)
+ When varying_first_dim(): [~sum_of_first_dims, N] where N is common
+ When varying_last_dim(): [M, ~sum_of_last_dims] where M is common
+ When varying_both_dims(): [1, total_elements] (fully flattened)
- first_dims and last_dims are OPTIONAL (None if dimension is uniform)
+ None first_dims: all tensors have the same first dimension
+ None last_dims: all tensors have the same last dimension
+ Both None: all tensors have identical shapes
+ Both set: each tensor has unique shape (first_dims[i], last_dims[i])
Data Layout:
- ALL data fields are stored as 1D flattened arrays (data, columnwise_data, scale_inv, etc.)
- logical_shape provides the conceptual 2D interpretation
- All data is stored on device in contiguous layout
Note: This structure is used only for combined storage of multiple tensors with the same dtype and scaling mode.
"""
def __init__(
self,
num_tensors: int,
shape: List[Tuple[int, int]],
quantizer: Optional[Quantizer] = None,
dtype: Optional[torch.dtype] = None,
data: Optional[torch.Tensor] = None,
columnwise_data: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
columnwise_scale_inv: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
columnwise_amax: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
first_dims: Optional[torch.Tensor] = None,
last_dims: Optional[torch.Tensor] = None,
tensor_offsets: Optional[torch.Tensor] = None,
offsets: Optional[List[int]] = None,
scale_inv_offsets: Optional[List[int]] = None,
columnwise_scale_inv_offsets: Optional[List[int]] = None,
logical_shape: Optional[Tuple[int, int]] = None,
) -> None:
"""
Initialize a GroupedTensor.
Args:
num_tensors: Number of tensors in the group
shape: 2D shape of each tensor (len num_tensors)
quantizer: Quantizer for the grouped tensor
data: Row-wise data buffer (1D flattened)
columnwise_data: Column-wise data buffer (1D flattened)
scale_inv: Row-wise scale inverse buffer
columnwise_scale_inv: Column-wise scale inverse buffer
amax: Row-wise amax buffer
columnwise_amax: Column-wise amax buffer
scale: Scale buffer (for FP8-DS only)
first_dims: Device tensor of int64 array of length num_tensors (or None if uniform)
last_dims: Device tensor of int64 array of length num_tensors (or None if uniform)
tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform)
offsets: Vector of integer offsets for each tensor.
logical_shape: 2D tuple representing conceptual shape
"""
self.num_tensors = num_tensors
self.quantizer = quantizer
self.shape = shape
self.dtype = (
dtype if dtype is not None else torch.float32
) # Default to float32 if not provided
# Data buffers
self.data = data
self.columnwise_data = columnwise_data
self.scale_inv = scale_inv
self.columnwise_scale_inv = columnwise_scale_inv
self.amax = amax
self.columnwise_amax = columnwise_amax
self.scale = scale
# For convenient indexing for python GroupedTensor API.
self.scale_inv_offsets = scale_inv_offsets
self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets
# Shape information (OPTIONAL - None if dimension is uniform across all tensors)
# first_dims[i] = first dimension of tensor i (None if all tensors have same first dim)
# last_dims[i] = last dimension of tensor i (None if all tensors have same last dim)
self.first_dims = (
first_dims # Device pointer to int64_t array of length num_tensors (or None)
)
self.last_dims = (
last_dims # Device pointer to int64_t array of length num_tensors (or None)
)
# Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape())
# tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1)
# Usage: tensor_i_ptr = data.data_ptr() + tensor_offsets[i] * element_size
# If None and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions)
self.tensor_offsets = (
tensor_offsets # Device pointer to int64_t array of length num_tensors (or None)
)
self.offsets = offsets # Vector of integer offsets for each tensor.
# Logical shape: conceptual 2D shape of the grouped data (REQUIRED)
# Represents how the 1D flattened data should be interpreted as 2D
# Always 2D with positive dimensions
self.logical_shape = logical_shape if logical_shape is not None else (0, 0)
# Hold a reference to the quantized tensors that occupy same storage as the GroupedTensor.
# Used as a convenience.
self.quantized_tensors = None
def has_data(self) -> bool:
"""
Check if the tensor has row-wise data.
Returns:
True if data buffer is initialized, False otherwise
"""
return self.data is not None
def has_columnwise_data(self) -> bool:
"""
Check if the tensor has column-wise data.
Returns:
True if columnwise_data buffer is initialized, False otherwise
"""
return self.columnwise_data is not None
def all_same_first_dim(self) -> bool:
"""
Check if all tensors in the group have the same first dimension.
Returns:
True if first dimension is uniform across all tensors
"""
return self.first_dims is None
def all_same_last_dim(self) -> bool:
"""
Check if all tensors in the group have the same last dimension.
Returns:
True if last dimension is uniform across all tensors
"""
return self.last_dims is None
def all_same_shape(self) -> bool:
"""
Check if all tensors in the group have identical shapes.
Returns:
True if all tensors have the same shape
"""
return self.first_dims is None and self.last_dims is None
def varying_both_dims(self) -> bool:
"""
Check if both dimensions vary across tensors.
Returns:
True if both first and last dimensions vary
"""
return self.first_dims is not None and self.last_dims is not None
def get_common_first_dim(self) -> int:
"""
Get the common first dimension when all tensors share it.
Returns:
The common first dimension
Raises:
RuntimeError: If first dimension varies across tensors or logical_shape is not 2D
"""
if not self.all_same_first_dim():
raise RuntimeError("First dim varies across tensors")
if len(self.logical_shape) != 2:
raise RuntimeError("Logical shape must be 2D")
if self.all_same_shape():
# When both dims are uniform: logical_shape = [num_tensors * M, N]
return self.logical_shape[0] // self.num_tensors
# When varying last dims but not first dim: logical_shape = [M, sum_of_last_dims]
return self.logical_shape[0]
def get_common_last_dim(self) -> int:
"""
Get the common last dimension when all tensors share it.
Returns:
The common last dimension
Raises:
RuntimeError: If last dimension varies across tensors or logical_shape is not 2D
"""
if not self.all_same_last_dim():
raise RuntimeError("Last dim varies across tensors")
if len(self.logical_shape) != 2:
raise RuntimeError("Logical shape must be 2D")
# For both uniform and varying first dim cases: logical_shape[1] is the common last dim
return self.logical_shape[1]
def get_dtype(self) -> torch.dtype:
"""
Get the high precision data type of the tensor.
Returns:
The high precision dtype of the data buffer
"""
return self.dtype
def clear(self) -> None:
"""
Reset tensor data and clear all buffers.
"""
self.data = None
self.columnwise_data = None
self.scale_inv = None
self.columnwise_scale_inv = None
self.amax = None
self.columnwise_amax = None
self.scale = None
self.first_dims = None
self.last_dims = None
self.tensor_offsets = None
self.logical_shape = (0, 0)
self.num_tensors = 0
self.quantizer = None
self.quantized_tensors = None
self.offsets = None
self.scale_inv_offsets = None
self.columnwise_scale_inv_offsets = None
def __repr__(self) -> str:
"""String representation of the GroupedTensor."""
return (
f"GroupedTensor(num_tensors={self.num_tensors}, "
f"shape={self.shape}, "
f"logical_shape={self.logical_shape}, "
f"dtype={self.get_dtype()})"
)
def __str__(self) -> str:
"""User-friendly string representation."""
shape_info = []
if self.all_same_shape():
shape_info.append("uniform shape")
else:
if not self.all_same_first_dim():
shape_info.append("varying first dim")
if not self.all_same_last_dim():
shape_info.append("varying last dim")
return (
f"GroupedTensor with {self.num_tensors} tensors "
f"({', '.join(shape_info) if shape_info else 'uniform'}), "
f"logical_shape={self.logical_shape}, "
f"dtype={self.get_dtype()}"
)
@staticmethod
def make_grouped_tensor_with_shapes(
num_tensors: int,
shape: List[Tuple[int, int]],
quantizer: Optional[Quantizer] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> GroupedTensor:
"""
Create a GroupedTensor for storing multiple weight tensors of the same shape.
Args:
num_tensors: Number of tensors
shape: 2D shape of each tensor (len num_tensors)
quantizer: Quantizer for each tensor
device: Device to allocate tensors on, defaults to current cuda device
dtype: Data type of the tensor (for high precision case)
Returns:
A GroupedTensor.
"""
# First dim
first_dim_list = [s[0] for s in shape]
uniform_first_dim = all(first_dim_list[0] == x for x in first_dim_list)
logical_first_dim = sum(first_dim_list)
if uniform_first_dim:
first_dims = None
else:
first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device=device)
# Last dim
last_dim_list = [s[1] for s in shape]
logical_last_dim = last_dim_list[0]
assert all(logical_last_dim == x for x in last_dim_list), "Last dims should be uniform"
return GroupedTensor.make_grouped_tensor(
num_tensors=num_tensors,
first_dims=first_dims,
last_dims=None,
logical_first_dim=logical_first_dim,
logical_last_dim=logical_last_dim,
quantizer=quantizer,
device=device,
dtype=dtype,
)
@staticmethod
def make_grouped_tensor(
num_tensors: int,
first_dims: Optional[torch.Tensor],
last_dims: Optional[torch.Tensor],
logical_first_dim: int,
logical_last_dim: int,
quantizer: Optional[Quantizer] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> GroupedTensor:
"""
Create a GroupedTensor for storing multiple weight tensors of the same shape.
Args:
num_tensors: Number of tensors
first_dims: Device tensor of int64 array of length num_tensors (or None if uniform)
last_dims: Device tensor of int64 array of length num_tensors (or None if uniform)
logical_first_dim: Logical first dimension
logical_last_dim: Logical last dimension
quantizer: Quantizer for each tensor
Used to figure out the recipe and what to allocate.
device: Device to allocate tensors on, defaults to current cuda device
dtype: Data type of the tensor (for high precision case)
Returns:
A GroupedTensor.
"""
# Set device
if device is None:
device = torch.cuda.current_device()
# Shape patterns and validation.
all_same_first = first_dims is None
all_same_last = last_dims is None
assert all_same_last, "Last dim must be uniform for GroupedTensor"
assert logical_first_dim > 0, "Logical first dim must be positive for GroupedTensor"
assert logical_last_dim > 0, "Logical last dim must be positive for GroupedTensor"
# assert (
# logical_first_dim % 128 == 0
# ), "Logical first dim must be divisible by 128"
# assert logical_last_dim % 128 == 0, "Logical last dim must be divisible by 128"
# Calculate tensor offsets (cumulative element offsets)
tensor_offsets = None
offsets = None
shape = []
if not all_same_first:
# Need explicit offsets for non-uniform shapes
# Offsets are based on number of elements and not pointers.
# Kernels need to calculate precise pointers based on size of elements.
# TODO(ksivaman): Single kernel + remove the host offset calculation.
tensor_offsets = torch.cat(
[
torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype),
torch.cumsum(first_dims * logical_last_dim, dim=0),
]
)
offsets = tensor_offsets.tolist()
first_dims_list = first_dims.tolist()
for i in range(num_tensors):
shape.append((first_dims_list[i], logical_last_dim))
else:
offsets = [
i * logical_first_dim * logical_last_dim // num_tensors
for i in range(num_tensors + 1)
]
for i in range(num_tensors):
shape.append((logical_first_dim // num_tensors, logical_last_dim))
# Calculate logical shape based
logical_shape = (logical_first_dim, logical_last_dim)
no_quantization = quantizer is None
rowwise_usage = quantizer.rowwise_usage if not no_quantization else True
columnwise_usage = quantizer.columnwise_usage if not no_quantization else False
# Calculate total elements across all tensors
total_elements = logical_first_dim * logical_last_dim
data = None
columnwise_data = None
scale_inv = None
columnwise_scale_inv = None
amax = None
columnwise_amax = None
scale = None
scale_inv_offsets = None
columnwise_scale_inv_offsets = None
if no_quantization:
assert dtype is not None, "dtype must be provided for unquantized GroupedTensor"
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8)
data = torch.empty(total_elements, dtype=dtype, device=device)
if columnwise_usage:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data = torch.empty(total_elements, dtype=dtype, device=device)
elif quantizer._get_compatible_recipe().mxfp8():
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8)
data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Scale inverse buffer for MXFP8 - complex shape based on block scaling
# For grouped tensors, we need to calculate scale_inv size for all tensors
total_scale_elements = 0
scale_inv_offsets = [0]
for i, s in enumerate(shape):
scale_inv_shape = quantizer.get_scale_shape(s, False)
scale_elements = math.prod(scale_inv_shape)
total_scale_elements += scale_elements
if i < num_tensors - 1:
scale_inv_offsets.append(total_scale_elements)
scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device)
if columnwise_usage:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Columnwise scale inverse buffer
total_columnwise_scale_elements = 0
columnwise_scale_inv_offsets = [0]
for i, s in enumerate(shape):
scale_inv_shape = quantizer.get_scale_shape(s, False)
columnwise_scale_elements = math.prod(scale_inv_shape)
total_columnwise_scale_elements += columnwise_scale_elements
if i < num_tensors - 1:
columnwise_scale_inv_offsets.append(total_columnwise_scale_elements)
columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.uint8, device=device
)
elif quantizer._get_compatible_recipe().delayed():
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8)
data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Scale inverse - one per tensor
scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device)
# One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1
scale_inv_offsets = list(range(num_tensors))
if columnwise_usage:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Columnwise scale inverse - one per tensor
columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device)
# One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1
columnwise_scale_inv_offsets = list(range(num_tensors))
# Amax buffer for delayed scaling - one per tensor
amax = torch.empty(num_tensors, dtype=torch.float32, device=device)
elif quantizer._get_compatible_recipe().nvfp4():
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte)
data = torch.empty((total_elements) // 2, dtype=torch.uint8, device=device)
# Scale inverse buffer for NVFP4 - complex shape based on block scaling
# For simplicity, calculate total scale elements needed
total_scale_elements = 0
scale_inv_offsets = [0]
for i, s in enumerate(shape):
scale_inv_shape = quantizer.get_scale_shape(s, False)
total_scale_elements += math.prod(scale_inv_shape)
if i < num_tensors - 1:
scale_inv_offsets.append(total_scale_elements)
scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device)
# Amax buffer - one per tensor
amax = torch.empty(num_tensors, dtype=torch.float32, device=device)
if columnwise_usage:
# Allocate columnwise data buffer (1D flattened, uint8, FP4 packed)
columnwise_data = torch.empty(
(total_elements) // 2, dtype=torch.uint8, device=device
)
# Columnwise scale inverse buffer
total_columnwise_scale_elements = 0
columnwise_scale_inv_offsets = [0]
for i, s in enumerate(shape):
columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True)
total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape)
if i < num_tensors - 1:
columnwise_scale_inv_offsets.append(total_columnwise_scale_elements)
columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.uint8, device=device
)
# Columnwise amax buffer - one per tensor
columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device)
elif quantizer._get_compatible_recipe().float8_block_scaling():
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8)
data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Scale inverse - size depends on block configuration
# For simplicity, calculate total scale elements needed
total_scale_elements = 0
scale_inv_offsets = [0]
for i, s in enumerate(shape):
scale_inv_shape = quantizer.get_scale_shape(s, False)
total_scale_elements += math.prod(scale_inv_shape)
if i < num_tensors - 1:
scale_inv_offsets.append(total_scale_elements)
scale_inv = torch.empty(total_scale_elements, dtype=torch.float32, device=device)
if columnwise_usage:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Columnwise scale inverse
total_columnwise_scale_elements = 0
columnwise_scale_inv_offsets = [0]
for i, s in enumerate(shape):
columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True)
total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape)
if i < num_tensors - 1:
columnwise_scale_inv_offsets.append(total_columnwise_scale_elements)
columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.float32, device=device
)
elif quantizer._get_compatible_recipe().float8_current_scaling():
# Current scaling - per-tensor scaling computed on the fly
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8)
data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Scale inverse - one per tensor
scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device)
# One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1
scale_inv_offsets = list(range(num_tensors))
if columnwise_usage:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Columnwise scale inverse - one per tensor
columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device)
# One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1
columnwise_scale_inv_offsets = list(range(num_tensors))
# Scale and amax buffers for current scaling - one per tensor
scale = torch.empty(num_tensors, dtype=torch.float32, device=device)
amax = torch.empty(num_tensors, dtype=torch.float32, device=device)
else:
raise ValueError(f"Unsupported quantizer for GroupedTensor: {quantizer}")
grouped_tensor = GroupedTensor(
num_tensors=num_tensors,
shape=shape,
dtype=dtype,
quantizer=quantizer,
data=data,
columnwise_data=columnwise_data,
scale_inv=scale_inv,
columnwise_scale_inv=columnwise_scale_inv,
amax=amax,
columnwise_amax=columnwise_amax,
scale=scale,
first_dims=first_dims,
last_dims=last_dims,
tensor_offsets=tensor_offsets,
offsets=offsets,
scale_inv_offsets=scale_inv_offsets,
columnwise_scale_inv_offsets=columnwise_scale_inv_offsets,
logical_shape=logical_shape,
)
grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors()
return grouped_tensor
def split_into_quantized_tensors(
self,
) -> List[Union[QuantizedTensorStorage, torch.Tensor]]:
"""
Split the GroupedTensor into a list of `num_tensors`
quantized tensors based on the quantizer. No additional memory allocation is performed,
so the tensors returned are the same as the ones used to create the GroupedTensor.
If quantizer is None, returns normal torch tensors.
If quantizer.internal is True, returns QuantizedTensorStorage.
Otherwise, returns QuantizedTensor.
TODO(ksivaman): Block cases where any dims are varying. This is needed only
to expose the weights as separate parameters.
"""
result = []
no_quantization = self.quantizer is None
# Case 1: No quantization - return regular torch tensors
if no_quantization:
for i in range(self.num_tensors):
# Get tensor shape
tensor_shape = self.shape[i]
# Get tensor data slice
if self.offsets is not None:
start_offset = self.offsets[i]
numel = tensor_shape[0] * tensor_shape[1]
end_offset = start_offset + numel
if self.has_data():
tensor_data = self.data[start_offset:end_offset].view(tensor_shape)
result.append(tensor_data)
elif self.has_columnwise_data():
tensor_data = self.columnwise_data[start_offset:end_offset].view(
tensor_shape
)
result.append(tensor_data)
else:
raise RuntimeError("GroupedTensor has no data to split")
else:
# All same shape case
numel = tensor_shape[0] * tensor_shape[1]
start_offset = i * numel
end_offset = start_offset + numel
if self.has_data():
tensor_data = self.data[start_offset:end_offset].view(tensor_shape)
result.append(tensor_data)
elif self.has_columnwise_data():
tensor_data = self.columnwise_data[start_offset:end_offset].view(
tensor_shape
)
result.append(tensor_data)
else:
raise RuntimeError("GroupedTensor has no data to split")
return result
# Case 2: Quantized tensors
recipe = self.quantizer._get_compatible_recipe()
for i in range(self.num_tensors):
# Get tensor shape
tensor_shape = self.shape[i]
numel = tensor_shape[0] * tensor_shape[1]
# Get data offsets
if self.offsets is not None:
data_start = self.offsets[i]
data_end = data_start + numel
else:
# All same shape
data_start = i * numel
data_end = data_start + numel
# Special shape handling for NVFP4.
nvfp4 = self.quantizer._get_compatible_recipe().nvfp4()
if nvfp4:
data_start = data_start // 2
data_end = data_end // 2
# Extract rowwise and columnwise data
rowwise_data = None
columnwise_data = None
if self.has_data():
if nvfp4:
rowwise_tensor_shape = self.quantizer.convert_shape_for_fp4(tensor_shape)
else:
rowwise_tensor_shape = tensor_shape
rowwise_data = self.data[data_start:data_end].view(rowwise_tensor_shape)
if self.has_columnwise_data():
columnwise_tensor_shape = self.quantizer.get_columnwise_shape(tensor_shape)
if nvfp4:
columnwise_tensor_shape = self.quantizer.convert_shape_for_fp4(
columnwise_tensor_shape
)
columnwise_data = self.columnwise_data[data_start:data_end].view(
columnwise_tensor_shape
)
# MXFP8 format
if recipe.mxfp8():
# Extract scale_inv data
rowwise_scale_inv = None
columnwise_scale_inv = None
if self.scale_inv is not None and self.scale_inv_offsets is not None:
scale_start = self.scale_inv_offsets[i]
if i < self.num_tensors - 1:
scale_end = self.scale_inv_offsets[i + 1]
else:
scale_end = self.scale_inv.numel()
# Calculate expected scale shape for MXFP8
scale_shape = self.quantizer.get_scale_shape(tensor_shape, False)
rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape)
if (
self.columnwise_scale_inv is not None
and self.columnwise_scale_inv_offsets is not None
):
cscale_start = self.columnwise_scale_inv_offsets[i]
if i < self.num_tensors - 1:
cscale_end = self.columnwise_scale_inv_offsets[i + 1]
else:
cscale_end = self.columnwise_scale_inv.numel()
cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True)
columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view(
cscale_shape
)
if self.quantizer.internal:
mxfp8_tensor_class = MXFP8TensorStorage
else:
mxfp8_tensor_class = MXFP8Tensor
tensor = mxfp8_tensor_class(
shape=tensor_shape,
dtype=self.dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=self.quantizer.dtype,
quantizer=self.quantizer,
with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm,
)
result.append(tensor)
# Delayed scaling or current scaling (both use Float8TensorStorage)
elif recipe.delayed() or recipe.float8_current_scaling():
# Scale inverse - one per tensor
scale_inv = None
if self.scale_inv is not None:
scale_inv = self.scale_inv[i : i + 1]
if self.quantizer.internal:
float8_tensor_class = Float8TensorStorage
else:
float8_tensor_class = Float8Tensor
tensor = float8_tensor_class(
shape=tensor_shape,
dtype=self.dtype,
data=rowwise_data,
fp8_scale_inv=scale_inv,
fp8_dtype=self.quantizer.dtype,
quantizer=self.quantizer,
data_transpose=columnwise_data,
)
result.append(tensor)
# Float8 block scaling
elif recipe.float8_block_scaling():
# Extract scale_inv data
rowwise_scale_inv = None
columnwise_scale_inv = None
if self.scale_inv is not None and self.scale_inv_offsets is not None:
scale_start = self.scale_inv_offsets[i]
if i < self.num_tensors - 1:
scale_end = self.scale_inv_offsets[i + 1]
else:
scale_end = self.scale_inv.numel()
# Get scale shape from quantizer
scale_shape = self.quantizer.get_scale_shape(tensor_shape, False)
rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape)
if (
self.columnwise_scale_inv is not None
and self.columnwise_scale_inv_offsets is not None
):
cscale_start = self.columnwise_scale_inv_offsets[i]
if i < self.num_tensors - 1:
cscale_end = self.columnwise_scale_inv_offsets[i + 1]
else:
cscale_end = self.columnwise_scale_inv.numel()
# Get columnwise scale shape from quantizer
cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True)
columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view(
cscale_shape
)
# Compute is_2D_scaled and data_format from quantizer attributes
is_2D_scaled = self.quantizer.block_scaling_dim == 2
if self.quantizer.internal:
float8_blockwise_q_tensor_class = Float8BlockwiseQTensorStorage
else:
float8_blockwise_q_tensor_class = Float8BlockwiseQTensor
tensor = float8_blockwise_q_tensor_class(
shape=tensor_shape,
dtype=self.dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=self.quantizer.dtype,
quantizer=self.quantizer,
is_2D_scaled=is_2D_scaled,
)
result.append(tensor)
# NVFP4 format
elif recipe.nvfp4():
# Extract scale_inv data
rowwise_scale_inv = None
columnwise_scale_inv = None
amax_rowwise = None
amax_columnwise = None
if self.scale_inv is not None and self.scale_inv_offsets is not None:
scale_start = self.scale_inv_offsets[i]
if i < self.num_tensors - 1:
scale_end = self.scale_inv_offsets[i + 1]
else:
scale_end = self.scale_inv.numel()
# Get scale shape from quantizer
scale_shape = self.quantizer.get_scale_shape(tensor_shape, False)
rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape)
if (
self.columnwise_scale_inv is not None
and self.columnwise_scale_inv_offsets is not None
):
cscale_start = self.columnwise_scale_inv_offsets[i]
if i < self.num_tensors - 1:
cscale_end = self.columnwise_scale_inv_offsets[i + 1]
else:
cscale_end = self.columnwise_scale_inv.numel()
# Get columnwise scale shape from quantizer
cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True)
columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view(
cscale_shape
)
# Extract amax - one per tensor
if self.amax is not None:
amax_rowwise = self.amax[i : i + 1]
if self.columnwise_amax is not None:
amax_columnwise = self.columnwise_amax[i : i + 1]
if self.quantizer.internal:
nvfp4_tensor_class = NVFP4TensorStorage
else:
nvfp4_tensor_class = NVFP4Tensor
tensor = nvfp4_tensor_class(
shape=tensor_shape,
dtype=self.dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
amax_rowwise=amax_rowwise,
amax_columnwise=amax_columnwise,
fp4_dtype=self.quantizer.dtype,
quantizer=self.quantizer,
with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm,
)
result.append(tensor)
else:
raise ValueError(f"Unsupported quantization recipe: {recipe}")
return result
@staticmethod
def create_and_quantize(
tensors: int,
quantizer: None | Quantizer,
*,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
noop_flag: Optional[torch.Tensor] = None,
) -> Tuple[QuantizedTensorStorage, ...]:
"""
Quantize given tensors into quantized tensors with underlying
storage allocated in a GroupedTensor.
"""
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=len(tensors),
shape=[t.shape for t in tensors],
quantizer=quantizer,
device=device,
dtype=dtype,
)
grouped_tensor.quantize(tensors, noop_flag=noop_flag)
return grouped_tensor
def quantize(
self,
tensors: List[torch.Tensor],
noop_flag: Optional[torch.Tensor] = None,
) -> Tuple[QuantizedTensorStorage, ...]:
"""
Quantize the GroupedTensor inplace.
"""
quantized_tensors = self.split_into_quantized_tensors()
for i in range(self.num_tensors):
self.quantizer.update_quantized(tensors[i], quantized_tensors[i], noop_flag=noop_flag)
return quantized_tensors
...@@ -111,6 +111,24 @@ class MXFP8TensorStorage(QuantizedTensorStorage): ...@@ -111,6 +111,24 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
if t is not None: if t is not None:
t.data = _empty_tensor() t.data = _empty_tensor()
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data buffers from another MXFP8TensorStorage."""
if not isinstance(src, MXFP8TensorStorage):
raise TypeError("copy_from_storage expects MXFP8TensorStorage")
if self._fp8_dtype != src._fp8_dtype:
raise RuntimeError("FP8 dtype mismatch in copy_from_storage")
if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales:
raise RuntimeError("Scale layout mismatch in copy_from_storage")
def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]):
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
_copy_optional(self._rowwise_data, src._rowwise_data)
_copy_optional(self._columnwise_data, src._columnwise_data)
_copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv)
_copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv)
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata.""" """Get this tensor's metadata."""
return { return {
......
...@@ -136,6 +136,26 @@ class NVFP4TensorStorage(QuantizedTensorStorage): ...@@ -136,6 +136,26 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
if t is not None: if t is not None:
t.data = _empty_tensor() t.data = _empty_tensor()
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data buffers from another NVFP4TensorStorage."""
if not isinstance(src, NVFP4TensorStorage):
raise TypeError("copy_from_storage expects NVFP4TensorStorage")
if self._fp4_dtype != src._fp4_dtype:
raise RuntimeError("FP4 dtype mismatch in copy_from_storage")
if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales:
raise RuntimeError("Scale layout mismatch in copy_from_storage")
def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]):
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
_copy_optional(self._rowwise_data, src._rowwise_data)
_copy_optional(self._columnwise_data, src._columnwise_data)
_copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv)
_copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv)
_copy_optional(self._amax_rowwise, src._amax_rowwise)
_copy_optional(self._amax_columnwise, src._amax_columnwise)
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata.""" """Get this tensor's metadata."""
return { return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment