Unverified Commit 99df8810 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Add logic for block-scaled tensors with GEMM swizzled scales (#2486)



* Add general C API for setting tensor params
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Implement general accessors for NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor tex swizzling to skip if scales are already swizzled
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add checks for non-swizzled scales in MXFP8 and NVFP4 kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support pre-swizzled scales in MXFP8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tex function to swizzle MXFP8 scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in inplace swizzle function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak comments to use "compact/swizzled format"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* MXFP8 quantize kernel with pre-swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expose pre-swizzled scales in modules
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in multi-swizzle
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 gated activations with swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add PyTorch infrastructure for pre-swizzled NVFP4 tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Deprecate DSv3-specific quantization logic in C API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove support for DSv3 compact data from quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove DSv3 compact data format from core lib
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in FP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update JAX to use new swizzled scale API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update C++ swizzle test with swizzled scales API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Return default tensor params when querying params for invalid NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug DSv3 FP8 test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug Userbuffers test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure gated activations populate FP8 transpose if needed
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable pre-swizzling with debug quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflicts and review suggestions

Update copyright years. Tweak comments. Fix various complaints from @greptile-apps.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use explicitly sized types in config accessors

Miscellaneous review suggestions from @ptrendx.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make util header for function that compute swizzled scale index
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from @greptile-apps
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Update expected error message in FP8 block-scaling test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @yaox12
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent a652730f
...@@ -199,10 +199,21 @@ class Quantizer(abc.ABC): ...@@ -199,10 +199,21 @@ class Quantizer(abc.ABC):
""" """
internal: bool internal: bool
"""Whether to solely optimize for matrix multiplication
The resulting quantized tensors are not guaranteed to support any
operation other than matrix multiplication. Use with care since
this is likely to break communication, checkpointing, and many
other features.
"""
optimize_for_gemm: bool
def __init__(self, *, rowwise: bool, columnwise: bool) -> None: def __init__(self, *, rowwise: bool, columnwise: bool) -> None:
self.rowwise_usage = rowwise self.rowwise_usage = rowwise
self.columnwise_usage = columnwise self.columnwise_usage = columnwise
self.internal = False self.internal = False
self.optimize_for_gemm = False
def __repr__(self): def __repr__(self):
return ( return (
...@@ -314,7 +325,11 @@ class Quantizer(abc.ABC): ...@@ -314,7 +325,11 @@ class Quantizer(abc.ABC):
return False return False
def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument
"""Returns whether or not given tensor can be quantized""" """Whether tensor supports quantized all-gather
Consider a less misleading function name.
"""
return True return True
def get_usages(self) -> Dict[str, bool]: def get_usages(self) -> Dict[str, bool]:
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
"""Tensor class with FP8 data quantized with NxN tiles""" """Tensor class with FP8 data quantized with NxN tiles"""
from __future__ import annotations from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union from collections.abc import Iterable
import math import math
from typing import Any, Optional, Tuple, Union
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.common.recipe import Float8BlockScaling, Recipe from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..quantized_tensor import QuantizedTensor, Quantizer from ..quantized_tensor import QuantizedTensor, Quantizer
...@@ -35,8 +35,6 @@ class Float8BlockQuantizer(Quantizer): ...@@ -35,8 +35,6 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float amax_epsilon: float
force_pow_2_scales: bool force_pow_2_scales: bool
block_scaling_dim: int block_scaling_dim: int
# Whether to produce tensors that will be used in all-gather
all_gather_usage: bool
def __init__( def __init__(
self, self,
...@@ -47,7 +45,6 @@ class Float8BlockQuantizer(Quantizer): ...@@ -47,7 +45,6 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float = 0.0, amax_epsilon: float = 0.0,
force_pow_2_scales: bool = True, force_pow_2_scales: bool = True,
block_scaling_dim: int = 2, block_scaling_dim: int = 2,
all_gather_usage: bool = False,
) -> None: ) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise) super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = fp8_dtype self.dtype = fp8_dtype
...@@ -55,7 +52,6 @@ class Float8BlockQuantizer(Quantizer): ...@@ -55,7 +52,6 @@ class Float8BlockQuantizer(Quantizer):
self.force_pow_2_scales = force_pow_2_scales self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon self.amax_epsilon = amax_epsilon
self.block_scaling_dim = block_scaling_dim self.block_scaling_dim = block_scaling_dim
self.all_gather_usage = all_gather_usage
def copy(self) -> Float8BlockQuantizer: def copy(self) -> Float8BlockQuantizer:
"""Create shallow copy""" """Create shallow copy"""
...@@ -65,11 +61,11 @@ class Float8BlockQuantizer(Quantizer): ...@@ -65,11 +61,11 @@ class Float8BlockQuantizer(Quantizer):
rowwise=self.rowwise_usage, rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage, columnwise=self.columnwise_usage,
block_scaling_dim=self.block_scaling_dim, block_scaling_dim=self.block_scaling_dim,
all_gather_usage=self.all_gather_usage,
amax_epsilon=self.amax_epsilon, amax_epsilon=self.amax_epsilon,
force_pow_2_scales=self.force_pow_2_scales, force_pow_2_scales=self.force_pow_2_scales,
) )
quantizer.internal = self.internal quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
return quantizer return quantizer
...@@ -123,103 +119,86 @@ class Float8BlockQuantizer(Quantizer): ...@@ -123,103 +119,86 @@ class Float8BlockQuantizer(Quantizer):
return tex.quantize(tensor, self) return tex.quantize(tensor, self)
def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for blockwise quantization. """Scaling tensor shape.
This method determines the shape of the scaling tensor needed for blockwise quantization, This method determines the shape of the scaling tensor based
taking into account the input tensor shape and whether columnwise scaling is used. on the quantizer configuration. The scales are padded to
The scales are padded to multiples of 4 on the inner dimension for compatibility with GEMM. multiples of 4 for compatibility with GEMM.
Parameters Parameters
---------- ----------
shape : Iterable[int] shape : Iterable[int]
Shape of the input tensor to be quantized Logical tensor shape.
columnwise : bool columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False) Whether the data is scaled column-wise (True) or row-wise (False).
Returns Returns
------- -------
Tuple[int, int] Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim) Scaling tensor shape.
For 2D tensors:
- If columnwise: (roundup(K/blocksize), round_to_multiple(roundup(M/blocksize), 4))
- If rowwise: (roundup(M/blocksize), round_to_multiple(roundup(K/blocksize), 4))
For 1D tensors:
- If columnwise: (roundup(M/blocksize), round_to_multiple(K, 4))
- If rowwise: (roundup(K/blocksize), round_to_multiple(M, 4))
""" """
M, K = 1, 1
for i in range(len(shape) - 1): # Flatten tensor to 2D
M *= shape[i] dim0 = math.prod(shape[:-1])
if len(shape) > 0: dim1 = shape[-1] if shape else 1
K = shape[-1]
# 2D 128x128 quantization block scaling # Check block dims
# CuBLAS requries 128x128 scaling factor to be padded if self.block_scaling_dim not in (1, 2):
# currently rowwise and columnwise format option doesn't apply to 2D scaling raise RuntimeError(
"Only 1D or 2D blocks are supported, "
f"but got block_scaling_dim={self.block_scaling_dim}"
)
# 128x128 block scaling
if self.block_scaling_dim == 2: if self.block_scaling_dim == 2:
scale_dim0 = (dim0 + self.block_len - 1) // self.block_len
scale_dim1 = (dim1 + self.block_len - 1) // self.block_len
if columnwise: if columnwise:
outer = math.ceil(K / self.block_len) return (scale_dim1, round_up_to_nearest_multiple(scale_dim0, 4))
inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) return (scale_dim0, round_up_to_nearest_multiple(scale_dim1, 4))
return (outer, inner)
# rowwise # 1x128 block scaling
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4)
return (outer, inner)
# 1D 1x128 quantization block scaling
# CuBLAS requries 1x128 scaling factor to be padded and transposed
assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported"
if columnwise: if columnwise:
columnwise_compact = self.all_gather_usage return (
outer = math.ceil(M / self.block_len) (dim0 + self.block_len - 1) // self.block_len,
inner = round_up_to_nearest_multiple(K, 4) if not columnwise_compact else K round_up_to_nearest_multiple(dim1, 4),
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS )
# for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner] return (
# so no need to swap inner outer here (dim1 + self.block_len - 1) // self.block_len,
return (outer, inner) round_up_to_nearest_multiple(dim0, 4),
# rowwise )
rowwise_compact = self.all_gather_usage
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(M, 4) if not rowwise_compact else M
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need
# for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here
return (outer, inner) if not rowwise_compact else (inner, outer)
def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of a tensor after columnwise permutation. """Column-wise data shape
This method rearranges the dimensions of a tensor to be columnwise, GEMMs expect that the column-wise data is transposed relative
moving the last dimension to the front and keeping the order of other dimensions. to the logical tensor shape.
Parameters Parameters
---------- ----------
shape : Iterable[int] shape : Iterable[int]
Original shape of the tensor Logical tensor shape.
Returns Returns
------- -------
Tuple[int, ...] Tuple[int, ...]
New shape with dimensions rearranged for columnwise layout. Column-wise data shape.
For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1).
Returns empty tuple for empty input shape.
""" """
if len(shape) == 0: colwise_shape = []
return tuple() if shape:
# currently columnwise format option only applies to 1D quantizer colwise_shape.append(shape[-1])
# for 2D scaling, columnwise format should always be GEMM_READY_DATA_AND_SCALES colwise_shape.extend(shape[:-1])
# since currently 2D scaling only applies to module weights
if self.block_scaling_dim == 1 and self.all_gather_usage:
return shape
colwise_shape = [shape[-1]]
for i in range(len(shape) - 1):
colwise_shape.append(shape[i])
return tuple(colwise_shape) return tuple(colwise_shape)
def is_quantizable(self, inp: torch.Tensor) -> bool: def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized""" """Returns whether or not given inp can be quantized"""
if inp.ndim < 2: shape = inp.size()
if len(shape) < 2:
return False return False
if inp.shape[-1] % self.block_len != 0: if shape[-1] % self.block_len != 0:
return False return False
if math.prod(inp.shape[:-1]) % self.block_len != 0: if math.prod(shape[:-1]) % self.block_len != 0:
return False return False
return True return True
...@@ -233,44 +212,36 @@ class Float8BlockQuantizer(Quantizer): ...@@ -233,44 +212,36 @@ class Float8BlockQuantizer(Quantizer):
pin_memory: bool = False, pin_memory: bool = False,
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
"""Construct quantized tensor with uninitialized data""" """Construct quantized tensor with uninitialized data"""
if device is None:
device = torch.device("cuda")
data_format = ( tensor_kwargs = {
tex.Float8BlockScaleTensorFormat.COMPACT "device": torch.device("cuda") if device is None else device,
if self.all_gather_usage "pin_memory": pin_memory,
else tex.Float8BlockScaleTensorFormat.GEMM_READY }
)
# Allocate FP8 data # Allocate buffers for row-scaled data
data = None rowwise_data = None
scale_inv = None rowwise_scale_inv = None
if self.rowwise_usage: if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) rowwise_data = torch.empty(shape, dtype=torch.uint8, **tensor_kwargs)
scale_shape = self.get_scale_shape(shape, columnwise=False) rowwise_scale_inv = torch.empty(
scale_inv = torch.empty( self.get_scale_shape(shape, columnwise=False),
scale_shape,
dtype=torch.float32, dtype=torch.float32,
device=device, **tensor_kwargs,
pin_memory=pin_memory,
) )
# Allocate FP8 data transpose if needed # Allocate buffers for column-scaled data
columnwise_data = None columnwise_data = None
columnwise_scale_inv = None columnwise_scale_inv = None
if self.columnwise_usage: if self.columnwise_usage:
columnwise_data = torch.empty( columnwise_data = torch.empty(
self.get_columnwise_shape(shape), self.get_columnwise_shape(shape),
dtype=torch.uint8, dtype=torch.uint8,
device=device, **tensor_kwargs,
pin_memory=pin_memory,
) )
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty( columnwise_scale_inv = torch.empty(
columnwise_scale_shape, self.get_scale_shape(shape, columnwise=True),
dtype=torch.float32, dtype=torch.float32,
device=device, **tensor_kwargs,
pin_memory=pin_memory,
) )
# Construct FP8 tensor # Construct FP8 tensor
...@@ -278,13 +249,12 @@ class Float8BlockQuantizer(Quantizer): ...@@ -278,13 +249,12 @@ class Float8BlockQuantizer(Quantizer):
shape=shape, shape=shape,
dtype=dtype, dtype=dtype,
fp8_dtype=self.dtype, fp8_dtype=self.dtype,
rowwise_data=data, rowwise_data=rowwise_data,
rowwise_scale_inv=scale_inv, rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data, columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv, columnwise_scale_inv=columnwise_scale_inv,
quantizer=self, quantizer=self,
is_2D_scaled=self.block_scaling_dim == 2, is_2D_scaled=self.block_scaling_dim == 2,
data_format=data_format,
requires_grad=requires_grad, requires_grad=requires_grad,
) )
...@@ -334,7 +304,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -334,7 +304,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
quantizer: Quantizer, quantizer: Quantizer,
is_2D_scaled: bool, is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY,
**kwargs, **kwargs,
): ):
instance = super().__new__( instance = super().__new__(
...@@ -346,7 +315,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -346,7 +315,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
fp8_dtype, fp8_dtype,
quantizer, quantizer,
is_2D_scaled, is_2D_scaled,
data_format,
*args, *args,
**kwargs, **kwargs,
) )
...@@ -357,8 +325,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -357,8 +325,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
return ( return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled}," f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize(dtype=self.dtype)})," f" data={self.dequantize(dtype=self.dtype)})"
f" data_format={self._data_format}"
) )
def quantize_( def quantize_(
...@@ -509,7 +476,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -509,7 +476,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dtype: torch.dtype, dtype: torch.dtype,
quantizer: Quantizer, quantizer: Quantizer,
is_2D_scaled: bool, is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat, data_format: Any = None, # pylint: disable=unused-argument
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
"""Build Float8BlockwiseQTensor, for use in __reduce__ """Build Float8BlockwiseQTensor, for use in __reduce__
...@@ -527,7 +494,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -527,7 +494,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dtype=dtype, dtype=dtype,
quantizer=quantizer, quantizer=quantizer,
is_2D_scaled=is_2D_scaled, is_2D_scaled=is_2D_scaled,
data_format=data_format,
) )
def __reduce_ex__(self, protocol: int) -> tuple: def __reduce_ex__(self, protocol: int) -> tuple:
...@@ -544,7 +510,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -544,7 +510,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
self.dtype, self.dtype,
self._quantizer, self._quantizer,
self._is_2D_scaled, self._is_2D_scaled,
self._data_format, None, # data_format
), ),
) )
...@@ -570,7 +536,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -570,7 +536,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dst._fp8_dtype = src._fp8_dtype dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv
dst._data_format = src._data_format
# Check that tensor dimensions match # Check that tensor dimensions match
if ( if (
...@@ -618,13 +583,6 @@ class _ViewFunc(torch.autograd.Function): ...@@ -618,13 +583,6 @@ class _ViewFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided # Return input tensor if shape is not provided
ctx.shape = tensor.shape ctx.shape = tensor.shape
if shape is None: if shape is None:
...@@ -693,14 +651,6 @@ class _ViewFunc(torch.autograd.Function): ...@@ -693,14 +651,6 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor): if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_data = ( new_data = (
grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None
) )
...@@ -740,13 +690,6 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -740,13 +690,6 @@ class _ReshapeFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided # Return input tensor if shape is not provided
ctx.shape = tensor.shape ctx.shape = tensor.shape
if shape is None: if shape is None:
...@@ -814,14 +757,6 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -814,14 +757,6 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor): if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_rowwise_data = None new_rowwise_data = None
new_columnwise_data = None new_columnwise_data = None
if grad._rowwise_data is not None: if grad._rowwise_data is not None:
......
...@@ -293,6 +293,7 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -293,6 +293,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax=self.amax, amax=self.amax,
) )
quantizer.internal = self.internal quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
return quantizer return quantizer
......
...@@ -54,6 +54,7 @@ class MXFP8Quantizer(Quantizer): ...@@ -54,6 +54,7 @@ class MXFP8Quantizer(Quantizer):
columnwise=self.columnwise_usage, columnwise=self.columnwise_usage,
) )
quantizer.internal = self.internal quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
return quantizer return quantizer
...@@ -156,6 +157,7 @@ class MXFP8Quantizer(Quantizer): ...@@ -156,6 +157,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv=columnwise_scale_inv, columnwise_scale_inv=columnwise_scale_inv,
quantizer=self, quantizer=self,
requires_grad=requires_grad, requires_grad=requires_grad,
with_gemm_swizzled_scales=self.optimize_for_gemm,
) )
def calibrate(self, tensor: torch.Tensor) -> None: def calibrate(self, tensor: torch.Tensor) -> None:
...@@ -179,6 +181,7 @@ class MXFP8Quantizer(Quantizer): ...@@ -179,6 +181,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv=None, columnwise_scale_inv=None,
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
quantizer=self, quantizer=self,
with_gemm_swizzled_scales=False,
) )
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
...@@ -188,6 +191,10 @@ class MXFP8Quantizer(Quantizer): ...@@ -188,6 +191,10 @@ class MXFP8Quantizer(Quantizer):
return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32) return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32)
def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> torch.Tensor: def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> torch.Tensor:
if tensor._with_gemm_swizzled_scales:
raise NotImplementedError(
"ONNX MXFP8 dequantization is only supported with scales in compact format."
)
return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv) return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
...@@ -229,9 +236,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -229,9 +236,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv: Optional[torch.Tensor], columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
with_gemm_swizzled_scales: bool,
**kwargs, **kwargs,
): ):
instance = super().__new__( return super().__new__(
cls, cls,
rowwise_data, rowwise_data,
rowwise_scale_inv, rowwise_scale_inv,
...@@ -239,10 +247,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -239,10 +247,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv, columnwise_scale_inv,
fp8_dtype, fp8_dtype,
quantizer, quantizer,
with_gemm_swizzled_scales,
*args, *args,
**kwargs, **kwargs,
) )
return instance
def __repr__(self, *, tensor_contents=None): def __repr__(self, *, tensor_contents=None):
return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})"
...@@ -334,39 +342,44 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -334,39 +342,44 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
# View op
if func == aten.view.default: if func == aten.view.default:
tensor = args[0] tensor = args[0]
data = tensor._rowwise_data shape = args[1]
out_data = data.__torch_dispatch__( if len(shape) < 2 or shape[-1] != tensor.size(-1):
func, raise ValueError(
types, f"Attempted to make view with size={tuple(shape)} "
[data] + list(args[1:]), f"from MXFP8 tensor with shape={tuple(tensor.size())}."
kwargs,
) )
out_shape = out_data.size() rowwise_data_view = None
columnwise_data_view = None
if tensor._rowwise_data is not None:
rowwise_data_view = tensor._rowwise_data.view(shape)
if tensor._columnwise_data is not None:
columnwise_data_view = tensor._columnwise_data.view(shape)
return MXFP8Tensor( return MXFP8Tensor(
shape=out_shape, shape=shape,
dtype=tensor.dtype, dtype=tensor.dtype,
rowwise_data=out_data, rowwise_data=rowwise_data_view,
rowwise_scale_inv=tensor._rowwise_scale_inv, rowwise_scale_inv=tensor._rowwise_scale_inv,
columnwise_data=tensor._columnwise_data, columnwise_data=columnwise_data_view,
columnwise_scale_inv=tensor._columnwise_scale_inv, columnwise_scale_inv=tensor._columnwise_scale_inv,
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
requires_grad=False, requires_grad=False,
fp8_dtype=tensor._fp8_dtype, fp8_dtype=tensor._fp8_dtype,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
if func == torch.ops.aten.copy_.default: if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1] dst, src = args[0], args[1]
if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor): if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor):
# Booleans to check if src has all the usages that dst needs to respect dst quantizer usages. if src._rowwise_data is None and dst._rowwise_data is not None:
# If not, default to base class behavior. pass
rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None elif src._columnwise_data is None and dst._columnwise_data is not None:
columnwise_matches = ( pass
src._columnwise_data is not None or dst._columnwise_data is None elif src._with_gemm_swizzled_scales != dst._with_gemm_swizzled_scales:
) pass
if rowwise_matches and columnwise_matches: else:
# src and dst match, so we can directly copy data
if dst._rowwise_data is not None: if dst._rowwise_data is not None:
dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs) dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs)
dst._rowwise_scale_inv.copy_( dst._rowwise_scale_inv.copy_(
...@@ -381,26 +394,25 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -381,26 +394,25 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
) )
return dst return dst
# FSDP2 related functions.
if func == aten.split.Tensor: if func == aten.split.Tensor:
# This is called if entire model is initialized on CUDA device and # With FSDP2, this is called if entire model is
# then splitted. Finally the shard needed by the process is used # initialized on CUDA device and then splitted. Finally
# and other splitted shards are discarded. # the shard needed by the process is used and other
# splitted shards are discarded.
tensor = args[0]
split_size = args[1]
if "dim" in kwargs: if "dim" in kwargs:
dim_to_split = kwargs["dim"] dim_to_split = kwargs["dim"]
else: else:
dim_to_split = args[2] if len(args) > 2 else 0 dim_to_split = args[2] if len(args) > 2 else 0
tensor = args[0]
split_size = args[1] # Fall back to high-precision if split is non-trivial
dim0_size = tensor.size(0)
dimlast_size = math.prod(tensor.shape[1:])
if ( if (
dim0_size % split_size != 0 dim_to_split != 0
or dim_to_split != 0 or tensor.size(0) % split_size != 0
or split_size % MXFP8_BLOCK_SCALING_SIZE != 0 or split_size % MXFP8_BLOCK_SCALING_SIZE != 0
or dimlast_size % MXFP8_BLOCK_SCALING_SIZE != 0 or tensor._with_gemm_swizzled_scales
): ):
# Handle splitting by dequantizing and splitting the hp tensor
return super().__torch_dispatch__(func, types, args, kwargs) return super().__torch_dispatch__(func, types, args, kwargs)
out_data = [] out_data = []
...@@ -460,27 +472,25 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -460,27 +472,25 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
requires_grad=False, requires_grad=False,
fp8_dtype=tensor._fp8_dtype, fp8_dtype=tensor._fp8_dtype,
with_gemm_swizzled_scales=False,
) )
for splitted_tensor_data in zip(*out_data) for splitted_tensor_data in zip(*out_data)
] ]
if func == torch.ops.aten.as_strided.default: if func == torch.ops.aten.as_strided.default:
# Applied on unsharded param in FSDP2. In our case, this should be a no-op # Applied on unsharded param in FSDP2. In our case, this should be a no-op
# This is needed for the case where some MXFP8 shards need padding i.e dimension 0 # This is needed for the case where some MXFP8 shards need padding i.e dimension 0
# of the unsharded param is not a multiple of the world size. If that is the case, # of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision. # we down the dequantization route and weights are allgathered in high precision.
# If weight doesnt need padding, this is just a no-op. # If weight doesnt need padding, this is just a no-op.
tensor = args[0]
shape = args[1] shape = args[1]
strides = args[2] strides = args[2]
tensor = args[0]
if ( if (
len(shape) != 2 len(shape) == len(strides) == 2
or len(strides) != 2 and tuple(strides) == (shape[-1], 1)
or strides[1] != 1 and tuple(shape) == tuple(tensor.size())
or shape[0] != tensor.shape[0]
or shape[1] != tensor.shape[1]
): ):
return super().__torch_dispatch__(func, types, args, kwargs)
return MXFP8Tensor.make_like(tensor) return MXFP8Tensor.make_like(tensor)
if func == aten.slice.Tensor: if func == aten.slice.Tensor:
...@@ -489,18 +499,11 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -489,18 +499,11 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
# of the unsharded param is not a multiple of the world size. If that is the case, # of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision instead. # we down the dequantization route and weights are allgathered in high precision instead.
# If sharded weight doesnt have padding, this is just a no-op. # If sharded weight doesnt have padding, this is just a no-op.
tensor = args[0]
dim = args[1] dim = args[1]
start = args[2] start = args[2]
length = args[3] length = args[3]
tensor = args[0] if start == 0 and length == tensor.size(dim):
if (
dim != 0
or length != tensor.shape[0]
or start != 0
or length % MXFP8_BLOCK_SCALING_SIZE != 0
or start % MXFP8_BLOCK_SCALING_SIZE != 0
):
return super().__torch_dispatch__(func, types, args, kwargs)
return MXFP8Tensor.make_like(tensor) return MXFP8Tensor.make_like(tensor)
if func == aten.new_zeros.default: if func == aten.new_zeros.default:
...@@ -558,7 +561,9 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -558,7 +561,9 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
requires_grad=False, requires_grad=False,
fp8_dtype=tensor._fp8_dtype, fp8_dtype=tensor._fp8_dtype,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
# Default case # Default case
return super().__torch_dispatch__(func, types, args, kwargs) return super().__torch_dispatch__(func, types, args, kwargs)
...@@ -584,19 +589,24 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -584,19 +589,24 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
# pylint: disable=unused-argument # pylint: disable=unused-argument
from transformer_engine.pytorch.distributed import _get_module_fsdp_state from transformer_engine.pytorch.distributed import _get_module_fsdp_state
# Get FSDP state
fsdp_state = _get_module_fsdp_state(module) fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
# Remove padding from scale inverses before allgather # Remove padding from scale inverses before allgather
# Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128] # Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128]
rowwise_scale_inv = self._rowwise_scale_inv rowwise_scale_inv = self._rowwise_scale_inv
columnwise_scale_inv = self._columnwise_scale_inv columnwise_scale_inv = self._columnwise_scale_inv
shape = self.shape shape = self.shape
if self._with_gemm_swizzled_scales:
raise NotImplementedError(
"FSDP2 is only supported for MXFP8Tensors with compact scales"
)
if rowwise_scale_inv is not None: if rowwise_scale_inv is not None:
# Remove padding from rowwise scale_inv # Remove padding from rowwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1]) flattened_in_shape0 = math.prod(shape[:-1])
if rowwise_scale_inv.size(0) != flattened_in_shape0: if rowwise_scale_inv.size(0) != flattened_in_shape0:
rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0] rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0]
if columnwise_scale_inv is not None: if columnwise_scale_inv is not None:
# Remove padding from columnwise scale_inv # Remove padding from columnwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE
...@@ -681,7 +691,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -681,7 +691,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
out._columnwise_data = columnwise_data out._columnwise_data = columnwise_data
out._columnwise_scale_inv = columnwise_scale_inv out._columnwise_scale_inv = columnwise_scale_inv
else: else:
# We ll be here when post all gather is called the first time. # We'll be here when post all gather is called the first time.
# MXFP8Tensor constructor makes a copy of the quantizer to # MXFP8Tensor constructor makes a copy of the quantizer to
# save as its own quantizer. For the consequent iterations, # save as its own quantizer. For the consequent iterations,
# the same quantizer is used. Copy is needed in the first iteration, # the same quantizer is used. Copy is needed in the first iteration,
...@@ -696,6 +706,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -696,6 +706,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
dtype=param_dtype, dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
quantizer=self._quantizer, quantizer=self._quantizer,
with_gemm_swizzled_scales=False,
) )
out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage)
return out, all_gather_outputs return out, all_gather_outputs
...@@ -711,6 +722,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -711,6 +722,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
dtype: torch.dtype, dtype: torch.dtype,
shape: torch.shape, shape: torch.shape,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
with_gemm_swizzled_scales: bool = False,
) -> MXFP8Tensor: ) -> MXFP8Tensor:
"""Build MXFP8Tensor, for use in __reduce__ """Build MXFP8Tensor, for use in __reduce__
...@@ -727,6 +739,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -727,6 +739,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
dtype=dtype, dtype=dtype,
shape=shape, shape=shape,
quantizer=quantizer, quantizer=quantizer,
with_gemm_swizzled_scales=with_gemm_swizzled_scales,
) )
def __reduce_ex__(self, protocol: int) -> tuple: def __reduce_ex__(self, protocol: int) -> tuple:
...@@ -742,6 +755,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -742,6 +755,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
self.dtype, self.dtype,
self.shape, self.shape,
self._quantizer, self._quantizer,
self._with_gemm_swizzled_scales,
), ),
) )
...@@ -763,7 +777,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -763,7 +777,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
if not devices_match(new_device, tensor.device): if not devices_match(new_device, tensor.device):
tensor = tensor.to(device=new_device) tensor = tensor.to(device=new_device)
# Just copy FP8 data if other tensor is MXFP8Tensor # Just copy data if other tensor is MXFP8Tensor
if isinstance(tensor, MXFP8Tensor): if isinstance(tensor, MXFP8Tensor):
if ( # pylint: disable=too-many-boolean-expressions if ( # pylint: disable=too-many-boolean-expressions
self.size() != tensor.size() self.size() != tensor.size()
...@@ -791,6 +805,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -791,6 +805,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
self._fp8_dtype = tensor._fp8_dtype self._fp8_dtype = tensor._fp8_dtype
self._rowwise_scale_inv = tensor._rowwise_scale_inv self._rowwise_scale_inv = tensor._rowwise_scale_inv
self._columnwise_scale_inv = tensor._columnwise_scale_inv self._columnwise_scale_inv = tensor._columnwise_scale_inv
self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales
return return
# Quantize to FP8 # Quantize to FP8
...@@ -862,6 +877,7 @@ class _ViewFunc(torch.autograd.Function): ...@@ -862,6 +877,7 @@ class _ViewFunc(torch.autograd.Function):
columnwise_scale_inv=tensor._columnwise_scale_inv, columnwise_scale_inv=tensor._columnwise_scale_inv,
fp8_dtype=tensor._fp8_dtype, fp8_dtype=tensor._fp8_dtype,
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
@staticmethod @staticmethod
...@@ -888,6 +904,7 @@ class _ViewFunc(torch.autograd.Function): ...@@ -888,6 +904,7 @@ class _ViewFunc(torch.autograd.Function):
columnwise_scale_inv=grad._columnwise_scale_inv, columnwise_scale_inv=grad._columnwise_scale_inv,
fp8_dtype=grad._fp8_dtype, fp8_dtype=grad._fp8_dtype,
quantizer=grad._quantizer, quantizer=grad._quantizer,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
) )
return dgrad, None return dgrad, None
return grad.view(ctx.shape), None return grad.view(ctx.shape), None
...@@ -948,6 +965,7 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -948,6 +965,7 @@ class _ReshapeFunc(torch.autograd.Function):
columnwise_scale_inv=tensor._columnwise_scale_inv, columnwise_scale_inv=tensor._columnwise_scale_inv,
fp8_dtype=tensor._fp8_dtype, fp8_dtype=tensor._fp8_dtype,
quantizer=tensor._quantizer, quantizer=tensor._quantizer,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
) )
@staticmethod @staticmethod
...@@ -973,6 +991,7 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -973,6 +991,7 @@ class _ReshapeFunc(torch.autograd.Function):
columnwise_scale_inv=grad._columnwise_scale_inv, columnwise_scale_inv=grad._columnwise_scale_inv,
fp8_dtype=grad._fp8_dtype, fp8_dtype=grad._fp8_dtype,
quantizer=grad._quantizer, quantizer=grad._quantizer,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
) )
return dgrad, None return dgrad, None
return grad.view(ctx.shape), None return grad.view(ctx.shape), None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment