Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
......@@ -10,9 +10,9 @@ import warnings
import torch
from transformer_engine_torch import CommOverlapType
from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_external_gemm
from ...cpp_extensions import general_gemm
from ...distributed import gather_along_first_dim, get_distributed_world_size
from ...distributed import get_distributed_world_size
from ...module.base import (
fill_userbuffers_buffer_for_all_gather,
get_ub,
......@@ -48,14 +48,14 @@ class UserbuffersBackwardLinear(FusedOperation):
# Basic operations that comprise this fused operation
op_idxs = {"linear": None, "bias": None, "reduce_scatter": None}
ops = []
if reduce_scatter is not None:
op_idxs["reduce_scatter"] = len(ops)
ops.append(reduce_scatter)
op_idxs["linear"] = len(ops)
ops.append(linear)
if bias is not None:
op_idxs["bias"] = len(ops)
ops.append(bias)
op_idxs["linear"] = len(ops)
ops.append(linear)
if reduce_scatter is not None:
op_idxs["reduce_scatter"] = len(ops)
ops.append(reduce_scatter)
# Initialize base class
super().__init__(ops)
......@@ -398,26 +398,35 @@ class UserbuffersBackwardLinear(FusedOperation):
# Initialize grad output
if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
# overlapping the AG operation with the dgrad GEMM.
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_comm_dgrad.get_communication_stream()
ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad")
grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_comm_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
dy, dy_work = gather_along_first_dim(
# We use the send stream to copy into the userbuffers.
# This is the same stream that we will use to access the data in the AG,
# so we dont need to add any syncs yet.
with torch.cuda.stream(dgrad_send_stream):
dy, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_overlap_wgrad,
dy_local,
grad_output_quantizer,
tensor_parallel_group,
async_op=True,
quantizer=grad_output_quantizer,
)
# Synchronize with the main stream
dy_work.wait()
# Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
bulk_overlap_ag_with_external_gemm(
ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream
)
if tensor_parallel_mode == "column":
dy = dy_local
......@@ -495,7 +504,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get basic operations
idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[idx]
linear_op_ctx = basic_op_ctxs[-1]
bias_op = None
if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"]
......@@ -556,6 +565,7 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,)
grad_params.reverse()
grad_extra_inputs = [() for _ in range(len(self.basic_ops))]
return grad_input, grad_params, grad_extra_inputs
......
......@@ -182,7 +182,7 @@ class UserbuffersForwardLinear(FusedOperation):
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
if output_quantizer is not None:
raise ValueError("FP8 output is not supported")
raise ValueError("Quantized output is not supported")
else:
input_quantizer = None
weight_quantizer = None
......@@ -282,7 +282,7 @@ class UserbuffersForwardLinear(FusedOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -307,21 +307,17 @@ class UserbuffersForwardLinear(FusedOperation):
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# Quantization metadata
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())):
raise RuntimeError(
f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})"
)
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed
if torch.is_autocast_enabled():
......@@ -356,6 +352,7 @@ class UserbuffersForwardLinear(FusedOperation):
w = extra_outputs["weight"]
# Save state for backward pass
if linear_op_ctx.requires_grad:
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
......@@ -366,9 +363,8 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
if bias_op is not None:
bias_op_ctx.with_quantized_compute = with_quantized_compute
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer()
if bias_op is not None and bias_op_ctx.requires_grad:
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer()
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -5,22 +5,25 @@
"""Manager class for a pipeline of fusible operations."""
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Callable, Iterable
from typing import Any, Optional
import itertools
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe, DelayedScaling
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusibleOperation,
OperationContext,
)
from transformer_engine.pytorch.ops.fused import (
fuse_backward_bias_activation,
fuse_backward_activation_bias,
fuse_backward_linear_add,
fuse_backward_linear_scale,
fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add,
fuse_forward_linear_scale_add,
fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear,
)
......@@ -68,8 +71,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
input_: torch.Tensor,
fuser: OperationFuser,
basic_op_kwargs: list[dict[str, Any]],
is_grad_enabled: bool,
*params_and_extra_inputs: torch.nn.Parameter,
*params_and_extra_inputs: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass
......@@ -83,8 +85,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
Container for the pipeline of operations to run
basic_op_kwargs: list of dict
Keyword arguments to BasicOperation
is_grad_enabled: bool
Should context be saved for backward
*params_and_extra_inputs: torch.Tensor
Other tensor inputs to include in autograd graph. Consists
of parameter tensors, followed by extra operation inputs.
......@@ -103,10 +103,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Mark input tensors as not deletable in backward
for tensor in (input_,) + params_and_extra_inputs:
tensor.do_not_clear = True
tensor._do_not_clear = True
# Unflatten list of parameters and extra tensor inputs
extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :]
extra_inputs = params_and_extra_inputs[-fuser.num_extra_inputs :]
basic_op_extra_inputs = []
for op in fuser._basic_ops:
xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
......@@ -114,44 +114,37 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Apply forward ops
x = input_
requires_grad = is_grad_enabled and x.requires_grad
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
extra_outputs = [None] * fuser._num_basic_ops
for op, basic_op_idxs in fuser._forward_ops:
# Check if backward op is required
if is_grad_enabled:
if not requires_grad:
requires_grad = any(param.requires_grad for param in op.parameters())
if not requires_grad:
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
# Set if backward op is required
for idx in basic_op_idxs:
basic_op_ctxs[idx].requires_grad = requires_grad
basic_op_ctxs[idx].requires_grad = idx >= fuser.first_op_requiring_backward
# Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_op_idx = basic_op_idxs[0] - 1
prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None
prev_op_grad_input_quantizer = None
if prev_op is not None and with_quantized_compute:
prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer()
prev_op_grad_output_quantizer = None
if prev_op is not None:
prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer()
next_op_idx = basic_op_idxs[-1] + 1
next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
next_op_input_quantizer = None
if next_op is not None and with_quantized_compute:
if next_op is not None:
next_op_input_quantizer = next_op.get_input_quantizer()
x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs],
x,
basic_op_extra_inputs=extra_inputs,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
prev_op_grad_output_quantizer=prev_op_grad_output_quantizer,
next_op_input_quantizer=next_op_input_quantizer,
basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
)
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
for y in ys:
y.requires_grad_(requires_grad)
y.requires_grad_(idx >= fuser.first_op_requiring_backward)
extra_outputs[idx] = ys
# Flatten list of extra outputs
......@@ -168,7 +161,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
extra_outputs_flat.extend(ys)
# Save context for backward pass
if is_grad_enabled:
if func_ctx is not None:
# Flatten list of saved tensors
to_save = []
......@@ -181,24 +174,29 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
ctx._saved_tensors_range = (range_start, range_end)
# Save tensors for backward
if with_quantized_compute:
tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
func_ctx.save_for_backward(*tensors_to_save)
func_ctx.tensor_objects = tensor_objects
else:
func_ctx.save_for_backward(*to_save)
# Whether to perform recipe update in backward pass
is_first_module = False
if fuser.first_op_requiring_backward < fuser._num_basic_ops:
is_first_module = FP8GlobalStateManager.is_first_fp8_module()
# Other context
func_ctx.backward_ops = fuser._backward_ops
func_ctx.basic_ops = fuser._basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.basic_op_num_params = fuser._num_list_basic_op_params
func_ctx.num_extra_inputs = fuser._num_extra_inputs
func_ctx.basic_op_num_params = fuser._basic_op_num_params
func_ctx.num_extra_inputs = fuser.num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
func_ctx.with_quantized_compute = with_quantized_compute
func_ctx.is_first_module = is_first_module
x.requires_grad_(requires_grad)
# Mark output tensors as not deletable in backward
for tensor in [x] + extra_outputs_flat:
tensor._do_not_clear = True
x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops)
if extra_outputs_flat:
return x, *extra_outputs_flat
......@@ -220,10 +218,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_op_ctxs = func_ctx.basic_op_ctxs
# Restore saved tensors
if func_ctx.with_quantized_compute:
saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
else:
saved_tensors = func_ctx.saved_tensors
# Unflatten list of saved tensors
for ctx in basic_op_ctxs:
......@@ -304,7 +299,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
dx, # input_
None, # fuser
None, # basic_op_kwargs
None, # is_grad_enabled
*grad_params_flat,
*grad_extra_inputs_flat,
)
......@@ -317,19 +311,12 @@ class OperationFuser:
----------
ops: list of FusibleOperation
Pipeline of operations
fuse_ops: bool
Whether to attempt fusing operations
recipe: Recipe, optional
Quantization recipe to use when fusing and executing operations.
Note: certain fusions may depend on what kind of recipe is being used.
"""
def __init__(
self,
ops: list[FusibleOperation],
fuse_ops: bool,
recipe: Optional[Recipe],
) -> None:
# Get list of basic operations
......@@ -343,25 +330,22 @@ class OperationFuser:
self._basic_ops: list[BasicOperation] = basic_ops
# Number of extra tensor inputs
self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops)
self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops)
self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs)
# Ops for forward and backward pass
# Ops for forward and backward pass, will be populated in fuse_ops
self._forward_ops: list[tuple[FusibleOperation, list[int]]]
self._backward_ops: list[tuple[FusibleOperation, list[int]]]
self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)]
self._backward_ops = list(reversed(self._forward_ops))
# Flag for checking if this is the first iteration
self._is_first_forward = True
# Fuse ops if needed
self.recipe = recipe
if fuse_ops:
self.fuse_ops()
# Cache and detect change of state relevant for fusing operations
self.recipe_type = None
self.first_op_requiring_backward = 0
self._last_amax_history_len = 0
# Flatten list of parameters
self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()]
self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops]
self._basic_op_params = [list(op.parameters()) for op in self._basic_ops]
self._basic_op_num_params = list(map(len, self._basic_op_params))
self._flat_basic_op_params = sum(self._basic_op_params, [])
@classmethod
def _fuse_forward_ops(
......@@ -373,6 +357,7 @@ class OperationFuser:
ops = fuse_userbuffers_forward_linear(ops)
ops = fuse_forward_linear_bias_add(ops)
ops = fuse_forward_linear_bias_activation(ops)
ops = fuse_forward_linear_scale_add(ops)
return ops
@classmethod
......@@ -384,13 +369,74 @@ class OperationFuser:
"""Attempt to fuse operations in backward pass"""
ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops)
ops = fuse_backward_bias_activation(ops, recipe)
ops = fuse_backward_linear_scale(ops)
ops = fuse_backward_activation_bias(ops, recipe)
return ops
def fuse_ops(self) -> None:
"""Attempt to fuse operations"""
self._forward_ops = self._fuse_forward_ops(self._forward_ops, self.recipe)
self._backward_ops = self._fuse_backward_ops(self._backward_ops, self.recipe)
def maybe_fuse_ops(
self,
is_grad_enabled: bool,
recipe: Optional[Recipe],
input_: torch.Tensor,
extra_inputs: list[Iterable[torch.Tensor]],
):
"""Attempt to fuse operations if neccesary"""
# Determine which basic ops require backward
if not is_grad_enabled:
first_op_requiring_backward = self._num_basic_ops
elif input_.requires_grad:
first_op_requiring_backward = 0
else:
first_op_requiring_backward = self._num_basic_ops
for op_idx in range(self._num_basic_ops):
op_inputs = itertools.chain(self._basic_op_params[op_idx], extra_inputs[op_idx])
if any(tensor.requires_grad for tensor in op_inputs):
first_op_requiring_backward = op_idx
break
# Early exit if fusion parameters haven't changed
need_reset = False
recipe_type = type(recipe)
fusion_params = (recipe_type, first_op_requiring_backward)
if fusion_params != (self.recipe_type, self.first_op_requiring_backward):
# Recipe type or grad requirmenets have changed
need_reset = True
elif (
recipe is not None
and recipe.delayed()
and self._last_amax_history_len != recipe.amax_history_len
):
# FP8 delayed scaling has changed amax history length
need_reset = True
if not need_reset:
return
# Reset recipe state
for op in self._basic_ops:
op.reset_recipe_state(recipe=recipe)
# Check if this is the first iteration
if self.recipe_type is None:
for op in self._basic_ops:
op.pre_first_fuser_forward()
# Prepare basic op lists for fusions
forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)]
backward_ops = list(reversed(forward_ops[first_op_requiring_backward:]))
# Fuse ops
self._forward_ops = self._fuse_forward_ops(forward_ops, recipe)
self._backward_ops = self._fuse_backward_ops(backward_ops, recipe)
# Save current fusion params
self.recipe_type, self.first_op_requiring_backward = fusion_params
# Save amax history length
if isinstance(recipe, DelayedScaling):
self._last_amax_history_len = recipe.amax_history_len
else:
self._last_amax_history_len = 0
def __call__(
self,
......@@ -399,23 +445,32 @@ class OperationFuser:
basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
# Verify extra input count
if len(extra_inputs) != self._num_extra_inputs:
if len(extra_inputs) != self.num_extra_inputs:
raise ValueError(
f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}"
f"Expected {self.num_extra_inputs} extra inputs but got {len(extra_inputs)}"
)
# Initialization before forward pass
if self._is_first_forward:
for op in self._basic_ops:
op.pre_first_forward(recipe=self.recipe)
self._is_first_forward = False
# Canonicalize op kwargs
if basic_op_kwargs is None:
basic_op_kwargs = [{}] * self._num_basic_ops
# Fuser forward pass
# Unflatten list of extra tensor inputs
extra_inputs_copy = list(extra_inputs)
basic_op_extra_inputs = []
for op in self._basic_ops:
xs, extra_inputs_copy = _split_tuple(extra_inputs_copy, op.num_extra_inputs)
basic_op_extra_inputs.append(xs)
# Get environment state
recipe = None
if FP8GlobalStateManager.is_fp8_enabled():
recipe = FP8GlobalStateManager.get_fp8_recipe()
is_grad_enabled = torch.is_grad_enabled()
# Attempt to fuse operations if neccesary
self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs)
# Fuser forward pass
if is_grad_enabled:
forward_func = _OperationFuserAutogradFunction.apply
args = []
......@@ -426,8 +481,7 @@ class OperationFuser:
input,
self,
basic_op_kwargs,
is_grad_enabled,
*self._basic_op_params,
*self._flat_basic_op_params,
*extra_inputs,
)
return forward_func(*args)
......@@ -6,7 +6,7 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Optional
from typing import Any, Optional
import torch
......@@ -91,6 +91,8 @@ class Linear(FusedOperation):
# Construct basic ops
ops = []
linear_idx = None
bias_idx = None
linear_kwargs = {
"in_features": in_features,
"out_features": out_features,
......@@ -111,14 +113,16 @@ class Linear(FusedOperation):
}
if tensor_parallel_mode == "row":
# Row TP: GEMM + bias + reduction
linear_idx = len(ops)
linear_kwargs["in_features"] = local_in_features
linear_kwargs["out_features"] = local_out_features
linear_kwargs["tensor_parallel_mode"] = None
linear_kwargs["tensor_parallel_group"] = None
linear_kwargs["sequence_parallel"] = False
bias_kwargs["size"] *= tensor_parallel_size
ops.append(BasicLinear(**linear_kwargs))
if bias:
bias_idx = len(ops)
bias_kwargs["size"] *= tensor_parallel_size
ops.append(Bias(**bias_kwargs))
if sequence_parallel:
ops.append(ReduceScatter(tensor_parallel_group))
......@@ -126,45 +130,81 @@ class Linear(FusedOperation):
ops.append(AllReduce(tensor_parallel_group))
else:
# Column TP or no TP: (gather + GEMM) + bias
linear_idx = len(ops)
ops.append(BasicLinear(**linear_kwargs))
if bias:
bias_idx = len(ops)
ops.append(Bias(**bias_kwargs))
# Initialize base class
super().__init__(ops)
self._has_bias: bool = bias
@property
def weight(self) -> torch.nn.Parameter:
"""Weight tensor
Parameter is owned by `BasicLinear` operation.
"""
return self.basic_ops[0].weight
@weight.setter
def weight(self, value: Optional[torch.nn.Parameter]) -> None:
self.basic_ops[0].weight = value
# Register parameters
self._linear_idx: Optional[int] = linear_idx
self._bias_idx: Optional[int] = bias_idx
self.register_parameter("weight", self.basic_ops[self._linear_idx].weight)
bias = None
if self._bias_idx is not None:
bias = self.basic_ops[self._bias_idx].bias
self.register_parameter("bias", bias)
@property
def bias(self) -> Optional[torch.nn.Parameter]:
"""Bias tensor
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
"""Add a parameter to the module
Parameter is owned by `Bias` operation.
Also updates the basic operation that owns the parameter.
"""
if self._has_bias:
return self.basic_ops[1].bias
return None
@bias.setter
def bias(self, value: Optional[torch.nn.Parameter]) -> None:
if self._has_bias:
self.basic_ops[1].bias = value
elif value is not None:
if name == "bias" and self._bias_idx is None and param is not None:
raise ValueError(
"Attempted to set bias parameter in Linear operation "
"that does not have bias enabled"
)
super().register_parameter(name, param)
if name == "weight":
self.basic_ops[self._linear_idx].weight = param
elif name == "bias" and self._bias_idx is not None:
self.basic_ops[self._bias_idx].bias = param
def state_dict(self, *, prefix: str = "", **kwargs) -> dict[str, Any]:
"""Save state"""
state_dict = super().state_dict(prefix=prefix, **kwargs)
# Remove basic op params from state dict
# Note: Logically, basic ops own params and fused ops are
# considered as stateless. However, we register weight and
# bias params in the linear op for convenience. We remove
# these redudant params from the checkpoint for backward
# compatibility.
if f"{prefix}weight" in state_dict:
del state_dict[f"{prefix}weight"]
if f"{prefix}bias" in state_dict:
del state_dict[f"{prefix}bias"]
return state_dict
def _load_from_state_dict(
self,
state_dict: dict[str, Any],
prefix: str,
*args,
**kwargs,
) -> None:
# Add basic op params to state dict
# Note: Logically, basic ops own params and fused ops are
# considered as stateless. However, we register weight and
# bias params in the linear op for convenience. We remove
# these redudant params from the checkpoint for backward
# compatibility.
if f"{prefix}weight" not in state_dict:
state_dict[f"{prefix}weight"] = state_dict[
f"{prefix}basic_ops.{self._linear_idx}.weight"
]
if f"{prefix}bias" not in state_dict:
if self._bias_idx is None:
state_dict[f"{prefix}bias"] = None
else:
state_dict[f"{prefix}bias"] = state_dict[f"{prefix}basic_ops.{self._bias_idx}.bias"]
# Load state dict
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
......@@ -15,9 +15,6 @@ import torch
from transformer_engine.common.recipe import Recipe
from ..fp8 import (
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
Float8BlockScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
fp8_autocast,
......@@ -65,18 +62,14 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def is_fused_op(self) -> bool:
"""Whether this op is the fusion of one or more basic ops"""
def pre_first_forward(
self,
*,
recipe: Optional[Recipe],
) -> None:
"""Preprocessing before forward pass"""
def pre_first_fuser_forward(self) -> None:
"""Preprocessing before first fuser forward pass"""
def get_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input tensor"""
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input's grad tensor"""
def get_grad_output_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized output's grad tensor"""
def fuser_forward(
self,
......@@ -84,7 +77,7 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -104,8 +97,8 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
Input tensor
basic_op_extra_inputs: list of torch.Tensor
Extra tensor inputs to basic operations
prev_op_grad_input_quantizer: Quantizer, optional
The grad_input_quantizer of the preceeding operation
prev_op_grad_output_quantizer: Quantizer, optional
The grad_output_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation
basic_op_kwargs: list of dict
......@@ -186,8 +179,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
super().__init__()
# Objects for quantization
self._quantizers: Optional[dict[str, list[Quantizer]]] = None
self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None
self._quantizers: Optional[dict[str, list[Quantizer]]] = None
with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None
self.reset_recipe_state(recipe=recipe)
@property
def is_fused_op(self) -> bool:
......@@ -214,19 +210,47 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
return self.get_quantizer("forward", 0)
return None
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
def get_grad_output_quantizer(self) -> Optional[Quantizer]:
if self.num_quantizers("backward") > 0:
return self.get_quantizer("backward", 0)
return None
def _reset_quantization_recipe_state(
def reset_recipe_state(
self,
*,
recipe: Recipe,
recipe: Optional[Recipe],
) -> None:
"""Construct state for quantization recipe"""
# Quantization recipe state for forward and backward pass
# Clear quantization state if necessary
if recipe is None:
self._fp8_metas = None
self._quantizers = None
return
# Communication group for FP8 amax reductions
fp8_group = FP8GlobalStateManager.get_fp8_group()
# Skip resetting recipe type if it did not actually change.
# This could happen for example if calling BasicOperation.forward directly, as in that
# case, the OperationFuser is not persistent, or when loading from a checkpoint
need_to_reset_recipe_state = False
if self._fp8_metas is None or self._quantizers is None:
need_to_reset_recipe_state = True
else:
for mode in ("forward", "backward"):
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]:
continue
recipe_state = self._fp8_metas[mode][fp8_meta_key]
if not isinstance(recipe, type(recipe_state.recipe)):
need_to_reset_recipe_state = True
break
if need_to_reset_recipe_state:
# Construct quantization recipe states
self._fp8_metas = {"forward": None, "backward": None}
self._quantizers = {"forward": [], "backward": []}
for mode in ("forward", "backward"):
......@@ -251,83 +275,76 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._fp8_metas[mode] = {
fp8_meta_key: recipe_state,
"recipe": recipe,
"fp8_group": FP8GlobalStateManager.get_fp8_group(),
"fp8_group": fp8_group,
}
# Construct builder class for quantized tensors
self._quantizers[mode] = recipe_state.make_quantizers()
def _update_quantization_recipe_state(
self,
*,
recipe: Recipe,
) -> None:
"""Make sure quantizer state matches quantization recipe"""
# Reset quantization state if needed
if self._fp8_metas is None or self._quantizers is None:
self._reset_quantization_recipe_state(recipe=recipe)
return
for mode in ("forward", "backward"):
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]:
continue
recipe_state = self._fp8_metas[mode][fp8_meta_key]
need_to_reset_recipe_state = (
(recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState))
or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState))
or (
recipe.float8_block_scaling()
and not isinstance(recipe_state, Float8BlockScalingRecipeState)
)
)
if need_to_reset_recipe_state:
self._reset_quantization_recipe_state(recipe=recipe)
return
# Quantization recipe state for forward and backward pass
else:
# Update quantization recipe states
for mode in ("forward", "backward"):
num_quantizers = self.num_quantizers(mode)
if num_quantizers == 0:
if self._fp8_metas[mode] is None:
continue
self._fp8_metas[mode]["recipe"] = recipe
self._fp8_metas[mode]["fp8_group"] = fp8_group
# Update FP8 metadata
fp8_meta = self._fp8_metas[mode]
fp8_meta["recipe"] = recipe
fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Get recipe state
# Update amax history for FP8 delayed scaling
if recipe.delayed():
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
recipe_state = fp8_meta[fp8_meta_key]
recipe_state = self._fp8_metas[mode][fp8_meta_key]
# Reallocate amax history if needed
if not recipe.delayed():
continue
current_length = recipe_state.amax_history.size(0)
target_length = recipe.amax_history_len
if current_length != target_length:
with torch.no_grad():
if target_length < current_length:
with torch.no_grad():
recipe_state.amax_history = recipe_state.amax_history[
:target_length
].clone()
else:
elif target_length > current_length:
with torch.no_grad():
recipe_state.amax_history = torch.nn.functional.pad(
recipe_state.amax_history,
pad=(0, 0, 0, target_length - current_length),
)
# Update quantizers with new amax pointers
self._quantizers[mode] = recipe_state.make_quantizers()
# Update the global buffers with new amax pointers
if FP8GlobalStateManager.get_buffer_info() in self._fp8_metas[mode]:
pos, buffer_key = self._fp8_metas[mode][
FP8GlobalStateManager.get_buffer_info()
]
if buffer_key in FP8GlobalStateManager.global_amax_buffer:
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = (
recipe_state.amax_history[0]
)
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][
pos
] = recipe_state.amax_history
# Add meta tensors to global buffer to participate in reduction
for mode in ("forward", "backward"):
if (
FP8GlobalStateManager.is_fp8_enabled()
and self.num_quantizers(mode)
and not FP8GlobalStateManager.fp8_graph_capturing()
):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self._fp8_metas[mode],
)
def get_quantizer(
self,
mode: str,
index: int,
) -> Quantizer:
) -> Optional[Quantizer]:
"""Get builder class for quantized tensor
Parameters
......@@ -337,7 +354,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""
if self._quantizers is None:
self._reset_quantization_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe())
return None
return self._quantizers[mode][index]
@torch.no_grad()
......@@ -388,33 +405,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)
def pre_first_forward(
self,
*,
recipe: Optional[Recipe],
) -> None:
"""Preprocessing before forward pass"""
# Initialize FP8 metadata if needed
if recipe is not None:
self._update_quantization_recipe_state(recipe=recipe)
if not FP8GlobalStateManager.fp8_graph_capturing():
if self.num_quantizers("forward"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self._fp8_metas["forward"],
)
if self.num_quantizers("backward"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self._fp8_metas["backward"],
)
@abc.abstractmethod
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
*,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
**kwargs: Any,
) -> torch.Tensor:
......@@ -426,8 +423,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
Context to coordinate between forward and backward passes
input_: torch.Tensor
Input tensor
prev_op_grad_input_quantizer: Quantizer, optional
The grad_input_quantizer of the preceeding operation
prev_op_grad_output_quantizer: Quantizer, optional
The grad_output_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation
......@@ -468,7 +465,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, list[tuple[()]]]:
......@@ -482,7 +479,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
output = self.op_forward(
basic_op_ctxs[0],
input_,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer,
prev_op_grad_output_quantizer=prev_op_grad_output_quantizer,
next_op_input_quantizer=next_op_input_quantizer,
**basic_op_kwargs[0],
)
......@@ -518,9 +515,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""Apply operation"""
from .fuser import OperationFuser
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
return OperationFuser([self])(
input,
*extra_inputs,
basic_op_kwargs=[kwargs],
......@@ -630,7 +625,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed
if self._fp8_metas is None or self._fp8_metas[mode] is None:
with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
self._reset_quantization_recipe_state(recipe=state[mode]["recipe"])
self.reset_recipe_state(recipe=state[mode]["recipe"])
fp8_meta = self._fp8_metas[mode]
# Load extra items
......@@ -708,13 +703,12 @@ class FusedOperation(FusibleOperation):
def get_input_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[0].get_input_quantizer()
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[-1].get_grad_input_quantizer()
def get_grad_output_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[-1].get_grad_output_quantizer()
def pre_first_forward(self, *args, **kwargs) -> None:
"""Preprocessing before forward pass"""
def pre_first_fuser_forward(self) -> None:
for op in self.basic_ops:
op.pre_first_forward(*args, **kwargs)
op.pre_first_fuser_forward()
def forward(
self,
......@@ -727,9 +721,7 @@ class FusedOperation(FusibleOperation):
basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
from .fuser import OperationFuser
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
return OperationFuser([self])(
input,
*extra_inputs,
basic_op_kwargs=basic_op_kwargs,
......
......@@ -10,7 +10,6 @@ from typing import Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe
from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.fuser import OperationFuser
......@@ -147,7 +146,6 @@ class Sequential(torch.nn.Module):
def _make_module_groups(
cls,
modules: Iterable[torch.nn.Module],
recipe: Optional[Recipe],
) -> list[OperationFuser | torch.nn.Module]:
"""Make list of modules, with fusible operations grouped together"""
......@@ -162,24 +160,7 @@ class Sequential(torch.nn.Module):
groups.append(module)
for idx, group in enumerate(groups):
if isinstance(group, list):
groups[idx] = OperationFuser(group, fuse_ops=True, recipe=recipe)
# Check if operations expect extra input or output tensors
# Note: If any op has extra inputs or outputs, then the entire
# Sequential must be made up of TE ops.
if len(groups) > 1:
ops = []
for group in groups:
if isinstance(group, OperationFuser):
ops.extend(group._basic_ops)
num_extra_inputs = sum(op.num_extra_inputs for op in ops)
num_extra_outputs = sum(op.num_extra_outputs for op in ops)
if num_extra_inputs > 0 or num_extra_outputs > 0:
raise RuntimeError(
f"`Sequential` expects {num_extra_inputs} extra inputs "
f"and {num_extra_outputs} extra outputs, "
"but it contains non-fusible operations"
)
groups[idx] = OperationFuser(group)
return groups
......@@ -190,22 +171,28 @@ class Sequential(torch.nn.Module):
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass"""
# Get current global state
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
global_state = (with_quantized_compute, type(recipe))
# Reset module groups is global state changed
if self._last_global_state != global_state:
self._module_groups = None
self._last_global_state = global_state
# Create module groups if needed
if self._module_groups is None:
self._module_groups = self._make_module_groups(self._modules.values(), recipe)
self._module_groups = self._make_module_groups(self._modules.values())
# Forward pass for each module group
x = input
extra_outputs: list[torch.Tensor] = []
for module_group in self._module_groups:
x = module_group(x, *extra_inputs)
if isinstance(module_group, OperationFuser):
xs, extra_inputs = (
(x,) + extra_inputs[: module_group.num_extra_inputs],
extra_inputs[module_group.num_extra_inputs :],
)
xs = module_group(*xs)
if isinstance(xs, tuple):
x, ys = xs[0], xs[1:]
extra_outputs.extend(ys)
else:
x = xs
else:
x = module_group(x)
if extra_outputs:
return (x,) + tuple(extra_outputs)
return x
......@@ -60,7 +60,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
......@@ -125,9 +125,15 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
self._columnwise_scale_inv = tensors[3]
return tensors[4:]
def get_data_tensors(self):
def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True):
"""Get this Tensor's data."""
if rowwise_data and columnwise_data:
return self._rowwise_data, self._columnwise_data
if rowwise_data:
return self._rowwise_data
if columnwise_data:
return self._columnwise_data
raise ValueError("No data to get, both rowwise_data and columnwise_data are False")
def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor:
"""Takes dequantized columnwise data and permutes to a rowwise shape"""
......
......@@ -86,7 +86,7 @@ class Float8TensorBase(QuantizedTensorBase):
else:
instance = super().__new__(cls, *args, **kwargs)
instance._data = data
instance._quantizer = quantizer
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._scale_inv = fp8_scale_inv
instance._transpose = data_transpose
......@@ -128,9 +128,15 @@ class Float8TensorBase(QuantizedTensorBase):
self._scale_inv = tensors[2]
return tensors[3:]
def get_data_tensors(self):
def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True):
"""Get this Tensor's data."""
if rowwise_data and columnwise_data:
return self._data, self._transpose
if rowwise_data:
return self._data
if columnwise_data:
return self._transpose
raise ValueError("No data to get, both rowwise_data and columnwise_data are False")
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Dequantize to a higher precision."""
......
......@@ -83,7 +83,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer
instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
......@@ -136,9 +136,15 @@ class MXFP8TensorBase(QuantizedTensorBase):
self._columnwise_scale_inv = tensors[3]
return tensors[4:]
def get_data_tensors(self):
def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True):
"""Get this Tensor's data."""
if rowwise_data and columnwise_data:
return self._rowwise_data, self._columnwise_data
if rowwise_data:
return self._rowwise_data
if columnwise_data:
return self._columnwise_data
raise ValueError("No data to get, both rowwise_data and columnwise_data are False")
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Dequantize to a higher precision."""
......
......@@ -524,7 +524,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor):
dst._rowwise_data = src._rowwise_data
dst._columnwise_data = src._columnwise_data
dst._quantizer = src._quantizer
dst._quantizer = src._quantizer.copy()
dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv
......
......@@ -109,10 +109,9 @@ class Float8Quantizer(Quantizer):
# Allocate FP8 data transpose if needed
data_transpose = None
if self.columnwise_usage:
inner_dim = data.size(-1)
transpose_shape = [data.size(-1)] + list(data.shape[:-1])
data_transpose = torch.empty(
inner_dim,
data.numel() // inner_dim,
transpose_shape,
dtype=torch.uint8,
device=device,
)
......@@ -186,6 +185,12 @@ class Float8Quantizer(Quantizer):
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return DelayedScaling
def supports_only_rowwise_all_gather(self) -> bool:
"""
Float8Quantizer supports only rowwise all-gather
"""
return True
class Float8CurrentScalingQuantizer(Quantizer):
"""Builder class for FP8 tensors with per-tensor current scaling
......@@ -231,7 +236,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax_epsilon: float = 0.0,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.scale = torch.ones(1, dtype=torch.float32, device=device)
self.scale = torch.empty(1, dtype=torch.float32, device=device)
self.amax = torch.empty(1, dtype=torch.float32, device=device)
self.dtype = tex.DType.kInt8 if int8_simulation_fp8_tensorwise else fp8_dtype
self.with_amax_reduction = with_amax_reduction
......@@ -363,6 +368,12 @@ class Float8CurrentScalingQuantizer(Quantizer):
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return Float8CurrentScaling
def supports_only_rowwise_all_gather(self) -> bool:
"""
Float8CurrentScalingQuantizer supports only rowwise all-gather
"""
return True
class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data
......@@ -691,7 +702,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
# Float8Tensor attributes
self._data = tensor._data
self._quantizer = tensor._quantizer
self._quantizer = tensor._quantizer.copy()
self._fp8_dtype = tensor._fp8_dtype
self._scale_inv = tensor._scale_inv
self._transpose = tensor._transpose
......
......@@ -100,7 +100,7 @@ class MXFP8Quantizer(Quantizer):
# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
scale_inv = torch.zeros(
scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
dtype=torch.uint8,
......@@ -112,7 +112,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty_like(data)
columnwise_scale_inv = torch.zeros(
columnwise_scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128),
dtype=torch.uint8,
......@@ -433,7 +433,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor)
self._rowwise_data = tensor._rowwise_data
self._columnwise_data = tensor._columnwise_data
self._quantizer = tensor._quantizer
self._quantizer = tensor._quantizer.copy()
self._fp8_dtype = tensor._fp8_dtype
self._rowwise_scale_inv = tensor._rowwise_scale_inv
self._columnwise_scale_inv = tensor._columnwise_scale_inv
......
......@@ -260,6 +260,10 @@ class Quantizer(abc.ABC):
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Returns recipe class that is compatible with this quantizer"""
def supports_only_rowwise_all_gather(self) -> bool:
"""Returns True if the quantizer supports only rowwise all-gather"""
return False
class _QuantizeFunc(torch.autograd.Function):
"""Cast to FP8 from other dtype"""
......
......@@ -236,14 +236,23 @@ class TransformerLayer(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False'
if set to `True`, L2 normalization is applied to query and key tensors
after RoPE (if applicable) but before attention computation.
This follows the Llama4 approach for QK normalization to improve
training stability and model performance.
qk_norm_type: Optional[str], default = None
type of normalization to apply to query and key tensors.
Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied.
When 'L2Normalization', L2 normalization is applied to query and key tensors.
When 'RMSNorm', RMS normalization is applied to query and key tensors.
When 'LayerNorm', layer normalization is applied to query and key tensors.
Normalization is applied after RoPE (if applicable) but before attention computation
when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach for
QK normalization to improve training stability and model performance.
qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors.
Only used when `use_qk_norm` is True.
epsilon value for normalization of query and key tensors.
Only used when `qk_norm_type` is not None.
qk_norm_before_rope: bool, default = `False`
if set to `True`, query and key normalization is applied before rotary position
embedding. When `False` (default), normalization is applied after RoPE.
This parameter allows supporting different architectural variants that apply
QK normalization at different points.
"""
def __init__(
......@@ -293,8 +302,9 @@ class TransformerLayer(torch.nn.Module):
device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd",
name: str = None,
use_qk_norm: bool = False,
qk_norm_type: Optional[str] = None,
qk_norm_eps: float = 1e-6,
qk_norm_before_rope: bool = False,
) -> None:
super().__init__()
......@@ -397,8 +407,9 @@ class TransformerLayer(torch.nn.Module):
return_bias=not self.parallel_attention_mlp,
normalization=normalization,
device=device,
use_qk_norm=use_qk_norm,
qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope,
name=name + ".self_attention" if name is not None else None,
)
......@@ -413,8 +424,9 @@ class TransformerLayer(torch.nn.Module):
return_bias=True,
normalization=normalization,
device=device,
use_qk_norm=use_qk_norm,
qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope,
name=name + ".inter_attention" if name is not None else None,
)
......
......@@ -341,13 +341,17 @@ def cross_entropy_forward(
return loss, _input
def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor):
def cross_entropy_backward(
_input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False
):
"""Backward implementation of cross entropy loss kernel"""
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
# Only check torch.equal when not in CUDA graph capturable mode
if not is_cg_capturable and torch.equal(
grad_output, torch.tensor(1.0, device=grad_output.device)
):
pass
else:
B, SQ, V = _input.shape
n_rows = B * SQ
......
......@@ -359,7 +359,7 @@ def _permute_kernel(
if prob == 0.0:
# for routing_map padding
# dst_row != -1 and prob == 0.0 means that this slot is padded
tl.store(output_ptr + output_off, 0, mask=mask)
tl.store(output_ptr + output_off, 0.0, mask=mask)
else:
tl.store(output_ptr + output_off, inp, mask=mask)
else:
......
......@@ -45,10 +45,10 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
for t in tensors:
if t is not None:
# Workaround for double buffering in cpu offload
if hasattr(t, "do_not_clear"):
if hasattr(t, "_do_not_clear"):
continue
if hasattr(t, "get_data_tensors"):
if any(hasattr(tensor, "do_not_clear") for tensor in t.get_data_tensors()):
if any(hasattr(tensor, "_do_not_clear") for tensor in t.get_data_tensors()):
continue
if hasattr(t, "clear"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment