Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
......@@ -123,6 +123,35 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
return dgelu
@jit_fuser
def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor:
"""L2 normalization fused - inference version"""
x_squared = x.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
return x * rsqrt_norm
@jit_fuser
def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
"""L2 normalization fused - training version that returns intermediate values"""
x_squared = x.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
y = x * rsqrt_norm
return y, rsqrt_norm
@jit_fuser
def l2normalization_backward_fused_(
grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float
) -> torch.Tensor:
"""L2 normalization backward fused"""
x_dy_sum = (x * grad_output).sum(dim=-1, keepdim=True)
x_norm_squared = x.pow(2).sum(dim=-1, keepdim=True) + eps
return rsqrt_norm * (grad_output - x * x_dy_sum / x_norm_squared)
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_"""
with gpu_autocast_ctx(enabled=False):
......@@ -141,6 +170,26 @@ def bgrad_dgelu_fused(
return None, dgelu_fused_(grad_output, inp)
def l2normalization_fused(x: torch.Tensor, eps: float) -> torch.Tensor:
"""Disable native AMP for l2normalization_fused_ - inference version"""
with gpu_autocast_ctx(enabled=False):
return l2normalization_fused_(x, eps)
def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
"""Disable native AMP for l2normalization_fwd_fused_ - training version"""
with gpu_autocast_ctx(enabled=False):
return l2normalization_fwd_fused_(x, eps)
def l2normalization_backward_fused(
grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float
) -> torch.Tensor:
"""Disable native AMP for l2normalization_backward_fused_"""
with gpu_autocast_ctx(enabled=False):
return l2normalization_backward_fused_(grad_output, x, rsqrt_norm, eps)
def bias_dropout_add(
x: torch.Tensor,
bias: torch.Tensor,
......@@ -266,3 +315,45 @@ def warmup_jit_bias_gelu_all_dtypes(
"""Call `warmup_jit_bias_gelu` for all training dtypes"""
for dtype in [torch.float32, torch.bfloat16, torch.float16]:
warmup_jit_bias_gelu(ffn_hidden_size, dtype, seq_length, micro_batch_size)
def warmup_jit_l2normalization(
hidden_size: int, dtype: torch.dtype, seq_length: int, micro_batch_size: int
) -> None:
"""Compile L2Normalization JIT function before the main training steps"""
# Save cuda RNG state to ensure warmup does not affect reproducibility.
rng_state = torch.cuda.get_rng_state()
inp = torch.rand(
(seq_length * micro_batch_size, hidden_size),
dtype=dtype,
device="cuda",
)
eps = 1e-6
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for input_grad in [False, True]:
inp.requires_grad = input_grad
for _ in range(5):
if input_grad:
# Test training version that returns intermediate values
output, rsqrt_norm = l2normalization_fwd_fused_(inp, eps)
# Test backward pass as well
grad_out = torch.rand_like(output)
_ = l2normalization_backward_fused_(grad_out, inp, rsqrt_norm, eps)
else:
# Test inference version
output = l2normalization_fused_(inp, eps)
del inp, output
torch.cuda.empty_cache()
torch.cuda.set_rng_state(rng_state)
def warmup_jit_l2normalization_all_dtypes(
hidden_size: int, seq_length: int, micro_batch_size: int
) -> None:
"""Call `warmup_jit_l2normalization` for all training dtypes"""
for dtype in [torch.float32, torch.bfloat16, torch.float16]:
warmup_jit_l2normalization(hidden_size, dtype, seq_length, micro_batch_size)
......@@ -44,7 +44,7 @@ from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import Recipe
from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION
......@@ -89,7 +89,7 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_workspace
if not _multi_stream_cublas_workspace:
for _ in range(tex._num_cublas_streams):
for _ in range(tex.get_num_cublas_streams()):
_multi_stream_cublas_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
)
......@@ -685,6 +685,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Update quantizers with new amax pointers.
self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers()
# Make sure weight tensors has correct quantizers
self._update_weight_quantizers()
# Update the global buffers with new amax and history pointers.
if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
......@@ -738,6 +740,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta[fp8_meta_tensor_key] = recipe_state
self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()
def _update_weight_quantizers(self) -> None:
"""Update the quantizers for the weight tensors."""
weight_tensors = self._get_weight_tensors()
weight_quantizers = self._get_weight_quantizers()
assert len(weight_tensors) == len(weight_quantizers), (
f"Number of weight tensors ({len(weight_tensors)}) and quantizers "
f"({len(weight_quantizers)}) must match"
)
for weight, quantizer in zip(weight_tensors, weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensorBase):
weight.update_quantizer(quantizer)
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_weight_tensors function"
)
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_weight_quantizers function"
)
def init_fp8_meta_tensors(self, recipe: Recipe) -> None:
"""Init scales and amaxes."""
self.set_meta_tensor(True, recipe)
......@@ -777,7 +803,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset("scaling_fwd")
reset("scaling_bwd")
def get_extra_state(self) -> Optional[torch.Tensor]:
def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing."""
# This implementation is working around a few issues:
......@@ -812,7 +838,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if not fp8_checkpoint:
return None
return torch.empty(0, dtype=torch.uint8)
# Copy tensors to CPU and store
state = {}
......@@ -838,13 +864,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized
def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state."""
# Maintain backwards compatibility with older checkpoints.
if state is None:
return
# Load state
if isinstance(state, torch.Tensor):
# No FP8 is indicated by an empty tensor we don't need to unpickle.
if state.numel() == 0:
return
# Default format: byte tensor with pickled data
state = pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO):
......@@ -857,6 +888,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None:
return
# TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing
if "recipe" not in state:
# TE 1.x only supported delayed scaling, which was the default recipe
state["recipe"] = DelayedScaling()
# TE 1.x also saved scale_inv, which is not needed with Recipe object
state.pop("scale_inv_fwd", None)
state.pop("scale_inv_bwd", None)
# Load extra items
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"] = state["recipe"]
......@@ -930,6 +969,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
_original_recipe = self.fp8_meta.get("recipe", None)
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
......@@ -968,6 +1009,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
_current_recipe = self.fp8_meta["recipe"]
if _original_recipe is not None and not (
issubclass(_current_recipe.__class__, _original_recipe.__class__)
or issubclass(_original_recipe.__class__, _current_recipe.__class__)
):
warnings.warn(
f"Recipe type changed from {_original_recipe.__class__.__name__} "
f"to {_current_recipe.__class__.__name__}. "
"This may affect model behavior."
)
# Clear cached workspaces as they were created with the old recipe/quantizer type
self._fp8_workspaces.clear()
@contextmanager
def prepare_forward(
self,
......@@ -992,6 +1046,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_activation_dtype(inp)
self.init_fp8_metadata(num_gemms=num_gemms)
self._check_weight_tensor_recipe_correspondence()
if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
assert self.fp8_meta["recipe"].reduce_amax, (
......@@ -1103,7 +1158,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if (
isinstance(
grad_output_.get_tensor(True),
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase),
(
QuantizedTensor,
Float8TensorBase,
MXFP8TensorBase,
Float8BlockwiseQTensorBase,
),
)
and ctx.use_bias
):
......@@ -1169,18 +1229,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
with get_rng_state_tracker().fork():
init_fn(param)
# If primary weights are in fp8, wrap the parameter as FP8Tensor
# Wrap parameters in QuantizedTensor if needed
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
high_precision_init_val = None
if self.primary_weights_in_fp8 and fp8_meta_index is not None:
# Keep high-precision values on CPU if needed
if self.preserve_high_precision_init_val:
high_precision_init_val = param.detach().cpu()
# Configure quantizer
quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
assert (
quantizer is not None
) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe.
if quantizer is None:
raise RuntimeError("Weight quantizer has not been initialized")
quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
quantizer.internal = False
# Quantize parameter
param = quantizer(param)
# Redo parameter wrap in case we broke it above
......@@ -1188,6 +1253,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
param = torch.nn.Parameter(param)
# Keep high-precision values on CPU if needed
if high_precision_init_val is not None:
# - Master weights are initialized from model weights, if we use fp8 primary
......@@ -1231,7 +1298,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fsdp_group: Optional[dist_group_type] = None,
workspace_dtype: Optional[torch.dtype] = None,
) -> QuantizedTensor:
"""Get FP8 workspace buffer and maybe update its values
"""Get workspace buffer for weights and maybe update its values
The workspace buffer may be cached for future function calls.
......@@ -1257,13 +1324,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
for debug quantization, this is dtype of the tensor.
"""
# FP8 primary weights
# Handle case where weights are already quantized
# Note: Make sure weights have required usages, but do not
# destroy unnecessary usages since they may be used later.
if isinstance(tensor, QuantizedTensor):
if update_workspace and quantizer is not None:
tensor.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
update_rowwise_usage = True if quantizer.rowwise_usage else None
update_columnwise_usage = True if quantizer.columnwise_usage else None
tensor.update_usage(
rowwise_usage=update_rowwise_usage,
columnwise_usage=update_columnwise_usage,
)
return tensor
# Try getting workspace from cache
......@@ -1387,6 +1457,43 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
)
self.name = f"Layer_{TEDebugState.get_layer_count()}"
def _check_weight_tensor_recipe_correspondence(self) -> None:
"""
Verify that the weight tensor types match their corresponding recipe type.
This is invoked in the forward().
This establishes a 1:1 correspondence between recipe types and tensor types:
- DelayedScaling → Float8Tensor
- Float8CurrentScaling → Float8Tensor
- MXFP8BlockScaling → MXFP8Tensor
- Float8BlockScaling → Float8BlockTensor
Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()),
but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()).
"""
if not self.fp8 and not self.fp8_calibration:
return
if not hasattr(self, "weight_names") or not self.weight_names:
return
recipe = self.fp8_meta["recipe"]
weight_tensors = [getattr(self, name) for name in self.weight_names]
for i, tensor in enumerate(weight_tensors):
if isinstance(tensor, QuantizedTensorBase):
quantizer = tensor._get_quantizer()
if quantizer is None:
continue
compatible_recipe_class = quantizer._get_compatible_recipe()
if compatible_recipe_class is None:
continue
if not isinstance(recipe, compatible_recipe_class):
raise RuntimeError(
f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe"
f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}."
" Please check the recipes assigned during fp8_model_init() and"
" fp8_autocast() calls."
)
def _turn_off_unsupported_features_in_debug(self):
if (
getattr(self, "ub_bulk_wgrad", False)
......
......@@ -242,8 +242,8 @@ class _GroupedLinear(torch.autograd.Function):
biases = saved_tensors[3 * N : 4 * N]
main_grads = ctx.main_grads
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO
for i in ctx.num_gemms:
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i]
weights[i] = w
......@@ -673,26 +673,19 @@ class GroupedLinear(TransformerEngineBaseModule):
), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensorBase) else w
for w in weight_tensors
]
input_quantizers, weight_quantizers, output_quantizers = (
[None] * self.num_gemms,
weight_quantizers = self._get_weight_quantizers()
input_quantizers, output_quantizers = (
[None] * self.num_gemms,
[None] * self.num_gemms,
)
......@@ -707,14 +700,6 @@ class GroupedLinear(TransformerEngineBaseModule):
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms):
input_quantizers[i].internal = False
weight_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
if torch.is_grad_enabled():
grad_output_quantizers = [
self.quantizers["scaling_bwd"][
......@@ -813,3 +798,30 @@ class GroupedLinear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensorBase) else w for w in weight_tensors
]
return weight_tensors
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None] * self.num_gemms
weight_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
return weight_quantizers
......@@ -5,7 +5,7 @@
"""LayerNormLinear API"""
import os
import warnings
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union, List
from functools import reduce
from operator import mul as multiply_op
......@@ -63,10 +63,11 @@ from ..tensor.quantized_tensor import (
)
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import (
......@@ -190,19 +191,13 @@ class _LayerNormLinear(torch.autograd.Function):
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_blockwise_ln_out_gather = (
fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer)
)
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not force_hp_blockwise_ln_out_gather
)
# Apply normalization
......@@ -239,15 +234,16 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8 or debug:
if not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out)
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(input_quantizer, Float8BlockQuantizer):
input_quantizer.all_gather_usage = False
ln_out_total = input_quantizer(ln_out_total)
else:
quantizer = None
if fp8 or debug:
quantizer = input_quantizer
if not with_quantized_norm and not force_hp_blockwise_ln_out_gather:
if not with_quantized_norm:
ln_out = quantizer(ln_out)
quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
......@@ -282,7 +278,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Configure quantizer
if weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=True)
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
# Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch
......@@ -397,7 +393,6 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ln_out_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)
ctx.force_hp_blockwise_ln_out_gather = force_hp_blockwise_ln_out_gather
# Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input:
......@@ -405,7 +400,10 @@ class _LayerNormLinear(torch.autograd.Function):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
if (
isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
or not ctx.ln_out_needs_gather
):
ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM.
......@@ -502,8 +500,8 @@ class _LayerNormLinear(torch.autograd.Function):
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp_shape)
shape[0] *= tp_size
shape = list(inp.shape)
shape[0] *= tp_size if with_input_all_gather else 1
return out, ln_out_return.view(shape)
return out, ln_out_return.view(inp_shape)
return out
......@@ -637,7 +635,7 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_work = None
if ctx.ln_out_needs_gather:
quantizer = None
if ctx.input_quantizer is not None and not ctx.force_hp_blockwise_ln_out_gather:
if ctx.input_quantizer is not None:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -752,6 +750,31 @@ class _LayerNormLinear(torch.autograd.Function):
wgrad = None
if ctx.requires_wgrad:
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output, mxfp8_grad_output_work = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
async_op=True,
quantizer=ctx.grad_output_quantizer,
)
# Synchronize with the main stream
mxfp8_grad_output_work.wait()
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
......@@ -766,22 +789,6 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
quantizer=ctx.grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
......@@ -1389,6 +1396,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
# elif other recipes (mxfp8, etc)
def reset_layer_norm_parameters(self) -> None:
......@@ -1484,20 +1493,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
) as inp:
# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
unfused_weights = self._get_weight_tensors()
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
......@@ -1603,8 +1599,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
(weight_quantizer,) = self._get_weight_quantizers()
if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled():
......@@ -1679,3 +1674,39 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None]
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
return [weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_linear."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
......@@ -5,7 +5,7 @@
"""LayerNormMLP API"""
import os
import warnings
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union, List
from functools import reduce
from operator import mul as multiply_op
......@@ -244,26 +244,18 @@ class _LayerNormMLP(torch.autograd.Function):
# All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
)
# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not debug
)
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
# Kernels not available for norm fusion.
with_quantized_norm = False
# Apply normalization
ln_out, mu, rsigma = apply_normalization(
......@@ -293,15 +285,16 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8 or debug:
if not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
fc1_input_quantizer.all_gather_usage = False
ln_out_total = fc1_input_quantizer(ln_out_total)
else:
quantizer = None
if fp8 or debug:
quantizer = fc1_input_quantizer
if not with_quantized_norm and not force_hp_fc1_input_gather:
if not with_quantized_norm:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag:
......@@ -333,8 +326,8 @@ class _LayerNormMLP(torch.autograd.Function):
# which handles weight caching etc.
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc1_weight_final = module.get_weight_workspace(
tensor=fc1_weight,
quantizer=fc1_weight_quantizer,
......@@ -567,7 +560,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.tensor_objects = tensor_objects
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather
ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer
ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer
ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer
......@@ -628,7 +620,7 @@ class _LayerNormMLP(torch.autograd.Function):
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp_shape)
shape[0] *= tp_size
shape[0] *= tp_size if (sequence_parallel and set_parallel_mode) else 1
return fc2_out, ln_out_return.view(shape)
return fc2_out, ln_out_return.view(inp_shape)
return fc2_out
......@@ -743,7 +735,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc1_dgrad = None
if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel:
quantizer = None
if ctx.fp8 or ctx.debug and not ctx.force_hp_fc1_input_gather:
if ctx.fp8 or ctx.debug:
quantizer = ctx.fc1_input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -841,6 +833,30 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_wgrad = None
if ctx.fc2_weight_requires_grad:
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_fc2_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output, mxfp8_fc2_grad_output_work = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
async_op=True,
quantizer=ctx.fc2_grad_output_quantizer,
)
# Synchronize with the main stream
mxfp8_fc2_grad_output_work.wait()
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
......@@ -852,22 +868,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True)
act_out = ctx.fc2_input_quantizer(act_out)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
quantizer=ctx.fc2_grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
......@@ -1661,8 +1661,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs
if FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling():
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_layer_norm_parameters(self) -> None:
......@@ -1772,15 +1775,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
) = quantizers
# Get weight tensors
fc1_weight = self.fc1_weight
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_weight = self.fc2_weight
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.from_float8()
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.from_float8()
fc2_weight = fc2_weight.dequantize()
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if ( not IS_HIP_EXTENSION
......@@ -1866,31 +1868,26 @@ class LayerNormMLP(TransformerEngineBaseModule):
def _get_quantizers(self, fp8_output):
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = [None] * 12
) = [None] * 10
fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers()
if self.fp8:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = True
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
fc2_input_quantizer.set_usage(
rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)),
)
fc1_input_quantizer.internal = True
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
if fp8_output:
fc2_output_quantizer = self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_OUTPUT
......@@ -2007,6 +2004,36 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
return [self.fc1_weight, self.fc2_weight]
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None, None]
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
return [fc1_weight_quantizer, fc2_weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_mlp."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
else:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
].all_gather_usage = True
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
......
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""Linear API"""
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union, List
from functools import reduce
from operator import mul as multiply_op
import warnings
......@@ -67,7 +67,7 @@ from ..tensor.quantized_tensor import (
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
......@@ -137,12 +137,6 @@ class _Linear(torch.autograd.Function):
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_input_gather = (
fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer)
)
# Configure Userbuffers communication (comm+GEMM overlap)
ub_obj = None
ub_type = None
......@@ -169,7 +163,7 @@ class _Linear(torch.autograd.Function):
if fp8 or debug:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if not force_hp_input_gather and not isinstance(inputmat, QuantizedTensorBase):
if not isinstance(inputmat, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
......@@ -348,10 +342,11 @@ class _Linear(torch.autograd.Function):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
if (
isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
or not ctx.backward_input_needs_gather
):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if force_hp_input_gather:
assert not isinstance(inputmat, QuantizedTensorBase)
saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
......@@ -397,7 +392,6 @@ class _Linear(torch.autograd.Function):
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_input_gather = force_hp_input_gather
ctx.input_quantizer = input_quantizer
ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer
......@@ -558,7 +552,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work = None
if ctx.backward_input_needs_gather:
quantizer = None
if (ctx.fp8 or ctx.debug) and not ctx.force_hp_input_gather:
if ctx.fp8 or ctx.debug:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -696,14 +690,23 @@ class _Linear(torch.autograd.Function):
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_output_arg,
ctx.tp_group,
quantizer=ctx.grad_output_quantizer,
)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output, grad_output_work = gather_along_first_dim(
grad_output_arg,
ctx.tp_group,
async_op=True,
quantizer=ctx.grad_output_quantizer,
)
# Synchronize with the main stream
grad_output_work.wait()
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
......@@ -1218,6 +1221,8 @@ class Linear(TransformerEngineBaseModule):
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_parameters(self, defer_init=False):
......@@ -1294,20 +1299,7 @@ class Linear(TransformerEngineBaseModule):
) as inp:
# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
unfused_weights = self._get_weight_tensors()
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
......@@ -1337,12 +1329,6 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer,
) = quantizers
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# recipe changed
if weight_quantizer is not None and isinstance(weight_tensor, QuantizedTensor):
weight_tensor._quantizer = weight_quantizer
if torch.is_grad_enabled():
linear_fn = _Linear.apply
args = []
......@@ -1403,8 +1389,7 @@ class Linear(TransformerEngineBaseModule):
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
(weight_quantizer,) = self._get_weight_quantizers()
if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled():
......@@ -1478,3 +1463,47 @@ class Linear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None]
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
return [weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + linear."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# set compact for inp tensor X
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
else:
if self.sequence_parallel and self.parallel_mode == "row":
# set compact for grad_output tensor dY
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].all_gather_usage = True
......@@ -11,6 +11,7 @@ from .all_reduce import AllReduce
from .basic_linear import BasicLinear
from .bias import Bias
from .identity import Identity
from .l2normalization import L2Normalization
from .layer_norm import LayerNorm
from .make_extra_output import MakeExtraOutput
from .quantize import Quantize
......
......@@ -96,12 +96,15 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
if not x.is_contiguous():
x = x.contiguous()
# Check if FP8 is enabled
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
if fp8_enabled and next_op is not None and next_op.num_quantizers("forward") > 0:
# Check if quantized compute is enabled
quantized_compute_enabled = FP8GlobalStateManager.is_fp8_enabled()
quantizer = None
if (
quantized_compute_enabled
and next_op is not None
and next_op.num_quantizers("forward") > 0
):
quantizer = next_op.get_quantizer("forward", 0)
else:
quantizer = None
# Launch kernel
y = self._activation_forward_impl(
......@@ -115,13 +118,13 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Quantize input to FP8 before caching if needed
if self.cache_quantized_input:
quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device)
quantizer.set_usage(rowwise=True, columnwise=False)
x = quantizer(x)
input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device)
input_quantizer.set_usage(rowwise=True, columnwise=False)
x = input_quantizer(x)
# Save state for backward pass
ctx.save_for_backward(x.detach())
ctx.fp8_enabled = fp8_enabled
ctx.quantized_compute_enabled = quantized_compute_enabled
ctx.dtype = dtype
ctx.prev_op = prev_op
......@@ -153,11 +156,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
if not dy.is_contiguous():
dy = dy.contiguous()
# Check if quantized compute is enabled
quantizer = None
if (
ctx.quantized_compute_enabled
and ctx.prev_op is not None
and ctx.prev_op.num_quantizers("backward") > 0
):
quantizer = ctx.prev_op.get_quantizer("backward", 0)
# Launch kernel
dx = self._activation_backward_impl(
reshape(dy, (-1, dy.size(-1))),
reshape(x, (-1, x.size(-1))),
None,
quantizer,
)
# Check grad input tensor
......
......@@ -22,7 +22,7 @@ from ...distributed import (
from ...fp8 import FP8GlobalStateManager
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer, QuantizedTensor
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
......@@ -324,12 +324,38 @@ class BasicLinear(BasicOperation):
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
# Recipe-specific configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
if any(
not isinstance(q, Float8CurrentScalingQuantizer)
for q in (input_quantizer, weight_quantizer, grad_output_quantizer)
):
raise RuntimeError(
"FP8 current-scaling recipe is enabled, "
f"but input quantizer is {input_quantizer.__class__.__name__}, "
f"weight quantizer is {weight_quantizer.__class__.__name__}, "
f"grad output quantizer is {grad_output_quantizer.__class__.__name__}"
)
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# recipe changed
if isinstance(weight_quantizer, Float8Quantizer) and isinstance(
weight, Float8TensorBase
):
if isinstance(
weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
) and isinstance(weight, Float8TensorBase):
weight._quantizer = weight_quantizer
@staticmethod
......@@ -349,7 +375,9 @@ class BasicLinear(BasicOperation):
input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None,
output_quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input_requires_grad: bool = True,
weight_requires_grad: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Functional API for forward pass
Parameters
......@@ -385,17 +413,25 @@ class BasicLinear(BasicOperation):
Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional
Builder class for quantized output tensor.
input_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the input tensor is
required in the backward pass.
weight_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the weight tensor is
required in the backward pass.
Returns
-------
torch.Tensor
Output tensor
torch.Tensor
Input tensor used in GEMM, possibly cast and reshaped from
provided input tensor
torch.Tensor
Weight tensor used in GEMM, possibly cast and reshaped from
provided weight tensor
torch.Tensor, optional
Input tensor, ready for use in backward pass. `None` is
returned if loss gradient w.r.t. the weight tensor is not
required.
torch.Tensor, optional
Weight tensor, ready for use in backward pass. `None` is
returned if loss gradient w.r.t. the input tensor is not
required.
"""
......@@ -416,7 +452,7 @@ class BasicLinear(BasicOperation):
if with_quantized_compute:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if with_x_all_gather:
input_quantizer.set_usage(columnwise=False)
x, x_async = gather_along_first_dim(
......@@ -449,7 +485,7 @@ class BasicLinear(BasicOperation):
if with_quantized_compute and not w_is_quantized:
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(rowwise=True)
weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
......@@ -526,17 +562,25 @@ class BasicLinear(BasicOperation):
else:
torch.distributed.all_reduce(y, group=tensor_parallel_group)
# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
# input tensor as context for backward pass.
if x_local is input:
x_local = x_local.detach()
# Prepare weight tensor for backward pass
if input_requires_grad:
if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensor):
w.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
w = None
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
# Prepare input tensor for backward pass
if weight_requires_grad:
if x_local is input:
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach()
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
x_local = None
return y, x_local, w
......@@ -892,7 +936,7 @@ class BasicLinear(BasicOperation):
dtype = torch.get_autocast_dtype("cuda")
# Linear forward
output, x_local, _ = BasicLinear._functional_forward(
output, x_local, w = BasicLinear._functional_forward(
input=input_,
weight=self.weight,
dtype=dtype,
......@@ -903,10 +947,12 @@ class BasicLinear(BasicOperation):
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
)
# Save state for backward pass
ctx.save_for_backward(x_local)
ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizer = input_quantizer
ctx.weight_quantizer = weight_quantizer
......@@ -926,7 +972,7 @@ class BasicLinear(BasicOperation):
) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
# Saved tensors from forward pass
(x_local,) = ctx.saved_tensors
(x_local, w) = ctx.saved_tensors
# wgrad fusion
accumulate_into_main_grad = self._accumulate_into_main_grad
......@@ -946,7 +992,7 @@ class BasicLinear(BasicOperation):
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=self.weight,
weight=w,
input_requires_grad=ctx.input_requires_grad,
weight_requires_grad=ctx.weight_requires_grad,
dtype=ctx.dtype,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusable operation for L2 Normalization."""
from __future__ import annotations
from typing import Optional
import torch
from ...tensor import QuantizedTensor
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from ...jit import (
l2normalization_fused,
l2normalization_fwd_fused,
l2normalization_backward_fused,
set_jit_fusion_options,
warmup_jit_l2normalization_all_dtypes,
)
class L2Normalization(BasicOperation):
r"""L2 Normalization
Applies L2 normalization over the last dimension of input tensors.
This is a parameter-free normalization that scales each vector to unit L2 norm.
.. math::
y = \frac{x}{\sqrt{\sum_{i} x_i^2 + \varepsilon}}
This operation is used e.g. for query-key normalization in attention mechanisms.
Parameters
----------
eps : float, default = 1e-6
A value added to the denominator for numerical stability
seq_length: int, default = None
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
functions are warmed up before training to ensure same kernels are used for forward
propagation and activation recompute phase.
micro_batch_size: int, default = None
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propagation and activation recompute phase.
"""
def __init__(
self,
*,
eps: float = 1e-6,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
) -> None:
super().__init__()
self.eps: float = eps
# JIT warmup for L2Normalization fused operations
if seq_length and micro_batch_size:
if torch.cuda.is_available():
set_jit_fusion_options()
# For L2Normalization, we don't know the hidden size until forward pass,
# but we can warm up with common sizes. For QK normalization, this will be
# the attention head dimension (hidden_size_per_attention_head), not the full
# model hidden dimension. Common head dimensions are 32, 64, 80, 96, 128, 256.
common_hidden_sizes = [32, 64, 80, 96, 128, 256]
for hidden_size in common_hidden_sizes:
warmup_jit_l2normalization_all_dtypes(hidden_size, seq_length, micro_batch_size)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Use input directly - torch.compile can handle multi-dimensional tensors
x = input_
if isinstance(x, QuantizedTensor):
x = x.dequantize()
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Compute L2 normalization using fused implementation
# L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps)
if requires_grad:
# Training: use version that returns both output and intermediate values
y, rsqrt_norm = l2normalization_fwd_fused(x, self.eps)
else:
# Inference: use lightweight version that only returns output
y = l2normalization_fused(x, self.eps)
rsqrt_norm = None # Not needed for inference
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, rsqrt_norm)
ctx.has_prev_op = prev_op is not None
return y
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
x, rsqrt_norm = ctx.saved_tensors
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
# Compute L2 norm backward pass using fused implementation
dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps)
# Clear saved tensors if possible
if ctx.has_prev_op:
clear_tensor_data(x)
clear_tensor_data(rsqrt_norm)
# No parameters, so empty tuple for param grads
return dx, ()
......@@ -51,7 +51,7 @@ class BackwardLinearAdd(FusedOperation):
linear_op_ctx = basic_op_ctxs[0]
# Saved tensors from forward pass
(x_local,) = linear_op_ctx.saved_tensors
(x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
......@@ -72,7 +72,7 @@ class BackwardLinearAdd(FusedOperation):
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=linear_op.weight,
weight=w,
input_requires_grad=linear_op_ctx.input_requires_grad,
weight_requires_grad=linear_op_ctx.weight_requires_grad,
dtype=grad_input.dtype,
......
......@@ -82,6 +82,10 @@ class ForwardLinearBiasActivation(FusedOperation):
else:
raise NotImplementedError("Activations are not yet supported")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
......@@ -106,7 +110,7 @@ class ForwardLinearBiasActivation(FusedOperation):
dtype = torch.get_autocast_dtype("cuda")
# Linear forward
output, x_local, _ = BasicLinear._functional_forward(
output, x_local, w = BasicLinear._functional_forward(
input=input_,
weight=linear_op.weight,
bias=bias,
......@@ -118,18 +122,20 @@ class ForwardLinearBiasActivation(FusedOperation):
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
)
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -76,6 +76,10 @@ class ForwardLinearBiasAdd(FusedOperation):
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
......@@ -98,7 +102,7 @@ class ForwardLinearBiasAdd(FusedOperation):
# Linear forward
output = basic_op_extra_inputs[self._op_idxs["add"]][0]
output, x_local, _ = BasicLinear._functional_forward(
output, x_local, w = BasicLinear._functional_forward(
input=input_,
weight=linear_op.weight,
bias=bias,
......@@ -111,18 +115,20 @@ class ForwardLinearBiasAdd(FusedOperation):
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
)
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -407,16 +407,26 @@ class UserbuffersBackwardLinear(FusedOperation):
# Initialize grad output
if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, MXFP8 does not
# allow reusing the grad output that was gathered for
# the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
dy, _ = gather_along_first_dim(
grad_output,
tensor_parallel_group,
quantizer=grad_output_quantizer,
)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_comm_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
dy, dy_work = gather_along_first_dim(
dy_local,
tensor_parallel_group,
async_op=True,
quantizer=grad_output_quantizer,
)
# Synchronize with the main stream
dy_work.wait()
if tensor_parallel_mode == "column":
dy = dy_local
if dy is None:
......@@ -500,7 +510,7 @@ class UserbuffersBackwardLinear(FusedOperation):
bias_op = self.basic_ops[idx]
# Saved tensors from forward pass
(x_local,) = linear_op_ctx.saved_tensors
(x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
......@@ -520,7 +530,7 @@ class UserbuffersBackwardLinear(FusedOperation):
retval = UserbuffersBackwardLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=linear_op.weight,
weight=w,
weight_requires_grad=linear_op_ctx.weight_requires_grad,
bias_requires_grad=(bias_op is not None),
dtype=linear_op_ctx.dtype,
......
......@@ -21,7 +21,7 @@ from ...module.base import (
_2X_ACC_FPROP,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype
from ..basic import BasicLinear, Bias, ReduceScatter
......@@ -98,6 +98,8 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None,
output_quantizer: Optional[Quantizer] = None,
input_requires_grad: bool = True,
weight_requires_grad: bool = True,
ub_comm_name: str,
) -> tuple[torch.Tensor, dict]:
"""Functional API for forward pass
......@@ -131,6 +133,12 @@ class UserbuffersForwardLinear(FusedOperation):
Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional
Builder class for quantized output tensor.
input_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the input tensor is
required in the backward pass.
weight_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the weight tensor is
required in the backward pass.
ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators
......@@ -141,8 +149,9 @@ class UserbuffersForwardLinear(FusedOperation):
torch.Tensor
Output tensor
dict
Extra output tensors. "input" is the input tensor,
possibly cast and reshaped from the provided input tensor.
Extra output tensors. "input" is the input tensor and
"weight" is the weight tensor, both ready for use in the
backward pass.
"""
......@@ -198,8 +207,10 @@ class UserbuffersForwardLinear(FusedOperation):
if with_ub_all_gather:
if input_quantizer is not None:
if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=True)
if isinstance(input_quantizer, Float8Quantizer):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
input_quantizer.set_usage(columnwise=False)
x_local = input_quantizer(x_local)
input_quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -212,7 +223,7 @@ class UserbuffersForwardLinear(FusedOperation):
else:
if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=True)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
x_local = input_quantizer(x_local)
else:
if isinstance(x_local, QuantizedTensorBase):
......@@ -225,7 +236,7 @@ class UserbuffersForwardLinear(FusedOperation):
w = weight
w_is_quantized = isinstance(w, QuantizedTensorBase)
if with_quantized_compute and not w_is_quantized:
weight_quantizer.set_usage(rowwise=True)
weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
......@@ -258,17 +269,25 @@ class UserbuffersForwardLinear(FusedOperation):
else:
y_local = gemm_output
# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
# input tensor as context for backward pass.
if x_local is input:
x_local = x_local.detach()
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensorBase):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
# Prepare weight tensor for backward pass
if input_requires_grad:
if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensorBase):
w.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
w = None
# Prepare input tensor for backward pass
if weight_requires_grad:
if x_local is input:
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach()
if with_quantized_compute and isinstance(x_local, QuantizedTensorBase):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
x_local = None
# Return cast tensors
extra_outputs = {"input": x_local, "weight": w}
......@@ -298,6 +317,10 @@ class UserbuffersForwardLinear(FusedOperation):
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# Quantization metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
......@@ -306,8 +329,10 @@ class UserbuffersForwardLinear(FusedOperation):
grad_input_quantizer = None
if with_quantized_compute:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if not recipe.delayed() and not recipe.mxfp8():
raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe")
if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())):
raise RuntimeError(
f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})"
)
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
......@@ -338,12 +363,15 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=None, # Not supported
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
ub_comm_name=linear_op._userbuffers_options["comm_name"],
)
x_local = extra_outputs["input"]
w = extra_outputs["weight"]
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
......@@ -351,8 +379,8 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -61,13 +61,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
def forward(
func_ctx: Optional[torch.autograd.function.FunctionCtx],
input_: torch.Tensor,
forward_ops: list[tuple[FusibleOperation, list[int]]],
backward_ops: list[tuple[FusibleOperation, list[int]]],
basic_ops: list[BasicOperation],
fuser: OperationFuser,
basic_op_kwargs: list[dict[str, Any]],
is_grad_enabled: bool,
num_params: int,
num_extra_inputs: int,
*params_and_extra_inputs: torch.nn.Parameter,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass
......@@ -78,20 +74,12 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
Context for PyTorch autograd function
input_: torch.Tensor
Input to first operation in pipeline
forward_ops: list of tuple
Forward pass operations and the indices of the
corresponding basic operations. The order should match
basic_ops.
backward_ops: list of tuple
Backward pass operations and the indices of the
corresponding basic operations. The order should be the
reverse of basic_ops.
basic_ops: list of BasicOperation
Basic operations
fuser: OperationFuser
Container for the pipeline of operations to run
basic_op_kwargs: list of dict
Keyword arguments to BasicOperation
num_params: int
Number of parameter tensors to include in autograd graph.
is_grad_enabled: bool
Should context be saved for backward
*params_and_extra_inputs: torch.Tensor
Other tensor inputs to include in autograd graph. Consists
of parameter tensors, followed by extra operation inputs.
......@@ -106,26 +94,20 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
"""
# Operation autograd contexts
basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))]
basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]
# Unflatten list of parameters and extra tensor inputs
if len(params_and_extra_inputs) != num_params + num_extra_inputs:
raise ValueError(
f"Expected {num_params + num_extra_inputs} extra tensor arguments "
f"({num_params} parameters, {num_extra_inputs} extra inputs), "
f"but got {len(params_and_extra_inputs)}"
)
_, extra_inputs = _split_tuple(params_and_extra_inputs, num_params)
extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :]
basic_op_extra_inputs = []
for op in basic_ops:
for op in fuser._basic_ops:
xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
basic_op_extra_inputs.append(xs)
# Apply forward ops
x = input_
requires_grad = is_grad_enabled and x.requires_grad
extra_outputs = [None for _ in range(len(basic_ops))]
for op, basic_op_idxs in forward_ops:
extra_outputs = [None] * fuser._num_basic_ops
for op, basic_op_idxs in fuser._forward_ops:
# Check if backward op is required
if is_grad_enabled:
......@@ -143,9 +125,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
prev_ops = [fuser._basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
next_ops = [
basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs
fuser._basic_ops[idx + 1] if (idx < fuser._num_basic_ops - 1) else None
for idx in basic_op_idxs
]
x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs],
......@@ -165,7 +148,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
extra_outputs_flat = []
for idx, ys in enumerate(extra_outputs):
ys = list(ys)
num_extra_outputs = basic_ops[idx].num_extra_outputs
num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs
if len(ys) != num_extra_outputs:
raise RuntimeError(
f"Expected op {idx} to generate "
......@@ -189,11 +172,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.save_for_backward(*to_save)
# Other context
func_ctx.backward_ops = backward_ops
func_ctx.basic_ops = basic_ops
func_ctx.backward_ops = fuser._backward_ops
func_ctx.basic_ops = fuser._basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.basic_op_num_params = [sum(1 for _ in op.parameters()) for op in basic_ops]
func_ctx.num_extra_inputs = num_extra_inputs
func_ctx.basic_op_num_params = fuser._num_list_basic_op_params
func_ctx.num_extra_inputs = fuser._num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
......@@ -216,8 +199,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_op_ctxs = func_ctx.basic_op_ctxs
# Unflatten list of saved tensors
saved_tensors = func_ctx.saved_tensors
for ctx in basic_op_ctxs:
ctx.saved_tensors = func_ctx.saved_tensors[slice(*ctx._saved_tensors_range)]
ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
ctx._saved_tensors_range = None
# Unflatten list of extra tensor output grads
......@@ -292,13 +276,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
return (
dx, # input_
None, # forward_ops
None, # backward_ops
None, # basic_ops
None, # fuser
None, # basic_op_kwargs
None, # is_grad_enabled
None, # num_params
None, # num_extra_inputs
*grad_params_flat,
*grad_extra_inputs_flat,
)
......@@ -345,6 +325,10 @@ class OperationFuser:
if fuse_ops:
self.fuse_ops()
# Flatten list of parameters
self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()]
self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops]
@classmethod
def _fuse_forward_ops(
cls,
......@@ -377,6 +361,11 @@ class OperationFuser:
*extra_inputs: torch.Tensor,
basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
# Verify extra input count
if len(extra_inputs) != self._num_extra_inputs:
raise ValueError(
f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}"
)
# Initialization before forward pass
for op in self._basic_ops:
......@@ -384,10 +373,7 @@ class OperationFuser:
# Canonicalize op kwargs
if basic_op_kwargs is None:
basic_op_kwargs = [{} for _ in range(len(self._basic_ops))]
# Flatten list of parameters
params = [param for op in self._basic_ops for param in op.parameters()]
basic_op_kwargs = [{}] * self._num_basic_ops
# Fuser forward pass
is_grad_enabled = torch.is_grad_enabled()
......@@ -399,14 +385,10 @@ class OperationFuser:
args = [None]
args += (
input,
self._forward_ops,
self._backward_ops,
self._basic_ops,
self,
basic_op_kwargs,
is_grad_enabled,
len(params),
self._num_extra_inputs,
*params,
*self._basic_op_params,
*extra_inputs,
)
return forward_func(*args)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "pip", "torch>=2.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
......@@ -31,7 +31,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers
from build_tools.te_version import te_version
from build_tools.pytorch import setup_pytorch_extension
from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements
os.environ["NVTE_PROJECT_BUILDING"] = "1"
......@@ -55,18 +55,8 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
setup_requires=[
"torch>=2.1",
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
],
install_requires=["torch>=2.1"],
tests_require=["numpy", "torchvision"],
install_requires=install_requirements(),
tests_require=test_requirements(),
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir)
......
......@@ -11,6 +11,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from ..quantized_tensor import QuantizedTensorBase
......@@ -37,6 +38,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
_rowwise_scale_inv: Optional[torch.Tensor]
_columnwise_scale_inv: Optional[torch.Tensor]
_is_2D_scaled: bool
_data_format: Float8BlockScaleTensorFormat
def __new__(
cls,
......@@ -48,6 +50,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
fp8_dtype: TE_DType,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY,
**kwargs,
):
instance = super().__new__(cls, *args, **kwargs)
......@@ -58,6 +61,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
instance._is_2D_scaled = is_2D_scaled
instance._data_format = data_format
return instance
......@@ -82,8 +86,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
"fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer,
"is_2D_scaled": self._is_2D_scaled,
"data_format": self._data_format,
}
def _is_gemm_ready_format(self) -> bool:
"""Whether data is in GEMM_READY format"""
return self._data_format == Float8BlockScaleTensorFormat.GEMM_READY
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
......@@ -136,34 +145,69 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
inner_q_dimension_tiled = True
if self._is_gemm_ready_format():
scales_tiled_dim, scales_untiled_dim = scale_inv.shape
inner_scale_dimension_tiled = False
scales_are_compact = False
else:
scales_untiled_dim, scales_tiled_dim = scale_inv.shape
inner_scale_dimension_tiled = True
scales_are_compact = True
else:
assert self._columnwise_data is not None, "No data to dequantize"
q = self._columnwise_data
scale_inv = self._columnwise_scale_inv
transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
scales_tiled_dim, scales_untiled_dim = scale_inv.shape
inner_scale_dimension_tiled = False
if self._is_gemm_ready_format():
inner_q_dimension_tiled = True
transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
scales_are_compact = False
else:
inner_q_dimension_tiled = False
transpose_output = False
if len(q.shape) >= 1:
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
scales_are_compact = True
orig_shape = q.shape
q = q.reshape(q_M, q_K)
k_tiles, scale_m = scale_inv.shape
if q_K % block_len != 0:
k_pad_amount = (block_len - (q_K % block_len)) % block_len
q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, 0), mode="constant", value=0
).contiguous()
_, padded_K = q.shape
q_tiled = q.reshape(q_M, k_tiles, block_len)
if scale_m > q_M:
# scale_m is 4 element aligned.
if inner_q_dimension_tiled:
if q_K % block_len != 0:
k_pad_amount = (block_len - (q_K % block_len)) % block_len
q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, 0), mode="constant", value=0
).contiguous()
padded_M, padded_K = q.shape
q_tiled = q.reshape(q_M, scales_tiled_dim, block_len)
else:
if q_M % block_len != 0:
m_pad_amount = (block_len - (q_M % block_len)) % block_len
q = torch.nn.functional.pad(
q, (0, 0, 0, m_pad_amount), mode="constant", value=0
).contiguous()
padded_M, padded_K = q.shape
q_tiled = q.reshape(scales_tiled_dim, block_len, q_K)
if not scales_are_compact and scales_untiled_dim > q_M:
# untiled scale dimension is 4 element aligned.
scale_inv = scale_inv[:, :q_M].contiguous()
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, k_tiles, 1)
if scales_are_compact and inner_scale_dimension_tiled:
dq_scale = scale_inv.contiguous().reshape(q_M, scales_tiled_dim, 1)
elif scales_are_compact and not inner_scale_dimension_tiled:
dq_scale = scale_inv.contiguous().reshape(scales_tiled_dim, 1, q_K)
else:
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, scales_tiled_dim, 1)
torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype]
result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale
if padded_K != q_K:
result = result.reshape(q_M, padded_K)[:, :q_K]
if padded_M != q_M or padded_K != q_K:
result = result.reshape(padded_M, padded_K)[:q_M, :q_K]
result = result.to(dtype)
if len(orig_shape) == 0:
result = result.reshape([])
......@@ -182,6 +226,12 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype)
if not self._is_gemm_ready_format():
raise NotImplementedError(
"Dequantize is only supported with GEMM_READY data format, "
f"but found _data_format={self._data_format}"
)
def format_scale_as_logical_shape(q_K, scales, block_len):
# The GEMM for 2D blocks required padding in the scales.
derived_scale_k_shape = math.ceil(q_K / block_len)
......@@ -247,6 +297,8 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
if self._rowwise_data is not None:
return self._rowwise_data.size(*args, **kwargs)
dims = list(self._columnwise_data.size(*args, **kwargs))
if not self._is_gemm_ready_format(): # compact format
return torch.Size(dims)
reordered = []
for i in range(1, len(dims)):
reordered.append(dims[i])
......@@ -285,6 +337,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
w = min(self._columnwise_scale_inv.shape[1], columnwise_scale_inv.shape[1])
self._columnwise_scale_inv[0:h, 0:w].copy_(columnwise_scale_inv[0:h, 0:w])
def _transpose_columnwise_data(self):
"""Plainly transpose the columnwise data and scale inv."""
if self._columnwise_data is not None:
self._columnwise_data = tex.fp8_transpose(
self._columnwise_data, self._fp8_dtype, out=None
)
def __repr__(self):
if self._rowwise_data is not None:
data = self.dequantize()
......
......@@ -4,13 +4,15 @@
"""Tensor class with FP8 data quantized with NxN tiles"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable
from typing import Optional, Tuple, Iterable, Union
import math
import torch
import transformer_engine_torch as tex
import os
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..utils import devices_match, round_up_to_nearest_multiple
......@@ -33,6 +35,8 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float
force_pow_2_scales: bool
block_scaling_dim: int
# Whether to produce tensors that will be used in all-gather
all_gather_usage: bool
def __init__(
self,
......@@ -43,6 +47,7 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float = 0.0,
force_pow_2_scales: bool = True,
block_scaling_dim: int = 2,
all_gather_usage: bool = False,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = tex.DType.kInt8 if int8_simulation_fp8 else fp8_dtype
......@@ -50,6 +55,7 @@ class Float8BlockQuantizer(Quantizer):
self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon
self.block_scaling_dim = block_scaling_dim
self.all_gather_usage = all_gather_usage
def update_quantized(
self,
......@@ -126,22 +132,36 @@ class Float8BlockQuantizer(Quantizer):
M *= shape[i]
if len(shape) > 0:
K = shape[-1]
# 2D 128x128 quantization block scaling
# CuBLAS requries 128x128 scaling factor to be padded
# currently rowwise and columnwise format option doesn't apply to 2D scaling
if self.block_scaling_dim == 2:
if columnwise:
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4)
return (outer, inner)
# rowwise
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4)
return (outer, inner)
# 1D 1x128 quantization block scaling
# CuBLAS requries 1x128 scaling factor to be padded and transposed
assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported"
if columnwise:
columnwise_compact = self.all_gather_usage
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(K, 4)
inner = round_up_to_nearest_multiple(K, 4) if not columnwise_compact else K
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS
# for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner]
# so no need to swap inner outer here
return (outer, inner)
# rowwise
rowwise_compact = self.all_gather_usage
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(M, 4)
return (outer, inner)
inner = round_up_to_nearest_multiple(M, 4) if not rowwise_compact else M
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need
# for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here
return (outer, inner) if not rowwise_compact else (inner, outer)
def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of a tensor after columnwise permutation.
......@@ -163,15 +183,25 @@ class Float8BlockQuantizer(Quantizer):
"""
if len(shape) == 0:
return tuple()
# currently columnwise format option only applies to 1D quantizer
# for 2D scaling, columnwise format should always be GEMM_READY_DATA_AND_SCALES
# since currently 2D scaling only applies to module weights
if self.block_scaling_dim == 1 and self.all_gather_usage:
return shape
colwise_shape = [shape[-1]]
for i in range(len(shape) - 1):
colwise_shape.append(shape[i])
return tuple(colwise_shape)
# TODO(kwyss): With FP8 gather support, we need to implement a
# shape/layout/swizzle check to know whether FP8 gather works
# cleanly by stacking data without aliasing tiles and whether
# the scales also stack on the proper dimensions.
def is_quantizable(self, inp: torch.Tensor) -> bool:
"""Returns whether or not given inp can be quantized"""
if inp.ndim < 2:
return False
if inp.shape[-1] % self.block_len != 0:
return False
if math.prod(inp.shape[:-1]) % self.block_len != 0:
return False
return True
def make_empty(
self,
......@@ -185,6 +215,12 @@ class Float8BlockQuantizer(Quantizer):
if device is None:
device = torch.device("cuda")
data_format = (
tex.Float8BlockScaleTensorFormat.COMPACT
if self.all_gather_usage
else tex.Float8BlockScaleTensorFormat.GEMM_READY
)
# Allocate FP8 data
data = None
scale_inv = None
......@@ -222,6 +258,7 @@ class Float8BlockQuantizer(Quantizer):
columnwise_scale_inv=columnwise_scale_inv,
quantizer=self,
is_2D_scaled=self.block_scaling_dim == 2,
data_format=data_format,
requires_grad=requires_grad,
)
......@@ -230,6 +267,9 @@ class Float8BlockQuantizer(Quantizer):
# where state from an estimator influences distribution parameters.
pass
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return Float8BlockScaling
class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
......@@ -260,7 +300,8 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize(dtype=self.dtype)})"
f" data={self.dequantize(dtype=self.dtype)}),"
f" data_format={self._data_format}"
)
def _get_quantizer(self) -> Quantizer:
......@@ -393,6 +434,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dtype: torch.dtype,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat,
) -> Float8BlockwiseQTensor:
"""Build Float8BlockwiseQTensor, for use in __reduce__
......@@ -410,6 +452,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dtype=dtype,
quantizer=quantizer,
is_2D_scaled=is_2D_scaled,
data_format=data_format,
)
def __reduce_ex__(self, protocol: int) -> tuple:
......@@ -426,6 +469,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
self.dtype,
self._quantizer,
self._is_2D_scaled,
self._data_format,
),
)
......@@ -451,6 +495,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv
dst._data_format = src._data_format
# Check that tensor dimensions match
if (
......@@ -498,6 +543,13 @@ class _ViewFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
......@@ -566,6 +618,14 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_data = (
grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None
)
......@@ -605,6 +665,13 @@ class _ReshapeFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
if shape is None:
......@@ -672,6 +739,14 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_rowwise_data = None
new_columnwise_data = None
if grad._rowwise_data is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment