Unverified Commit 99df8810 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Add logic for block-scaled tensors with GEMM swizzled scales (#2486)



* Add general C API for setting tensor params
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Implement general accessors for NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor tex swizzling to skip if scales are already swizzled
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add checks for non-swizzled scales in MXFP8 and NVFP4 kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support pre-swizzled scales in MXFP8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tex function to swizzle MXFP8 scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in inplace swizzle function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak comments to use "compact/swizzled format"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* MXFP8 quantize kernel with pre-swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expose pre-swizzled scales in modules
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in multi-swizzle
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 gated activations with swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add PyTorch infrastructure for pre-swizzled NVFP4 tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Deprecate DSv3-specific quantization logic in C API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove support for DSv3 compact data from quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove DSv3 compact data format from core lib
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in FP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update JAX to use new swizzled scale API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update C++ swizzle test with swizzled scales API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Return default tensor params when querying params for invalid NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug DSv3 FP8 test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug Userbuffers test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure gated activations populate FP8 transpose if needed
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable pre-swizzling with debug quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflicts and review suggestions

Update copyright years. Tweak comments. Fix various complaints from @greptile-apps.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use explicitly sized types in config accessors

Miscellaneous review suggestions from @ptrendx.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make util header for function that compute swizzled scale index
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from @greptile-apps
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Update expected error message in FP8 block-scaling test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @yaox12
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent a652730f
...@@ -199,10 +199,21 @@ class Quantizer(abc.ABC): ...@@ -199,10 +199,21 @@ class Quantizer(abc.ABC):
""" """
internal: bool internal: bool
"""Whether to solely optimize for matrix multiplication
The resulting quantized tensors are not guaranteed to support any
operation other than matrix multiplication. Use with care since
this is likely to break communication, checkpointing, and many
other features.
"""
optimize_for_gemm: bool
def __init__(self, *, rowwise: bool, columnwise: bool) -> None: def __init__(self, *, rowwise: bool, columnwise: bool) -> None:
self.rowwise_usage = rowwise self.rowwise_usage = rowwise
self.columnwise_usage = columnwise self.columnwise_usage = columnwise
self.internal = False self.internal = False
self.optimize_for_gemm = False
def __repr__(self): def __repr__(self):
return ( return (
...@@ -314,7 +325,11 @@ class Quantizer(abc.ABC): ...@@ -314,7 +325,11 @@ class Quantizer(abc.ABC):
return False return False
def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument
"""Returns whether or not given tensor can be quantized""" """Whether tensor supports quantized all-gather
Consider a less misleading function name.
"""
return True return True
def get_usages(self) -> Dict[str, bool]: def get_usages(self) -> Dict[str, bool]:
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
"""Tensor class with FP8 data quantized with NxN tiles""" """Tensor class with FP8 data quantized with NxN tiles"""
from __future__ import annotations from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union from collections.abc import Iterable
import math import math
from typing import Any, Optional, Tuple, Union
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.common.recipe import Float8BlockScaling, Recipe from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..quantized_tensor import QuantizedTensor, Quantizer from ..quantized_tensor import QuantizedTensor, Quantizer
...@@ -35,8 +35,6 @@ class Float8BlockQuantizer(Quantizer): ...@@ -35,8 +35,6 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float amax_epsilon: float
force_pow_2_scales: bool force_pow_2_scales: bool
block_scaling_dim: int block_scaling_dim: int
# Whether to produce tensors that will be used in all-gather
all_gather_usage: bool
def __init__( def __init__(
self, self,
...@@ -47,7 +45,6 @@ class Float8BlockQuantizer(Quantizer): ...@@ -47,7 +45,6 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float = 0.0, amax_epsilon: float = 0.0,
force_pow_2_scales: bool = True, force_pow_2_scales: bool = True,
block_scaling_dim: int = 2, block_scaling_dim: int = 2,
all_gather_usage: bool = False,
) -> None: ) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise) super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = fp8_dtype self.dtype = fp8_dtype
...@@ -55,7 +52,6 @@ class Float8BlockQuantizer(Quantizer): ...@@ -55,7 +52,6 @@ class Float8BlockQuantizer(Quantizer):
self.force_pow_2_scales = force_pow_2_scales self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon self.amax_epsilon = amax_epsilon
self.block_scaling_dim = block_scaling_dim self.block_scaling_dim = block_scaling_dim
self.all_gather_usage = all_gather_usage
def copy(self) -> Float8BlockQuantizer: def copy(self) -> Float8BlockQuantizer:
"""Create shallow copy""" """Create shallow copy"""
...@@ -65,11 +61,11 @@ class Float8BlockQuantizer(Quantizer): ...@@ -65,11 +61,11 @@ class Float8BlockQuantizer(Quantizer):
rowwise=self.rowwise_usage, rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage, columnwise=self.columnwise_usage,
block_scaling_dim=self.block_scaling_dim, block_scaling_dim=self.block_scaling_dim,
all_gather_usage=self.all_gather_usage,
amax_epsilon=self.amax_epsilon, amax_epsilon=self.amax_epsilon,
force_pow_2_scales=self.force_pow_2_scales, force_pow_2_scales=self.force_pow_2_scales,
) )
quantizer.internal = self.internal quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
return quantizer return quantizer
...@@ -123,103 +119,86 @@ class Float8BlockQuantizer(Quantizer): ...@@ -123,103 +119,86 @@ class Float8BlockQuantizer(Quantizer):
return tex.quantize(tensor, self) return tex.quantize(tensor, self)
def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for blockwise quantization. """Scaling tensor shape.
This method determines the shape of the scaling tensor needed for blockwise quantization, This method determines the shape of the scaling tensor based
taking into account the input tensor shape and whether columnwise scaling is used. on the quantizer configuration. The scales are padded to
The scales are padded to multiples of 4 on the inner dimension for compatibility with GEMM. multiples of 4 for compatibility with GEMM.
Parameters Parameters
---------- ----------
shape : Iterable[int] shape : Iterable[int]
Shape of the input tensor to be quantized Logical tensor shape.
columnwise : bool columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False) Whether the data is scaled column-wise (True) or row-wise (False).
Returns Returns
------- -------
Tuple[int, int] Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim) Scaling tensor shape.
For 2D tensors:
- If columnwise: (roundup(K/blocksize), round_to_multiple(roundup(M/blocksize), 4))
- If rowwise: (roundup(M/blocksize), round_to_multiple(roundup(K/blocksize), 4))
For 1D tensors:
- If columnwise: (roundup(M/blocksize), round_to_multiple(K, 4))
- If rowwise: (roundup(K/blocksize), round_to_multiple(M, 4))
""" """
M, K = 1, 1
for i in range(len(shape) - 1): # Flatten tensor to 2D
M *= shape[i] dim0 = math.prod(shape[:-1])
if len(shape) > 0: dim1 = shape[-1] if shape else 1
K = shape[-1]
# 2D 128x128 quantization block scaling # Check block dims
# CuBLAS requries 128x128 scaling factor to be padded if self.block_scaling_dim not in (1, 2):
# currently rowwise and columnwise format option doesn't apply to 2D scaling raise RuntimeError(
"Only 1D or 2D blocks are supported, "
f"but got block_scaling_dim={self.block_scaling_dim}"
)
# 128x128 block scaling
if self.block_scaling_dim == 2: if self.block_scaling_dim == 2:
scale_dim0 = (dim0 + self.block_len - 1) // self.block_len
scale_dim1 = (dim1 + self.block_len - 1) // self.block_len
if columnwise: if columnwise:
outer = math.ceil(K / self.block_len) return (scale_dim1, round_up_to_nearest_multiple(scale_dim0, 4))
inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) return (scale_dim0, round_up_to_nearest_multiple(scale_dim1, 4))
return (outer, inner)
# rowwise # 1x128 block scaling
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4)
return (outer, inner)
# 1D 1x128 quantization block scaling
# CuBLAS requries 1x128 scaling factor to be padded and transposed
assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported"
if columnwise: if columnwise:
columnwise_compact = self.all_gather_usage return (
outer = math.ceil(M / self.block_len) (dim0 + self.block_len - 1) // self.block_len,
inner = round_up_to_nearest_multiple(K, 4) if not columnwise_compact else K round_up_to_nearest_multiple(dim1, 4),
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS )
# for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner] return (
# so no need to swap inner outer here (dim1 + self.block_len - 1) // self.block_len,
return (outer, inner) round_up_to_nearest_multiple(dim0, 4),
# rowwise )
rowwise_compact = self.all_gather_usage
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(M, 4) if not rowwise_compact else M
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need
# for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here
return (outer, inner) if not rowwise_compact else (inner, outer)
def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of a tensor after columnwise permutation. """Column-wise data shape
This method rearranges the dimensions of a tensor to be columnwise, GEMMs expect that the column-wise data is transposed relative
moving the last dimension to the front and keeping the order of other dimensions. to the logical tensor shape.
Parameters Parameters
---------- ----------
shape : Iterable[int] shape : Iterable[int]
Original shape of the tensor Logical tensor shape.
Returns Returns
------- -------
Tuple[int, ...] Tuple[int, ...]
New shape with dimensions rearranged for columnwise layout. Column-wise data shape.
For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1).
Returns empty tuple for empty input shape.
""" """
if len(shape) == 0: colwise_shape = []
return tuple() if shape:
# currently columnwise format option only applies to 1D quantizer colwise_shape.append(shape[-1])
# for 2D scaling, columnwise format should always be GEMM_READY_DATA_AND_SCALES colwise_shape.extend(shape[:-1])
# since currently 2D scaling only applies to module weights
if self.block_scaling_dim == 1 and self.all_gather_usage:
return shape
colwise_shape = [shape[-1]]
for i in range(len(shape) - 1):
colwise_shape.append(shape[i])
return tuple(colwise_shape) return tuple(colwise_shape)
def is_quantizable(self, inp: torch.Tensor) -> bool: def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized""" """Returns whether or not given inp can be quantized"""
if inp.ndim < 2: shape = inp.size()
if len(shape) < 2:
return False return False
if inp.shape[-1] % self.block_len != 0: if shape[-1] % self.block_len != 0:
return False return False
if math.prod(inp.shape[:-1]) % self.block_len != 0: if math.prod(shape[:-1]) % self.block_len != 0:
return False return False
return True return True
...@@ -233,44 +212,36 @@ class Float8BlockQuantizer(Quantizer): ...@@ -233,44 +212,36 @@ class Float8BlockQuantizer(Quantizer):
pin_memory: bool = False, pin_memory: bool = False,
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
"""Construct quantized tensor with uninitialized data""" """Construct quantized tensor with uninitialized data"""
if device is None:
device = torch.device("cuda")
data_format = ( tensor_kwargs = {
tex.Float8BlockScaleTensorFormat.COMPACT "device": torch.device("cuda") if device is None else device,
if self.all_gather_usage "pin_memory": pin_memory,
else tex.Float8BlockScaleTensorFormat.GEMM_READY }
)
# Allocate FP8 data # Allocate buffers for row-scaled data
data = None rowwise_data = None
scale_inv = None rowwise_scale_inv = None
if self.rowwise_usage: if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) rowwise_data = torch.empty(shape, dtype=torch.uint8, **tensor_kwargs)
scale_shape = self.get_scale_shape(shape, columnwise=False) rowwise_scale_inv = torch.empty(
scale_inv = torch.empty( self.get_scale_shape(shape, columnwise=False),
scale_shape,
dtype=torch.float32, dtype=torch.float32,
device=device, **tensor_kwargs,
pin_memory=pin_memory,
) )
# Allocate FP8 data transpose if needed # Allocate buffers for column-scaled data
columnwise_data = None columnwise_data = None
columnwise_scale_inv = None columnwise_scale_inv = None
if self.columnwise_usage: if self.columnwise_usage:
columnwise_data = torch.empty( columnwise_data = torch.empty(
self.get_columnwise_shape(shape), self.get_columnwise_shape(shape),
dtype=torch.uint8, dtype=torch.uint8,
device=device, **tensor_kwargs,
pin_memory=pin_memory,
) )
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty( columnwise_scale_inv = torch.empty(
columnwise_scale_shape, self.get_scale_shape(shape, columnwise=True),
dtype=torch.float32, dtype=torch.float32,
device=device, **tensor_kwargs,
pin_memory=pin_memory,
) )
# Construct FP8 tensor # Construct FP8 tensor
...@@ -278,13 +249,12 @@ class Float8BlockQuantizer(Quantizer): ...@@ -278,13 +249,12 @@ class Float8BlockQuantizer(Quantizer):
shape=shape, shape=shape,
dtype=dtype, dtype=dtype,
fp8_dtype=self.dtype, fp8_dtype=self.dtype,
rowwise_data=data, rowwise_data=rowwise_data,
rowwise_scale_inv=scale_inv, rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data, columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv, columnwise_scale_inv=columnwise_scale_inv,
quantizer=self, quantizer=self,
is_2D_scaled=self.block_scaling_dim == 2, is_2D_scaled=self.block_scaling_dim == 2,
data_format=data_format,
requires_grad=requires_grad, requires_grad=requires_grad,
) )
...@@ -334,7 +304,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -334,7 +304,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
quantizer: Quantizer, quantizer: Quantizer,
is_2D_scaled: bool, is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY,
**kwargs, **kwargs,
): ):
instance = super().__new__( instance = super().__new__(
...@@ -346,7 +315,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -346,7 +315,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
fp8_dtype, fp8_dtype,
quantizer, quantizer,
is_2D_scaled, is_2D_scaled,
data_format,
*args, *args,
**kwargs, **kwargs,
) )
...@@ -357,8 +325,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -357,8 +325,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
return ( return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled}," f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize(dtype=self.dtype)})," f" data={self.dequantize(dtype=self.dtype)})"
f" data_format={self._data_format}"
) )
def quantize_( def quantize_(
...@@ -509,7 +476,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -509,7 +476,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dtype: torch.dtype, dtype: torch.dtype,
quantizer: Quantizer, quantizer: Quantizer,
is_2D_scaled: bool, is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat, data_format: Any = None, # pylint: disable=unused-argument
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
"""Build Float8BlockwiseQTensor, for use in __reduce__ """Build Float8BlockwiseQTensor, for use in __reduce__
...@@ -527,7 +494,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -527,7 +494,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dtype=dtype, dtype=dtype,
quantizer=quantizer, quantizer=quantizer,
is_2D_scaled=is_2D_scaled, is_2D_scaled=is_2D_scaled,
data_format=data_format,
) )
def __reduce_ex__(self, protocol: int) -> tuple: def __reduce_ex__(self, protocol: int) -> tuple:
...@@ -544,7 +510,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -544,7 +510,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
self.dtype, self.dtype,
self._quantizer, self._quantizer,
self._is_2D_scaled, self._is_2D_scaled,
self._data_format, None, # data_format
), ),
) )
...@@ -570,7 +536,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -570,7 +536,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dst._fp8_dtype = src._fp8_dtype dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv
dst._data_format = src._data_format
# Check that tensor dimensions match # Check that tensor dimensions match
if ( if (
...@@ -618,13 +583,6 @@ class _ViewFunc(torch.autograd.Function): ...@@ -618,13 +583,6 @@ class _ViewFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided # Return input tensor if shape is not provided
ctx.shape = tensor.shape ctx.shape = tensor.shape
if shape is None: if shape is None:
...@@ -693,14 +651,6 @@ class _ViewFunc(torch.autograd.Function): ...@@ -693,14 +651,6 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor): if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_data = ( new_data = (
grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None
) )
...@@ -740,13 +690,6 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -740,13 +690,6 @@ class _ReshapeFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided # Return input tensor if shape is not provided
ctx.shape = tensor.shape ctx.shape = tensor.shape
if shape is None: if shape is None:
...@@ -814,14 +757,6 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -814,14 +757,6 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor): if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_rowwise_data = None new_rowwise_data = None
new_columnwise_data = None new_columnwise_data = None
if grad._rowwise_data is not None: if grad._rowwise_data is not None:
......
...@@ -293,6 +293,7 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -293,6 +293,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax=self.amax, amax=self.amax,
) )
quantizer.internal = self.internal quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
return quantizer return quantizer
......
...@@ -54,6 +54,7 @@ class MXFP8Quantizer(Quantizer): ...@@ -54,6 +54,7 @@ class MXFP8Quantizer(Quantizer):
columnwise=self.columnwise_usage, columnwise=self.columnwise_usage,
) )
quantizer.internal = self.internal quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
return quantizer return quantizer
...@@ -156,6 +157,7 @@ class MXFP8Quantizer(Quantizer): ...@@ -156,6 +157,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv=columnwise_scale_inv, columnwise_scale_inv=columnwise_scale_inv,
quantizer=self, quantizer=self,
requires_grad=requires_grad, requires_grad=requires_grad,
with_gemm_swizzled_scales=self.optimize_for_gemm,
) )
def calibrate(self, tensor: torch.Tensor) -> None: def calibrate(self, tensor: torch.Tensor) -> None:
...@@ -179,6 +181,7 @@ class MXFP8Quantizer(Quantizer): ...@@ -179,6 +181,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv=None, columnwise_scale_inv=None,
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
quantizer=self, quantizer=self,
with_gemm_swizzled_scales=False,
) )
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
...@@ -188,6 +191,10 @@ class MXFP8Quantizer(Quantizer): ...@@ -188,6 +191,10 @@ class MXFP8Quantizer(Quantizer):
return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32) return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32)
def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> torch.Tensor: def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> torch.Tensor:
if tensor._with_gemm_swizzled_scales:
raise NotImplementedError(
"ONNX MXFP8 dequantization is only supported with scales in compact format."
)
return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv) return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
...@@ -229,9 +236,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -229,9 +236,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv: Optional[torch.Tensor], columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
with_gemm_swizzled_scales: bool,
**kwargs, **kwargs,
): ):
instance = super().__new__( return super().__new__(
cls, cls,
rowwise_data, rowwise_data,
rowwise_scale_inv, rowwise_scale_inv,
...@@ -239,10 +247,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -239,10 +247,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv, columnwise_scale_inv,
fp8_dtype, fp8_dtype,
quantizer, quantizer,
with_gemm_swizzled_scales,
*args, *args,
**kwargs, **kwargs,
) )
return instance
def __repr__(self, *, tensor_contents=None): def __repr__(self, *, tensor_contents=None):
return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})"
...@@ -334,39 +342,44 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -334,39 +342,44 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
# View op
if func == aten.view.default: if func == aten.view.default:
tensor = args[0] tensor = args[0]
data = tensor._rowwise_data shape = args[1]
out_data = data.__torch_dispatch__( if len(shape) < 2 or shape[-1] != tensor.size(-1):
func, raise ValueError(
types, f"Attempted to make view with size={tuple(shape)} "
[data] + list(args[1:]), f"from MXFP8 tensor with shape={tuple(tensor.size())}."
kwargs, )
) rowwise_data_view = None
out_shape = out_data.size() columnwise_data_view = None
if tensor._rowwise_data is not None:
rowwise_data_view = tensor._rowwise_data.view(shape)
if tensor._columnwise_data is not None:
columnwise_data_view = tensor._columnwise_data.view(shape)
return MXFP8Tensor( return MXFP8Tensor(
shape=out_shape, shape=shape,
dtype=tensor.dtype, dtype=tensor.dtype,
rowwise_data=out_data, rowwise_data=rowwise_data_view,
rowwise_scale_inv=tensor._rowwise_scale_inv, rowwise_scale_inv=tensor._rowwise_scale_inv,
columnwise_data=tensor._columnwise_data, columnwise_data=columnwise_data_view,
columnwise_scale_inv=tensor._columnwise_scale_inv, columnwise_scale_inv=tensor._columnwise_scale_inv,
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
requires_grad=False, requires_grad=False,
fp8_dtype=tensor._fp8_dtype, fp8_dtype=tensor._fp8_dtype,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
if func == torch.ops.aten.copy_.default: if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1] dst, src = args[0], args[1]
if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor): if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor):
# Booleans to check if src has all the usages that dst needs to respect dst quantizer usages. if src._rowwise_data is None and dst._rowwise_data is not None:
# If not, default to base class behavior. pass
rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None elif src._columnwise_data is None and dst._columnwise_data is not None:
columnwise_matches = ( pass
src._columnwise_data is not None or dst._columnwise_data is None elif src._with_gemm_swizzled_scales != dst._with_gemm_swizzled_scales:
) pass
if rowwise_matches and columnwise_matches: else:
# src and dst match, so we can directly copy data
if dst._rowwise_data is not None: if dst._rowwise_data is not None:
dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs) dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs)
dst._rowwise_scale_inv.copy_( dst._rowwise_scale_inv.copy_(
...@@ -381,26 +394,25 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -381,26 +394,25 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
) )
return dst return dst
# FSDP2 related functions.
if func == aten.split.Tensor: if func == aten.split.Tensor:
# This is called if entire model is initialized on CUDA device and # With FSDP2, this is called if entire model is
# then splitted. Finally the shard needed by the process is used # initialized on CUDA device and then splitted. Finally
# and other splitted shards are discarded. # the shard needed by the process is used and other
# splitted shards are discarded.
tensor = args[0]
split_size = args[1]
if "dim" in kwargs: if "dim" in kwargs:
dim_to_split = kwargs["dim"] dim_to_split = kwargs["dim"]
else: else:
dim_to_split = args[2] if len(args) > 2 else 0 dim_to_split = args[2] if len(args) > 2 else 0
tensor = args[0]
split_size = args[1] # Fall back to high-precision if split is non-trivial
dim0_size = tensor.size(0)
dimlast_size = math.prod(tensor.shape[1:])
if ( if (
dim0_size % split_size != 0 dim_to_split != 0
or dim_to_split != 0 or tensor.size(0) % split_size != 0
or split_size % MXFP8_BLOCK_SCALING_SIZE != 0 or split_size % MXFP8_BLOCK_SCALING_SIZE != 0
or dimlast_size % MXFP8_BLOCK_SCALING_SIZE != 0 or tensor._with_gemm_swizzled_scales
): ):
# Handle splitting by dequantizing and splitting the hp tensor
return super().__torch_dispatch__(func, types, args, kwargs) return super().__torch_dispatch__(func, types, args, kwargs)
out_data = [] out_data = []
...@@ -460,28 +472,26 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -460,28 +472,26 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
requires_grad=False, requires_grad=False,
fp8_dtype=tensor._fp8_dtype, fp8_dtype=tensor._fp8_dtype,
with_gemm_swizzled_scales=False,
) )
for splitted_tensor_data in zip(*out_data) for splitted_tensor_data in zip(*out_data)
] ]
if func == torch.ops.aten.as_strided.default: if func == torch.ops.aten.as_strided.default:
# Applied on unsharded param in FSDP2. In our case, this should be a no-op # Applied on unsharded param in FSDP2. In our case, this should be a no-op
# This is needed for the case where some MXFP8 shards need padding i.e dimension 0 # This is needed for the case where some MXFP8 shards need padding i.e dimension 0
# of the unsharded param is not a multiple of the world size. If that is the case, # of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision. # we down the dequantization route and weights are allgathered in high precision.
# If weight doesnt need padding, this is just a no-op. # If weight doesnt need padding, this is just a no-op.
tensor = args[0]
shape = args[1] shape = args[1]
strides = args[2] strides = args[2]
tensor = args[0]
if ( if (
len(shape) != 2 len(shape) == len(strides) == 2
or len(strides) != 2 and tuple(strides) == (shape[-1], 1)
or strides[1] != 1 and tuple(shape) == tuple(tensor.size())
or shape[0] != tensor.shape[0]
or shape[1] != tensor.shape[1]
): ):
return super().__torch_dispatch__(func, types, args, kwargs) return MXFP8Tensor.make_like(tensor)
return MXFP8Tensor.make_like(tensor)
if func == aten.slice.Tensor: if func == aten.slice.Tensor:
# FSDP2 needed function. # FSDP2 needed function.
...@@ -489,19 +499,12 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -489,19 +499,12 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
# of the unsharded param is not a multiple of the world size. If that is the case, # of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision instead. # we down the dequantization route and weights are allgathered in high precision instead.
# If sharded weight doesnt have padding, this is just a no-op. # If sharded weight doesnt have padding, this is just a no-op.
tensor = args[0]
dim = args[1] dim = args[1]
start = args[2] start = args[2]
length = args[3] length = args[3]
tensor = args[0] if start == 0 and length == tensor.size(dim):
if ( return MXFP8Tensor.make_like(tensor)
dim != 0
or length != tensor.shape[0]
or start != 0
or length % MXFP8_BLOCK_SCALING_SIZE != 0
or start % MXFP8_BLOCK_SCALING_SIZE != 0
):
return super().__torch_dispatch__(func, types, args, kwargs)
return MXFP8Tensor.make_like(tensor)
if func == aten.new_zeros.default: if func == aten.new_zeros.default:
rowwise_data = None rowwise_data = None
...@@ -558,7 +561,9 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -558,7 +561,9 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
requires_grad=False, requires_grad=False,
fp8_dtype=tensor._fp8_dtype, fp8_dtype=tensor._fp8_dtype,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
# Default case # Default case
return super().__torch_dispatch__(func, types, args, kwargs) return super().__torch_dispatch__(func, types, args, kwargs)
...@@ -584,19 +589,24 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -584,19 +589,24 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
# pylint: disable=unused-argument # pylint: disable=unused-argument
from transformer_engine.pytorch.distributed import _get_module_fsdp_state from transformer_engine.pytorch.distributed import _get_module_fsdp_state
# Get FSDP state
fsdp_state = _get_module_fsdp_state(module) fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
# Remove padding from scale inverses before allgather # Remove padding from scale inverses before allgather
# Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128] # Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128]
rowwise_scale_inv = self._rowwise_scale_inv rowwise_scale_inv = self._rowwise_scale_inv
columnwise_scale_inv = self._columnwise_scale_inv columnwise_scale_inv = self._columnwise_scale_inv
shape = self.shape shape = self.shape
if self._with_gemm_swizzled_scales:
raise NotImplementedError(
"FSDP2 is only supported for MXFP8Tensors with compact scales"
)
if rowwise_scale_inv is not None: if rowwise_scale_inv is not None:
# Remove padding from rowwise scale_inv # Remove padding from rowwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1]) flattened_in_shape0 = math.prod(shape[:-1])
if rowwise_scale_inv.size(0) != flattened_in_shape0: if rowwise_scale_inv.size(0) != flattened_in_shape0:
rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0] rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0]
if columnwise_scale_inv is not None: if columnwise_scale_inv is not None:
# Remove padding from columnwise scale_inv # Remove padding from columnwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE
...@@ -681,7 +691,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -681,7 +691,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
out._columnwise_data = columnwise_data out._columnwise_data = columnwise_data
out._columnwise_scale_inv = columnwise_scale_inv out._columnwise_scale_inv = columnwise_scale_inv
else: else:
# We ll be here when post all gather is called the first time. # We'll be here when post all gather is called the first time.
# MXFP8Tensor constructor makes a copy of the quantizer to # MXFP8Tensor constructor makes a copy of the quantizer to
# save as its own quantizer. For the consequent iterations, # save as its own quantizer. For the consequent iterations,
# the same quantizer is used. Copy is needed in the first iteration, # the same quantizer is used. Copy is needed in the first iteration,
...@@ -696,6 +706,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -696,6 +706,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
dtype=param_dtype, dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
quantizer=self._quantizer, quantizer=self._quantizer,
with_gemm_swizzled_scales=False,
) )
out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage)
return out, all_gather_outputs return out, all_gather_outputs
...@@ -711,6 +722,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -711,6 +722,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
dtype: torch.dtype, dtype: torch.dtype,
shape: torch.shape, shape: torch.shape,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
with_gemm_swizzled_scales: bool = False,
) -> MXFP8Tensor: ) -> MXFP8Tensor:
"""Build MXFP8Tensor, for use in __reduce__ """Build MXFP8Tensor, for use in __reduce__
...@@ -727,6 +739,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -727,6 +739,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
dtype=dtype, dtype=dtype,
shape=shape, shape=shape,
quantizer=quantizer, quantizer=quantizer,
with_gemm_swizzled_scales=with_gemm_swizzled_scales,
) )
def __reduce_ex__(self, protocol: int) -> tuple: def __reduce_ex__(self, protocol: int) -> tuple:
...@@ -742,6 +755,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -742,6 +755,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
self.dtype, self.dtype,
self.shape, self.shape,
self._quantizer, self._quantizer,
self._with_gemm_swizzled_scales,
), ),
) )
...@@ -763,7 +777,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -763,7 +777,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
if not devices_match(new_device, tensor.device): if not devices_match(new_device, tensor.device):
tensor = tensor.to(device=new_device) tensor = tensor.to(device=new_device)
# Just copy FP8 data if other tensor is MXFP8Tensor # Just copy data if other tensor is MXFP8Tensor
if isinstance(tensor, MXFP8Tensor): if isinstance(tensor, MXFP8Tensor):
if ( # pylint: disable=too-many-boolean-expressions if ( # pylint: disable=too-many-boolean-expressions
self.size() != tensor.size() self.size() != tensor.size()
...@@ -791,6 +805,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -791,6 +805,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
self._fp8_dtype = tensor._fp8_dtype self._fp8_dtype = tensor._fp8_dtype
self._rowwise_scale_inv = tensor._rowwise_scale_inv self._rowwise_scale_inv = tensor._rowwise_scale_inv
self._columnwise_scale_inv = tensor._columnwise_scale_inv self._columnwise_scale_inv = tensor._columnwise_scale_inv
self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales
return return
# Quantize to FP8 # Quantize to FP8
...@@ -862,6 +877,7 @@ class _ViewFunc(torch.autograd.Function): ...@@ -862,6 +877,7 @@ class _ViewFunc(torch.autograd.Function):
columnwise_scale_inv=tensor._columnwise_scale_inv, columnwise_scale_inv=tensor._columnwise_scale_inv,
fp8_dtype=tensor._fp8_dtype, fp8_dtype=tensor._fp8_dtype,
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
@staticmethod @staticmethod
...@@ -888,6 +904,7 @@ class _ViewFunc(torch.autograd.Function): ...@@ -888,6 +904,7 @@ class _ViewFunc(torch.autograd.Function):
columnwise_scale_inv=grad._columnwise_scale_inv, columnwise_scale_inv=grad._columnwise_scale_inv,
fp8_dtype=grad._fp8_dtype, fp8_dtype=grad._fp8_dtype,
quantizer=grad._quantizer, quantizer=grad._quantizer,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
) )
return dgrad, None return dgrad, None
return grad.view(ctx.shape), None return grad.view(ctx.shape), None
...@@ -948,6 +965,7 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -948,6 +965,7 @@ class _ReshapeFunc(torch.autograd.Function):
columnwise_scale_inv=tensor._columnwise_scale_inv, columnwise_scale_inv=tensor._columnwise_scale_inv,
fp8_dtype=tensor._fp8_dtype, fp8_dtype=tensor._fp8_dtype,
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
@staticmethod @staticmethod
...@@ -973,6 +991,7 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -973,6 +991,7 @@ class _ReshapeFunc(torch.autograd.Function):
columnwise_scale_inv=grad._columnwise_scale_inv, columnwise_scale_inv=grad._columnwise_scale_inv,
fp8_dtype=grad._fp8_dtype, fp8_dtype=grad._fp8_dtype,
quantizer=grad._quantizer, quantizer=grad._quantizer,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
) )
return dgrad, None return dgrad, None
return grad.view(ctx.shape), None return grad.view(ctx.shape), None
...@@ -193,6 +193,7 @@ class NVFP4Quantizer(Quantizer): ...@@ -193,6 +193,7 @@ class NVFP4Quantizer(Quantizer):
stochastic_rounding=self.stochastic_rounding, stochastic_rounding=self.stochastic_rounding,
) )
quantizer.internal = self.internal quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
quantizer.rht_matrix = self.rht_matrix quantizer.rht_matrix = self.rht_matrix
quantizer.rht_matrix_random_sign_mask_t = self.rht_matrix_random_sign_mask_t quantizer.rht_matrix_random_sign_mask_t = self.rht_matrix_random_sign_mask_t
...@@ -359,6 +360,7 @@ class NVFP4Quantizer(Quantizer): ...@@ -359,6 +360,7 @@ class NVFP4Quantizer(Quantizer):
fp4_dtype=self.dtype, fp4_dtype=self.dtype,
quantizer=self, quantizer=self,
requires_grad=requires_grad, requires_grad=requires_grad,
with_gemm_swizzled_scales=False,
) )
def calibrate(self, tensor: torch.Tensor) -> None: def calibrate(self, tensor: torch.Tensor) -> None:
...@@ -418,6 +420,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -418,6 +420,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise: Optional[torch.Tensor], amax_columnwise: Optional[torch.Tensor],
fp4_dtype: TE_DType, fp4_dtype: TE_DType,
quantizer: Quantizer, quantizer: Quantizer,
with_gemm_swizzled_scales: bool,
**kwargs, **kwargs,
): ):
instance = super().__new__( instance = super().__new__(
...@@ -430,6 +433,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -430,6 +433,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise, amax_columnwise,
fp4_dtype, fp4_dtype,
quantizer, quantizer,
with_gemm_swizzled_scales,
*args, *args,
**kwargs, **kwargs,
) )
...@@ -592,6 +596,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -592,6 +596,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise=amax_columnwise, amax_columnwise=amax_columnwise,
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
requires_grad=tensor.requires_grad, requires_grad=tensor.requires_grad,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
# Default case # Default case
...@@ -610,6 +615,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -610,6 +615,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
fp4_dtype: TE_DType, fp4_dtype: TE_DType,
dtype: torch.dtype, dtype: torch.dtype,
quantizer: Quantizer, quantizer: Quantizer,
with_gemm_swizzled_scales: bool = False,
) -> NVFP4Tensor: ) -> NVFP4Tensor:
"""Build NVFP4Tensor, for use in __reduce__ """Build NVFP4Tensor, for use in __reduce__
...@@ -629,6 +635,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -629,6 +635,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise=amax_columnwise, amax_columnwise=amax_columnwise,
quantizer=quantizer, quantizer=quantizer,
requires_grad=False, requires_grad=False,
with_gemm_swizzled_scales=with_gemm_swizzled_scales,
) )
def __reduce_ex__(self, protocol: int) -> tuple: def __reduce_ex__(self, protocol: int) -> tuple:
...@@ -646,6 +653,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -646,6 +653,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
self._fp4_dtype, self._fp4_dtype,
self.dtype, self.dtype,
self._quantizer, self._quantizer,
self._with_gemm_swizzled_scales,
), ),
) )
...@@ -696,6 +704,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -696,6 +704,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
self._columnwise_scale_inv = tensor._columnwise_scale_inv self._columnwise_scale_inv = tensor._columnwise_scale_inv
self._amax_rowwise = tensor._amax_rowwise self._amax_rowwise = tensor._amax_rowwise
self._amax_columnwise = tensor._amax_columnwise self._amax_columnwise = tensor._amax_columnwise
self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales
return return
# Quantize to FP8 # Quantize to FP8
...@@ -782,6 +791,7 @@ class _ViewFunc(torch.autograd.Function): ...@@ -782,6 +791,7 @@ class _ViewFunc(torch.autograd.Function):
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
fp4_dtype=tensor._fp4_dtype, fp4_dtype=tensor._fp4_dtype,
requires_grad=tensor.requires_grad, requires_grad=tensor.requires_grad,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
@staticmethod @staticmethod
...@@ -823,6 +833,7 @@ class _ViewFunc(torch.autograd.Function): ...@@ -823,6 +833,7 @@ class _ViewFunc(torch.autograd.Function):
quantizer=grad._quantizer, quantizer=grad._quantizer,
fp4_dtype=grad._fp4_dtype, fp4_dtype=grad._fp4_dtype,
requires_grad=grad.requires_grad, requires_grad=grad.requires_grad,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
) )
return dgrad, None return dgrad, None
return grad.view(ctx.shape), None return grad.view(ctx.shape), None
...@@ -902,6 +913,7 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -902,6 +913,7 @@ class _ReshapeFunc(torch.autograd.Function):
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
fp4_dtype=tensor._fp4_dtype, fp4_dtype=tensor._fp4_dtype,
requires_grad=tensor.requires_grad, requires_grad=tensor.requires_grad,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
@staticmethod @staticmethod
...@@ -943,6 +955,7 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -943,6 +955,7 @@ class _ReshapeFunc(torch.autograd.Function):
quantizer=grad._quantizer, quantizer=grad._quantizer,
fp4_dtype=grad._fp4_dtype, fp4_dtype=grad._fp4_dtype,
requires_grad=grad.requires_grad, requires_grad=grad.requires_grad,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
) )
return dgrad, None return dgrad, None
return grad.view(ctx.shape), None return grad.view(ctx.shape), None
...@@ -11,7 +11,6 @@ import torch ...@@ -11,7 +11,6 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ...quantized_tensor import QuantizedTensorStorage, Quantizer
...@@ -36,7 +35,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -36,7 +35,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
_rowwise_scale_inv: Optional[torch.Tensor] _rowwise_scale_inv: Optional[torch.Tensor]
_columnwise_scale_inv: Optional[torch.Tensor] _columnwise_scale_inv: Optional[torch.Tensor]
_is_2D_scaled: bool _is_2D_scaled: bool
_data_format: Float8BlockScaleTensorFormat
def __new__( def __new__(
cls, cls,
...@@ -47,7 +45,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -47,7 +45,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
quantizer: Quantizer, quantizer: Quantizer,
is_2D_scaled: bool, is_2D_scaled: bool,
data_format: Float8BlockScaleTensorFormat,
*args, *args,
**kwargs, **kwargs,
): ):
...@@ -62,7 +59,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -62,7 +59,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
instance._rowwise_scale_inv = rowwise_scale_inv instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv
instance._is_2D_scaled = is_2D_scaled instance._is_2D_scaled = is_2D_scaled
instance._data_format = data_format
return instance return instance
...@@ -87,13 +83,8 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -87,13 +83,8 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
"fp8_dtype": self._fp8_dtype, "fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer, "quantizer": self._quantizer,
"is_2D_scaled": self._is_2D_scaled, "is_2D_scaled": self._is_2D_scaled,
"data_format": self._data_format,
} }
def _is_gemm_ready_format(self) -> bool:
"""Whether data is in GEMM_READY format"""
return self._data_format == Float8BlockScaleTensorFormat.GEMM_READY
def prepare_for_saving( def prepare_for_saving(
self, self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorStorage]: ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorStorage]:
...@@ -153,36 +144,18 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -153,36 +144,18 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
for i in range(len(q.shape) - 1): for i in range(len(q.shape) - 1):
q_M *= q.shape[i] q_M *= q.shape[i]
inner_q_dimension_tiled = True inner_q_dimension_tiled = True
if self._is_gemm_ready_format(): scales_tiled_dim, scales_untiled_dim = scale_inv.shape
scales_tiled_dim, scales_untiled_dim = scale_inv.shape
inner_scale_dimension_tiled = False
scales_are_compact = False
else:
scales_untiled_dim, scales_tiled_dim = scale_inv.shape
inner_scale_dimension_tiled = True
scales_are_compact = True
else: else:
assert self._columnwise_data is not None, "No data to dequantize" assert self._columnwise_data is not None, "No data to dequantize"
q = self._columnwise_data q = self._columnwise_data
scale_inv = self._columnwise_scale_inv scale_inv = self._columnwise_scale_inv
scales_tiled_dim, scales_untiled_dim = scale_inv.shape scales_tiled_dim, scales_untiled_dim = scale_inv.shape
inner_scale_dimension_tiled = False inner_q_dimension_tiled = True
if self._is_gemm_ready_format(): transpose_output = True
inner_q_dimension_tiled = True if len(q.shape) >= 1:
transpose_output = True q_M = q.shape[0]
if len(q.shape) >= 1: for i in range(1, len(q.shape)):
q_M = q.shape[0] q_K *= q.shape[i]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
scales_are_compact = False
else:
inner_q_dimension_tiled = False
transpose_output = False
if len(q.shape) >= 1:
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
scales_are_compact = True
orig_shape = q.shape orig_shape = q.shape
q = q.reshape(q_M, q_K) q = q.reshape(q_M, q_K)
...@@ -202,15 +175,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -202,15 +175,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
).contiguous() ).contiguous()
padded_M, padded_K = q.shape padded_M, padded_K = q.shape
q_tiled = q.reshape(scales_tiled_dim, block_len, q_K) q_tiled = q.reshape(scales_tiled_dim, block_len, q_K)
if not scales_are_compact and scales_untiled_dim > q_M: if scales_untiled_dim > q_M:
# untiled scale dimension is 4 element aligned. # untiled scale dimension is 4 element aligned.
scale_inv = scale_inv[:, :q_M].contiguous() scale_inv = scale_inv[:, :q_M].contiguous()
if scales_are_compact and inner_scale_dimension_tiled: dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, scales_tiled_dim, 1)
dq_scale = scale_inv.contiguous().reshape(q_M, scales_tiled_dim, 1)
elif scales_are_compact and not inner_scale_dimension_tiled:
dq_scale = scale_inv.contiguous().reshape(scales_tiled_dim, 1, q_K)
else:
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, scales_tiled_dim, 1)
torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype]
result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale
if padded_M != q_M or padded_K != q_K: if padded_M != q_M or padded_K != q_K:
...@@ -233,12 +201,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -233,12 +201,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
if not self._is_2D_scaled: if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype) return self._dequantize_vectorwise(dtype=dtype)
if not self._is_gemm_ready_format():
raise NotImplementedError(
"Dequantize is only supported with GEMM_READY data format, "
f"but found _data_format={self._data_format}"
)
def format_scale_as_logical_shape(q_K, scales, block_len): def format_scale_as_logical_shape(q_K, scales, block_len):
# The GEMM for 2D blocks required padding in the scales. # The GEMM for 2D blocks required padding in the scales.
derived_scale_k_shape = math.ceil(q_K / block_len) derived_scale_k_shape = math.ceil(q_K / block_len)
...@@ -304,8 +266,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -304,8 +266,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
if self._rowwise_data is not None: if self._rowwise_data is not None:
return self._rowwise_data.size(*args, **kwargs) return self._rowwise_data.size(*args, **kwargs)
dims = list(self._columnwise_data.size(*args, **kwargs)) dims = list(self._columnwise_data.size(*args, **kwargs))
if not self._is_gemm_ready_format(): # compact format
return torch.Size(dims)
reordered = [] reordered = []
for i in range(1, len(dims)): for i in range(1, len(dims)):
reordered.append(dims[i]) reordered.append(dims[i])
...@@ -366,7 +326,7 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -366,7 +326,7 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
return ( return (
"Float8BlockwiseQTensorStorage(" "Float8BlockwiseQTensorStorage("
f"fp8_dtype={self._fp8_dtype}, " f"fp8_dtype={self._fp8_dtype}, "
f"{descriptor}_scaled_data={data}" f"{descriptor}_scaled_data={data})"
) )
def update_usage( def update_usage(
......
...@@ -57,13 +57,23 @@ class MXFP8TensorStorage(QuantizedTensorStorage): ...@@ -57,13 +57,23 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
""" """
# Row-scaled FP8 data
_rowwise_data: Optional[torch.Tensor] _rowwise_data: Optional[torch.Tensor]
# Column-scaled FP8 data
_columnwise_data: Optional[torch.Tensor] _columnwise_data: Optional[torch.Tensor]
_quantizer: Optional[Quantizer] # Scaling factors for row-scaled FP8 data
_fp8_dtype: TE_DType
_rowwise_scale_inv: torch.Tensor _rowwise_scale_inv: torch.Tensor
# Scaling factors for column-scaled FP8 data
_columnwise_scale_inv: torch.Tensor _columnwise_scale_inv: torch.Tensor
# Builder class for casting to MXFP8
_quantizer: Optional[Quantizer]
# FP8 data type
_fp8_dtype: TE_DType
# Whether scaling factors are in the swizzled format expected by
# GEMM
_with_gemm_swizzled_scales: bool
def __new__( def __new__(
cls, cls,
rowwise_data: Optional[torch.Tensor], rowwise_data: Optional[torch.Tensor],
...@@ -72,6 +82,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage): ...@@ -72,6 +82,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
columnwise_scale_inv: Optional[torch.Tensor], columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
with_gemm_swizzled_scales: bool,
*args, *args,
**kwargs, **kwargs,
): ):
...@@ -81,10 +92,11 @@ class MXFP8TensorStorage(QuantizedTensorStorage): ...@@ -81,10 +92,11 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
instance = super().__new__(cls, *args, **kwargs) instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data instance._columnwise_data = columnwise_data
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales
return instance return instance
...@@ -108,6 +120,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage): ...@@ -108,6 +120,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
"columnwise_scale_inv": self._columnwise_scale_inv, "columnwise_scale_inv": self._columnwise_scale_inv,
"fp8_dtype": self._fp8_dtype, "fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer, "quantizer": self._quantizer,
"with_gemm_swizzled_scales": self._with_gemm_swizzled_scales,
} }
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]:
...@@ -197,6 +210,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage): ...@@ -197,6 +210,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
columnwise_scale_inv=self._columnwise_scale_inv, columnwise_scale_inv=self._columnwise_scale_inv,
fp8_dtype=self._fp8_dtype, fp8_dtype=self._fp8_dtype,
quantizer=self._quantizer, quantizer=self._quantizer,
with_gemm_swizzled_scales=self._with_gemm_swizzled_scales,
) )
def __repr__(self): def __repr__(self):
...@@ -255,7 +269,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage): ...@@ -255,7 +269,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
self._columnwise_data = None self._columnwise_data = None
self._columnwise_scale_inv = None self._columnwise_scale_inv = None
def get_usages(self) -> Tuple[bool, bool]: def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor""" """Get the usage of the tensor"""
return { return {
"rowwise": self._rowwise_data is not None, "rowwise": self._rowwise_data is not None,
......
...@@ -71,15 +71,29 @@ class NVFP4TensorStorage(QuantizedTensorStorage): ...@@ -71,15 +71,29 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
""" """
# Row-scaled FP4 data
_rowwise_data: Optional[torch.Tensor] _rowwise_data: Optional[torch.Tensor]
# Column-scaled FP4 data
_columnwise_data: Optional[torch.Tensor] _columnwise_data: Optional[torch.Tensor]
_quantizer: Optional[Quantizer] # Block scaling factors for row-scaled FP4 data
_rowwise_scale_inv: torch.Tensor _rowwise_scale_inv: torch.Tensor
# Block scaling factors for column-scaled FP4 data
_columnwise_scale_inv: torch.Tensor _columnwise_scale_inv: torch.Tensor
_fp4_dtype: TE_DType # Input absolute maximum value (used to compute tensor scale for
# row-scaled FP4 data)
_amax_rowwise: torch.Tensor _amax_rowwise: torch.Tensor
# Input absolute maximum value (used to compute tensor scale for
# column-scaled FP4 data)
_amax_columnwise: torch.Tensor _amax_columnwise: torch.Tensor
# Builder class for casting to MXFP8
_quantizer: Optional[Quantizer]
# FP4 data type
_fp4_dtype: TE_DType
# Whether scaling factors are in the swizzled format expected by
# GEMM
_with_gemm_swizzled_scales: bool
def __new__( def __new__(
cls, cls,
rowwise_data: Optional[torch.Tensor], rowwise_data: Optional[torch.Tensor],
...@@ -90,6 +104,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage): ...@@ -90,6 +104,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
amax_columnwise: torch.Tensor, amax_columnwise: torch.Tensor,
fp4_dtype: TE_DType, fp4_dtype: TE_DType,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
with_gemm_swizzled_scales: bool,
*args, *args,
**kwargs, **kwargs,
): ):
...@@ -104,6 +119,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage): ...@@ -104,6 +119,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
instance._columnwise_scale_inv = columnwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv
instance._amax_rowwise = amax_rowwise instance._amax_rowwise = amax_rowwise
instance._amax_columnwise = amax_columnwise instance._amax_columnwise = amax_columnwise
instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales
return instance return instance
...@@ -131,6 +147,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage): ...@@ -131,6 +147,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
"amax_columnwise": self._amax_columnwise, "amax_columnwise": self._amax_columnwise,
"fp4_dtype": self._fp4_dtype, "fp4_dtype": self._fp4_dtype,
"quantizer": self._quantizer, "quantizer": self._quantizer,
"with_gemm_swizzled_scales": self._with_gemm_swizzled_scales,
} }
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]:
...@@ -248,6 +265,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage): ...@@ -248,6 +265,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
amax_columnwise=self._amax_columnwise, amax_columnwise=self._amax_columnwise,
quantizer=self._quantizer, quantizer=self._quantizer,
fp4_dtype=self._fp4_dtype, fp4_dtype=self._fp4_dtype,
with_gemm_swizzled_scales=self._with_gemm_swizzled_scales,
) )
def __repr__(self): def __repr__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment