"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "ff36139ffc66294c19b503c1e52dc42c2cd265f6"
Unverified Commit 9985b02c authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch] FP8 Subchannel Recipe With FP8 Gather And Configurable Scaling...


[PyTorch] FP8 Subchannel Recipe With FP8 Gather And Configurable Scaling Factor Tensor Swizzling (#1707)

* functional kernel for columnwise + no-transpose option, still hacky
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* pass all quantizer unit tests
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor, add gemm ready api
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make format options private members, simplify api
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* swizzle scales right before gemm
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* bug fix of single layer test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* attempt to fix lint issue
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fp8 gather pass, need minor refine
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix return_layernorm_output_gathered case
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* remove special cases, add sanity check before gemm
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint ungrouped imports
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Implement dequantize for compact 1D blocks.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* add more unit test with dequantize compact supported
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint again
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* make ag for subchannel respect async
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* zero tolerance in distributed test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix zero tolerance test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve rebase issues
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint & format
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* bug fix
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* relax rtol for fp32 distributed test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix some ci issue
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix ci test failure in debug mode
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Force row-wise and column-wise data to have same data format

Prototype "all-gather usage" in quantizer.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove dead logic for high-precision AGs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug FP8 block-wise tensor tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug distributed test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Handle case where LayerNormLinear returns gathered norm output
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix debug mode
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKeith Wyss <kwyss@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 6123d7e0
...@@ -66,7 +66,7 @@ from ..tensor.quantized_tensor import ( ...@@ -66,7 +66,7 @@ from ..tensor.quantized_tensor import (
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled from ...debug.pytorch.utils import any_feature_enabled
...@@ -136,12 +136,6 @@ class _Linear(torch.autograd.Function): ...@@ -136,12 +136,6 @@ class _Linear(torch.autograd.Function):
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
) )
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_input_gather = (
fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer)
)
# Configure Userbuffers communication (comm+GEMM overlap) # Configure Userbuffers communication (comm+GEMM overlap)
ub_obj = None ub_obj = None
ub_type = None ub_type = None
...@@ -168,7 +162,7 @@ class _Linear(torch.autograd.Function): ...@@ -168,7 +162,7 @@ class _Linear(torch.autograd.Function):
if fp8 or debug: if fp8 or debug:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
if not force_hp_input_gather and not isinstance(inputmat, QuantizedTensorBase): if not isinstance(inputmat, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance( if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
...@@ -347,10 +341,11 @@ class _Linear(torch.autograd.Function): ...@@ -347,10 +341,11 @@ class _Linear(torch.autograd.Function):
# For sequence parallel in vanilla FP8, rowwise data is # For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data # to gather the input. For MXFP8, columnwise only data
# can be allgathered. # can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: if (
isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
or not ctx.backward_input_needs_gather
):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if force_hp_input_gather:
assert not isinstance(inputmat, QuantizedTensorBase)
saved_inputmat = inputmat saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
...@@ -396,7 +391,6 @@ class _Linear(torch.autograd.Function): ...@@ -396,7 +391,6 @@ class _Linear(torch.autograd.Function):
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_input_gather = force_hp_input_gather
ctx.input_quantizer = input_quantizer ctx.input_quantizer = input_quantizer
ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer
...@@ -557,7 +551,7 @@ class _Linear(torch.autograd.Function): ...@@ -557,7 +551,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work = None inputmat_total_work = None
if ctx.backward_input_needs_gather: if ctx.backward_input_needs_gather:
quantizer = None quantizer = None
if (ctx.fp8 or ctx.debug) and not ctx.force_hp_input_gather: if ctx.fp8 or ctx.debug:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -1214,6 +1208,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -1214,6 +1208,8 @@ class Linear(TransformerEngineBaseModule):
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling(): if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
# elif for other recipes (mxfp8, etc.) # elif for other recipes (mxfp8, etc.)
def reset_parameters(self, defer_init=False): def reset_parameters(self, defer_init=False):
...@@ -1479,3 +1475,22 @@ class Linear(TransformerEngineBaseModule): ...@@ -1479,3 +1475,22 @@ class Linear(TransformerEngineBaseModule):
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True weight_quantizer.internal = True
return [weight_quantizer] return [weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + linear."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# set compact for inp tensor X
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
else:
if self.sequence_parallel and self.parallel_mode == "row":
# set compact for grad_output tensor dY
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].all_gather_usage = True
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from ..quantized_tensor import QuantizedTensorBase from ..quantized_tensor import QuantizedTensorBase
...@@ -37,6 +38,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -37,6 +38,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
_rowwise_scale_inv: Optional[torch.Tensor] _rowwise_scale_inv: Optional[torch.Tensor]
_columnwise_scale_inv: Optional[torch.Tensor] _columnwise_scale_inv: Optional[torch.Tensor]
_is_2D_scaled: bool _is_2D_scaled: bool
_data_format: Float8BlockScaleTensorFormat
def __new__( def __new__(
cls, cls,
...@@ -48,6 +50,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -48,6 +50,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
quantizer: Quantizer, quantizer: Quantizer,
is_2D_scaled: bool, is_2D_scaled: bool,
data_format: Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY,
**kwargs, **kwargs,
): ):
instance = super().__new__(cls, *args, **kwargs) instance = super().__new__(cls, *args, **kwargs)
...@@ -58,6 +61,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -58,6 +61,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
instance._rowwise_scale_inv = rowwise_scale_inv instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv
instance._is_2D_scaled = is_2D_scaled instance._is_2D_scaled = is_2D_scaled
instance._data_format = data_format
return instance return instance
...@@ -82,8 +86,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -82,8 +86,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
"fp8_dtype": self._fp8_dtype, "fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer, "quantizer": self._quantizer,
"is_2D_scaled": self._is_2D_scaled, "is_2D_scaled": self._is_2D_scaled,
"data_format": self._data_format,
} }
def _is_gemm_ready_format(self) -> bool:
"""Whether data is in GEMM_READY format"""
return self._data_format == Float8BlockScaleTensorFormat.GEMM_READY
def prepare_for_saving( def prepare_for_saving(
self, self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
...@@ -136,34 +145,69 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -136,34 +145,69 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
q_K = q.shape[-1] q_K = q.shape[-1]
for i in range(len(q.shape) - 1): for i in range(len(q.shape) - 1):
q_M *= q.shape[i] q_M *= q.shape[i]
inner_q_dimension_tiled = True
if self._is_gemm_ready_format():
scales_tiled_dim, scales_untiled_dim = scale_inv.shape
inner_scale_dimension_tiled = False
scales_are_compact = False
else:
scales_untiled_dim, scales_tiled_dim = scale_inv.shape
inner_scale_dimension_tiled = True
scales_are_compact = True
else: else:
assert self._columnwise_data is not None, "No data to dequantize" assert self._columnwise_data is not None, "No data to dequantize"
q = self._columnwise_data q = self._columnwise_data
scale_inv = self._columnwise_scale_inv scale_inv = self._columnwise_scale_inv
transpose_output = True scales_tiled_dim, scales_untiled_dim = scale_inv.shape
if len(q.shape) >= 1: inner_scale_dimension_tiled = False
q_M = q.shape[0] if self._is_gemm_ready_format():
for i in range(1, len(q.shape)): inner_q_dimension_tiled = True
q_K *= q.shape[i] transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
scales_are_compact = False
else:
inner_q_dimension_tiled = False
transpose_output = False
if len(q.shape) >= 1:
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
scales_are_compact = True
orig_shape = q.shape orig_shape = q.shape
q = q.reshape(q_M, q_K) q = q.reshape(q_M, q_K)
k_tiles, scale_m = scale_inv.shape if inner_q_dimension_tiled:
if q_K % block_len != 0: if q_K % block_len != 0:
k_pad_amount = (block_len - (q_K % block_len)) % block_len k_pad_amount = (block_len - (q_K % block_len)) % block_len
q = torch.nn.functional.pad( q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, 0), mode="constant", value=0 q, (0, k_pad_amount, 0, 0), mode="constant", value=0
).contiguous() ).contiguous()
_, padded_K = q.shape padded_M, padded_K = q.shape
q_tiled = q.reshape(q_M, k_tiles, block_len) q_tiled = q.reshape(q_M, scales_tiled_dim, block_len)
if scale_m > q_M: else:
# scale_m is 4 element aligned. if q_M % block_len != 0:
m_pad_amount = (block_len - (q_M % block_len)) % block_len
q = torch.nn.functional.pad(
q, (0, 0, 0, m_pad_amount), mode="constant", value=0
).contiguous()
padded_M, padded_K = q.shape
q_tiled = q.reshape(scales_tiled_dim, block_len, q_K)
if not scales_are_compact and scales_untiled_dim > q_M:
# untiled scale dimension is 4 element aligned.
scale_inv = scale_inv[:, :q_M].contiguous() scale_inv = scale_inv[:, :q_M].contiguous()
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, k_tiles, 1) if scales_are_compact and inner_scale_dimension_tiled:
dq_scale = scale_inv.contiguous().reshape(q_M, scales_tiled_dim, 1)
elif scales_are_compact and not inner_scale_dimension_tiled:
dq_scale = scale_inv.contiguous().reshape(scales_tiled_dim, 1, q_K)
else:
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, scales_tiled_dim, 1)
torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype]
result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale
if padded_K != q_K: if padded_M != q_M or padded_K != q_K:
result = result.reshape(q_M, padded_K)[:, :q_K] result = result.reshape(padded_M, padded_K)[:q_M, :q_K]
result = result.to(dtype) result = result.to(dtype)
if len(orig_shape) == 0: if len(orig_shape) == 0:
result = result.reshape([]) result = result.reshape([])
...@@ -182,6 +226,12 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -182,6 +226,12 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
if not self._is_2D_scaled: if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype) return self._dequantize_vectorwise(dtype=dtype)
if not self._is_gemm_ready_format():
raise NotImplementedError(
"Dequantize is only supported with GEMM_READY data format, "
f"but found _data_format={self._data_format}"
)
def format_scale_as_logical_shape(q_K, scales, block_len): def format_scale_as_logical_shape(q_K, scales, block_len):
# The GEMM for 2D blocks required padding in the scales. # The GEMM for 2D blocks required padding in the scales.
derived_scale_k_shape = math.ceil(q_K / block_len) derived_scale_k_shape = math.ceil(q_K / block_len)
...@@ -247,6 +297,8 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -247,6 +297,8 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
if self._rowwise_data is not None: if self._rowwise_data is not None:
return self._rowwise_data.size(*args, **kwargs) return self._rowwise_data.size(*args, **kwargs)
dims = list(self._columnwise_data.size(*args, **kwargs)) dims = list(self._columnwise_data.size(*args, **kwargs))
if not self._is_gemm_ready_format(): # compact format
return torch.Size(dims)
reordered = [] reordered = []
for i in range(1, len(dims)): for i in range(1, len(dims)):
reordered.append(dims[i]) reordered.append(dims[i])
...@@ -285,6 +337,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -285,6 +337,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
w = min(self._columnwise_scale_inv.shape[1], columnwise_scale_inv.shape[1]) w = min(self._columnwise_scale_inv.shape[1], columnwise_scale_inv.shape[1])
self._columnwise_scale_inv[0:h, 0:w].copy_(columnwise_scale_inv[0:h, 0:w]) self._columnwise_scale_inv[0:h, 0:w].copy_(columnwise_scale_inv[0:h, 0:w])
def _transpose_columnwise_data(self):
"""Plainly transpose the columnwise data and scale inv."""
if self._columnwise_data is not None:
self._columnwise_data = tex.fp8_transpose(
self._columnwise_data, self._fp8_dtype, out=None
)
def __repr__(self): def __repr__(self):
if self._rowwise_data is not None: if self._rowwise_data is not None:
data = self.dequantize() data = self.dequantize()
......
...@@ -33,6 +33,8 @@ class Float8BlockQuantizer(Quantizer): ...@@ -33,6 +33,8 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float amax_epsilon: float
force_pow_2_scales: bool force_pow_2_scales: bool
block_scaling_dim: int block_scaling_dim: int
# Whether to produce tensors that will be used in all-gather
all_gather_usage: bool
def __init__( def __init__(
self, self,
...@@ -43,6 +45,7 @@ class Float8BlockQuantizer(Quantizer): ...@@ -43,6 +45,7 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float = 0.0, amax_epsilon: float = 0.0,
force_pow_2_scales: bool = True, force_pow_2_scales: bool = True,
block_scaling_dim: int = 2, block_scaling_dim: int = 2,
all_gather_usage: bool = False,
) -> None: ) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise) super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = fp8_dtype self.dtype = fp8_dtype
...@@ -50,6 +53,7 @@ class Float8BlockQuantizer(Quantizer): ...@@ -50,6 +53,7 @@ class Float8BlockQuantizer(Quantizer):
self.force_pow_2_scales = force_pow_2_scales self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon self.amax_epsilon = amax_epsilon
self.block_scaling_dim = block_scaling_dim self.block_scaling_dim = block_scaling_dim
self.all_gather_usage = all_gather_usage
def update_quantized( def update_quantized(
self, self,
...@@ -126,22 +130,36 @@ class Float8BlockQuantizer(Quantizer): ...@@ -126,22 +130,36 @@ class Float8BlockQuantizer(Quantizer):
M *= shape[i] M *= shape[i]
if len(shape) > 0: if len(shape) > 0:
K = shape[-1] K = shape[-1]
# 2D 128x128 quantization block scaling
# CuBLAS requries 128x128 scaling factor to be padded
# currently rowwise and columnwise format option doesn't apply to 2D scaling
if self.block_scaling_dim == 2: if self.block_scaling_dim == 2:
if columnwise: if columnwise:
outer = math.ceil(K / self.block_len) outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4)
return (outer, inner) return (outer, inner)
# rowwise
outer = math.ceil(M / self.block_len) outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4) inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4)
return (outer, inner) return (outer, inner)
# 1D 1x128 quantization block scaling
# CuBLAS requries 1x128 scaling factor to be padded and transposed
assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported" assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported"
if columnwise: if columnwise:
columnwise_compact = self.all_gather_usage
outer = math.ceil(M / self.block_len) outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(K, 4) inner = round_up_to_nearest_multiple(K, 4) if not columnwise_compact else K
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS
# for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner]
# so no need to swap inner outer here
return (outer, inner) return (outer, inner)
# rowwise
rowwise_compact = self.all_gather_usage
outer = math.ceil(K / self.block_len) outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(M, 4) inner = round_up_to_nearest_multiple(M, 4) if not rowwise_compact else M
return (outer, inner) # GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need
# for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here
return (outer, inner) if not rowwise_compact else (inner, outer)
def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of a tensor after columnwise permutation. """Calculate the shape of a tensor after columnwise permutation.
...@@ -163,15 +181,25 @@ class Float8BlockQuantizer(Quantizer): ...@@ -163,15 +181,25 @@ class Float8BlockQuantizer(Quantizer):
""" """
if len(shape) == 0: if len(shape) == 0:
return tuple() return tuple()
# currently columnwise format option only applies to 1D quantizer
# for 2D scaling, columnwise format should always be GEMM_READY_DATA_AND_SCALES
# since currently 2D scaling only applies to module weights
if self.block_scaling_dim == 1 and self.all_gather_usage:
return shape
colwise_shape = [shape[-1]] colwise_shape = [shape[-1]]
for i in range(len(shape) - 1): for i in range(len(shape) - 1):
colwise_shape.append(shape[i]) colwise_shape.append(shape[i])
return tuple(colwise_shape) return tuple(colwise_shape)
# TODO(kwyss): With FP8 gather support, we need to implement a def is_quantizable(self, inp: torch.Tensor) -> bool:
# shape/layout/swizzle check to know whether FP8 gather works """Returns whether or not given inp can be quantized"""
# cleanly by stacking data without aliasing tiles and whether if inp.ndim < 2:
# the scales also stack on the proper dimensions. return False
if inp.shape[-1] % self.block_len != 0:
return False
if math.prod(inp.shape[:-1]) % self.block_len != 0:
return False
return True
def make_empty( def make_empty(
self, self,
...@@ -185,6 +213,12 @@ class Float8BlockQuantizer(Quantizer): ...@@ -185,6 +213,12 @@ class Float8BlockQuantizer(Quantizer):
if device is None: if device is None:
device = torch.device("cuda") device = torch.device("cuda")
data_format = (
tex.Float8BlockScaleTensorFormat.COMPACT
if self.all_gather_usage
else tex.Float8BlockScaleTensorFormat.GEMM_READY
)
# Allocate FP8 data # Allocate FP8 data
data = None data = None
scale_inv = None scale_inv = None
...@@ -222,6 +256,7 @@ class Float8BlockQuantizer(Quantizer): ...@@ -222,6 +256,7 @@ class Float8BlockQuantizer(Quantizer):
columnwise_scale_inv=columnwise_scale_inv, columnwise_scale_inv=columnwise_scale_inv,
quantizer=self, quantizer=self,
is_2D_scaled=self.block_scaling_dim == 2, is_2D_scaled=self.block_scaling_dim == 2,
data_format=data_format,
requires_grad=requires_grad, requires_grad=requires_grad,
) )
...@@ -263,7 +298,8 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -263,7 +298,8 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
return ( return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled}," f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize(dtype=self.dtype)})" f" data={self.dequantize(dtype=self.dtype)}),"
f" data_format={self._data_format}"
) )
def _get_quantizer(self) -> Quantizer: def _get_quantizer(self) -> Quantizer:
...@@ -396,6 +432,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -396,6 +432,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dtype: torch.dtype, dtype: torch.dtype,
quantizer: Quantizer, quantizer: Quantizer,
is_2D_scaled: bool, is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat,
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
"""Build Float8BlockwiseQTensor, for use in __reduce__ """Build Float8BlockwiseQTensor, for use in __reduce__
...@@ -413,6 +450,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -413,6 +450,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dtype=dtype, dtype=dtype,
quantizer=quantizer, quantizer=quantizer,
is_2D_scaled=is_2D_scaled, is_2D_scaled=is_2D_scaled,
data_format=data_format,
) )
def __reduce_ex__(self, protocol: int) -> tuple: def __reduce_ex__(self, protocol: int) -> tuple:
...@@ -429,6 +467,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -429,6 +467,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
self.dtype, self.dtype,
self._quantizer, self._quantizer,
self._is_2D_scaled, self._is_2D_scaled,
self._data_format,
), ),
) )
...@@ -454,6 +493,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -454,6 +493,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dst._fp8_dtype = src._fp8_dtype dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv
dst._data_format = src._data_format
# Check that tensor dimensions match # Check that tensor dimensions match
if ( if (
...@@ -501,6 +541,13 @@ class _ViewFunc(torch.autograd.Function): ...@@ -501,6 +541,13 @@ class _ViewFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided # Return input tensor if shape is not provided
ctx.shape = tensor.shape ctx.shape = tensor.shape
if shape is None: if shape is None:
...@@ -569,6 +616,14 @@ class _ViewFunc(torch.autograd.Function): ...@@ -569,6 +616,14 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor): if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_data = ( new_data = (
grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None
) )
...@@ -608,6 +663,13 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -608,6 +663,13 @@ class _ReshapeFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided # Return input tensor if shape is not provided
ctx.shape = tensor.shape ctx.shape = tensor.shape
if shape is None: if shape is None:
...@@ -675,6 +737,14 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -675,6 +737,14 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor): if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_rowwise_data = None new_rowwise_data = None
new_columnwise_data = None new_columnwise_data = None
if grad._rowwise_data is not None: if grad._rowwise_data is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment