Unverified Commit 99df8810 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Add logic for block-scaled tensors with GEMM swizzled scales (#2486)



* Add general C API for setting tensor params
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Implement general accessors for NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor tex swizzling to skip if scales are already swizzled
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add checks for non-swizzled scales in MXFP8 and NVFP4 kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support pre-swizzled scales in MXFP8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tex function to swizzle MXFP8 scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in inplace swizzle function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak comments to use "compact/swizzled format"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* MXFP8 quantize kernel with pre-swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expose pre-swizzled scales in modules
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in multi-swizzle
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 gated activations with swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add PyTorch infrastructure for pre-swizzled NVFP4 tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Deprecate DSv3-specific quantization logic in C API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove support for DSv3 compact data from quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove DSv3 compact data format from core lib
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in FP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update JAX to use new swizzled scale API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update C++ swizzle test with swizzled scales API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Return default tensor params when querying params for invalid NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug DSv3 FP8 test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug Userbuffers test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure gated activations populate FP8 transpose if needed
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable pre-swizzling with debug quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflicts and review suggestions

Update copyright years. Tweak comments. Fix various complaints from @greptile-apps.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use explicitly sized types in config accessors

Miscellaneous review suggestions from @ptrendx.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make util header for function that compute swizzled scale index
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from @greptile-apps
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Update expected error message in FP8 block-scaling test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @yaox12
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent a652730f
...@@ -199,10 +199,21 @@ class Quantizer(abc.ABC): ...@@ -199,10 +199,21 @@ class Quantizer(abc.ABC):
""" """
internal: bool internal: bool
"""Whether to solely optimize for matrix multiplication
The resulting quantized tensors are not guaranteed to support any
operation other than matrix multiplication. Use with care since
this is likely to break communication, checkpointing, and many
other features.
"""
optimize_for_gemm: bool
def __init__(self, *, rowwise: bool, columnwise: bool) -> None: def __init__(self, *, rowwise: bool, columnwise: bool) -> None:
self.rowwise_usage = rowwise self.rowwise_usage = rowwise
self.columnwise_usage = columnwise self.columnwise_usage = columnwise
self.internal = False self.internal = False
self.optimize_for_gemm = False
def __repr__(self): def __repr__(self):
return ( return (
...@@ -314,7 +325,11 @@ class Quantizer(abc.ABC): ...@@ -314,7 +325,11 @@ class Quantizer(abc.ABC):
return False return False
def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument
"""Returns whether or not given tensor can be quantized""" """Whether tensor supports quantized all-gather
Consider a less misleading function name.
"""
return True return True
def get_usages(self) -> Dict[str, bool]: def get_usages(self) -> Dict[str, bool]:
......
...@@ -293,6 +293,7 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -293,6 +293,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax=self.amax, amax=self.amax,
) )
quantizer.internal = self.internal quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
return quantizer return quantizer
......
...@@ -193,6 +193,7 @@ class NVFP4Quantizer(Quantizer): ...@@ -193,6 +193,7 @@ class NVFP4Quantizer(Quantizer):
stochastic_rounding=self.stochastic_rounding, stochastic_rounding=self.stochastic_rounding,
) )
quantizer.internal = self.internal quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
quantizer.rht_matrix = self.rht_matrix quantizer.rht_matrix = self.rht_matrix
quantizer.rht_matrix_random_sign_mask_t = self.rht_matrix_random_sign_mask_t quantizer.rht_matrix_random_sign_mask_t = self.rht_matrix_random_sign_mask_t
...@@ -359,6 +360,7 @@ class NVFP4Quantizer(Quantizer): ...@@ -359,6 +360,7 @@ class NVFP4Quantizer(Quantizer):
fp4_dtype=self.dtype, fp4_dtype=self.dtype,
quantizer=self, quantizer=self,
requires_grad=requires_grad, requires_grad=requires_grad,
with_gemm_swizzled_scales=False,
) )
def calibrate(self, tensor: torch.Tensor) -> None: def calibrate(self, tensor: torch.Tensor) -> None:
...@@ -418,6 +420,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -418,6 +420,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise: Optional[torch.Tensor], amax_columnwise: Optional[torch.Tensor],
fp4_dtype: TE_DType, fp4_dtype: TE_DType,
quantizer: Quantizer, quantizer: Quantizer,
with_gemm_swizzled_scales: bool,
**kwargs, **kwargs,
): ):
instance = super().__new__( instance = super().__new__(
...@@ -430,6 +433,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -430,6 +433,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise, amax_columnwise,
fp4_dtype, fp4_dtype,
quantizer, quantizer,
with_gemm_swizzled_scales,
*args, *args,
**kwargs, **kwargs,
) )
...@@ -592,6 +596,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -592,6 +596,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise=amax_columnwise, amax_columnwise=amax_columnwise,
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
requires_grad=tensor.requires_grad, requires_grad=tensor.requires_grad,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
# Default case # Default case
...@@ -610,6 +615,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -610,6 +615,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
fp4_dtype: TE_DType, fp4_dtype: TE_DType,
dtype: torch.dtype, dtype: torch.dtype,
quantizer: Quantizer, quantizer: Quantizer,
with_gemm_swizzled_scales: bool = False,
) -> NVFP4Tensor: ) -> NVFP4Tensor:
"""Build NVFP4Tensor, for use in __reduce__ """Build NVFP4Tensor, for use in __reduce__
...@@ -629,6 +635,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -629,6 +635,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise=amax_columnwise, amax_columnwise=amax_columnwise,
quantizer=quantizer, quantizer=quantizer,
requires_grad=False, requires_grad=False,
with_gemm_swizzled_scales=with_gemm_swizzled_scales,
) )
def __reduce_ex__(self, protocol: int) -> tuple: def __reduce_ex__(self, protocol: int) -> tuple:
...@@ -646,6 +653,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -646,6 +653,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
self._fp4_dtype, self._fp4_dtype,
self.dtype, self.dtype,
self._quantizer, self._quantizer,
self._with_gemm_swizzled_scales,
), ),
) )
...@@ -696,6 +704,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -696,6 +704,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
self._columnwise_scale_inv = tensor._columnwise_scale_inv self._columnwise_scale_inv = tensor._columnwise_scale_inv
self._amax_rowwise = tensor._amax_rowwise self._amax_rowwise = tensor._amax_rowwise
self._amax_columnwise = tensor._amax_columnwise self._amax_columnwise = tensor._amax_columnwise
self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales
return return
# Quantize to FP8 # Quantize to FP8
...@@ -782,6 +791,7 @@ class _ViewFunc(torch.autograd.Function): ...@@ -782,6 +791,7 @@ class _ViewFunc(torch.autograd.Function):
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
fp4_dtype=tensor._fp4_dtype, fp4_dtype=tensor._fp4_dtype,
requires_grad=tensor.requires_grad, requires_grad=tensor.requires_grad,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
@staticmethod @staticmethod
...@@ -823,6 +833,7 @@ class _ViewFunc(torch.autograd.Function): ...@@ -823,6 +833,7 @@ class _ViewFunc(torch.autograd.Function):
quantizer=grad._quantizer, quantizer=grad._quantizer,
fp4_dtype=grad._fp4_dtype, fp4_dtype=grad._fp4_dtype,
requires_grad=grad.requires_grad, requires_grad=grad.requires_grad,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
) )
return dgrad, None return dgrad, None
return grad.view(ctx.shape), None return grad.view(ctx.shape), None
...@@ -902,6 +913,7 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -902,6 +913,7 @@ class _ReshapeFunc(torch.autograd.Function):
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
fp4_dtype=tensor._fp4_dtype, fp4_dtype=tensor._fp4_dtype,
requires_grad=tensor.requires_grad, requires_grad=tensor.requires_grad,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
@staticmethod @staticmethod
...@@ -943,6 +955,7 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -943,6 +955,7 @@ class _ReshapeFunc(torch.autograd.Function):
quantizer=grad._quantizer, quantizer=grad._quantizer,
fp4_dtype=grad._fp4_dtype, fp4_dtype=grad._fp4_dtype,
requires_grad=grad.requires_grad, requires_grad=grad.requires_grad,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
) )
return dgrad, None return dgrad, None
return grad.view(ctx.shape), None return grad.view(ctx.shape), None
...@@ -11,7 +11,6 @@ import torch ...@@ -11,7 +11,6 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ...quantized_tensor import QuantizedTensorStorage, Quantizer
...@@ -36,7 +35,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -36,7 +35,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
_rowwise_scale_inv: Optional[torch.Tensor] _rowwise_scale_inv: Optional[torch.Tensor]
_columnwise_scale_inv: Optional[torch.Tensor] _columnwise_scale_inv: Optional[torch.Tensor]
_is_2D_scaled: bool _is_2D_scaled: bool
_data_format: Float8BlockScaleTensorFormat
def __new__( def __new__(
cls, cls,
...@@ -47,7 +45,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -47,7 +45,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
quantizer: Quantizer, quantizer: Quantizer,
is_2D_scaled: bool, is_2D_scaled: bool,
data_format: Float8BlockScaleTensorFormat,
*args, *args,
**kwargs, **kwargs,
): ):
...@@ -62,7 +59,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -62,7 +59,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
instance._rowwise_scale_inv = rowwise_scale_inv instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv
instance._is_2D_scaled = is_2D_scaled instance._is_2D_scaled = is_2D_scaled
instance._data_format = data_format
return instance return instance
...@@ -87,13 +83,8 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -87,13 +83,8 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
"fp8_dtype": self._fp8_dtype, "fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer, "quantizer": self._quantizer,
"is_2D_scaled": self._is_2D_scaled, "is_2D_scaled": self._is_2D_scaled,
"data_format": self._data_format,
} }
def _is_gemm_ready_format(self) -> bool:
"""Whether data is in GEMM_READY format"""
return self._data_format == Float8BlockScaleTensorFormat.GEMM_READY
def prepare_for_saving( def prepare_for_saving(
self, self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorStorage]: ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorStorage]:
...@@ -153,36 +144,18 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -153,36 +144,18 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
for i in range(len(q.shape) - 1): for i in range(len(q.shape) - 1):
q_M *= q.shape[i] q_M *= q.shape[i]
inner_q_dimension_tiled = True inner_q_dimension_tiled = True
if self._is_gemm_ready_format(): scales_tiled_dim, scales_untiled_dim = scale_inv.shape
scales_tiled_dim, scales_untiled_dim = scale_inv.shape
inner_scale_dimension_tiled = False
scales_are_compact = False
else:
scales_untiled_dim, scales_tiled_dim = scale_inv.shape
inner_scale_dimension_tiled = True
scales_are_compact = True
else: else:
assert self._columnwise_data is not None, "No data to dequantize" assert self._columnwise_data is not None, "No data to dequantize"
q = self._columnwise_data q = self._columnwise_data
scale_inv = self._columnwise_scale_inv scale_inv = self._columnwise_scale_inv
scales_tiled_dim, scales_untiled_dim = scale_inv.shape scales_tiled_dim, scales_untiled_dim = scale_inv.shape
inner_scale_dimension_tiled = False inner_q_dimension_tiled = True
if self._is_gemm_ready_format(): transpose_output = True
inner_q_dimension_tiled = True if len(q.shape) >= 1:
transpose_output = True q_M = q.shape[0]
if len(q.shape) >= 1: for i in range(1, len(q.shape)):
q_M = q.shape[0] q_K *= q.shape[i]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
scales_are_compact = False
else:
inner_q_dimension_tiled = False
transpose_output = False
if len(q.shape) >= 1:
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
scales_are_compact = True
orig_shape = q.shape orig_shape = q.shape
q = q.reshape(q_M, q_K) q = q.reshape(q_M, q_K)
...@@ -202,15 +175,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -202,15 +175,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
).contiguous() ).contiguous()
padded_M, padded_K = q.shape padded_M, padded_K = q.shape
q_tiled = q.reshape(scales_tiled_dim, block_len, q_K) q_tiled = q.reshape(scales_tiled_dim, block_len, q_K)
if not scales_are_compact and scales_untiled_dim > q_M: if scales_untiled_dim > q_M:
# untiled scale dimension is 4 element aligned. # untiled scale dimension is 4 element aligned.
scale_inv = scale_inv[:, :q_M].contiguous() scale_inv = scale_inv[:, :q_M].contiguous()
if scales_are_compact and inner_scale_dimension_tiled: dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, scales_tiled_dim, 1)
dq_scale = scale_inv.contiguous().reshape(q_M, scales_tiled_dim, 1)
elif scales_are_compact and not inner_scale_dimension_tiled:
dq_scale = scale_inv.contiguous().reshape(scales_tiled_dim, 1, q_K)
else:
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, scales_tiled_dim, 1)
torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype]
result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale
if padded_M != q_M or padded_K != q_K: if padded_M != q_M or padded_K != q_K:
...@@ -233,12 +201,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -233,12 +201,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
if not self._is_2D_scaled: if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype) return self._dequantize_vectorwise(dtype=dtype)
if not self._is_gemm_ready_format():
raise NotImplementedError(
"Dequantize is only supported with GEMM_READY data format, "
f"but found _data_format={self._data_format}"
)
def format_scale_as_logical_shape(q_K, scales, block_len): def format_scale_as_logical_shape(q_K, scales, block_len):
# The GEMM for 2D blocks required padding in the scales. # The GEMM for 2D blocks required padding in the scales.
derived_scale_k_shape = math.ceil(q_K / block_len) derived_scale_k_shape = math.ceil(q_K / block_len)
...@@ -304,8 +266,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -304,8 +266,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
if self._rowwise_data is not None: if self._rowwise_data is not None:
return self._rowwise_data.size(*args, **kwargs) return self._rowwise_data.size(*args, **kwargs)
dims = list(self._columnwise_data.size(*args, **kwargs)) dims = list(self._columnwise_data.size(*args, **kwargs))
if not self._is_gemm_ready_format(): # compact format
return torch.Size(dims)
reordered = [] reordered = []
for i in range(1, len(dims)): for i in range(1, len(dims)):
reordered.append(dims[i]) reordered.append(dims[i])
...@@ -366,7 +326,7 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): ...@@ -366,7 +326,7 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
return ( return (
"Float8BlockwiseQTensorStorage(" "Float8BlockwiseQTensorStorage("
f"fp8_dtype={self._fp8_dtype}, " f"fp8_dtype={self._fp8_dtype}, "
f"{descriptor}_scaled_data={data}" f"{descriptor}_scaled_data={data})"
) )
def update_usage( def update_usage(
......
...@@ -57,13 +57,23 @@ class MXFP8TensorStorage(QuantizedTensorStorage): ...@@ -57,13 +57,23 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
""" """
# Row-scaled FP8 data
_rowwise_data: Optional[torch.Tensor] _rowwise_data: Optional[torch.Tensor]
# Column-scaled FP8 data
_columnwise_data: Optional[torch.Tensor] _columnwise_data: Optional[torch.Tensor]
_quantizer: Optional[Quantizer] # Scaling factors for row-scaled FP8 data
_fp8_dtype: TE_DType
_rowwise_scale_inv: torch.Tensor _rowwise_scale_inv: torch.Tensor
# Scaling factors for column-scaled FP8 data
_columnwise_scale_inv: torch.Tensor _columnwise_scale_inv: torch.Tensor
# Builder class for casting to MXFP8
_quantizer: Optional[Quantizer]
# FP8 data type
_fp8_dtype: TE_DType
# Whether scaling factors are in the swizzled format expected by
# GEMM
_with_gemm_swizzled_scales: bool
def __new__( def __new__(
cls, cls,
rowwise_data: Optional[torch.Tensor], rowwise_data: Optional[torch.Tensor],
...@@ -72,6 +82,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage): ...@@ -72,6 +82,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
columnwise_scale_inv: Optional[torch.Tensor], columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
with_gemm_swizzled_scales: bool,
*args, *args,
**kwargs, **kwargs,
): ):
...@@ -81,10 +92,11 @@ class MXFP8TensorStorage(QuantizedTensorStorage): ...@@ -81,10 +92,11 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
instance = super().__new__(cls, *args, **kwargs) instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data instance._columnwise_data = columnwise_data
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales
return instance return instance
...@@ -108,6 +120,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage): ...@@ -108,6 +120,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
"columnwise_scale_inv": self._columnwise_scale_inv, "columnwise_scale_inv": self._columnwise_scale_inv,
"fp8_dtype": self._fp8_dtype, "fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer, "quantizer": self._quantizer,
"with_gemm_swizzled_scales": self._with_gemm_swizzled_scales,
} }
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]:
...@@ -197,6 +210,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage): ...@@ -197,6 +210,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
columnwise_scale_inv=self._columnwise_scale_inv, columnwise_scale_inv=self._columnwise_scale_inv,
fp8_dtype=self._fp8_dtype, fp8_dtype=self._fp8_dtype,
quantizer=self._quantizer, quantizer=self._quantizer,
with_gemm_swizzled_scales=self._with_gemm_swizzled_scales,
) )
def __repr__(self): def __repr__(self):
...@@ -255,7 +269,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage): ...@@ -255,7 +269,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
self._columnwise_data = None self._columnwise_data = None
self._columnwise_scale_inv = None self._columnwise_scale_inv = None
def get_usages(self) -> Tuple[bool, bool]: def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor""" """Get the usage of the tensor"""
return { return {
"rowwise": self._rowwise_data is not None, "rowwise": self._rowwise_data is not None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment