Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
...@@ -123,6 +123,35 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: ...@@ -123,6 +123,35 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
return dgelu return dgelu
@jit_fuser
def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor:
"""L2 normalization fused - inference version"""
x_squared = x.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
return x * rsqrt_norm
@jit_fuser
def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
"""L2 normalization fused - training version that returns intermediate values"""
x_squared = x.pow(2)
l2_norm_squared = x_squared.sum(dim=-1, keepdim=True)
rsqrt_norm = torch.rsqrt(l2_norm_squared + eps)
y = x * rsqrt_norm
return y, rsqrt_norm
@jit_fuser
def l2normalization_backward_fused_(
grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float
) -> torch.Tensor:
"""L2 normalization backward fused"""
x_dy_sum = (x * grad_output).sum(dim=-1, keepdim=True)
x_norm_squared = x.pow(2).sum(dim=-1, keepdim=True) + eps
return rsqrt_norm * (grad_output - x * x_dy_sum / x_norm_squared)
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_""" """Disable native AMP for bias_gelu_fused_"""
with gpu_autocast_ctx(enabled=False): with gpu_autocast_ctx(enabled=False):
...@@ -141,6 +170,26 @@ def bgrad_dgelu_fused( ...@@ -141,6 +170,26 @@ def bgrad_dgelu_fused(
return None, dgelu_fused_(grad_output, inp) return None, dgelu_fused_(grad_output, inp)
def l2normalization_fused(x: torch.Tensor, eps: float) -> torch.Tensor:
"""Disable native AMP for l2normalization_fused_ - inference version"""
with gpu_autocast_ctx(enabled=False):
return l2normalization_fused_(x, eps)
def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]:
"""Disable native AMP for l2normalization_fwd_fused_ - training version"""
with gpu_autocast_ctx(enabled=False):
return l2normalization_fwd_fused_(x, eps)
def l2normalization_backward_fused(
grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float
) -> torch.Tensor:
"""Disable native AMP for l2normalization_backward_fused_"""
with gpu_autocast_ctx(enabled=False):
return l2normalization_backward_fused_(grad_output, x, rsqrt_norm, eps)
def bias_dropout_add( def bias_dropout_add(
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
...@@ -266,3 +315,45 @@ def warmup_jit_bias_gelu_all_dtypes( ...@@ -266,3 +315,45 @@ def warmup_jit_bias_gelu_all_dtypes(
"""Call `warmup_jit_bias_gelu` for all training dtypes""" """Call `warmup_jit_bias_gelu` for all training dtypes"""
for dtype in [torch.float32, torch.bfloat16, torch.float16]: for dtype in [torch.float32, torch.bfloat16, torch.float16]:
warmup_jit_bias_gelu(ffn_hidden_size, dtype, seq_length, micro_batch_size) warmup_jit_bias_gelu(ffn_hidden_size, dtype, seq_length, micro_batch_size)
def warmup_jit_l2normalization(
hidden_size: int, dtype: torch.dtype, seq_length: int, micro_batch_size: int
) -> None:
"""Compile L2Normalization JIT function before the main training steps"""
# Save cuda RNG state to ensure warmup does not affect reproducibility.
rng_state = torch.cuda.get_rng_state()
inp = torch.rand(
(seq_length * micro_batch_size, hidden_size),
dtype=dtype,
device="cuda",
)
eps = 1e-6
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for input_grad in [False, True]:
inp.requires_grad = input_grad
for _ in range(5):
if input_grad:
# Test training version that returns intermediate values
output, rsqrt_norm = l2normalization_fwd_fused_(inp, eps)
# Test backward pass as well
grad_out = torch.rand_like(output)
_ = l2normalization_backward_fused_(grad_out, inp, rsqrt_norm, eps)
else:
# Test inference version
output = l2normalization_fused_(inp, eps)
del inp, output
torch.cuda.empty_cache()
torch.cuda.set_rng_state(rng_state)
def warmup_jit_l2normalization_all_dtypes(
hidden_size: int, seq_length: int, micro_batch_size: int
) -> None:
"""Call `warmup_jit_l2normalization` for all training dtypes"""
for dtype in [torch.float32, torch.bfloat16, torch.float16]:
warmup_jit_l2normalization(hidden_size, dtype, seq_length, micro_batch_size)
...@@ -44,7 +44,7 @@ from ..tensor._internal.float8_tensor_base import Float8TensorBase ...@@ -44,7 +44,7 @@ from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import torch_get_autocast_gpu_dtype from ..utils import torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import Recipe from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
...@@ -89,7 +89,7 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: ...@@ -89,7 +89,7 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas.""" """Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_workspace global _multi_stream_cublas_workspace
if not _multi_stream_cublas_workspace: if not _multi_stream_cublas_workspace:
for _ in range(tex._num_cublas_streams): for _ in range(tex.get_num_cublas_streams()):
_multi_stream_cublas_workspace.append( _multi_stream_cublas_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda") torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
) )
...@@ -685,6 +685,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -685,6 +685,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Update quantizers with new amax pointers. # Update quantizers with new amax pointers.
self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers() self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers()
# Make sure weight tensors has correct quantizers
self._update_weight_quantizers()
# Update the global buffers with new amax and history pointers. # Update the global buffers with new amax and history pointers.
if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
...@@ -738,6 +740,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -738,6 +740,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta[fp8_meta_tensor_key] = recipe_state self.fp8_meta[fp8_meta_tensor_key] = recipe_state
self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers() self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()
def _update_weight_quantizers(self) -> None:
"""Update the quantizers for the weight tensors."""
weight_tensors = self._get_weight_tensors()
weight_quantizers = self._get_weight_quantizers()
assert len(weight_tensors) == len(weight_quantizers), (
f"Number of weight tensors ({len(weight_tensors)}) and quantizers "
f"({len(weight_quantizers)}) must match"
)
for weight, quantizer in zip(weight_tensors, weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensorBase):
weight.update_quantizer(quantizer)
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_weight_tensors function"
)
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_weight_quantizers function"
)
def init_fp8_meta_tensors(self, recipe: Recipe) -> None: def init_fp8_meta_tensors(self, recipe: Recipe) -> None:
"""Init scales and amaxes.""" """Init scales and amaxes."""
self.set_meta_tensor(True, recipe) self.set_meta_tensor(True, recipe)
...@@ -777,7 +803,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -777,7 +803,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset("scaling_fwd") reset("scaling_fwd")
reset("scaling_bwd") reset("scaling_bwd")
def get_extra_state(self) -> Optional[torch.Tensor]: def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing.""" """Save before checkpointing."""
# This implementation is working around a few issues: # This implementation is working around a few issues:
...@@ -812,7 +838,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -812,7 +838,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state = None state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if not fp8_checkpoint: if not fp8_checkpoint:
return None return torch.empty(0, dtype=torch.uint8)
# Copy tensors to CPU and store # Copy tensors to CPU and store
state = {} state = {}
...@@ -838,13 +864,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -838,13 +864,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized return state_serialized
def set_extra_state(self, state: Optional[torch.Tensor]) -> None: def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state.""" """Load previous state."""
# Maintain backwards compatibility with older checkpoints.
if state is None: if state is None:
return return
# Load state # Load state
if isinstance(state, torch.Tensor): if isinstance(state, torch.Tensor):
# No FP8 is indicated by an empty tensor we don't need to unpickle.
if state.numel() == 0:
return
# Default format: byte tensor with pickled data # Default format: byte tensor with pickled data
state = pickle.loads(state.detach().cpu().numpy().tobytes()) state = pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO): elif isinstance(state, io.BytesIO):
...@@ -857,6 +888,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -857,6 +888,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None: if state is None:
return return
# TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing
if "recipe" not in state:
# TE 1.x only supported delayed scaling, which was the default recipe
state["recipe"] = DelayedScaling()
# TE 1.x also saved scale_inv, which is not needed with Recipe object
state.pop("scale_inv_fwd", None)
state.pop("scale_inv_bwd", None)
# Load extra items # Load extra items
self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"] = state["recipe"] self.fp8_meta["recipe"] = state["recipe"]
...@@ -930,6 +969,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -930,6 +969,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution. # assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None: def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop.""" """Initialize fp8 related metadata and tensors during fprop."""
_original_recipe = self.fp8_meta.get("recipe", None)
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
...@@ -968,6 +1009,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -968,6 +1009,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
_current_recipe = self.fp8_meta["recipe"]
if _original_recipe is not None and not (
issubclass(_current_recipe.__class__, _original_recipe.__class__)
or issubclass(_original_recipe.__class__, _current_recipe.__class__)
):
warnings.warn(
f"Recipe type changed from {_original_recipe.__class__.__name__} "
f"to {_current_recipe.__class__.__name__}. "
"This may affect model behavior."
)
# Clear cached workspaces as they were created with the old recipe/quantizer type
self._fp8_workspaces.clear()
@contextmanager @contextmanager
def prepare_forward( def prepare_forward(
self, self,
...@@ -992,6 +1046,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -992,6 +1046,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_activation_dtype(inp) self.set_activation_dtype(inp)
self.init_fp8_metadata(num_gemms=num_gemms) self.init_fp8_metadata(num_gemms=num_gemms)
self._check_weight_tensor_recipe_correspondence()
if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
assert self.fp8_meta["recipe"].reduce_amax, ( assert self.fp8_meta["recipe"].reduce_amax, (
...@@ -1103,7 +1158,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1103,7 +1158,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if ( if (
isinstance( isinstance(
grad_output_.get_tensor(True), grad_output_.get_tensor(True),
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase), (
QuantizedTensor,
Float8TensorBase,
MXFP8TensorBase,
Float8BlockwiseQTensorBase,
),
) )
and ctx.use_bias and ctx.use_bias
): ):
...@@ -1169,18 +1229,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1169,18 +1229,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
with get_rng_state_tracker().fork(): with get_rng_state_tracker().fork():
init_fn(param) init_fn(param)
# If primary weights are in fp8, wrap the parameter as FP8Tensor # Wrap parameters in QuantizedTensor if needed
fp8_meta_index = self.param_init_meta[name].fp8_meta_index fp8_meta_index = self.param_init_meta[name].fp8_meta_index
high_precision_init_val = None high_precision_init_val = None
if self.primary_weights_in_fp8 and fp8_meta_index is not None: if self.primary_weights_in_fp8 and fp8_meta_index is not None:
# Keep high-precision values on CPU if needed
if self.preserve_high_precision_init_val: if self.preserve_high_precision_init_val:
high_precision_init_val = param.detach().cpu() high_precision_init_val = param.detach().cpu()
# Configure quantizer
quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
assert ( if quantizer is None:
quantizer is not None raise RuntimeError("Weight quantizer has not been initialized")
) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe. quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
quantizer.internal = False quantizer.internal = False
# Quantize parameter
param = quantizer(param) param = quantizer(param)
# Redo parameter wrap in case we broke it above # Redo parameter wrap in case we broke it above
...@@ -1188,6 +1253,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1188,6 +1253,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# re-applying the nn.Parameter() wrap is a no-op when the input is already # re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety. # a parameter so we always re-apply it just for extra safety.
param = torch.nn.Parameter(param) param = torch.nn.Parameter(param)
# Keep high-precision values on CPU if needed
if high_precision_init_val is not None: if high_precision_init_val is not None:
# - Master weights are initialized from model weights, if we use fp8 primary # - Master weights are initialized from model weights, if we use fp8 primary
...@@ -1231,7 +1298,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1231,7 +1298,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fsdp_group: Optional[dist_group_type] = None, fsdp_group: Optional[dist_group_type] = None,
workspace_dtype: Optional[torch.dtype] = None, workspace_dtype: Optional[torch.dtype] = None,
) -> QuantizedTensor: ) -> QuantizedTensor:
"""Get FP8 workspace buffer and maybe update its values """Get workspace buffer for weights and maybe update its values
The workspace buffer may be cached for future function calls. The workspace buffer may be cached for future function calls.
...@@ -1257,13 +1324,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1257,13 +1324,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
for debug quantization, this is dtype of the tensor. for debug quantization, this is dtype of the tensor.
""" """
# FP8 primary weights # Handle case where weights are already quantized
# Note: Make sure weights have required usages, but do not
# destroy unnecessary usages since they may be used later.
if isinstance(tensor, QuantizedTensor): if isinstance(tensor, QuantizedTensor):
if update_workspace and quantizer is not None: update_rowwise_usage = True if quantizer.rowwise_usage else None
tensor.update_usage( update_columnwise_usage = True if quantizer.columnwise_usage else None
rowwise_usage=quantizer.rowwise_usage, tensor.update_usage(
columnwise_usage=quantizer.columnwise_usage, rowwise_usage=update_rowwise_usage,
) columnwise_usage=update_columnwise_usage,
)
return tensor return tensor
# Try getting workspace from cache # Try getting workspace from cache
...@@ -1387,6 +1457,43 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1387,6 +1457,43 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
) )
self.name = f"Layer_{TEDebugState.get_layer_count()}" self.name = f"Layer_{TEDebugState.get_layer_count()}"
def _check_weight_tensor_recipe_correspondence(self) -> None:
"""
Verify that the weight tensor types match their corresponding recipe type.
This is invoked in the forward().
This establishes a 1:1 correspondence between recipe types and tensor types:
- DelayedScaling → Float8Tensor
- Float8CurrentScaling → Float8Tensor
- MXFP8BlockScaling → MXFP8Tensor
- Float8BlockScaling → Float8BlockTensor
Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()),
but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()).
"""
if not self.fp8 and not self.fp8_calibration:
return
if not hasattr(self, "weight_names") or not self.weight_names:
return
recipe = self.fp8_meta["recipe"]
weight_tensors = [getattr(self, name) for name in self.weight_names]
for i, tensor in enumerate(weight_tensors):
if isinstance(tensor, QuantizedTensorBase):
quantizer = tensor._get_quantizer()
if quantizer is None:
continue
compatible_recipe_class = quantizer._get_compatible_recipe()
if compatible_recipe_class is None:
continue
if not isinstance(recipe, compatible_recipe_class):
raise RuntimeError(
f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe"
f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}."
" Please check the recipes assigned during fp8_model_init() and"
" fp8_autocast() calls."
)
def _turn_off_unsupported_features_in_debug(self): def _turn_off_unsupported_features_in_debug(self):
if ( if (
getattr(self, "ub_bulk_wgrad", False) getattr(self, "ub_bulk_wgrad", False)
......
...@@ -242,8 +242,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -242,8 +242,8 @@ class _GroupedLinear(torch.autograd.Function):
biases = saved_tensors[3 * N : 4 * N] biases = saved_tensors[3 * N : 4 * N]
main_grads = ctx.main_grads main_grads = ctx.main_grads
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
for i in ctx.num_gemms: for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i] w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
...@@ -673,26 +673,19 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -673,26 +673,19 @@ class GroupedLinear(TransformerEngineBaseModule):
), "GroupedLinear doesn't support input tensor in FP8." ), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None: if skip_fp8_weight_update is not None:
is_first_microbatch = False is_first_microbatch = False
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = self._get_weight_tensors()
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensorBase) else w
for w in weight_tensors
]
input_quantizers, weight_quantizers, output_quantizers = ( weight_quantizers = self._get_weight_quantizers()
[None] * self.num_gemms, input_quantizers, output_quantizers = (
[None] * self.num_gemms, [None] * self.num_gemms,
[None] * self.num_gemms, [None] * self.num_gemms,
) )
...@@ -707,14 +700,6 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -707,14 +700,6 @@ class GroupedLinear(TransformerEngineBaseModule):
# TODO: use internal after #1638 is merged. # pylint: disable=fixme # TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms): for i in range(self.num_gemms):
input_quantizers[i].internal = False input_quantizers[i].internal = False
weight_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
if torch.is_grad_enabled(): if torch.is_grad_enabled():
grad_output_quantizers = [ grad_output_quantizers = [
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
...@@ -813,3 +798,30 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -813,3 +798,30 @@ class GroupedLinear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensorBase) else w for w in weight_tensors
]
return weight_tensors
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None] * self.num_gemms
weight_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
return weight_quantizers
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""LayerNormLinear API""" """LayerNormLinear API"""
import os import os
import warnings import warnings
from typing import Callable, Dict, Optional, Tuple, Union from typing import Callable, Dict, Optional, Tuple, Union, List
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
...@@ -63,10 +63,11 @@ from ..tensor.quantized_tensor import ( ...@@ -63,10 +63,11 @@ from ..tensor.quantized_tensor import (
) )
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled from ...debug.pytorch.utils import any_feature_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import ( from ..cpp_extensions import (
...@@ -190,19 +191,13 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -190,19 +191,13 @@ class _LayerNormLinear(torch.autograd.Function):
# All-gather is not supported with FP8 column-wise data # All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False) input_quantizer.set_usage(columnwise=False)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_blockwise_ln_out_gather = (
fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer)
)
# Avoid quantized norm kernel if norm output will be returned # Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision. # or if a gather of ln_out must be in high precision.
with_quantized_norm = ( with_quantized_norm = (
fp8 fp8
and not debug
and not return_layernorm_output and not return_layernorm_output
and not return_layernorm_output_gathered and not return_layernorm_output_gathered
and not force_hp_blockwise_ln_out_gather
) )
# Apply normalization # Apply normalization
...@@ -239,15 +234,16 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -239,15 +234,16 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total ln_out_return = ln_out_total
if fp8 or debug: if fp8 or debug:
if not force_hp_blockwise_ln_out_gather: ln_out = input_quantizer(ln_out)
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(input_quantizer, Float8BlockQuantizer):
input_quantizer.all_gather_usage = False
ln_out_total = input_quantizer(ln_out_total) ln_out_total = input_quantizer(ln_out_total)
else: else:
quantizer = None quantizer = None
if fp8 or debug: if fp8 or debug:
quantizer = input_quantizer quantizer = input_quantizer
if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: if not with_quantized_norm:
ln_out = quantizer(ln_out) ln_out = quantizer(ln_out)
quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
...@@ -282,7 +278,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -282,7 +278,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Configure quantizer # Configure quantizer
if weight_quantizer is not None: if weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=True) weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
# Get quantized weight # Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
...@@ -397,7 +393,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -397,7 +393,6 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ln_out_needs_gather = ( ctx.ln_out_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel weight.requires_grad and parallel_mode == "column" and sequence_parallel
) )
ctx.force_hp_blockwise_ln_out_gather = force_hp_blockwise_ln_out_gather
# Input with column-wise usage is needed for wgrad GEMM. # Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input: if backward_needs_input:
...@@ -405,7 +400,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -405,7 +400,10 @@ class _LayerNormLinear(torch.autograd.Function):
# For sequence parallel in vanilla FP8, rowwise data is # For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data # to gather the input. For MXFP8, columnwise only data
# can be allgathered. # can be allgathered.
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: if (
isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
or not ctx.ln_out_needs_gather
):
ln_out.update_usage(rowwise_usage=False) ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
...@@ -502,8 +500,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -502,8 +500,8 @@ class _LayerNormLinear(torch.autograd.Function):
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp_shape) shape = list(inp.shape)
shape[0] *= tp_size shape[0] *= tp_size if with_input_all_gather else 1
return out, ln_out_return.view(shape) return out, ln_out_return.view(shape)
return out, ln_out_return.view(inp_shape) return out, ln_out_return.view(inp_shape)
return out return out
...@@ -637,7 +635,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -637,7 +635,7 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_work = None ln_out_total_work = None
if ctx.ln_out_needs_gather: if ctx.ln_out_needs_gather:
quantizer = None quantizer = None
if ctx.input_quantizer is not None and not ctx.force_hp_blockwise_ln_out_gather: if ctx.input_quantizer is not None:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -752,6 +750,31 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -752,6 +750,31 @@ class _LayerNormLinear(torch.autograd.Function):
wgrad = None wgrad = None
if ctx.requires_wgrad: if ctx.requires_wgrad:
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output, mxfp8_grad_output_work = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
async_op=True,
quantizer=ctx.grad_output_quantizer,
)
# Synchronize with the main stream
mxfp8_grad_output_work.wait()
# Prepare input tensor # Prepare input tensor
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
...@@ -766,22 +789,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -766,22 +789,6 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total) ln_out_total = ctx.input_quantizer(ln_out_total)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
quantizer=ctx.grad_output_quantizer,
)
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase): if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True) grad_output.update_usage(columnwise_usage=True)
...@@ -1389,6 +1396,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1389,6 +1396,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling(): if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
# elif other recipes (mxfp8, etc) # elif other recipes (mxfp8, etc)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
...@@ -1484,20 +1493,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1484,20 +1493,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
) as inp: ) as inp:
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = self._get_weight_tensors()
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights) weight_tensor = noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
...@@ -1603,8 +1599,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1603,8 +1599,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = True input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] (weight_quantizer,) = self._get_weight_quantizers()
weight_quantizer.internal = True
if fp8_output: if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled(): if torch.is_grad_enabled():
...@@ -1679,3 +1674,39 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1679,3 +1674,39 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None]
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
return [weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_linear."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""LayerNormMLP API""" """LayerNormMLP API"""
import os import os
import warnings import warnings
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union, List
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
...@@ -244,26 +244,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -244,26 +244,18 @@ class _LayerNormMLP(torch.autograd.Function):
# All-gather is not supported with FP8 column-wise data # All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False) fc1_input_quantizer.set_usage(columnwise=False)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
)
# for fp8 DelayedScaling: layernorm output = FP8 # for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned # only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8 # for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned # high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm # for debug: : layernorm output = High precision to enable processing of this norm
with_quantized_norm = ( with_quantized_norm = (
fp8 fp8
and not debug
and not return_layernorm_output and not return_layernorm_output
and not return_layernorm_output_gathered and not return_layernorm_output_gathered
and not debug
) )
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
# Kernels not available for norm fusion.
with_quantized_norm = False
# Apply normalization # Apply normalization
ln_out, mu, rsigma = apply_normalization( ln_out, mu, rsigma = apply_normalization(
...@@ -293,15 +285,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -293,15 +285,16 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total ln_out_return = ln_out_total
if fp8 or debug: if fp8 or debug:
if not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out)
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
fc1_input_quantizer.all_gather_usage = False
ln_out_total = fc1_input_quantizer(ln_out_total) ln_out_total = fc1_input_quantizer(ln_out_total)
else: else:
quantizer = None quantizer = None
if fp8 or debug: if fp8 or debug:
quantizer = fc1_input_quantizer quantizer = fc1_input_quantizer
if not with_quantized_norm and not force_hp_fc1_input_gather: if not with_quantized_norm:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag: if ub_overlap_ag:
...@@ -333,8 +326,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -333,8 +326,8 @@ class _LayerNormMLP(torch.autograd.Function):
# which handles weight caching etc. # which handles weight caching etc.
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True) fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc1_weight_final = module.get_weight_workspace( fc1_weight_final = module.get_weight_workspace(
tensor=fc1_weight, tensor=fc1_weight,
quantizer=fc1_weight_quantizer, quantizer=fc1_weight_quantizer,
...@@ -567,7 +560,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -567,7 +560,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather
ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer
ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer
ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer
...@@ -628,7 +620,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -628,7 +620,7 @@ class _LayerNormMLP(torch.autograd.Function):
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp_shape) shape = list(inp_shape)
shape[0] *= tp_size shape[0] *= tp_size if (sequence_parallel and set_parallel_mode) else 1
return fc2_out, ln_out_return.view(shape) return fc2_out, ln_out_return.view(shape)
return fc2_out, ln_out_return.view(inp_shape) return fc2_out, ln_out_return.view(inp_shape)
return fc2_out return fc2_out
...@@ -743,7 +735,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -743,7 +735,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc1_dgrad = None ub_obj_fc1_dgrad = None
if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel:
quantizer = None quantizer = None
if ctx.fp8 or ctx.debug and not ctx.force_hp_fc1_input_gather: if ctx.fp8 or ctx.debug:
quantizer = ctx.fc1_input_quantizer quantizer = ctx.fc1_input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -841,6 +833,30 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -841,6 +833,30 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_wgrad = None fc2_wgrad = None
if ctx.fc2_weight_requires_grad: if ctx.fc2_weight_requires_grad:
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_fc2_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output, mxfp8_fc2_grad_output_work = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
async_op=True,
quantizer=ctx.fc2_grad_output_quantizer,
)
# Synchronize with the main stream
mxfp8_fc2_grad_output_work.wait()
# Prepare input tensor # Prepare input tensor
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
...@@ -852,22 +868,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -852,22 +868,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True)
act_out = ctx.fc2_input_quantizer(act_out) act_out = ctx.fc2_input_quantizer(act_out)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
quantizer=ctx.fc2_grad_output_quantizer,
)
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase): if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True) grad_output.update_usage(columnwise_usage=True)
...@@ -1661,8 +1661,11 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1661,8 +1661,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs # customize quantizers based on each recipe & layer configs
if FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
# elif for other recipes (mxfp8, etc.) # elif for other recipes (mxfp8, etc.)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
...@@ -1772,15 +1775,14 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1772,15 +1775,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
) = quantizers ) = quantizers
# Get weight tensors # Get weight tensors
fc1_weight = self.fc1_weight fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None fc1_bias = self.fc1_bias if self.use_bias else None
fc2_weight = self.fc2_weight
fc2_bias = self.fc2_bias if self.use_bias else None fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8: if not self.fp8:
if isinstance(fc1_weight, Float8Tensor): if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.from_float8() fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor): if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.from_float8() fc2_weight = fc2_weight.dequantize()
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if ( not IS_HIP_EXTENSION if ( not IS_HIP_EXTENSION
...@@ -1866,31 +1868,26 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1866,31 +1868,26 @@ class LayerNormMLP(TransformerEngineBaseModule):
def _get_quantizers(self, fp8_output): def _get_quantizers(self, fp8_output):
( (
fc1_input_quantizer, fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer, fc1_output_quantizer,
fc1_grad_input_quantizer, fc1_grad_input_quantizer,
fc1_grad_weight_quantizer, fc1_grad_weight_quantizer,
fc1_grad_output_quantizer, fc1_grad_output_quantizer,
fc2_input_quantizer, fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer, fc2_output_quantizer,
fc2_grad_input_quantizer, fc2_grad_input_quantizer,
fc2_grad_weight_quantizer, fc2_grad_weight_quantizer,
fc2_grad_output_quantizer, fc2_grad_output_quantizer,
) = [None] * 12 ) = [None] * 10
fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers()
if self.fp8: if self.fp8:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = True fc1_input_quantizer.internal = True
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
fc2_input_quantizer.set_usage( fc2_input_quantizer.set_usage(
rowwise=True, rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)),
) )
fc1_input_quantizer.internal = True fc1_input_quantizer.internal = True
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
if fp8_output: if fp8_output:
fc2_output_quantizer = self.quantizers["scaling_fwd"][ fc2_output_quantizer = self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_OUTPUT tex.FP8FwdTensors.GEMM2_OUTPUT
...@@ -2007,6 +2004,36 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2007,6 +2004,36 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT2 tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
return [self.fc1_weight, self.fc2_weight]
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None, None]
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
return [fc1_weight_quantizer, fc2_weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_mlp."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
else:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
].all_gather_usage = True
def backward_dw(self): def backward_dw(self):
""" """
Execute the delayed weight gradient computation. Execute the delayed weight gradient computation.
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Linear API""" """Linear API"""
from typing import Callable, Dict, Optional, Tuple, Union from typing import Callable, Dict, Optional, Tuple, Union, List
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import warnings import warnings
...@@ -67,7 +67,7 @@ from ..tensor.quantized_tensor import ( ...@@ -67,7 +67,7 @@ from ..tensor.quantized_tensor import (
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled from ...debug.pytorch.utils import any_feature_enabled
...@@ -137,12 +137,6 @@ class _Linear(torch.autograd.Function): ...@@ -137,12 +137,6 @@ class _Linear(torch.autograd.Function):
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
) )
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_input_gather = (
fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer)
)
# Configure Userbuffers communication (comm+GEMM overlap) # Configure Userbuffers communication (comm+GEMM overlap)
ub_obj = None ub_obj = None
ub_type = None ub_type = None
...@@ -169,7 +163,7 @@ class _Linear(torch.autograd.Function): ...@@ -169,7 +163,7 @@ class _Linear(torch.autograd.Function):
if fp8 or debug: if fp8 or debug:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
if not force_hp_input_gather and not isinstance(inputmat, QuantizedTensorBase): if not isinstance(inputmat, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance( if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
...@@ -348,10 +342,11 @@ class _Linear(torch.autograd.Function): ...@@ -348,10 +342,11 @@ class _Linear(torch.autograd.Function):
# For sequence parallel in vanilla FP8, rowwise data is # For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data # to gather the input. For MXFP8, columnwise only data
# can be allgathered. # can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: if (
isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
or not ctx.backward_input_needs_gather
):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if force_hp_input_gather:
assert not isinstance(inputmat, QuantizedTensorBase)
saved_inputmat = inputmat saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
...@@ -397,7 +392,6 @@ class _Linear(torch.autograd.Function): ...@@ -397,7 +392,6 @@ class _Linear(torch.autograd.Function):
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_input_gather = force_hp_input_gather
ctx.input_quantizer = input_quantizer ctx.input_quantizer = input_quantizer
ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer
...@@ -558,7 +552,7 @@ class _Linear(torch.autograd.Function): ...@@ -558,7 +552,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work = None inputmat_total_work = None
if ctx.backward_input_needs_gather: if ctx.backward_input_needs_gather:
quantizer = None quantizer = None
if (ctx.fp8 or ctx.debug) and not ctx.force_hp_input_gather: if ctx.fp8 or ctx.debug:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -696,14 +690,23 @@ class _Linear(torch.autograd.Function): ...@@ -696,14 +690,23 @@ class _Linear(torch.autograd.Function):
# all-gather with wgrad GEMM. Also, we can't # all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we # convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered # can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking # for the dgrad GEMM. We work around by explicitly
# all-gather for column-scaled MXFP8 data. # overlapping the NCCL operation with the dgrad GEMM.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim( # Get the communication stream from the dgrad GEMM and set it as the current torch stream
grad_output_arg, dgrad_comm_stream = ub_obj_dgrad.get_communication_stream()
ctx.tp_group, with torch.cuda.stream(dgrad_comm_stream):
quantizer=ctx.grad_output_quantizer, # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
) # This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output, grad_output_work = gather_along_first_dim(
grad_output_arg,
ctx.tp_group,
async_op=True,
quantizer=ctx.grad_output_quantizer,
)
# Synchronize with the main stream
grad_output_work.wait()
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase): if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True) grad_output.update_usage(columnwise_usage=True)
...@@ -1218,6 +1221,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -1218,6 +1221,8 @@ class Linear(TransformerEngineBaseModule):
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling(): if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
# elif for other recipes (mxfp8, etc.) # elif for other recipes (mxfp8, etc.)
def reset_parameters(self, defer_init=False): def reset_parameters(self, defer_init=False):
...@@ -1294,20 +1299,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1294,20 +1299,7 @@ class Linear(TransformerEngineBaseModule):
) as inp: ) as inp:
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = self._get_weight_tensors()
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights) weight_tensor = noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
...@@ -1337,12 +1329,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -1337,12 +1329,6 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
) = quantizers ) = quantizers
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# recipe changed
if weight_quantizer is not None and isinstance(weight_tensor, QuantizedTensor):
weight_tensor._quantizer = weight_quantizer
if torch.is_grad_enabled(): if torch.is_grad_enabled():
linear_fn = _Linear.apply linear_fn = _Linear.apply
args = [] args = []
...@@ -1403,8 +1389,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1403,8 +1389,7 @@ class Linear(TransformerEngineBaseModule):
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = True input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] (weight_quantizer,) = self._get_weight_quantizers()
weight_quantizer.internal = True
if fp8_output: if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled(): if torch.is_grad_enabled():
...@@ -1478,3 +1463,47 @@ class Linear(TransformerEngineBaseModule): ...@@ -1478,3 +1463,47 @@ class Linear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
return unfused_weights
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
return [None]
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
return [weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + linear."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
# set compact for inp tensor X
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
else:
if self.sequence_parallel and self.parallel_mode == "row":
# set compact for grad_output tensor dY
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].all_gather_usage = True
...@@ -11,6 +11,7 @@ from .all_reduce import AllReduce ...@@ -11,6 +11,7 @@ from .all_reduce import AllReduce
from .basic_linear import BasicLinear from .basic_linear import BasicLinear
from .bias import Bias from .bias import Bias
from .identity import Identity from .identity import Identity
from .l2normalization import L2Normalization
from .layer_norm import LayerNorm from .layer_norm import LayerNorm
from .make_extra_output import MakeExtraOutput from .make_extra_output import MakeExtraOutput
from .quantize import Quantize from .quantize import Quantize
......
...@@ -96,12 +96,15 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -96,12 +96,15 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
if not x.is_contiguous(): if not x.is_contiguous():
x = x.contiguous() x = x.contiguous()
# Check if FP8 is enabled # Check if quantized compute is enabled
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() quantized_compute_enabled = FP8GlobalStateManager.is_fp8_enabled()
if fp8_enabled and next_op is not None and next_op.num_quantizers("forward") > 0: quantizer = None
if (
quantized_compute_enabled
and next_op is not None
and next_op.num_quantizers("forward") > 0
):
quantizer = next_op.get_quantizer("forward", 0) quantizer = next_op.get_quantizer("forward", 0)
else:
quantizer = None
# Launch kernel # Launch kernel
y = self._activation_forward_impl( y = self._activation_forward_impl(
...@@ -115,13 +118,13 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -115,13 +118,13 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Quantize input to FP8 before caching if needed # Quantize input to FP8 before caching if needed
if self.cache_quantized_input: if self.cache_quantized_input:
quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device) input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device)
quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
x = quantizer(x) x = input_quantizer(x)
# Save state for backward pass # Save state for backward pass
ctx.save_for_backward(x.detach()) ctx.save_for_backward(x.detach())
ctx.fp8_enabled = fp8_enabled ctx.quantized_compute_enabled = quantized_compute_enabled
ctx.dtype = dtype ctx.dtype = dtype
ctx.prev_op = prev_op ctx.prev_op = prev_op
...@@ -153,11 +156,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -153,11 +156,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
if not dy.is_contiguous(): if not dy.is_contiguous():
dy = dy.contiguous() dy = dy.contiguous()
# Check if quantized compute is enabled
quantizer = None
if (
ctx.quantized_compute_enabled
and ctx.prev_op is not None
and ctx.prev_op.num_quantizers("backward") > 0
):
quantizer = ctx.prev_op.get_quantizer("backward", 0)
# Launch kernel # Launch kernel
dx = self._activation_backward_impl( dx = self._activation_backward_impl(
reshape(dy, (-1, dy.size(-1))), reshape(dy, (-1, dy.size(-1))),
reshape(x, (-1, x.size(-1))), reshape(x, (-1, x.size(-1))),
None, quantizer,
) )
# Check grad input tensor # Check grad input tensor
......
...@@ -22,7 +22,7 @@ from ...distributed import ( ...@@ -22,7 +22,7 @@ from ...distributed import (
from ...fp8 import FP8GlobalStateManager from ...fp8 import FP8GlobalStateManager
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer, QuantizedTensor from ...tensor import Quantizer, QuantizedTensor
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
...@@ -324,12 +324,38 @@ class BasicLinear(BasicOperation): ...@@ -324,12 +324,38 @@ class BasicLinear(BasicOperation):
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
# Recipe-specific configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
if any(
not isinstance(q, Float8CurrentScalingQuantizer)
for q in (input_quantizer, weight_quantizer, grad_output_quantizer)
):
raise RuntimeError(
"FP8 current-scaling recipe is enabled, "
f"but input quantizer is {input_quantizer.__class__.__name__}, "
f"weight quantizer is {weight_quantizer.__class__.__name__}, "
f"grad output quantizer is {grad_output_quantizer.__class__.__name__}"
)
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Make sure weight tensor has correct quantizer # Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization # Note: Quantizer might have changed if quantization
# recipe changed # recipe changed
if isinstance(weight_quantizer, Float8Quantizer) and isinstance( if isinstance(
weight, Float8TensorBase weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
): ) and isinstance(weight, Float8TensorBase):
weight._quantizer = weight_quantizer weight._quantizer = weight_quantizer
@staticmethod @staticmethod
...@@ -349,7 +375,9 @@ class BasicLinear(BasicOperation): ...@@ -349,7 +375,9 @@ class BasicLinear(BasicOperation):
input_quantizer: Optional[Quantizer] = None, input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None,
output_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input_requires_grad: bool = True,
weight_requires_grad: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Functional API for forward pass """Functional API for forward pass
Parameters Parameters
...@@ -385,17 +413,25 @@ class BasicLinear(BasicOperation): ...@@ -385,17 +413,25 @@ class BasicLinear(BasicOperation):
Builder class for quantized weight tensor. Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional output_quantizer: Quantizer, optional
Builder class for quantized output tensor. Builder class for quantized output tensor.
input_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the input tensor is
required in the backward pass.
weight_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the weight tensor is
required in the backward pass.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output tensor Output tensor
torch.Tensor torch.Tensor, optional
Input tensor used in GEMM, possibly cast and reshaped from Input tensor, ready for use in backward pass. `None` is
provided input tensor returned if loss gradient w.r.t. the weight tensor is not
torch.Tensor required.
Weight tensor used in GEMM, possibly cast and reshaped from torch.Tensor, optional
provided weight tensor Weight tensor, ready for use in backward pass. `None` is
returned if loss gradient w.r.t. the input tensor is not
required.
""" """
...@@ -416,7 +452,7 @@ class BasicLinear(BasicOperation): ...@@ -416,7 +452,7 @@ class BasicLinear(BasicOperation):
if with_quantized_compute: if with_quantized_compute:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if with_x_all_gather: if with_x_all_gather:
input_quantizer.set_usage(columnwise=False) input_quantizer.set_usage(columnwise=False)
x, x_async = gather_along_first_dim( x, x_async = gather_along_first_dim(
...@@ -449,7 +485,7 @@ class BasicLinear(BasicOperation): ...@@ -449,7 +485,7 @@ class BasicLinear(BasicOperation):
if with_quantized_compute and not w_is_quantized: if with_quantized_compute and not w_is_quantized:
if weight_quantizer is None: if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor") raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(rowwise=True) weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w) w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized: elif not with_quantized_compute and w_is_quantized:
w = w.dequantize() w = w.dequantize()
...@@ -526,17 +562,25 @@ class BasicLinear(BasicOperation): ...@@ -526,17 +562,25 @@ class BasicLinear(BasicOperation):
else: else:
torch.distributed.all_reduce(y, group=tensor_parallel_group) torch.distributed.all_reduce(y, group=tensor_parallel_group)
# Detach input tensor if needed # Prepare weight tensor for backward pass
# Note: PyTorch autograd produces esoteric errors if we save if input_requires_grad:
# input tensor as context for backward pass. if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensor):
if x_local is input: w.update_usage(rowwise_usage=False, columnwise_usage=True)
x_local = x_local.detach() else:
w = None
# Configure input tensor for backward pass # Prepare input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensor): if weight_requires_grad:
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather): if x_local is input:
# FP8 does not support all-gather of transpose data # PyTorch autograd produces esoteric errors if we
x_local.update_usage(rowwise_usage=False, columnwise_usage=True) # cache input tensor directly.
x_local = x_local.detach()
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
x_local = None
return y, x_local, w return y, x_local, w
...@@ -892,7 +936,7 @@ class BasicLinear(BasicOperation): ...@@ -892,7 +936,7 @@ class BasicLinear(BasicOperation):
dtype = torch.get_autocast_dtype("cuda") dtype = torch.get_autocast_dtype("cuda")
# Linear forward # Linear forward
output, x_local, _ = BasicLinear._functional_forward( output, x_local, w = BasicLinear._functional_forward(
input=input_, input=input_,
weight=self.weight, weight=self.weight,
dtype=dtype, dtype=dtype,
...@@ -903,10 +947,12 @@ class BasicLinear(BasicOperation): ...@@ -903,10 +947,12 @@ class BasicLinear(BasicOperation):
input_quantizer=input_quantizer, input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer, weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer, output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
) )
# Save state for backward pass # Save state for backward pass
ctx.save_for_backward(x_local) ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizer = input_quantizer ctx.input_quantizer = input_quantizer
ctx.weight_quantizer = weight_quantizer ctx.weight_quantizer = weight_quantizer
...@@ -926,7 +972,7 @@ class BasicLinear(BasicOperation): ...@@ -926,7 +972,7 @@ class BasicLinear(BasicOperation):
) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local,) = ctx.saved_tensors (x_local, w) = ctx.saved_tensors
# wgrad fusion # wgrad fusion
accumulate_into_main_grad = self._accumulate_into_main_grad accumulate_into_main_grad = self._accumulate_into_main_grad
...@@ -946,7 +992,7 @@ class BasicLinear(BasicOperation): ...@@ -946,7 +992,7 @@ class BasicLinear(BasicOperation):
grad_input, grad_weight = BasicLinear._functional_backward( grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output, grad_output=grad_output,
input=x_local, input=x_local,
weight=self.weight, weight=w,
input_requires_grad=ctx.input_requires_grad, input_requires_grad=ctx.input_requires_grad,
weight_requires_grad=ctx.weight_requires_grad, weight_requires_grad=ctx.weight_requires_grad,
dtype=ctx.dtype, dtype=ctx.dtype,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusable operation for L2 Normalization."""
from __future__ import annotations
from typing import Optional
import torch
from ...tensor import QuantizedTensor
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from ...jit import (
l2normalization_fused,
l2normalization_fwd_fused,
l2normalization_backward_fused,
set_jit_fusion_options,
warmup_jit_l2normalization_all_dtypes,
)
class L2Normalization(BasicOperation):
r"""L2 Normalization
Applies L2 normalization over the last dimension of input tensors.
This is a parameter-free normalization that scales each vector to unit L2 norm.
.. math::
y = \frac{x}{\sqrt{\sum_{i} x_i^2 + \varepsilon}}
This operation is used e.g. for query-key normalization in attention mechanisms.
Parameters
----------
eps : float, default = 1e-6
A value added to the denominator for numerical stability
seq_length: int, default = None
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
functions are warmed up before training to ensure same kernels are used for forward
propagation and activation recompute phase.
micro_batch_size: int, default = None
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propagation and activation recompute phase.
"""
def __init__(
self,
*,
eps: float = 1e-6,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
) -> None:
super().__init__()
self.eps: float = eps
# JIT warmup for L2Normalization fused operations
if seq_length and micro_batch_size:
if torch.cuda.is_available():
set_jit_fusion_options()
# For L2Normalization, we don't know the hidden size until forward pass,
# but we can warm up with common sizes. For QK normalization, this will be
# the attention head dimension (hidden_size_per_attention_head), not the full
# model hidden dimension. Common head dimensions are 32, 64, 80, 96, 128, 256.
common_hidden_sizes = [32, 64, 80, 96, 128, 256]
for hidden_size in common_hidden_sizes:
warmup_jit_l2normalization_all_dtypes(hidden_size, seq_length, micro_batch_size)
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Use input directly - torch.compile can handle multi-dimensional tensors
x = input_
if isinstance(x, QuantizedTensor):
x = x.dequantize()
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Compute L2 normalization using fused implementation
# L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps)
if requires_grad:
# Training: use version that returns both output and intermediate values
y, rsqrt_norm = l2normalization_fwd_fused(x, self.eps)
else:
# Inference: use lightweight version that only returns output
y = l2normalization_fused(x, self.eps)
rsqrt_norm = None # Not needed for inference
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, rsqrt_norm)
ctx.has_prev_op = prev_op is not None
return y
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
x, rsqrt_norm = ctx.saved_tensors
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
# Compute L2 norm backward pass using fused implementation
dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps)
# Clear saved tensors if possible
if ctx.has_prev_op:
clear_tensor_data(x)
clear_tensor_data(rsqrt_norm)
# No parameters, so empty tuple for param grads
return dx, ()
...@@ -51,7 +51,7 @@ class BackwardLinearAdd(FusedOperation): ...@@ -51,7 +51,7 @@ class BackwardLinearAdd(FusedOperation):
linear_op_ctx = basic_op_ctxs[0] linear_op_ctx = basic_op_ctxs[0]
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local,) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion # wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
...@@ -72,7 +72,7 @@ class BackwardLinearAdd(FusedOperation): ...@@ -72,7 +72,7 @@ class BackwardLinearAdd(FusedOperation):
grad_input, grad_weight = BasicLinear._functional_backward( grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output, grad_output=grad_output,
input=x_local, input=x_local,
weight=linear_op.weight, weight=w,
input_requires_grad=linear_op_ctx.input_requires_grad, input_requires_grad=linear_op_ctx.input_requires_grad,
weight_requires_grad=linear_op_ctx.weight_requires_grad, weight_requires_grad=linear_op_ctx.weight_requires_grad,
dtype=grad_input.dtype, dtype=grad_input.dtype,
......
...@@ -82,6 +82,10 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -82,6 +82,10 @@ class ForwardLinearBiasActivation(FusedOperation):
else: else:
raise NotImplementedError("Activations are not yet supported") raise NotImplementedError("Activations are not yet supported")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None input_quantizer = None
...@@ -106,7 +110,7 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -106,7 +110,7 @@ class ForwardLinearBiasActivation(FusedOperation):
dtype = torch.get_autocast_dtype("cuda") dtype = torch.get_autocast_dtype("cuda")
# Linear forward # Linear forward
output, x_local, _ = BasicLinear._functional_forward( output, x_local, w = BasicLinear._functional_forward(
input=input_, input=input_,
weight=linear_op.weight, weight=linear_op.weight,
bias=bias, bias=bias,
...@@ -118,18 +122,20 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -118,18 +122,20 @@ class ForwardLinearBiasActivation(FusedOperation):
input_quantizer=input_quantizer, input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer, weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer, output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
) )
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -76,6 +76,10 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -76,6 +76,10 @@ class ForwardLinearBiasAdd(FusedOperation):
if basic_op_kwargs[idx]: if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments") raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None input_quantizer = None
...@@ -98,7 +102,7 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -98,7 +102,7 @@ class ForwardLinearBiasAdd(FusedOperation):
# Linear forward # Linear forward
output = basic_op_extra_inputs[self._op_idxs["add"]][0] output = basic_op_extra_inputs[self._op_idxs["add"]][0]
output, x_local, _ = BasicLinear._functional_forward( output, x_local, w = BasicLinear._functional_forward(
input=input_, input=input_,
weight=linear_op.weight, weight=linear_op.weight,
bias=bias, bias=bias,
...@@ -111,18 +115,20 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -111,18 +115,20 @@ class ForwardLinearBiasAdd(FusedOperation):
input_quantizer=input_quantizer, input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer, weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer, output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
) )
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -407,16 +407,26 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -407,16 +407,26 @@ class UserbuffersBackwardLinear(FusedOperation):
# Initialize grad output # Initialize grad output
if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer): if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output # UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, MXFP8 does not # all-gather with wgrad GEMM. Also, we can't
# allow reusing the grad output that was gathered for # convert row-scaled MXFP8 to column-scaled, so we
# the dgrad GEMM. We work around with blocking # can't reuse the grad output that was gathered
# all-gather for column-scaled MXFP8 data. # for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
grad_output_quantizer.set_usage(rowwise=False, columnwise=True) grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
dy, _ = gather_along_first_dim( # Get the communication stream from the dgrad GEMM and set it as the current torch stream
grad_output, dgrad_comm_stream = ub_comm_dgrad.get_communication_stream()
tensor_parallel_group, with torch.cuda.stream(dgrad_comm_stream):
quantizer=grad_output_quantizer, # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
) # This ensures that we don't start until all communication for the dgrad GEMM is complete
dy, dy_work = gather_along_first_dim(
dy_local,
tensor_parallel_group,
async_op=True,
quantizer=grad_output_quantizer,
)
# Synchronize with the main stream
dy_work.wait()
if tensor_parallel_mode == "column": if tensor_parallel_mode == "column":
dy = dy_local dy = dy_local
if dy is None: if dy is None:
...@@ -500,7 +510,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -500,7 +510,7 @@ class UserbuffersBackwardLinear(FusedOperation):
bias_op = self.basic_ops[idx] bias_op = self.basic_ops[idx]
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local,) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion # wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
...@@ -520,7 +530,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -520,7 +530,7 @@ class UserbuffersBackwardLinear(FusedOperation):
retval = UserbuffersBackwardLinear._functional_backward( retval = UserbuffersBackwardLinear._functional_backward(
grad_output=grad_output, grad_output=grad_output,
input=x_local, input=x_local,
weight=linear_op.weight, weight=w,
weight_requires_grad=linear_op_ctx.weight_requires_grad, weight_requires_grad=linear_op_ctx.weight_requires_grad,
bias_requires_grad=(bias_op is not None), bias_requires_grad=(bias_op is not None),
dtype=linear_op_ctx.dtype, dtype=linear_op_ctx.dtype,
......
...@@ -21,7 +21,7 @@ from ...module.base import ( ...@@ -21,7 +21,7 @@ from ...module.base import (
_2X_ACC_FPROP, _2X_ACC_FPROP,
) )
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype from ...utils import canonicalize_device, canonicalize_dtype
from ..basic import BasicLinear, Bias, ReduceScatter from ..basic import BasicLinear, Bias, ReduceScatter
...@@ -98,6 +98,8 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -98,6 +98,8 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer: Optional[Quantizer] = None, input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None,
output_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None,
input_requires_grad: bool = True,
weight_requires_grad: bool = True,
ub_comm_name: str, ub_comm_name: str,
) -> tuple[torch.Tensor, dict]: ) -> tuple[torch.Tensor, dict]:
"""Functional API for forward pass """Functional API for forward pass
...@@ -131,6 +133,12 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -131,6 +133,12 @@ class UserbuffersForwardLinear(FusedOperation):
Builder class for quantized weight tensor. Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional output_quantizer: Quantizer, optional
Builder class for quantized output tensor. Builder class for quantized output tensor.
input_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the input tensor is
required in the backward pass.
weight_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the weight tensor is
required in the backward pass.
ub_comm_name: str ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators used to access the corresponding Userbuffers communicators
...@@ -141,8 +149,9 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -141,8 +149,9 @@ class UserbuffersForwardLinear(FusedOperation):
torch.Tensor torch.Tensor
Output tensor Output tensor
dict dict
Extra output tensors. "input" is the input tensor, Extra output tensors. "input" is the input tensor and
possibly cast and reshaped from the provided input tensor. "weight" is the weight tensor, both ready for use in the
backward pass.
""" """
...@@ -198,8 +207,10 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -198,8 +207,10 @@ class UserbuffersForwardLinear(FusedOperation):
if with_ub_all_gather: if with_ub_all_gather:
if input_quantizer is not None: if input_quantizer is not None:
if not isinstance(x_local, QuantizedTensorBase): if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=True) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if isinstance(input_quantizer, Float8Quantizer): if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
input_quantizer.set_usage(columnwise=False) input_quantizer.set_usage(columnwise=False)
x_local = input_quantizer(x_local) x_local = input_quantizer(x_local)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -212,7 +223,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -212,7 +223,7 @@ class UserbuffersForwardLinear(FusedOperation):
else: else:
if with_quantized_compute: if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase): if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=True) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
x_local = input_quantizer(x_local) x_local = input_quantizer(x_local)
else: else:
if isinstance(x_local, QuantizedTensorBase): if isinstance(x_local, QuantizedTensorBase):
...@@ -225,7 +236,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -225,7 +236,7 @@ class UserbuffersForwardLinear(FusedOperation):
w = weight w = weight
w_is_quantized = isinstance(w, QuantizedTensorBase) w_is_quantized = isinstance(w, QuantizedTensorBase)
if with_quantized_compute and not w_is_quantized: if with_quantized_compute and not w_is_quantized:
weight_quantizer.set_usage(rowwise=True) weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w) w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized: elif not with_quantized_compute and w_is_quantized:
w = w.dequantize() w = w.dequantize()
...@@ -258,17 +269,25 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -258,17 +269,25 @@ class UserbuffersForwardLinear(FusedOperation):
else: else:
y_local = gemm_output y_local = gemm_output
# Detach input tensor if needed # Prepare weight tensor for backward pass
# Note: PyTorch autograd produces esoteric errors if we save if input_requires_grad:
# input tensor as context for backward pass. if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensorBase):
if x_local is input: w.update_usage(rowwise_usage=False, columnwise_usage=True)
x_local = x_local.detach() else:
w = None
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensorBase): # Prepare input tensor for backward pass
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather): if weight_requires_grad:
# FP8 does not support all-gather of transpose data if x_local is input:
x_local.update_usage(rowwise_usage=False, columnwise_usage=True) # PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach()
if with_quantized_compute and isinstance(x_local, QuantizedTensorBase):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
x_local = None
# Return cast tensors # Return cast tensors
extra_outputs = {"input": x_local, "weight": w} extra_outputs = {"input": x_local, "weight": w}
...@@ -298,6 +317,10 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -298,6 +317,10 @@ class UserbuffersForwardLinear(FusedOperation):
if basic_op_kwargs[idx]: if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments") raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# Quantization metadata # Quantization metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None input_quantizer = None
...@@ -306,8 +329,10 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -306,8 +329,10 @@ class UserbuffersForwardLinear(FusedOperation):
grad_input_quantizer = None grad_input_quantizer = None
if with_quantized_compute: if with_quantized_compute:
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if not recipe.delayed() and not recipe.mxfp8(): if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())):
raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe") raise RuntimeError(
f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})"
)
input_quantizer = linear_op.get_quantizer("forward", 0) input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1) weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_output_quantizer = linear_op.get_quantizer("backward", 0)
...@@ -338,12 +363,15 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -338,12 +363,15 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer=input_quantizer, input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer, weight_quantizer=weight_quantizer,
output_quantizer=None, # Not supported output_quantizer=None, # Not supported
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
ub_comm_name=linear_op._userbuffers_options["comm_name"], ub_comm_name=linear_op._userbuffers_options["comm_name"],
) )
x_local = extra_outputs["input"] x_local = extra_outputs["input"]
w = extra_outputs["weight"]
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local) linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.weight_quantizer = weight_quantizer
...@@ -351,8 +379,8 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -351,8 +379,8 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -61,13 +61,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -61,13 +61,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
def forward( def forward(
func_ctx: Optional[torch.autograd.function.FunctionCtx], func_ctx: Optional[torch.autograd.function.FunctionCtx],
input_: torch.Tensor, input_: torch.Tensor,
forward_ops: list[tuple[FusibleOperation, list[int]]], fuser: OperationFuser,
backward_ops: list[tuple[FusibleOperation, list[int]]],
basic_ops: list[BasicOperation],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
is_grad_enabled: bool, is_grad_enabled: bool,
num_params: int,
num_extra_inputs: int,
*params_and_extra_inputs: torch.nn.Parameter, *params_and_extra_inputs: torch.nn.Parameter,
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass """Forward pass
...@@ -78,20 +74,12 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -78,20 +74,12 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
Context for PyTorch autograd function Context for PyTorch autograd function
input_: torch.Tensor input_: torch.Tensor
Input to first operation in pipeline Input to first operation in pipeline
forward_ops: list of tuple fuser: OperationFuser
Forward pass operations and the indices of the Container for the pipeline of operations to run
corresponding basic operations. The order should match
basic_ops.
backward_ops: list of tuple
Backward pass operations and the indices of the
corresponding basic operations. The order should be the
reverse of basic_ops.
basic_ops: list of BasicOperation
Basic operations
basic_op_kwargs: list of dict basic_op_kwargs: list of dict
Keyword arguments to BasicOperation Keyword arguments to BasicOperation
num_params: int is_grad_enabled: bool
Number of parameter tensors to include in autograd graph. Should context be saved for backward
*params_and_extra_inputs: torch.Tensor *params_and_extra_inputs: torch.Tensor
Other tensor inputs to include in autograd graph. Consists Other tensor inputs to include in autograd graph. Consists
of parameter tensors, followed by extra operation inputs. of parameter tensors, followed by extra operation inputs.
...@@ -106,26 +94,20 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -106,26 +94,20 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
""" """
# Operation autograd contexts # Operation autograd contexts
basic_op_ctxs = [OperationContext() for _ in range(len(basic_ops))] basic_op_ctxs = [OperationContext() for _ in range(fuser._num_basic_ops)]
# Unflatten list of parameters and extra tensor inputs # Unflatten list of parameters and extra tensor inputs
if len(params_and_extra_inputs) != num_params + num_extra_inputs: extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :]
raise ValueError(
f"Expected {num_params + num_extra_inputs} extra tensor arguments "
f"({num_params} parameters, {num_extra_inputs} extra inputs), "
f"but got {len(params_and_extra_inputs)}"
)
_, extra_inputs = _split_tuple(params_and_extra_inputs, num_params)
basic_op_extra_inputs = [] basic_op_extra_inputs = []
for op in basic_ops: for op in fuser._basic_ops:
xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
basic_op_extra_inputs.append(xs) basic_op_extra_inputs.append(xs)
# Apply forward ops # Apply forward ops
x = input_ x = input_
requires_grad = is_grad_enabled and x.requires_grad requires_grad = is_grad_enabled and x.requires_grad
extra_outputs = [None for _ in range(len(basic_ops))] extra_outputs = [None] * fuser._num_basic_ops
for op, basic_op_idxs in forward_ops: for op, basic_op_idxs in fuser._forward_ops:
# Check if backward op is required # Check if backward op is required
if is_grad_enabled: if is_grad_enabled:
...@@ -143,9 +125,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -143,9 +125,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Forward op # Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] prev_ops = [fuser._basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
next_ops = [ next_ops = [
basic_ops[idx + 1] if (idx < len(basic_ops) - 1) else None for idx in basic_op_idxs fuser._basic_ops[idx + 1] if (idx < fuser._num_basic_ops - 1) else None
for idx in basic_op_idxs
] ]
x, fused_op_extra_outputs = op.fuser_forward( x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs], [basic_op_ctxs[idx] for idx in basic_op_idxs],
...@@ -165,7 +148,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -165,7 +148,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
extra_outputs_flat = [] extra_outputs_flat = []
for idx, ys in enumerate(extra_outputs): for idx, ys in enumerate(extra_outputs):
ys = list(ys) ys = list(ys)
num_extra_outputs = basic_ops[idx].num_extra_outputs num_extra_outputs = fuser._basic_ops[idx].num_extra_outputs
if len(ys) != num_extra_outputs: if len(ys) != num_extra_outputs:
raise RuntimeError( raise RuntimeError(
f"Expected op {idx} to generate " f"Expected op {idx} to generate "
...@@ -189,11 +172,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -189,11 +172,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.save_for_backward(*to_save) func_ctx.save_for_backward(*to_save)
# Other context # Other context
func_ctx.backward_ops = backward_ops func_ctx.backward_ops = fuser._backward_ops
func_ctx.basic_ops = basic_ops func_ctx.basic_ops = fuser._basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.basic_op_num_params = [sum(1 for _ in op.parameters()) for op in basic_ops] func_ctx.basic_op_num_params = fuser._num_list_basic_op_params
func_ctx.num_extra_inputs = num_extra_inputs func_ctx.num_extra_inputs = fuser._num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
...@@ -216,8 +199,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -216,8 +199,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_op_ctxs = func_ctx.basic_op_ctxs basic_op_ctxs = func_ctx.basic_op_ctxs
# Unflatten list of saved tensors # Unflatten list of saved tensors
saved_tensors = func_ctx.saved_tensors
for ctx in basic_op_ctxs: for ctx in basic_op_ctxs:
ctx.saved_tensors = func_ctx.saved_tensors[slice(*ctx._saved_tensors_range)] ctx.saved_tensors = saved_tensors[slice(*ctx._saved_tensors_range)]
ctx._saved_tensors_range = None ctx._saved_tensors_range = None
# Unflatten list of extra tensor output grads # Unflatten list of extra tensor output grads
...@@ -292,13 +276,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -292,13 +276,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
return ( return (
dx, # input_ dx, # input_
None, # forward_ops None, # fuser
None, # backward_ops
None, # basic_ops
None, # basic_op_kwargs None, # basic_op_kwargs
None, # is_grad_enabled None, # is_grad_enabled
None, # num_params
None, # num_extra_inputs
*grad_params_flat, *grad_params_flat,
*grad_extra_inputs_flat, *grad_extra_inputs_flat,
) )
...@@ -345,6 +325,10 @@ class OperationFuser: ...@@ -345,6 +325,10 @@ class OperationFuser:
if fuse_ops: if fuse_ops:
self.fuse_ops() self.fuse_ops()
# Flatten list of parameters
self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()]
self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops]
@classmethod @classmethod
def _fuse_forward_ops( def _fuse_forward_ops(
cls, cls,
...@@ -377,6 +361,11 @@ class OperationFuser: ...@@ -377,6 +361,11 @@ class OperationFuser:
*extra_inputs: torch.Tensor, *extra_inputs: torch.Tensor,
basic_op_kwargs: Optional[list[dict[str, Any]]] = None, basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
# Verify extra input count
if len(extra_inputs) != self._num_extra_inputs:
raise ValueError(
f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}"
)
# Initialization before forward pass # Initialization before forward pass
for op in self._basic_ops: for op in self._basic_ops:
...@@ -384,10 +373,7 @@ class OperationFuser: ...@@ -384,10 +373,7 @@ class OperationFuser:
# Canonicalize op kwargs # Canonicalize op kwargs
if basic_op_kwargs is None: if basic_op_kwargs is None:
basic_op_kwargs = [{} for _ in range(len(self._basic_ops))] basic_op_kwargs = [{}] * self._num_basic_ops
# Flatten list of parameters
params = [param for op in self._basic_ops for param in op.parameters()]
# Fuser forward pass # Fuser forward pass
is_grad_enabled = torch.is_grad_enabled() is_grad_enabled = torch.is_grad_enabled()
...@@ -399,14 +385,10 @@ class OperationFuser: ...@@ -399,14 +385,10 @@ class OperationFuser:
args = [None] args = [None]
args += ( args += (
input, input,
self._forward_ops, self,
self._backward_ops,
self._basic_ops,
basic_op_kwargs, basic_op_kwargs,
is_grad_enabled, is_grad_enabled,
len(params), *self._basic_op_params,
self._num_extra_inputs,
*params,
*extra_inputs, *extra_inputs,
) )
return forward_func(*args) return forward_func(*args)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "pip", "torch>=2.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
...@@ -31,7 +31,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_ ...@@ -31,7 +31,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from build_tools.build_ext import get_build_ext from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers from build_tools.utils import copy_common_headers
from build_tools.te_version import te_version from build_tools.te_version import te_version
from build_tools.pytorch import setup_pytorch_extension from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements
os.environ["NVTE_PROJECT_BUILDING"] = "1" os.environ["NVTE_PROJECT_BUILDING"] = "1"
...@@ -55,18 +55,8 @@ if __name__ == "__main__": ...@@ -55,18 +55,8 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib", description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
setup_requires=[ install_requires=install_requirements(),
"torch>=2.1", tests_require=test_requirements(),
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
],
install_requires=["torch>=2.1"],
tests_require=["numpy", "torchvision"],
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir) shutil.rmtree(common_headers_dir)
......
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from ..quantized_tensor import QuantizedTensorBase from ..quantized_tensor import QuantizedTensorBase
...@@ -37,6 +38,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -37,6 +38,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
_rowwise_scale_inv: Optional[torch.Tensor] _rowwise_scale_inv: Optional[torch.Tensor]
_columnwise_scale_inv: Optional[torch.Tensor] _columnwise_scale_inv: Optional[torch.Tensor]
_is_2D_scaled: bool _is_2D_scaled: bool
_data_format: Float8BlockScaleTensorFormat
def __new__( def __new__(
cls, cls,
...@@ -48,6 +50,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -48,6 +50,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
quantizer: Quantizer, quantizer: Quantizer,
is_2D_scaled: bool, is_2D_scaled: bool,
data_format: Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY,
**kwargs, **kwargs,
): ):
instance = super().__new__(cls, *args, **kwargs) instance = super().__new__(cls, *args, **kwargs)
...@@ -58,6 +61,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -58,6 +61,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
instance._rowwise_scale_inv = rowwise_scale_inv instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv
instance._is_2D_scaled = is_2D_scaled instance._is_2D_scaled = is_2D_scaled
instance._data_format = data_format
return instance return instance
...@@ -82,8 +86,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -82,8 +86,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
"fp8_dtype": self._fp8_dtype, "fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer, "quantizer": self._quantizer,
"is_2D_scaled": self._is_2D_scaled, "is_2D_scaled": self._is_2D_scaled,
"data_format": self._data_format,
} }
def _is_gemm_ready_format(self) -> bool:
"""Whether data is in GEMM_READY format"""
return self._data_format == Float8BlockScaleTensorFormat.GEMM_READY
def prepare_for_saving( def prepare_for_saving(
self, self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
...@@ -136,34 +145,69 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -136,34 +145,69 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
q_K = q.shape[-1] q_K = q.shape[-1]
for i in range(len(q.shape) - 1): for i in range(len(q.shape) - 1):
q_M *= q.shape[i] q_M *= q.shape[i]
inner_q_dimension_tiled = True
if self._is_gemm_ready_format():
scales_tiled_dim, scales_untiled_dim = scale_inv.shape
inner_scale_dimension_tiled = False
scales_are_compact = False
else:
scales_untiled_dim, scales_tiled_dim = scale_inv.shape
inner_scale_dimension_tiled = True
scales_are_compact = True
else: else:
assert self._columnwise_data is not None, "No data to dequantize" assert self._columnwise_data is not None, "No data to dequantize"
q = self._columnwise_data q = self._columnwise_data
scale_inv = self._columnwise_scale_inv scale_inv = self._columnwise_scale_inv
transpose_output = True scales_tiled_dim, scales_untiled_dim = scale_inv.shape
if len(q.shape) >= 1: inner_scale_dimension_tiled = False
q_M = q.shape[0] if self._is_gemm_ready_format():
for i in range(1, len(q.shape)): inner_q_dimension_tiled = True
q_K *= q.shape[i] transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
scales_are_compact = False
else:
inner_q_dimension_tiled = False
transpose_output = False
if len(q.shape) >= 1:
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
scales_are_compact = True
orig_shape = q.shape orig_shape = q.shape
q = q.reshape(q_M, q_K) q = q.reshape(q_M, q_K)
k_tiles, scale_m = scale_inv.shape if inner_q_dimension_tiled:
if q_K % block_len != 0: if q_K % block_len != 0:
k_pad_amount = (block_len - (q_K % block_len)) % block_len k_pad_amount = (block_len - (q_K % block_len)) % block_len
q = torch.nn.functional.pad( q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, 0), mode="constant", value=0 q, (0, k_pad_amount, 0, 0), mode="constant", value=0
).contiguous() ).contiguous()
_, padded_K = q.shape padded_M, padded_K = q.shape
q_tiled = q.reshape(q_M, k_tiles, block_len) q_tiled = q.reshape(q_M, scales_tiled_dim, block_len)
if scale_m > q_M: else:
# scale_m is 4 element aligned. if q_M % block_len != 0:
m_pad_amount = (block_len - (q_M % block_len)) % block_len
q = torch.nn.functional.pad(
q, (0, 0, 0, m_pad_amount), mode="constant", value=0
).contiguous()
padded_M, padded_K = q.shape
q_tiled = q.reshape(scales_tiled_dim, block_len, q_K)
if not scales_are_compact and scales_untiled_dim > q_M:
# untiled scale dimension is 4 element aligned.
scale_inv = scale_inv[:, :q_M].contiguous() scale_inv = scale_inv[:, :q_M].contiguous()
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, k_tiles, 1) if scales_are_compact and inner_scale_dimension_tiled:
dq_scale = scale_inv.contiguous().reshape(q_M, scales_tiled_dim, 1)
elif scales_are_compact and not inner_scale_dimension_tiled:
dq_scale = scale_inv.contiguous().reshape(scales_tiled_dim, 1, q_K)
else:
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, scales_tiled_dim, 1)
torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype]
result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale
if padded_K != q_K: if padded_M != q_M or padded_K != q_K:
result = result.reshape(q_M, padded_K)[:, :q_K] result = result.reshape(padded_M, padded_K)[:q_M, :q_K]
result = result.to(dtype) result = result.to(dtype)
if len(orig_shape) == 0: if len(orig_shape) == 0:
result = result.reshape([]) result = result.reshape([])
...@@ -182,6 +226,12 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -182,6 +226,12 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
if not self._is_2D_scaled: if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype) return self._dequantize_vectorwise(dtype=dtype)
if not self._is_gemm_ready_format():
raise NotImplementedError(
"Dequantize is only supported with GEMM_READY data format, "
f"but found _data_format={self._data_format}"
)
def format_scale_as_logical_shape(q_K, scales, block_len): def format_scale_as_logical_shape(q_K, scales, block_len):
# The GEMM for 2D blocks required padding in the scales. # The GEMM for 2D blocks required padding in the scales.
derived_scale_k_shape = math.ceil(q_K / block_len) derived_scale_k_shape = math.ceil(q_K / block_len)
...@@ -247,6 +297,8 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -247,6 +297,8 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
if self._rowwise_data is not None: if self._rowwise_data is not None:
return self._rowwise_data.size(*args, **kwargs) return self._rowwise_data.size(*args, **kwargs)
dims = list(self._columnwise_data.size(*args, **kwargs)) dims = list(self._columnwise_data.size(*args, **kwargs))
if not self._is_gemm_ready_format(): # compact format
return torch.Size(dims)
reordered = [] reordered = []
for i in range(1, len(dims)): for i in range(1, len(dims)):
reordered.append(dims[i]) reordered.append(dims[i])
...@@ -285,6 +337,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -285,6 +337,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
w = min(self._columnwise_scale_inv.shape[1], columnwise_scale_inv.shape[1]) w = min(self._columnwise_scale_inv.shape[1], columnwise_scale_inv.shape[1])
self._columnwise_scale_inv[0:h, 0:w].copy_(columnwise_scale_inv[0:h, 0:w]) self._columnwise_scale_inv[0:h, 0:w].copy_(columnwise_scale_inv[0:h, 0:w])
def _transpose_columnwise_data(self):
"""Plainly transpose the columnwise data and scale inv."""
if self._columnwise_data is not None:
self._columnwise_data = tex.fp8_transpose(
self._columnwise_data, self._fp8_dtype, out=None
)
def __repr__(self): def __repr__(self):
if self._rowwise_data is not None: if self._rowwise_data is not None:
data = self.dequantize() data = self.dequantize()
......
...@@ -4,13 +4,15 @@ ...@@ -4,13 +4,15 @@
"""Tensor class with FP8 data quantized with NxN tiles""" """Tensor class with FP8 data quantized with NxN tiles"""
from __future__ import annotations from __future__ import annotations
from typing import Optional, Tuple, Iterable from typing import Optional, Tuple, Iterable, Union
import math import math
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
import os import os
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..utils import devices_match, round_up_to_nearest_multiple from ..utils import devices_match, round_up_to_nearest_multiple
...@@ -33,6 +35,8 @@ class Float8BlockQuantizer(Quantizer): ...@@ -33,6 +35,8 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float amax_epsilon: float
force_pow_2_scales: bool force_pow_2_scales: bool
block_scaling_dim: int block_scaling_dim: int
# Whether to produce tensors that will be used in all-gather
all_gather_usage: bool
def __init__( def __init__(
self, self,
...@@ -43,6 +47,7 @@ class Float8BlockQuantizer(Quantizer): ...@@ -43,6 +47,7 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon: float = 0.0, amax_epsilon: float = 0.0,
force_pow_2_scales: bool = True, force_pow_2_scales: bool = True,
block_scaling_dim: int = 2, block_scaling_dim: int = 2,
all_gather_usage: bool = False,
) -> None: ) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise) super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = tex.DType.kInt8 if int8_simulation_fp8 else fp8_dtype self.dtype = tex.DType.kInt8 if int8_simulation_fp8 else fp8_dtype
...@@ -50,6 +55,7 @@ class Float8BlockQuantizer(Quantizer): ...@@ -50,6 +55,7 @@ class Float8BlockQuantizer(Quantizer):
self.force_pow_2_scales = force_pow_2_scales self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon self.amax_epsilon = amax_epsilon
self.block_scaling_dim = block_scaling_dim self.block_scaling_dim = block_scaling_dim
self.all_gather_usage = all_gather_usage
def update_quantized( def update_quantized(
self, self,
...@@ -126,22 +132,36 @@ class Float8BlockQuantizer(Quantizer): ...@@ -126,22 +132,36 @@ class Float8BlockQuantizer(Quantizer):
M *= shape[i] M *= shape[i]
if len(shape) > 0: if len(shape) > 0:
K = shape[-1] K = shape[-1]
# 2D 128x128 quantization block scaling
# CuBLAS requries 128x128 scaling factor to be padded
# currently rowwise and columnwise format option doesn't apply to 2D scaling
if self.block_scaling_dim == 2: if self.block_scaling_dim == 2:
if columnwise: if columnwise:
outer = math.ceil(K / self.block_len) outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4)
return (outer, inner) return (outer, inner)
# rowwise
outer = math.ceil(M / self.block_len) outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4) inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4)
return (outer, inner) return (outer, inner)
# 1D 1x128 quantization block scaling
# CuBLAS requries 1x128 scaling factor to be padded and transposed
assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported" assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported"
if columnwise: if columnwise:
columnwise_compact = self.all_gather_usage
outer = math.ceil(M / self.block_len) outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(K, 4) inner = round_up_to_nearest_multiple(K, 4) if not columnwise_compact else K
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS
# for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner]
# so no need to swap inner outer here
return (outer, inner) return (outer, inner)
# rowwise
rowwise_compact = self.all_gather_usage
outer = math.ceil(K / self.block_len) outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(M, 4) inner = round_up_to_nearest_multiple(M, 4) if not rowwise_compact else M
return (outer, inner) # GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need
# for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here
return (outer, inner) if not rowwise_compact else (inner, outer)
def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of a tensor after columnwise permutation. """Calculate the shape of a tensor after columnwise permutation.
...@@ -163,15 +183,25 @@ class Float8BlockQuantizer(Quantizer): ...@@ -163,15 +183,25 @@ class Float8BlockQuantizer(Quantizer):
""" """
if len(shape) == 0: if len(shape) == 0:
return tuple() return tuple()
# currently columnwise format option only applies to 1D quantizer
# for 2D scaling, columnwise format should always be GEMM_READY_DATA_AND_SCALES
# since currently 2D scaling only applies to module weights
if self.block_scaling_dim == 1 and self.all_gather_usage:
return shape
colwise_shape = [shape[-1]] colwise_shape = [shape[-1]]
for i in range(len(shape) - 1): for i in range(len(shape) - 1):
colwise_shape.append(shape[i]) colwise_shape.append(shape[i])
return tuple(colwise_shape) return tuple(colwise_shape)
# TODO(kwyss): With FP8 gather support, we need to implement a def is_quantizable(self, inp: torch.Tensor) -> bool:
# shape/layout/swizzle check to know whether FP8 gather works """Returns whether or not given inp can be quantized"""
# cleanly by stacking data without aliasing tiles and whether if inp.ndim < 2:
# the scales also stack on the proper dimensions. return False
if inp.shape[-1] % self.block_len != 0:
return False
if math.prod(inp.shape[:-1]) % self.block_len != 0:
return False
return True
def make_empty( def make_empty(
self, self,
...@@ -185,6 +215,12 @@ class Float8BlockQuantizer(Quantizer): ...@@ -185,6 +215,12 @@ class Float8BlockQuantizer(Quantizer):
if device is None: if device is None:
device = torch.device("cuda") device = torch.device("cuda")
data_format = (
tex.Float8BlockScaleTensorFormat.COMPACT
if self.all_gather_usage
else tex.Float8BlockScaleTensorFormat.GEMM_READY
)
# Allocate FP8 data # Allocate FP8 data
data = None data = None
scale_inv = None scale_inv = None
...@@ -222,6 +258,7 @@ class Float8BlockQuantizer(Quantizer): ...@@ -222,6 +258,7 @@ class Float8BlockQuantizer(Quantizer):
columnwise_scale_inv=columnwise_scale_inv, columnwise_scale_inv=columnwise_scale_inv,
quantizer=self, quantizer=self,
is_2D_scaled=self.block_scaling_dim == 2, is_2D_scaled=self.block_scaling_dim == 2,
data_format=data_format,
requires_grad=requires_grad, requires_grad=requires_grad,
) )
...@@ -230,6 +267,9 @@ class Float8BlockQuantizer(Quantizer): ...@@ -230,6 +267,9 @@ class Float8BlockQuantizer(Quantizer):
# where state from an estimator influences distribution parameters. # where state from an estimator influences distribution parameters.
pass pass
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return Float8BlockScaling
class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
...@@ -260,7 +300,8 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -260,7 +300,8 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
return ( return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled}," f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize(dtype=self.dtype)})" f" data={self.dequantize(dtype=self.dtype)}),"
f" data_format={self._data_format}"
) )
def _get_quantizer(self) -> Quantizer: def _get_quantizer(self) -> Quantizer:
...@@ -393,6 +434,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -393,6 +434,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dtype: torch.dtype, dtype: torch.dtype,
quantizer: Quantizer, quantizer: Quantizer,
is_2D_scaled: bool, is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat,
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
"""Build Float8BlockwiseQTensor, for use in __reduce__ """Build Float8BlockwiseQTensor, for use in __reduce__
...@@ -410,6 +452,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -410,6 +452,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dtype=dtype, dtype=dtype,
quantizer=quantizer, quantizer=quantizer,
is_2D_scaled=is_2D_scaled, is_2D_scaled=is_2D_scaled,
data_format=data_format,
) )
def __reduce_ex__(self, protocol: int) -> tuple: def __reduce_ex__(self, protocol: int) -> tuple:
...@@ -426,6 +469,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -426,6 +469,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
self.dtype, self.dtype,
self._quantizer, self._quantizer,
self._is_2D_scaled, self._is_2D_scaled,
self._data_format,
), ),
) )
...@@ -451,6 +495,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -451,6 +495,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dst._fp8_dtype = src._fp8_dtype dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv
dst._data_format = src._data_format
# Check that tensor dimensions match # Check that tensor dimensions match
if ( if (
...@@ -498,6 +543,13 @@ class _ViewFunc(torch.autograd.Function): ...@@ -498,6 +543,13 @@ class _ViewFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided # Return input tensor if shape is not provided
ctx.shape = tensor.shape ctx.shape = tensor.shape
if shape is None: if shape is None:
...@@ -566,6 +618,14 @@ class _ViewFunc(torch.autograd.Function): ...@@ -566,6 +618,14 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor): if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"View is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_data = ( new_data = (
grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None
) )
...@@ -605,6 +665,13 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -605,6 +665,13 @@ class _ReshapeFunc(torch.autograd.Function):
) -> Float8BlockwiseQTensor: ) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Check for invalid configurations
if not tensor._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={tensor._data_format}"
)
# Return input tensor if shape is not provided # Return input tensor if shape is not provided
ctx.shape = tensor.shape ctx.shape = tensor.shape
if shape is None: if shape is None:
...@@ -672,6 +739,14 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -672,6 +739,14 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor): if isinstance(grad, Float8BlockwiseQTensor):
# Check for invalid configurations
if not grad._is_gemm_ready_format():
raise NotImplementedError(
"Reshape is only supported with GEMM_READY data format, "
f"but found data_format={grad._data_format}"
)
new_rowwise_data = None new_rowwise_data = None
new_columnwise_data = None new_columnwise_data = None
if grad._rowwise_data is not None: if grad._rowwise_data is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment