Unverified Commit 9985b02c authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch] FP8 Subchannel Recipe With FP8 Gather And Configurable Scaling...


[PyTorch] FP8 Subchannel Recipe With FP8 Gather And Configurable Scaling Factor Tensor Swizzling (#1707)

* functional kernel for columnwise + no-transpose option, still hacky
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* pass all quantizer unit tests
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor, add gemm ready api
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make format options private members, simplify api
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* swizzle scales right before gemm
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* bug fix of single layer test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* attempt to fix lint issue
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fp8 gather pass, need minor refine
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix return_layernorm_output_gathered case
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* remove special cases, add sanity check before gemm
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint ungrouped imports
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Implement dequantize for compact 1D blocks.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* add more unit test with dequantize compact supported
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint again
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* make ag for subchannel respect async
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* zero tolerance in distributed test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix zero tolerance test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve rebase issues
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint & format
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* bug fix
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* relax rtol for fp32 distributed test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix some ci issue
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix ci test failure in debug mode
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Force row-wise and column-wise data to have same data format

Prototype "all-gather usage" in quantizer.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove dead logic for high-precision AGs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug FP8 block-wise tensor tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug distributed test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Handle case where LayerNormLinear returns gathered norm output
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix debug mode
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKeith Wyss <kwyss@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 6123d7e0
......@@ -66,7 +66,7 @@ from ..tensor.quantized_tensor import (
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
......@@ -136,12 +136,6 @@ class _Linear(torch.autograd.Function):
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_input_gather = (
fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer)
)
# Configure Userbuffers communication (comm+GEMM overlap)
ub_obj = None
ub_type = None
......@@ -168,7 +162,7 @@ class _Linear(torch.autograd.Function):
if fp8 or debug:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if not force_hp_input_gather and not isinstance(inputmat, QuantizedTensorBase):
if not isinstance(inputmat, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
......@@ -347,10 +341,11 @@ class _Linear(torch.autograd.Function):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
if (
isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
or not ctx.backward_input_needs_gather
):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if force_hp_input_gather:
assert not isinstance(inputmat, QuantizedTensorBase)
saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
......@@ -396,7 +391,6 @@ class _Linear(torch.autograd.Function):
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_input_gather = force_hp_input_gather
ctx.input_quantizer = input_quantizer
ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer
......@@ -557,7 +551,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work = None
if ctx.backward_input_needs_gather:
quantizer = None
if (ctx.fp8 or ctx.debug) and not ctx.force_hp_input_gather:
if ctx.fp8 or ctx.debug:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -1214,6 +1208,8 @@ class Linear(TransformerEngineBaseModule):
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_parameters(self, defer_init=False):
......@@ -1479,3 +1475,22 @@ class Linear(TransformerEngineBaseModule):
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
return [weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + linear."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# set compact for inp tensor X
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
else:
if self.sequence_parallel and self.parallel_mode == "row":
# set compact for grad_output tensor dY
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].all_gather_usage = True
......@@ -11,6 +11,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from ..quantized_tensor import QuantizedTensorBase
......@@ -37,6 +38,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
_rowwise_scale_inv: Optional[torch.Tensor]
_columnwise_scale_inv: Optional[torch.Tensor]
_is_2D_scaled: bool
_data_format: Float8BlockScaleTensorFormat
def __new__(
cls,
......@@ -48,6 +50,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
fp8_dtype: TE_DType,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY,
**kwargs,
):
instance = super().__new__(cls, *args, **kwargs)
......@@ -58,6 +61,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
instance._is_2D_scaled = is_2D_scaled
instance._data_format = data_format
return instance
......@@ -82,8 +86,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
"fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer,
"is_2D_scaled": self._is_2D_scaled,
"data_format": self._data_format,
}
def _is_gemm_ready_format(self) -> bool:
"""Whether data is in GEMM_READY format"""
return self._data_format == Float8BlockScaleTensorFormat.GEMM_READY
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
......@@ -136,34 +145,69 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
inner_q_dimension_tiled = True
if self._is_gemm_ready_format():
scales_tiled_dim, scales_untiled_dim = scale_inv.shape
inner_scale_dimension_tiled = False
scales_are_compact = False
else:
scales_untiled_dim, scales_tiled_dim = scale_inv.shape
inner_scale_dimension_tiled = True
scales_are_compact = True
else:
assert self._columnwise_data is not None, "No data to dequantize"
q = self._columnwise_data
scale_inv = self._columnwise_scale_inv
transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
scales_tiled_dim, scales_untiled_dim = scale_inv.shape
inner_scale_dimension_tiled = False
if self._is_gemm_ready_format():
inner_q_dimension_tiled = True
transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
scales_are_compact = False
else:
inner_q_dimension_tiled = False
transpose_output = False
if len(q.shape) >= 1:
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
scales_are_compact = True
orig_shape = q.shape
q = q.reshape(q_M, q_K)
k_tiles, scale_m = scale_inv.shape
if q_K % block_len != 0:
k_pad_amount = (block_len - (q_K % block_len)) % block_len
q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, 0), mode="constant", value=0
).contiguous()
_, padded_K = q.shape
q_tiled = q.reshape(q_M, k_tiles, block_len)
if scale_m > q_M:
# scale_m is 4 element aligned.
if inner_q_dimension_tiled:
if q_K % block_len != 0:
k_pad_amount = (block_len - (q_K % block_len)) % block_len
q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, 0), mode="constant", value=0
).contiguous()
padded_M, padded_K = q.shape
q_tiled = q.reshape(q_M, scales_tiled_dim, block_len)
else:
if q_M % block_len != 0:
m_pad_amount = (block_len - (q_M % block_len)) % block_len
q = torch.nn.functional.pad(
q, (0, 0, 0, m_pad_amount), mode="constant", value=0
).contiguous()
padded_M, padded_K = q.shape
q_tiled = q.reshape(scales_tiled_dim, block_len, q_K)
if not scales_are_compact and scales_untiled_dim > q_M:
# untiled scale dimension is 4 element aligned.
scale_inv = scale_inv[:, :q_M].contiguous()
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, k_tiles, 1)
if scales_are_compact and inner_scale_dimension_tiled:
dq_scale = scale_inv.contiguous().reshape(q_M, scales_tiled_dim, 1)
elif scales_are_compact and not inner_scale_dimension_tiled:
dq_scale = scale_inv.contiguous().reshape(scales_tiled_dim, 1, q_K)
else:
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, scales_tiled_dim, 1)
torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype]
result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale
if padded_K != q_K:
result = result.reshape(q_M, padded_K)[:, :q_K]
if padded_M != q_M or padded_K != q_K:
result = result.reshape(padded_M, padded_K)[:q_M, :q_K]
result = result.to(dtype)
if len(orig_shape) == 0:
result = result.reshape([])
......@@ -182,6 +226,12 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype)
if not self._is_gemm_ready_format():
raise NotImplementedError(
"Dequantize is only supported with GEMM_READY data format, "
f"but found _data_format={self._data_format}"
)
def format_scale_as_logical_shape(q_K, scales, block_len):
# The GEMM for 2D blocks required padding in the scales.
derived_scale_k_shape = math.ceil(q_K / block_len)
......@@ -247,6 +297,8 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
if self._rowwise_data is not None:
return self._rowwise_data.size(*args, **kwargs)
dims = list(self._columnwise_data.size(*args, **kwargs))
if not self._is_gemm_ready_format(): # compact format
return torch.Size(dims)
reordered = []
for i in range(1, len(dims)):
reordered.append(dims[i])
......@@ -285,6 +337,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
w = min(self._columnwise_scale_inv.shape[1], columnwise_scale_inv.shape[1])
self._columnwise_scale_inv[0:h, 0:w].copy_(columnwise_scale_inv[0:h, 0:w])
def _transpose_columnwise_data(self):
"""Plainly transpose the columnwise data and scale inv."""
if self._columnwise_data is not None:
self._columnwise_data = tex.fp8_transpose(
self._columnwise_data, self._fp8_dtype, out=None
)
def __repr__(self):
if self._rowwise_data is not None:
data = self.dequantize()
......
......@@ -33,6 +33,8 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float
force_pow_2_scales: bool
block_scaling_dim: int
# Whether to produce tensors that will be used in all-gather
all_gather_usage: bool
def __init__(
self,
......@@ -43,6 +45,7 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float = 0.0,
force_pow_2_scales: bool = True,
block_scaling_dim: int = 2,
all_gather_usage: bool = False,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = fp8_dtype
......@@ -50,6 +53,7 @@ class Float8BlockQuantizer(Quantizer):
self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon
self.block_scaling_dim = block_scaling_dim
self.all_gather_usage = all_gather_usage
def update_quantized(
self,
......@@ -126,22 +130,36 @@ class Float8BlockQuantizer(Quantizer):
M *= shape[i]
if len(shape) > 0:
K = shape[-1]
# 2D 128x128 quantization block scaling
# CuBLAS requries 128x128 scaling factor to be padded
# currently rowwise and columnwise format option doesn't apply to 2D scaling
if self.block_scaling_dim == 2:
if columnwise:
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4)
return (outer, inner)
# rowwise
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4)
return (outer, inner)
# 1D 1x128 quantization block scaling
# CuBLAS requries 1x128 scaling factor to be padded and transposed
assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported"
if columnwise:
columnwise_compact = self.all_gather_usage
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(K, 4)
inner = round_up_to_nearest_multiple(K, 4) if not columnwise_compact else K
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS
# for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner]
# so no need to swap inner outer here
return (outer, inner)
# rowwise
rowwise_compact = self.all_gather_usage
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(M, 4)
return (outer, inner)
inner = round_up_to_nearest_multiple(M, 4) if not rowwise_compact else M
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need
# for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here
return (outer, inner) if not rowwise_compact else (inner, outer)
def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of a tensor after columnwise permutation.
......@@ -163,15 +181,25 @@ class Float8BlockQuantizer(Quantizer):
"""
if len(shape) == 0:
return tuple()
# currently columnwise format option only applies to 1D quantizer
# for 2D scaling, columnwise format should always be GEMM_READY_DATA_AND_SCALES
# since currently 2D scaling only applies to module weights
if self.block_scaling_dim == 1 and self.all_gather_usage:
return shape
colwise_shape = [shape[-1]]
for i in range(len(shape) - 1):
colwise_shape.append(shape[i])
return tuple(colwise_shape)
# TODO(kwyss): With FP8 gather support, we need to implement a
# shape/layout/swizzle check to know whether FP8 gather works
# cleanly by stacking data without aliasing tiles and whether
# the scales also stack on the proper dimensions.
def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized"""
if inp.ndim < 2:
return False
if inp.shape[-1] % self.block_len != 0:
return False
if math.prod(inp.shape[:-1]) % self.block_len != 0:
return False
return True
def make_empty(
self,
......@@ -185,6 +213,12 @@ class Float8BlockQuantizer(Quantizer):
if device is None:
device = torch.device("cuda")
data_format = (
tex.Float8BlockScaleTensorFormat.COMPACT
if self.all_gather_usage
else tex.Float8BlockScaleTensorFormat.GEMM_READY
)
# Allocate FP8 data
data = None
scale_inv = None
......@@ -222,6 +256,7 @@ class Float8BlockQuantizer(Quantizer):
columnwise_scale_inv=columnwise_scale_inv,
quantizer=self,
is_2D_scaled=self.block_scaling_dim == 2,
data_format=data_format,
requires_grad=requires_grad,
)
......@@ -263,7 +298,8 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize(dtype=self.dtype)})"
f" data={self.dequantize(dtype=self.dtype)}),"
f" data_format={self._data_format}"
)
def _get_quantizer(self) -> Quantizer:
......@@ -396,6 +432,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dtype: torch.dtype,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat,
) -> Float8BlockwiseQTensor:
"""Build Float8BlockwiseQTensor, for use in __reduce__
......@@ -413,6 +450,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dtype=dtype,
quantizer=quantizer,
is_2D_scaled=is_2D_scaled,
data_format=data_format,
)
def __reduce_ex__(self, protocol: int) -> tuple:
......@@ -429,6 +467,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
self.dtype,
self._quantizer,
self._is_2D_scaled,
self._data_format,
),
)
......@@ -454,6 +493,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv
dst._data_format = src._data_format
# Check that tensor dimensions match
if (
......@@ -501,6 +541,13 @@ class _ViewFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
......@@ -569,6 +616,14 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_data = (
grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None
)
......@@ -608,6 +663,13 @@ class _ReshapeFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
......@@ -675,6 +737,14 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_rowwise_data = None
new_columnwise_data = None
if grad._rowwise_data is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment