Unverified Commit 99df8810 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Add logic for block-scaled tensors with GEMM swizzled scales (#2486)



* Add general C API for setting tensor params
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Implement general accessors for NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor tex swizzling to skip if scales are already swizzled
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add checks for non-swizzled scales in MXFP8 and NVFP4 kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support pre-swizzled scales in MXFP8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tex function to swizzle MXFP8 scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in inplace swizzle function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweak comments to use "compact/swizzled format"
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* MXFP8 quantize kernel with pre-swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expose pre-swizzled scales in modules
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in multi-swizzle
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support MXFP8 gated activations with swizzled scales
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add PyTorch infrastructure for pre-swizzled NVFP4 tensors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Deprecate DSv3-specific quantization logic in C API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove support for DSv3 compact data from quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove DSv3 compact data format from core lib
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug in FP8 all-gather
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update JAX to use new swizzled scale API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update C++ swizzle test with swizzled scales API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Return default tensor params when querying params for invalid NVTETensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug DSv3 FP8 test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug Userbuffers test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure gated activations populate FP8 transpose if needed
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable pre-swizzling with debug quantizer
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflicts and review suggestions

Update copyright years. Tweak comments. Fix various complaints from @greptile-apps.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use explicitly sized types in config accessors

Miscellaneous review suggestions from @ptrendx.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make util header for function that compute swizzled scale index
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from @greptile-apps
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Update expected error message in FP8 block-scaling test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @yaox12
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent a652730f
......@@ -199,10 +199,21 @@ class Quantizer(abc.ABC):
"""
internal: bool
"""Whether to solely optimize for matrix multiplication
The resulting quantized tensors are not guaranteed to support any
operation other than matrix multiplication. Use with care since
this is likely to break communication, checkpointing, and many
other features.
"""
optimize_for_gemm: bool
def __init__(self, *, rowwise: bool, columnwise: bool) -> None:
self.rowwise_usage = rowwise
self.columnwise_usage = columnwise
self.internal = False
self.optimize_for_gemm = False
def __repr__(self):
return (
......@@ -314,7 +325,11 @@ class Quantizer(abc.ABC):
return False
def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument
"""Returns whether or not given tensor can be quantized"""
"""Whether tensor supports quantized all-gather
Consider a less misleading function name.
"""
return True
def get_usages(self) -> Dict[str, bool]:
......
......@@ -4,14 +4,14 @@
"""Tensor class with FP8 data quantized with NxN tiles"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union
from collections.abc import Iterable
import math
from typing import Any, Optional, Tuple, Union
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..quantized_tensor import QuantizedTensor, Quantizer
......@@ -35,8 +35,6 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float
force_pow_2_scales: bool
block_scaling_dim: int
# Whether to produce tensors that will be used in all-gather
all_gather_usage: bool
def __init__(
self,
......@@ -47,7 +45,6 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float = 0.0,
force_pow_2_scales: bool = True,
block_scaling_dim: int = 2,
all_gather_usage: bool = False,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = fp8_dtype
......@@ -55,7 +52,6 @@ class Float8BlockQuantizer(Quantizer):
self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon
self.block_scaling_dim = block_scaling_dim
self.all_gather_usage = all_gather_usage
def copy(self) -> Float8BlockQuantizer:
"""Create shallow copy"""
......@@ -65,11 +61,11 @@ class Float8BlockQuantizer(Quantizer):
rowwise=self.rowwise_usage,
columnwise=self.columnwise_usage,
block_scaling_dim=self.block_scaling_dim,
all_gather_usage=self.all_gather_usage,
amax_epsilon=self.amax_epsilon,
force_pow_2_scales=self.force_pow_2_scales,
)
quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
return quantizer
......@@ -123,103 +119,86 @@ class Float8BlockQuantizer(Quantizer):
return tex.quantize(tensor, self)
def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for blockwise quantization.
"""Scaling tensor shape.
This method determines the shape of the scaling tensor needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
The scales are padded to multiples of 4 on the inner dimension for compatibility with GEMM.
This method determines the shape of the scaling tensor based
on the quantizer configuration. The scales are padded to
multiples of 4 for compatibility with GEMM.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
Logical tensor shape.
columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False)
Whether the data is scaled column-wise (True) or row-wise (False).
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For 2D tensors:
- If columnwise: (roundup(K/blocksize), round_to_multiple(roundup(M/blocksize), 4))
- If rowwise: (roundup(M/blocksize), round_to_multiple(roundup(K/blocksize), 4))
For 1D tensors:
- If columnwise: (roundup(M/blocksize), round_to_multiple(K, 4))
- If rowwise: (roundup(K/blocksize), round_to_multiple(M, 4))
Scaling tensor shape.
"""
M, K = 1, 1
for i in range(len(shape) - 1):
M *= shape[i]
if len(shape) > 0:
K = shape[-1]
# 2D 128x128 quantization block scaling
# CuBLAS requries 128x128 scaling factor to be padded
# currently rowwise and columnwise format option doesn't apply to 2D scaling
# Flatten tensor to 2D
dim0 = math.prod(shape[:-1])
dim1 = shape[-1] if shape else 1
# Check block dims
if self.block_scaling_dim not in (1, 2):
raise RuntimeError(
"Only 1D or 2D blocks are supported, "
f"but got block_scaling_dim={self.block_scaling_dim}"
)
# 128x128 block scaling
if self.block_scaling_dim == 2:
scale_dim0 = (dim0 + self.block_len - 1) // self.block_len
scale_dim1 = (dim1 + self.block_len - 1) // self.block_len
if columnwise:
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4)
return (outer, inner)
# rowwise
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4)
return (outer, inner)
# 1D 1x128 quantization block scaling
# CuBLAS requries 1x128 scaling factor to be padded and transposed
assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported"
return (scale_dim1, round_up_to_nearest_multiple(scale_dim0, 4))
return (scale_dim0, round_up_to_nearest_multiple(scale_dim1, 4))
# 1x128 block scaling
if columnwise:
columnwise_compact = self.all_gather_usage
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(K, 4) if not columnwise_compact else K
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS
# for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner]
# so no need to swap inner outer here
return (outer, inner)
# rowwise
rowwise_compact = self.all_gather_usage
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(M, 4) if not rowwise_compact else M
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need
# for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here
return (outer, inner) if not rowwise_compact else (inner, outer)
return (
(dim0 + self.block_len - 1) // self.block_len,
round_up_to_nearest_multiple(dim1, 4),
)
return (
(dim1 + self.block_len - 1) // self.block_len,
round_up_to_nearest_multiple(dim0, 4),
)
def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of a tensor after columnwise permutation.
"""Column-wise data shape
This method rearranges the dimensions of a tensor to be columnwise,
moving the last dimension to the front and keeping the order of other dimensions.
GEMMs expect that the column-wise data is transposed relative
to the logical tensor shape.
Parameters
----------
shape : Iterable[int]
Original shape of the tensor
Logical tensor shape.
Returns
-------
Tuple[int, ...]
New shape with dimensions rearranged for columnwise layout.
For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1).
Returns empty tuple for empty input shape.
Column-wise data shape.
"""
if len(shape) == 0:
return tuple()
# currently columnwise format option only applies to 1D quantizer
# for 2D scaling, columnwise format should always be GEMM_READY_DATA_AND_SCALES
# since currently 2D scaling only applies to module weights
if self.block_scaling_dim == 1 and self.all_gather_usage:
return shape
colwise_shape = [shape[-1]]
for i in range(len(shape) - 1):
colwise_shape.append(shape[i])
colwise_shape = []
if shape:
colwise_shape.append(shape[-1])
colwise_shape.extend(shape[:-1])
return tuple(colwise_shape)
def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized"""
if inp.ndim < 2:
shape = inp.size()
if len(shape) < 2:
return False
if inp.shape[-1] % self.block_len != 0:
if shape[-1] % self.block_len != 0:
return False
if math.prod(inp.shape[:-1]) % self.block_len != 0:
if math.prod(shape[:-1]) % self.block_len != 0:
return False
return True
......@@ -233,44 +212,36 @@ class Float8BlockQuantizer(Quantizer):
pin_memory: bool = False,
) -> Float8BlockwiseQTensor:
"""Construct quantized tensor with uninitialized data"""
if device is None:
device = torch.device("cuda")
data_format = (
tex.Float8BlockScaleTensorFormat.COMPACT
if self.all_gather_usage
else tex.Float8BlockScaleTensorFormat.GEMM_READY
)
tensor_kwargs = {
"device": torch.device("cuda") if device is None else device,
"pin_memory": pin_memory,
}
# Allocate FP8 data
data = None
scale_inv = None
# Allocate buffers for row-scaled data
rowwise_data = None
rowwise_scale_inv = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(
scale_shape,
rowwise_data = torch.empty(shape, dtype=torch.uint8, **tensor_kwargs)
rowwise_scale_inv = torch.empty(
self.get_scale_shape(shape, columnwise=False),
dtype=torch.float32,
device=device,
pin_memory=pin_memory,
**tensor_kwargs,
)
# Allocate FP8 data transpose if needed
# Allocate buffers for column-scaled data
columnwise_data = None
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty(
self.get_columnwise_shape(shape),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
**tensor_kwargs,
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape,
self.get_scale_shape(shape, columnwise=True),
dtype=torch.float32,
device=device,
pin_memory=pin_memory,
**tensor_kwargs,
)
# Construct FP8 tensor
......@@ -278,13 +249,12 @@ class Float8BlockQuantizer(Quantizer):
shape=shape,
dtype=dtype,
fp8_dtype=self.dtype,
rowwise_data=data,
rowwise_scale_inv=scale_inv,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
quantizer=self,
is_2D_scaled=self.block_scaling_dim == 2,
data_format=data_format,
requires_grad=requires_grad,
)
......@@ -334,7 +304,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
fp8_dtype: TE_DType,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY,
**kwargs,
):
instance = super().__new__(
......@@ -346,7 +315,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
fp8_dtype,
quantizer,
is_2D_scaled,
data_format,
*args,
**kwargs,
)
......@@ -357,8 +325,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize(dtype=self.dtype)}),"
f" data_format={self._data_format}"
f" data={self.dequantize(dtype=self.dtype)})"
)
def quantize_(
......@@ -509,7 +476,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dtype: torch.dtype,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat,
data_format: Any = None, # pylint: disable=unused-argument
) -> Float8BlockwiseQTensor:
"""Build Float8BlockwiseQTensor, for use in __reduce__
......@@ -527,7 +494,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dtype=dtype,
quantizer=quantizer,
is_2D_scaled=is_2D_scaled,
data_format=data_format,
)
def __reduce_ex__(self, protocol: int) -> tuple:
......@@ -544,7 +510,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
self.dtype,
self._quantizer,
self._is_2D_scaled,
self._data_format,
None, # data_format
),
)
......@@ -570,7 +536,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv
dst._data_format = src._data_format
# Check that tensor dimensions match
if (
......@@ -618,13 +583,6 @@ class _ViewFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
......@@ -693,14 +651,6 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_data = (
grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None
)
......@@ -740,13 +690,6 @@ class _ReshapeFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
......@@ -814,14 +757,6 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_rowwise_data = None
new_columnwise_data = None
if grad._rowwise_data is not None:
......
......@@ -293,6 +293,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax=self.amax,
)
quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
return quantizer
......
......@@ -54,6 +54,7 @@ class MXFP8Quantizer(Quantizer):
columnwise=self.columnwise_usage,
)
quantizer.internal = self.internal
quantizer.optimize_for_gemm = self.optimize_for_gemm
return quantizer
......@@ -156,6 +157,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv=columnwise_scale_inv,
quantizer=self,
requires_grad=requires_grad,
with_gemm_swizzled_scales=self.optimize_for_gemm,
)
def calibrate(self, tensor: torch.Tensor) -> None:
......@@ -179,6 +181,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
quantizer=self,
with_gemm_swizzled_scales=False,
)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
......@@ -188,6 +191,10 @@ class MXFP8Quantizer(Quantizer):
return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32)
def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> torch.Tensor:
if tensor._with_gemm_swizzled_scales:
raise NotImplementedError(
"ONNX MXFP8 dequantization is only supported with scales in compact format."
)
return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
......@@ -229,9 +236,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType,
quantizer: Optional[Quantizer],
with_gemm_swizzled_scales: bool,
**kwargs,
):
instance = super().__new__(
return super().__new__(
cls,
rowwise_data,
rowwise_scale_inv,
......@@ -239,10 +247,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv,
fp8_dtype,
quantizer,
with_gemm_swizzled_scales,
*args,
**kwargs,
)
return instance
def __repr__(self, *, tensor_contents=None):
return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})"
......@@ -334,39 +342,44 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# View op
if func == aten.view.default:
tensor = args[0]
data = tensor._rowwise_data
out_data = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
shape = args[1]
if len(shape) < 2 or shape[-1] != tensor.size(-1):
raise ValueError(
f"Attempted to make view with size={tuple(shape)} "
f"from MXFP8 tensor with shape={tuple(tensor.size())}."
)
out_shape = out_data.size()
rowwise_data_view = None
columnwise_data_view = None
if tensor._rowwise_data is not None:
rowwise_data_view = tensor._rowwise_data.view(shape)
if tensor._columnwise_data is not None:
columnwise_data_view = tensor._columnwise_data.view(shape)
return MXFP8Tensor(
shape=out_shape,
shape=shape,
dtype=tensor.dtype,
rowwise_data=out_data,
rowwise_data=rowwise_data_view,
rowwise_scale_inv=tensor._rowwise_scale_inv,
columnwise_data=tensor._columnwise_data,
columnwise_data=columnwise_data_view,
columnwise_scale_inv=tensor._columnwise_scale_inv,
quantizer=tensor._quantizer,
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
)
if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1]
if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor):
# Booleans to check if src has all the usages that dst needs to respect dst quantizer usages.
# If not, default to base class behavior.
rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None
columnwise_matches = (
src._columnwise_data is not None or dst._columnwise_data is None
)
if rowwise_matches and columnwise_matches:
if src._rowwise_data is None and dst._rowwise_data is not None:
pass
elif src._columnwise_data is None and dst._columnwise_data is not None:
pass
elif src._with_gemm_swizzled_scales != dst._with_gemm_swizzled_scales:
pass
else:
# src and dst match, so we can directly copy data
if dst._rowwise_data is not None:
dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs)
dst._rowwise_scale_inv.copy_(
......@@ -381,26 +394,25 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
)
return dst
# FSDP2 related functions.
if func == aten.split.Tensor:
# This is called if entire model is initialized on CUDA device and
# then splitted. Finally the shard needed by the process is used
# and other splitted shards are discarded.
# With FSDP2, this is called if entire model is
# initialized on CUDA device and then splitted. Finally
# the shard needed by the process is used and other
# splitted shards are discarded.
tensor = args[0]
split_size = args[1]
if "dim" in kwargs:
dim_to_split = kwargs["dim"]
else:
dim_to_split = args[2] if len(args) > 2 else 0
tensor = args[0]
split_size = args[1]
dim0_size = tensor.size(0)
dimlast_size = math.prod(tensor.shape[1:])
# Fall back to high-precision if split is non-trivial
if (
dim0_size % split_size != 0
or dim_to_split != 0
dim_to_split != 0
or tensor.size(0) % split_size != 0
or split_size % MXFP8_BLOCK_SCALING_SIZE != 0
or dimlast_size % MXFP8_BLOCK_SCALING_SIZE != 0
or tensor._with_gemm_swizzled_scales
):
# Handle splitting by dequantizing and splitting the hp tensor
return super().__torch_dispatch__(func, types, args, kwargs)
out_data = []
......@@ -460,27 +472,25 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
quantizer=tensor._quantizer,
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
with_gemm_swizzled_scales=False,
)
for splitted_tensor_data in zip(*out_data)
]
if func == torch.ops.aten.as_strided.default:
# Applied on unsharded param in FSDP2. In our case, this should be a no-op
# This is needed for the case where some MXFP8 shards need padding i.e dimension 0
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision.
# If weight doesnt need padding, this is just a no-op.
tensor = args[0]
shape = args[1]
strides = args[2]
tensor = args[0]
if (
len(shape) != 2
or len(strides) != 2
or strides[1] != 1
or shape[0] != tensor.shape[0]
or shape[1] != tensor.shape[1]
len(shape) == len(strides) == 2
and tuple(strides) == (shape[-1], 1)
and tuple(shape) == tuple(tensor.size())
):
return super().__torch_dispatch__(func, types, args, kwargs)
return MXFP8Tensor.make_like(tensor)
if func == aten.slice.Tensor:
......@@ -489,18 +499,11 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision instead.
# If sharded weight doesnt have padding, this is just a no-op.
tensor = args[0]
dim = args[1]
start = args[2]
length = args[3]
tensor = args[0]
if (
dim != 0
or length != tensor.shape[0]
or start != 0
or length % MXFP8_BLOCK_SCALING_SIZE != 0
or start % MXFP8_BLOCK_SCALING_SIZE != 0
):
return super().__torch_dispatch__(func, types, args, kwargs)
if start == 0 and length == tensor.size(dim):
return MXFP8Tensor.make_like(tensor)
if func == aten.new_zeros.default:
......@@ -558,7 +561,9 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
quantizer=tensor._quantizer,
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
)
# Default case
return super().__torch_dispatch__(func, types, args, kwargs)
......@@ -584,19 +589,24 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
# pylint: disable=unused-argument
from transformer_engine.pytorch.distributed import _get_module_fsdp_state
# Get FSDP state
fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
# Remove padding from scale inverses before allgather
# Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128]
rowwise_scale_inv = self._rowwise_scale_inv
columnwise_scale_inv = self._columnwise_scale_inv
shape = self.shape
if self._with_gemm_swizzled_scales:
raise NotImplementedError(
"FSDP2 is only supported for MXFP8Tensors with compact scales"
)
if rowwise_scale_inv is not None:
# Remove padding from rowwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1])
if rowwise_scale_inv.size(0) != flattened_in_shape0:
rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0]
if columnwise_scale_inv is not None:
# Remove padding from columnwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE
......@@ -681,7 +691,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
out._columnwise_data = columnwise_data
out._columnwise_scale_inv = columnwise_scale_inv
else:
# We ll be here when post all gather is called the first time.
# We'll be here when post all gather is called the first time.
# MXFP8Tensor constructor makes a copy of the quantizer to
# save as its own quantizer. For the consequent iterations,
# the same quantizer is used. Copy is needed in the first iteration,
......@@ -696,6 +706,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
quantizer=self._quantizer,
with_gemm_swizzled_scales=False,
)
out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage)
return out, all_gather_outputs
......@@ -711,6 +722,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
dtype: torch.dtype,
shape: torch.shape,
quantizer: Optional[Quantizer] = None,
with_gemm_swizzled_scales: bool = False,
) -> MXFP8Tensor:
"""Build MXFP8Tensor, for use in __reduce__
......@@ -727,6 +739,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
dtype=dtype,
shape=shape,
quantizer=quantizer,
with_gemm_swizzled_scales=with_gemm_swizzled_scales,
)
def __reduce_ex__(self, protocol: int) -> tuple:
......@@ -742,6 +755,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
self.dtype,
self.shape,
self._quantizer,
self._with_gemm_swizzled_scales,
),
)
......@@ -763,7 +777,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
if not devices_match(new_device, tensor.device):
tensor = tensor.to(device=new_device)
# Just copy FP8 data if other tensor is MXFP8Tensor
# Just copy data if other tensor is MXFP8Tensor
if isinstance(tensor, MXFP8Tensor):
if ( # pylint: disable=too-many-boolean-expressions
self.size() != tensor.size()
......@@ -791,6 +805,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
self._fp8_dtype = tensor._fp8_dtype
self._rowwise_scale_inv = tensor._rowwise_scale_inv
self._columnwise_scale_inv = tensor._columnwise_scale_inv
self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales
return
# Quantize to FP8
......@@ -862,6 +877,7 @@ class _ViewFunc(torch.autograd.Function):
columnwise_scale_inv=tensor._columnwise_scale_inv,
fp8_dtype=tensor._fp8_dtype,
quantizer=tensor._quantizer,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
)
@staticmethod
......@@ -888,6 +904,7 @@ class _ViewFunc(torch.autograd.Function):
columnwise_scale_inv=grad._columnwise_scale_inv,
fp8_dtype=grad._fp8_dtype,
quantizer=grad._quantizer,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
)
return dgrad, None
return grad.view(ctx.shape), None
......@@ -948,6 +965,7 @@ class _ReshapeFunc(torch.autograd.Function):
columnwise_scale_inv=tensor._columnwise_scale_inv,
fp8_dtype=tensor._fp8_dtype,
quantizer=tensor._quantizer,
with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales,
)
@staticmethod
......@@ -973,6 +991,7 @@ class _ReshapeFunc(torch.autograd.Function):
columnwise_scale_inv=grad._columnwise_scale_inv,
fp8_dtype=grad._fp8_dtype,
quantizer=grad._quantizer,
with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales,
)
return dgrad, None
return grad.view(ctx.shape), None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment