Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
...@@ -10,9 +10,9 @@ import warnings ...@@ -10,9 +10,9 @@ import warnings
import torch import torch
from transformer_engine_torch import CommOverlapType from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_external_gemm
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...distributed import gather_along_first_dim, get_distributed_world_size from ...distributed import get_distributed_world_size
from ...module.base import ( from ...module.base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_ub, get_ub,
...@@ -48,14 +48,14 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -48,14 +48,14 @@ class UserbuffersBackwardLinear(FusedOperation):
# Basic operations that comprise this fused operation # Basic operations that comprise this fused operation
op_idxs = {"linear": None, "bias": None, "reduce_scatter": None} op_idxs = {"linear": None, "bias": None, "reduce_scatter": None}
ops = [] ops = []
if reduce_scatter is not None: op_idxs["linear"] = len(ops)
op_idxs["reduce_scatter"] = len(ops) ops.append(linear)
ops.append(reduce_scatter)
if bias is not None: if bias is not None:
op_idxs["bias"] = len(ops) op_idxs["bias"] = len(ops)
ops.append(bias) ops.append(bias)
op_idxs["linear"] = len(ops) if reduce_scatter is not None:
ops.append(linear) op_idxs["reduce_scatter"] = len(ops)
ops.append(reduce_scatter)
# Initialize base class # Initialize base class
super().__init__(ops) super().__init__(ops)
...@@ -398,26 +398,35 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -398,26 +398,35 @@ class UserbuffersBackwardLinear(FusedOperation):
# Initialize grad output # Initialize grad output
if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer): if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output # UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't # all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we # convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered # can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly # for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM. # overlapping the AG operation with the dgrad GEMM.
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_comm_dgrad.get_communication_stream()
ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad")
grad_output_quantizer.set_usage(rowwise=False, columnwise=True) grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_comm_dgrad.get_communication_stream() # We use the send stream to copy into the userbuffers.
with torch.cuda.stream(dgrad_comm_stream): # This is the same stream that we will use to access the data in the AG,
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather # so we dont need to add any syncs yet.
# This ensures that we don't start until all communication for the dgrad GEMM is complete with torch.cuda.stream(dgrad_send_stream):
dy, dy_work = gather_along_first_dim( dy, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_overlap_wgrad,
dy_local, dy_local,
grad_output_quantizer,
tensor_parallel_group, tensor_parallel_group,
async_op=True,
quantizer=grad_output_quantizer,
) )
# Synchronize with the main stream
dy_work.wait() # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
bulk_overlap_ag_with_external_gemm(
ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream
)
if tensor_parallel_mode == "column": if tensor_parallel_mode == "column":
dy = dy_local dy = dy_local
...@@ -495,7 +504,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -495,7 +504,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get basic operations # Get basic operations
idx = self._op_idxs["linear"] idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx] linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[idx] linear_op_ctx = basic_op_ctxs[-1]
bias_op = None bias_op = None
if self._op_idxs["bias"] is not None: if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"] idx = self._op_idxs["bias"]
...@@ -556,6 +565,7 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -556,6 +565,7 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_params[self._op_idxs["linear"]] = (grad_weight,) grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None: if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,) grad_params[self._op_idxs["bias"]] = (grad_bias,)
grad_params.reverse()
grad_extra_inputs = [() for _ in range(len(self.basic_ops))] grad_extra_inputs = [() for _ in range(len(self.basic_ops))]
return grad_input, grad_params, grad_extra_inputs return grad_input, grad_params, grad_extra_inputs
......
...@@ -182,7 +182,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -182,7 +182,7 @@ class UserbuffersForwardLinear(FusedOperation):
if weight_quantizer is None: if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor") raise ValueError("Missing quantizer for weight tensor")
if output_quantizer is not None: if output_quantizer is not None:
raise ValueError("FP8 output is not supported") raise ValueError("Quantized output is not supported")
else: else:
input_quantizer = None input_quantizer = None
weight_quantizer = None weight_quantizer = None
...@@ -282,7 +282,7 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -282,7 +282,7 @@ class UserbuffersForwardLinear(FusedOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -307,21 +307,17 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -307,21 +307,17 @@ class UserbuffersForwardLinear(FusedOperation):
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# Quantization metadata # Quantization metadata
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute: if with_quantized_compute:
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())):
raise RuntimeError( raise RuntimeError(
f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})" f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})"
) )
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed # Get autocast dtype if needed
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
...@@ -356,19 +352,19 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -356,19 +352,19 @@ class UserbuffersForwardLinear(FusedOperation):
w = extra_outputs["weight"] w = extra_outputs["weight"]
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local, w) if linear_op_ctx.requires_grad:
linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.input_dims = input_.size() linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.input_dims = input_.size()
linear_op_ctx.weight_requires_grad = weight_requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
if bias_op is not None: linear_op_ctx.weight_requires_grad = weight_requires_grad
bias_op_ctx.with_quantized_compute = with_quantized_compute if bias_op is not None and bias_op_ctx.requires_grad:
bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer()
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -5,22 +5,25 @@ ...@@ -5,22 +5,25 @@
"""Manager class for a pipeline of fusible operations.""" """Manager class for a pipeline of fusible operations."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable, Iterable
from typing import Any, Optional from typing import Any, Optional
import itertools
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe, DelayedScaling
from transformer_engine.pytorch.ops.op import ( from transformer_engine.pytorch.ops.op import (
BasicOperation, BasicOperation,
FusibleOperation, FusibleOperation,
OperationContext, OperationContext,
) )
from transformer_engine.pytorch.ops.fused import ( from transformer_engine.pytorch.ops.fused import (
fuse_backward_bias_activation, fuse_backward_activation_bias,
fuse_backward_linear_add, fuse_backward_linear_add,
fuse_backward_linear_scale,
fuse_forward_linear_bias_activation, fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add, fuse_forward_linear_bias_add,
fuse_forward_linear_scale_add,
fuse_userbuffers_backward_linear, fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear, fuse_userbuffers_forward_linear,
) )
...@@ -68,8 +71,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -68,8 +71,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
input_: torch.Tensor, input_: torch.Tensor,
fuser: OperationFuser, fuser: OperationFuser,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
is_grad_enabled: bool, *params_and_extra_inputs: torch.Tensor,
*params_and_extra_inputs: torch.nn.Parameter,
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass """Forward pass
...@@ -83,8 +85,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -83,8 +85,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
Container for the pipeline of operations to run Container for the pipeline of operations to run
basic_op_kwargs: list of dict basic_op_kwargs: list of dict
Keyword arguments to BasicOperation Keyword arguments to BasicOperation
is_grad_enabled: bool
Should context be saved for backward
*params_and_extra_inputs: torch.Tensor *params_and_extra_inputs: torch.Tensor
Other tensor inputs to include in autograd graph. Consists Other tensor inputs to include in autograd graph. Consists
of parameter tensors, followed by extra operation inputs. of parameter tensors, followed by extra operation inputs.
...@@ -103,10 +103,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -103,10 +103,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Mark input tensors as not deletable in backward # Mark input tensors as not deletable in backward
for tensor in (input_,) + params_and_extra_inputs: for tensor in (input_,) + params_and_extra_inputs:
tensor.do_not_clear = True tensor._do_not_clear = True
# Unflatten list of parameters and extra tensor inputs # Unflatten list of parameters and extra tensor inputs
extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :] extra_inputs = params_and_extra_inputs[-fuser.num_extra_inputs :]
basic_op_extra_inputs = [] basic_op_extra_inputs = []
for op in fuser._basic_ops: for op in fuser._basic_ops:
xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
...@@ -114,44 +114,37 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -114,44 +114,37 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Apply forward ops # Apply forward ops
x = input_ x = input_
requires_grad = is_grad_enabled and x.requires_grad
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
extra_outputs = [None] * fuser._num_basic_ops extra_outputs = [None] * fuser._num_basic_ops
for op, basic_op_idxs in fuser._forward_ops: for op, basic_op_idxs in fuser._forward_ops:
# Check if backward op is required # Set if backward op is required
if is_grad_enabled:
if not requires_grad:
requires_grad = any(param.requires_grad for param in op.parameters())
if not requires_grad:
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
for idx in basic_op_idxs: for idx in basic_op_idxs:
basic_op_ctxs[idx].requires_grad = requires_grad basic_op_ctxs[idx].requires_grad = idx >= fuser.first_op_requiring_backward
# Forward op # Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_op_idx = basic_op_idxs[0] - 1 prev_op_idx = basic_op_idxs[0] - 1
prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None
prev_op_grad_input_quantizer = None prev_op_grad_output_quantizer = None
if prev_op is not None and with_quantized_compute: if prev_op is not None:
prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer() prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer()
next_op_idx = basic_op_idxs[-1] + 1 next_op_idx = basic_op_idxs[-1] + 1
next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None
next_op_input_quantizer = None next_op_input_quantizer = None
if next_op is not None and with_quantized_compute: if next_op is not None:
next_op_input_quantizer = next_op.get_input_quantizer() next_op_input_quantizer = next_op.get_input_quantizer()
x, fused_op_extra_outputs = op.fuser_forward( x, fused_op_extra_outputs = op.fuser_forward(
[basic_op_ctxs[idx] for idx in basic_op_idxs], [basic_op_ctxs[idx] for idx in basic_op_idxs],
x, x,
basic_op_extra_inputs=extra_inputs, basic_op_extra_inputs=extra_inputs,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer, prev_op_grad_output_quantizer=prev_op_grad_output_quantizer,
next_op_input_quantizer=next_op_input_quantizer, next_op_input_quantizer=next_op_input_quantizer,
basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
) )
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
for y in ys: for y in ys:
y.requires_grad_(requires_grad) y.requires_grad_(idx >= fuser.first_op_requiring_backward)
extra_outputs[idx] = ys extra_outputs[idx] = ys
# Flatten list of extra outputs # Flatten list of extra outputs
...@@ -168,7 +161,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -168,7 +161,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
extra_outputs_flat.extend(ys) extra_outputs_flat.extend(ys)
# Save context for backward pass # Save context for backward pass
if is_grad_enabled: if func_ctx is not None:
# Flatten list of saved tensors # Flatten list of saved tensors
to_save = [] to_save = []
...@@ -181,24 +174,29 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -181,24 +174,29 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
ctx._saved_tensors_range = (range_start, range_end) ctx._saved_tensors_range = (range_start, range_end)
# Save tensors for backward # Save tensors for backward
if with_quantized_compute: tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
tensors_to_save, tensor_objects = prepare_for_saving(*to_save) func_ctx.save_for_backward(*tensors_to_save)
func_ctx.save_for_backward(*tensors_to_save) func_ctx.tensor_objects = tensor_objects
func_ctx.tensor_objects = tensor_objects
else: # Whether to perform recipe update in backward pass
func_ctx.save_for_backward(*to_save) is_first_module = False
if fuser.first_op_requiring_backward < fuser._num_basic_ops:
is_first_module = FP8GlobalStateManager.is_first_fp8_module()
# Other context # Other context
func_ctx.backward_ops = fuser._backward_ops func_ctx.backward_ops = fuser._backward_ops
func_ctx.basic_ops = fuser._basic_ops func_ctx.basic_ops = fuser._basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.basic_op_num_params = fuser._num_list_basic_op_params func_ctx.basic_op_num_params = fuser._basic_op_num_params
func_ctx.num_extra_inputs = fuser._num_extra_inputs func_ctx.num_extra_inputs = fuser.num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() func_ctx.is_first_module = is_first_module
func_ctx.with_quantized_compute = with_quantized_compute
x.requires_grad_(requires_grad) # Mark output tensors as not deletable in backward
for tensor in [x] + extra_outputs_flat:
tensor._do_not_clear = True
x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops)
if extra_outputs_flat: if extra_outputs_flat:
return x, *extra_outputs_flat return x, *extra_outputs_flat
...@@ -220,10 +218,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -220,10 +218,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_op_ctxs = func_ctx.basic_op_ctxs basic_op_ctxs = func_ctx.basic_op_ctxs
# Restore saved tensors # Restore saved tensors
if func_ctx.with_quantized_compute: saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
else:
saved_tensors = func_ctx.saved_tensors
# Unflatten list of saved tensors # Unflatten list of saved tensors
for ctx in basic_op_ctxs: for ctx in basic_op_ctxs:
...@@ -304,7 +299,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -304,7 +299,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
dx, # input_ dx, # input_
None, # fuser None, # fuser
None, # basic_op_kwargs None, # basic_op_kwargs
None, # is_grad_enabled
*grad_params_flat, *grad_params_flat,
*grad_extra_inputs_flat, *grad_extra_inputs_flat,
) )
...@@ -317,19 +311,12 @@ class OperationFuser: ...@@ -317,19 +311,12 @@ class OperationFuser:
---------- ----------
ops: list of FusibleOperation ops: list of FusibleOperation
Pipeline of operations Pipeline of operations
fuse_ops: bool
Whether to attempt fusing operations
recipe: Recipe, optional
Quantization recipe to use when fusing and executing operations.
Note: certain fusions may depend on what kind of recipe is being used.
""" """
def __init__( def __init__(
self, self,
ops: list[FusibleOperation], ops: list[FusibleOperation],
fuse_ops: bool,
recipe: Optional[Recipe],
) -> None: ) -> None:
# Get list of basic operations # Get list of basic operations
...@@ -343,25 +330,22 @@ class OperationFuser: ...@@ -343,25 +330,22 @@ class OperationFuser:
self._basic_ops: list[BasicOperation] = basic_ops self._basic_ops: list[BasicOperation] = basic_ops
# Number of extra tensor inputs # Number of extra tensor inputs
self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops) self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops)
self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs)
# Ops for forward and backward pass # Ops for forward and backward pass, will be populated in fuse_ops
self._forward_ops: list[tuple[FusibleOperation, list[int]]] self._forward_ops: list[tuple[FusibleOperation, list[int]]]
self._backward_ops: list[tuple[FusibleOperation, list[int]]] self._backward_ops: list[tuple[FusibleOperation, list[int]]]
self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)]
self._backward_ops = list(reversed(self._forward_ops))
# Flag for checking if this is the first iteration # Cache and detect change of state relevant for fusing operations
self._is_first_forward = True self.recipe_type = None
self.first_op_requiring_backward = 0
# Fuse ops if needed self._last_amax_history_len = 0
self.recipe = recipe
if fuse_ops:
self.fuse_ops()
# Flatten list of parameters # Flatten list of parameters
self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()] self._basic_op_params = [list(op.parameters()) for op in self._basic_ops]
self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops] self._basic_op_num_params = list(map(len, self._basic_op_params))
self._flat_basic_op_params = sum(self._basic_op_params, [])
@classmethod @classmethod
def _fuse_forward_ops( def _fuse_forward_ops(
...@@ -373,6 +357,7 @@ class OperationFuser: ...@@ -373,6 +357,7 @@ class OperationFuser:
ops = fuse_userbuffers_forward_linear(ops) ops = fuse_userbuffers_forward_linear(ops)
ops = fuse_forward_linear_bias_add(ops) ops = fuse_forward_linear_bias_add(ops)
ops = fuse_forward_linear_bias_activation(ops) ops = fuse_forward_linear_bias_activation(ops)
ops = fuse_forward_linear_scale_add(ops)
return ops return ops
@classmethod @classmethod
...@@ -384,13 +369,74 @@ class OperationFuser: ...@@ -384,13 +369,74 @@ class OperationFuser:
"""Attempt to fuse operations in backward pass""" """Attempt to fuse operations in backward pass"""
ops = fuse_userbuffers_backward_linear(ops) ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops) ops = fuse_backward_linear_add(ops)
ops = fuse_backward_bias_activation(ops, recipe) ops = fuse_backward_linear_scale(ops)
ops = fuse_backward_activation_bias(ops, recipe)
return ops return ops
def fuse_ops(self) -> None: def maybe_fuse_ops(
"""Attempt to fuse operations""" self,
self._forward_ops = self._fuse_forward_ops(self._forward_ops, self.recipe) is_grad_enabled: bool,
self._backward_ops = self._fuse_backward_ops(self._backward_ops, self.recipe) recipe: Optional[Recipe],
input_: torch.Tensor,
extra_inputs: list[Iterable[torch.Tensor]],
):
"""Attempt to fuse operations if neccesary"""
# Determine which basic ops require backward
if not is_grad_enabled:
first_op_requiring_backward = self._num_basic_ops
elif input_.requires_grad:
first_op_requiring_backward = 0
else:
first_op_requiring_backward = self._num_basic_ops
for op_idx in range(self._num_basic_ops):
op_inputs = itertools.chain(self._basic_op_params[op_idx], extra_inputs[op_idx])
if any(tensor.requires_grad for tensor in op_inputs):
first_op_requiring_backward = op_idx
break
# Early exit if fusion parameters haven't changed
need_reset = False
recipe_type = type(recipe)
fusion_params = (recipe_type, first_op_requiring_backward)
if fusion_params != (self.recipe_type, self.first_op_requiring_backward):
# Recipe type or grad requirmenets have changed
need_reset = True
elif (
recipe is not None
and recipe.delayed()
and self._last_amax_history_len != recipe.amax_history_len
):
# FP8 delayed scaling has changed amax history length
need_reset = True
if not need_reset:
return
# Reset recipe state
for op in self._basic_ops:
op.reset_recipe_state(recipe=recipe)
# Check if this is the first iteration
if self.recipe_type is None:
for op in self._basic_ops:
op.pre_first_fuser_forward()
# Prepare basic op lists for fusions
forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)]
backward_ops = list(reversed(forward_ops[first_op_requiring_backward:]))
# Fuse ops
self._forward_ops = self._fuse_forward_ops(forward_ops, recipe)
self._backward_ops = self._fuse_backward_ops(backward_ops, recipe)
# Save current fusion params
self.recipe_type, self.first_op_requiring_backward = fusion_params
# Save amax history length
if isinstance(recipe, DelayedScaling):
self._last_amax_history_len = recipe.amax_history_len
else:
self._last_amax_history_len = 0
def __call__( def __call__(
self, self,
...@@ -399,23 +445,32 @@ class OperationFuser: ...@@ -399,23 +445,32 @@ class OperationFuser:
basic_op_kwargs: Optional[list[dict[str, Any]]] = None, basic_op_kwargs: Optional[list[dict[str, Any]]] = None,
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
# Verify extra input count # Verify extra input count
if len(extra_inputs) != self._num_extra_inputs: if len(extra_inputs) != self.num_extra_inputs:
raise ValueError( raise ValueError(
f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}" f"Expected {self.num_extra_inputs} extra inputs but got {len(extra_inputs)}"
) )
# Initialization before forward pass
if self._is_first_forward:
for op in self._basic_ops:
op.pre_first_forward(recipe=self.recipe)
self._is_first_forward = False
# Canonicalize op kwargs # Canonicalize op kwargs
if basic_op_kwargs is None: if basic_op_kwargs is None:
basic_op_kwargs = [{}] * self._num_basic_ops basic_op_kwargs = [{}] * self._num_basic_ops
# Fuser forward pass # Unflatten list of extra tensor inputs
extra_inputs_copy = list(extra_inputs)
basic_op_extra_inputs = []
for op in self._basic_ops:
xs, extra_inputs_copy = _split_tuple(extra_inputs_copy, op.num_extra_inputs)
basic_op_extra_inputs.append(xs)
# Get environment state
recipe = None
if FP8GlobalStateManager.is_fp8_enabled():
recipe = FP8GlobalStateManager.get_fp8_recipe()
is_grad_enabled = torch.is_grad_enabled() is_grad_enabled = torch.is_grad_enabled()
# Attempt to fuse operations if neccesary
self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs)
# Fuser forward pass
if is_grad_enabled: if is_grad_enabled:
forward_func = _OperationFuserAutogradFunction.apply forward_func = _OperationFuserAutogradFunction.apply
args = [] args = []
...@@ -426,8 +481,7 @@ class OperationFuser: ...@@ -426,8 +481,7 @@ class OperationFuser:
input, input,
self, self,
basic_op_kwargs, basic_op_kwargs,
is_grad_enabled, *self._flat_basic_op_params,
*self._basic_op_params,
*extra_inputs, *extra_inputs,
) )
return forward_func(*args) return forward_func(*args)
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from typing import Optional from typing import Any, Optional
import torch import torch
...@@ -91,6 +91,8 @@ class Linear(FusedOperation): ...@@ -91,6 +91,8 @@ class Linear(FusedOperation):
# Construct basic ops # Construct basic ops
ops = [] ops = []
linear_idx = None
bias_idx = None
linear_kwargs = { linear_kwargs = {
"in_features": in_features, "in_features": in_features,
"out_features": out_features, "out_features": out_features,
...@@ -111,14 +113,16 @@ class Linear(FusedOperation): ...@@ -111,14 +113,16 @@ class Linear(FusedOperation):
} }
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
# Row TP: GEMM + bias + reduction # Row TP: GEMM + bias + reduction
linear_idx = len(ops)
linear_kwargs["in_features"] = local_in_features linear_kwargs["in_features"] = local_in_features
linear_kwargs["out_features"] = local_out_features linear_kwargs["out_features"] = local_out_features
linear_kwargs["tensor_parallel_mode"] = None linear_kwargs["tensor_parallel_mode"] = None
linear_kwargs["tensor_parallel_group"] = None linear_kwargs["tensor_parallel_group"] = None
linear_kwargs["sequence_parallel"] = False linear_kwargs["sequence_parallel"] = False
bias_kwargs["size"] *= tensor_parallel_size
ops.append(BasicLinear(**linear_kwargs)) ops.append(BasicLinear(**linear_kwargs))
if bias: if bias:
bias_idx = len(ops)
bias_kwargs["size"] *= tensor_parallel_size
ops.append(Bias(**bias_kwargs)) ops.append(Bias(**bias_kwargs))
if sequence_parallel: if sequence_parallel:
ops.append(ReduceScatter(tensor_parallel_group)) ops.append(ReduceScatter(tensor_parallel_group))
...@@ -126,45 +130,81 @@ class Linear(FusedOperation): ...@@ -126,45 +130,81 @@ class Linear(FusedOperation):
ops.append(AllReduce(tensor_parallel_group)) ops.append(AllReduce(tensor_parallel_group))
else: else:
# Column TP or no TP: (gather + GEMM) + bias # Column TP or no TP: (gather + GEMM) + bias
linear_idx = len(ops)
ops.append(BasicLinear(**linear_kwargs)) ops.append(BasicLinear(**linear_kwargs))
if bias: if bias:
bias_idx = len(ops)
ops.append(Bias(**bias_kwargs)) ops.append(Bias(**bias_kwargs))
# Initialize base class # Initialize base class
super().__init__(ops) super().__init__(ops)
self._has_bias: bool = bias # Register parameters
self._linear_idx: Optional[int] = linear_idx
self._bias_idx: Optional[int] = bias_idx
self.register_parameter("weight", self.basic_ops[self._linear_idx].weight)
bias = None
if self._bias_idx is not None:
bias = self.basic_ops[self._bias_idx].bias
self.register_parameter("bias", bias)
@property def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
def weight(self) -> torch.nn.Parameter: """Add a parameter to the module
"""Weight tensor
Parameter is owned by `BasicLinear` operation. Also updates the basic operation that owns the parameter.
""" """
return self.basic_ops[0].weight if name == "bias" and self._bias_idx is None and param is not None:
@weight.setter
def weight(self, value: Optional[torch.nn.Parameter]) -> None:
self.basic_ops[0].weight = value
@property
def bias(self) -> Optional[torch.nn.Parameter]:
"""Bias tensor
Parameter is owned by `Bias` operation.
"""
if self._has_bias:
return self.basic_ops[1].bias
return None
@bias.setter
def bias(self, value: Optional[torch.nn.Parameter]) -> None:
if self._has_bias:
self.basic_ops[1].bias = value
elif value is not None:
raise ValueError( raise ValueError(
"Attempted to set bias parameter in Linear operation " "Attempted to set bias parameter in Linear operation "
"that does not have bias enabled" "that does not have bias enabled"
) )
super().register_parameter(name, param)
if name == "weight":
self.basic_ops[self._linear_idx].weight = param
elif name == "bias" and self._bias_idx is not None:
self.basic_ops[self._bias_idx].bias = param
def state_dict(self, *, prefix: str = "", **kwargs) -> dict[str, Any]:
"""Save state"""
state_dict = super().state_dict(prefix=prefix, **kwargs)
# Remove basic op params from state dict
# Note: Logically, basic ops own params and fused ops are
# considered as stateless. However, we register weight and
# bias params in the linear op for convenience. We remove
# these redudant params from the checkpoint for backward
# compatibility.
if f"{prefix}weight" in state_dict:
del state_dict[f"{prefix}weight"]
if f"{prefix}bias" in state_dict:
del state_dict[f"{prefix}bias"]
return state_dict
def _load_from_state_dict(
self,
state_dict: dict[str, Any],
prefix: str,
*args,
**kwargs,
) -> None:
# Add basic op params to state dict
# Note: Logically, basic ops own params and fused ops are
# considered as stateless. However, we register weight and
# bias params in the linear op for convenience. We remove
# these redudant params from the checkpoint for backward
# compatibility.
if f"{prefix}weight" not in state_dict:
state_dict[f"{prefix}weight"] = state_dict[
f"{prefix}basic_ops.{self._linear_idx}.weight"
]
if f"{prefix}bias" not in state_dict:
if self._bias_idx is None:
state_dict[f"{prefix}bias"] = None
else:
state_dict[f"{prefix}bias"] = state_dict[f"{prefix}basic_ops.{self._bias_idx}.bias"]
# Load state dict
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
...@@ -15,9 +15,6 @@ import torch ...@@ -15,9 +15,6 @@ import torch
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from ..fp8 import ( from ..fp8 import (
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
Float8BlockScalingRecipeState,
FP8GlobalStateManager, FP8GlobalStateManager,
RecipeState, RecipeState,
fp8_autocast, fp8_autocast,
...@@ -65,18 +62,14 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -65,18 +62,14 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def is_fused_op(self) -> bool: def is_fused_op(self) -> bool:
"""Whether this op is the fusion of one or more basic ops""" """Whether this op is the fusion of one or more basic ops"""
def pre_first_forward( def pre_first_fuser_forward(self) -> None:
self, """Preprocessing before first fuser forward pass"""
*,
recipe: Optional[Recipe],
) -> None:
"""Preprocessing before forward pass"""
def get_input_quantizer(self) -> Optional[Quantizer]: def get_input_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input tensor""" """Get builder class for quantized input tensor"""
def get_grad_input_quantizer(self) -> Optional[Quantizer]: def get_grad_output_quantizer(self) -> Optional[Quantizer]:
"""Get builder class for quantized input's grad tensor""" """Get builder class for quantized output's grad tensor"""
def fuser_forward( def fuser_forward(
self, self,
...@@ -84,7 +77,7 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -84,7 +77,7 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -104,8 +97,8 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): ...@@ -104,8 +97,8 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
Input tensor Input tensor
basic_op_extra_inputs: list of torch.Tensor basic_op_extra_inputs: list of torch.Tensor
Extra tensor inputs to basic operations Extra tensor inputs to basic operations
prev_op_grad_input_quantizer: Quantizer, optional prev_op_grad_output_quantizer: Quantizer, optional
The grad_input_quantizer of the preceeding operation The grad_output_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation The input_quantizer of the following operation
basic_op_kwargs: list of dict basic_op_kwargs: list of dict
...@@ -186,8 +179,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -186,8 +179,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
super().__init__() super().__init__()
# Objects for quantization # Objects for quantization
self._quantizers: Optional[dict[str, list[Quantizer]]] = None
self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None
self._quantizers: Optional[dict[str, list[Quantizer]]] = None
with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None
self.reset_recipe_state(recipe=recipe)
@property @property
def is_fused_op(self) -> bool: def is_fused_op(self) -> bool:
...@@ -214,120 +210,141 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -214,120 +210,141 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
return self.get_quantizer("forward", 0) return self.get_quantizer("forward", 0)
return None return None
def get_grad_input_quantizer(self) -> Optional[Quantizer]: def get_grad_output_quantizer(self) -> Optional[Quantizer]:
if self.num_quantizers("backward") > 0: if self.num_quantizers("backward") > 0:
return self.get_quantizer("backward", 0) return self.get_quantizer("backward", 0)
return None return None
def _reset_quantization_recipe_state( def reset_recipe_state(
self, self,
*, *,
recipe: Recipe, recipe: Optional[Recipe],
) -> None: ) -> None:
"""Construct state for quantization recipe""" """Construct state for quantization recipe"""
# Quantization recipe state for forward and backward pass # Clear quantization state if necessary
self._fp8_metas = {"forward": None, "backward": None} if recipe is None:
self._quantizers = {"forward": [], "backward": []} self._fp8_metas = None
for mode in ("forward", "backward"): self._quantizers = None
num_quantizers = self.num_quantizers(mode) return
if num_quantizers == 0:
continue
if recipe.float8_block_scaling():
raise NotImplementedError(
"Fusible operations do not support FP8 block scaling recipe"
)
# Construct quantization recipe state
recipe_state = RecipeState.create(
recipe,
mode=mode,
num_quantizers=num_quantizers,
)
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
self._fp8_metas[mode] = {
fp8_meta_key: recipe_state,
"recipe": recipe,
"fp8_group": FP8GlobalStateManager.get_fp8_group(),
}
# Construct builder class for quantized tensors
self._quantizers[mode] = recipe_state.make_quantizers()
def _update_quantization_recipe_state( # Communication group for FP8 amax reductions
self, fp8_group = FP8GlobalStateManager.get_fp8_group()
*,
recipe: Recipe,
) -> None:
"""Make sure quantizer state matches quantization recipe"""
# Reset quantization state if needed # Skip resetting recipe type if it did not actually change.
# This could happen for example if calling BasicOperation.forward directly, as in that
# case, the OperationFuser is not persistent, or when loading from a checkpoint
need_to_reset_recipe_state = False
if self._fp8_metas is None or self._quantizers is None: if self._fp8_metas is None or self._quantizers is None:
self._reset_quantization_recipe_state(recipe=recipe) need_to_reset_recipe_state = True
return else:
for mode in ("forward", "backward"): for mode in ("forward", "backward"):
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"), forward=(mode == "forward"),
)
if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]:
continue
recipe_state = self._fp8_metas[mode][fp8_meta_key]
need_to_reset_recipe_state = (
(recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState))
or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState))
or (
recipe.float8_block_scaling()
and not isinstance(recipe_state, Float8BlockScalingRecipeState)
) )
) if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]:
if need_to_reset_recipe_state: continue
self._reset_quantization_recipe_state(recipe=recipe) recipe_state = self._fp8_metas[mode][fp8_meta_key]
return if not isinstance(recipe, type(recipe_state.recipe)):
need_to_reset_recipe_state = True
break
if need_to_reset_recipe_state:
# Construct quantization recipe states
self._fp8_metas = {"forward": None, "backward": None}
self._quantizers = {"forward": [], "backward": []}
for mode in ("forward", "backward"):
num_quantizers = self.num_quantizers(mode)
if num_quantizers == 0:
continue
# Quantization recipe state for forward and backward pass if recipe.float8_block_scaling():
for mode in ("forward", "backward"): raise NotImplementedError(
num_quantizers = self.num_quantizers(mode) "Fusible operations do not support FP8 block scaling recipe"
if num_quantizers == 0: )
continue
# Update FP8 metadata # Construct quantization recipe state
fp8_meta = self._fp8_metas[mode] recipe_state = RecipeState.create(
fp8_meta["recipe"] = recipe recipe,
fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() mode=mode,
num_quantizers=num_quantizers,
)
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
self._fp8_metas[mode] = {
fp8_meta_key: recipe_state,
"recipe": recipe,
"fp8_group": fp8_group,
}
# Get recipe state # Construct builder class for quantized tensors
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( self._quantizers[mode] = recipe_state.make_quantizers()
forward=(mode == "forward"), else:
) # Update quantization recipe states
recipe_state = fp8_meta[fp8_meta_key] for mode in ("forward", "backward"):
if self._fp8_metas[mode] is None:
continue
self._fp8_metas[mode]["recipe"] = recipe
self._fp8_metas[mode]["fp8_group"] = fp8_group
# Reallocate amax history if needed # Update amax history for FP8 delayed scaling
if not recipe.delayed(): if recipe.delayed():
continue fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
recipe_state = self._fp8_metas[mode][fp8_meta_key]
current_length = recipe_state.amax_history.size(0) # Reallocate amax history if needed
target_length = recipe.amax_history_len current_length = recipe_state.amax_history.size(0)
if current_length != target_length: target_length = recipe.amax_history_len
with torch.no_grad():
if target_length < current_length: if target_length < current_length:
recipe_state.amax_history = recipe_state.amax_history[ with torch.no_grad():
:target_length recipe_state.amax_history = recipe_state.amax_history[
].clone() :target_length
else: ].clone()
recipe_state.amax_history = torch.nn.functional.pad( elif target_length > current_length:
recipe_state.amax_history, with torch.no_grad():
pad=(0, 0, 0, target_length - current_length), recipe_state.amax_history = torch.nn.functional.pad(
) recipe_state.amax_history,
self._quantizers[mode] = recipe_state.make_quantizers() pad=(0, 0, 0, target_length - current_length),
)
# Update quantizers with new amax pointers
self._quantizers[mode] = recipe_state.make_quantizers()
# Update the global buffers with new amax pointers
if FP8GlobalStateManager.get_buffer_info() in self._fp8_metas[mode]:
pos, buffer_key = self._fp8_metas[mode][
FP8GlobalStateManager.get_buffer_info()
]
if buffer_key in FP8GlobalStateManager.global_amax_buffer:
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = (
recipe_state.amax_history[0]
)
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][
pos
] = recipe_state.amax_history
# Add meta tensors to global buffer to participate in reduction
for mode in ("forward", "backward"):
if (
FP8GlobalStateManager.is_fp8_enabled()
and self.num_quantizers(mode)
and not FP8GlobalStateManager.fp8_graph_capturing()
):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self._fp8_metas[mode],
)
def get_quantizer( def get_quantizer(
self, self,
mode: str, mode: str,
index: int, index: int,
) -> Quantizer: ) -> Optional[Quantizer]:
"""Get builder class for quantized tensor """Get builder class for quantized tensor
Parameters Parameters
...@@ -337,7 +354,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -337,7 +354,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
""" """
if self._quantizers is None: if self._quantizers is None:
self._reset_quantization_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe()) return None
return self._quantizers[mode][index] return self._quantizers[mode][index]
@torch.no_grad() @torch.no_grad()
...@@ -388,33 +405,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -388,33 +405,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)
def pre_first_forward(
self,
*,
recipe: Optional[Recipe],
) -> None:
"""Preprocessing before forward pass"""
# Initialize FP8 metadata if needed
if recipe is not None:
self._update_quantization_recipe_state(recipe=recipe)
if not FP8GlobalStateManager.fp8_graph_capturing():
if self.num_quantizers("forward"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self._fp8_metas["forward"],
)
if self.num_quantizers("backward"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self._fp8_metas["backward"],
)
@abc.abstractmethod @abc.abstractmethod
def op_forward( def op_forward(
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
*, *,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -426,8 +423,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -426,8 +423,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
Context to coordinate between forward and backward passes Context to coordinate between forward and backward passes
input_: torch.Tensor input_: torch.Tensor
Input tensor Input tensor
prev_op_grad_input_quantizer: Quantizer, optional prev_op_grad_output_quantizer: Quantizer, optional
The grad_input_quantizer of the preceeding operation The grad_output_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation The input_quantizer of the following operation
...@@ -468,7 +465,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -468,7 +465,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, list[tuple[()]]]: ) -> tuple[torch.Tensor, list[tuple[()]]]:
...@@ -482,7 +479,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -482,7 +479,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
output = self.op_forward( output = self.op_forward(
basic_op_ctxs[0], basic_op_ctxs[0],
input_, input_,
prev_op_grad_input_quantizer=prev_op_grad_input_quantizer, prev_op_grad_output_quantizer=prev_op_grad_output_quantizer,
next_op_input_quantizer=next_op_input_quantizer, next_op_input_quantizer=next_op_input_quantizer,
**basic_op_kwargs[0], **basic_op_kwargs[0],
) )
...@@ -518,9 +515,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -518,9 +515,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""Apply operation""" """Apply operation"""
from .fuser import OperationFuser from .fuser import OperationFuser
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() return OperationFuser([self])(
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
input, input,
*extra_inputs, *extra_inputs,
basic_op_kwargs=[kwargs], basic_op_kwargs=[kwargs],
...@@ -630,7 +625,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -630,7 +625,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed # Get op's quantizer state, initializing if needed
if self._fp8_metas is None or self._fp8_metas[mode] is None: if self._fp8_metas is None or self._fp8_metas[mode] is None:
with fp8_autocast(fp8_recipe=state[mode]["recipe"]): with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
self._reset_quantization_recipe_state(recipe=state[mode]["recipe"]) self.reset_recipe_state(recipe=state[mode]["recipe"])
fp8_meta = self._fp8_metas[mode] fp8_meta = self._fp8_metas[mode]
# Load extra items # Load extra items
...@@ -708,13 +703,12 @@ class FusedOperation(FusibleOperation): ...@@ -708,13 +703,12 @@ class FusedOperation(FusibleOperation):
def get_input_quantizer(self) -> Optional[Quantizer]: def get_input_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[0].get_input_quantizer() return self.basic_ops[0].get_input_quantizer()
def get_grad_input_quantizer(self) -> Optional[Quantizer]: def get_grad_output_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[-1].get_grad_input_quantizer() return self.basic_ops[-1].get_grad_output_quantizer()
def pre_first_forward(self, *args, **kwargs) -> None: def pre_first_fuser_forward(self) -> None:
"""Preprocessing before forward pass"""
for op in self.basic_ops: for op in self.basic_ops:
op.pre_first_forward(*args, **kwargs) op.pre_first_fuser_forward()
def forward( def forward(
self, self,
...@@ -727,9 +721,7 @@ class FusedOperation(FusibleOperation): ...@@ -727,9 +721,7 @@ class FusedOperation(FusibleOperation):
basic_op_kwargs = [{} for _ in range(len(self.basic_ops))] basic_op_kwargs = [{} for _ in range(len(self.basic_ops))]
from .fuser import OperationFuser from .fuser import OperationFuser
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() return OperationFuser([self])(
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
return OperationFuser([self], fuse_ops=False, recipe=recipe)(
input, input,
*extra_inputs, *extra_inputs,
basic_op_kwargs=basic_op_kwargs, basic_op_kwargs=basic_op_kwargs,
......
...@@ -10,7 +10,6 @@ from typing import Optional ...@@ -10,7 +10,6 @@ from typing import Optional
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe
from transformer_engine.pytorch.ops.op import FusibleOperation from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.fuser import OperationFuser from transformer_engine.pytorch.ops.fuser import OperationFuser
...@@ -147,7 +146,6 @@ class Sequential(torch.nn.Module): ...@@ -147,7 +146,6 @@ class Sequential(torch.nn.Module):
def _make_module_groups( def _make_module_groups(
cls, cls,
modules: Iterable[torch.nn.Module], modules: Iterable[torch.nn.Module],
recipe: Optional[Recipe],
) -> list[OperationFuser | torch.nn.Module]: ) -> list[OperationFuser | torch.nn.Module]:
"""Make list of modules, with fusible operations grouped together""" """Make list of modules, with fusible operations grouped together"""
...@@ -162,24 +160,7 @@ class Sequential(torch.nn.Module): ...@@ -162,24 +160,7 @@ class Sequential(torch.nn.Module):
groups.append(module) groups.append(module)
for idx, group in enumerate(groups): for idx, group in enumerate(groups):
if isinstance(group, list): if isinstance(group, list):
groups[idx] = OperationFuser(group, fuse_ops=True, recipe=recipe) groups[idx] = OperationFuser(group)
# Check if operations expect extra input or output tensors
# Note: If any op has extra inputs or outputs, then the entire
# Sequential must be made up of TE ops.
if len(groups) > 1:
ops = []
for group in groups:
if isinstance(group, OperationFuser):
ops.extend(group._basic_ops)
num_extra_inputs = sum(op.num_extra_inputs for op in ops)
num_extra_outputs = sum(op.num_extra_outputs for op in ops)
if num_extra_inputs > 0 or num_extra_outputs > 0:
raise RuntimeError(
f"`Sequential` expects {num_extra_inputs} extra inputs "
f"and {num_extra_outputs} extra outputs, "
"but it contains non-fusible operations"
)
return groups return groups
...@@ -190,22 +171,28 @@ class Sequential(torch.nn.Module): ...@@ -190,22 +171,28 @@ class Sequential(torch.nn.Module):
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass""" """Forward pass"""
# Get current global state
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
global_state = (with_quantized_compute, type(recipe))
# Reset module groups is global state changed
if self._last_global_state != global_state:
self._module_groups = None
self._last_global_state = global_state
# Create module groups if needed # Create module groups if needed
if self._module_groups is None: if self._module_groups is None:
self._module_groups = self._make_module_groups(self._modules.values(), recipe) self._module_groups = self._make_module_groups(self._modules.values())
# Forward pass for each module group # Forward pass for each module group
x = input x = input
extra_outputs: list[torch.Tensor] = []
for module_group in self._module_groups: for module_group in self._module_groups:
x = module_group(x, *extra_inputs) if isinstance(module_group, OperationFuser):
xs, extra_inputs = (
(x,) + extra_inputs[: module_group.num_extra_inputs],
extra_inputs[module_group.num_extra_inputs :],
)
xs = module_group(*xs)
if isinstance(xs, tuple):
x, ys = xs[0], xs[1:]
extra_outputs.extend(ys)
else:
x = xs
else:
x = module_group(x)
if extra_outputs:
return (x,) + tuple(extra_outputs)
return x return x
...@@ -60,7 +60,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -60,7 +60,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
instance = super().__new__(cls, *args, **kwargs) instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data instance._columnwise_data = columnwise_data
instance._quantizer = quantizer instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv
...@@ -125,9 +125,15 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -125,9 +125,15 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
self._columnwise_scale_inv = tensors[3] self._columnwise_scale_inv = tensors[3]
return tensors[4:] return tensors[4:]
def get_data_tensors(self): def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True):
"""Get this Tensor's data.""" """Get this Tensor's data."""
return self._rowwise_data, self._columnwise_data if rowwise_data and columnwise_data:
return self._rowwise_data, self._columnwise_data
if rowwise_data:
return self._rowwise_data
if columnwise_data:
return self._columnwise_data
raise ValueError("No data to get, both rowwise_data and columnwise_data are False")
def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor: def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor:
"""Takes dequantized columnwise data and permutes to a rowwise shape""" """Takes dequantized columnwise data and permutes to a rowwise shape"""
......
...@@ -86,7 +86,7 @@ class Float8TensorBase(QuantizedTensorBase): ...@@ -86,7 +86,7 @@ class Float8TensorBase(QuantizedTensorBase):
else: else:
instance = super().__new__(cls, *args, **kwargs) instance = super().__new__(cls, *args, **kwargs)
instance._data = data instance._data = data
instance._quantizer = quantizer instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype instance._fp8_dtype = fp8_dtype
instance._scale_inv = fp8_scale_inv instance._scale_inv = fp8_scale_inv
instance._transpose = data_transpose instance._transpose = data_transpose
...@@ -128,9 +128,15 @@ class Float8TensorBase(QuantizedTensorBase): ...@@ -128,9 +128,15 @@ class Float8TensorBase(QuantizedTensorBase):
self._scale_inv = tensors[2] self._scale_inv = tensors[2]
return tensors[3:] return tensors[3:]
def get_data_tensors(self): def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True):
"""Get this Tensor's data.""" """Get this Tensor's data."""
return self._data, self._transpose if rowwise_data and columnwise_data:
return self._data, self._transpose
if rowwise_data:
return self._data
if columnwise_data:
return self._transpose
raise ValueError("No data to get, both rowwise_data and columnwise_data are False")
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Dequantize to a higher precision.""" """Dequantize to a higher precision."""
......
...@@ -83,7 +83,7 @@ class MXFP8TensorBase(QuantizedTensorBase): ...@@ -83,7 +83,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
instance = super().__new__(cls, *args, **kwargs) instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data instance._columnwise_data = columnwise_data
instance._quantizer = quantizer instance._quantizer = quantizer.copy() if quantizer is not None else None
instance._fp8_dtype = fp8_dtype instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv
...@@ -136,9 +136,15 @@ class MXFP8TensorBase(QuantizedTensorBase): ...@@ -136,9 +136,15 @@ class MXFP8TensorBase(QuantizedTensorBase):
self._columnwise_scale_inv = tensors[3] self._columnwise_scale_inv = tensors[3]
return tensors[4:] return tensors[4:]
def get_data_tensors(self): def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True):
"""Get this Tensor's data.""" """Get this Tensor's data."""
return self._rowwise_data, self._columnwise_data if rowwise_data and columnwise_data:
return self._rowwise_data, self._columnwise_data
if rowwise_data:
return self._rowwise_data
if columnwise_data:
return self._columnwise_data
raise ValueError("No data to get, both rowwise_data and columnwise_data are False")
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Dequantize to a higher precision.""" """Dequantize to a higher precision."""
......
...@@ -524,7 +524,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -524,7 +524,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor):
dst._rowwise_data = src._rowwise_data dst._rowwise_data = src._rowwise_data
dst._columnwise_data = src._columnwise_data dst._columnwise_data = src._columnwise_data
dst._quantizer = src._quantizer dst._quantizer = src._quantizer.copy()
dst._fp8_dtype = src._fp8_dtype dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv
......
...@@ -109,10 +109,9 @@ class Float8Quantizer(Quantizer): ...@@ -109,10 +109,9 @@ class Float8Quantizer(Quantizer):
# Allocate FP8 data transpose if needed # Allocate FP8 data transpose if needed
data_transpose = None data_transpose = None
if self.columnwise_usage: if self.columnwise_usage:
inner_dim = data.size(-1) transpose_shape = [data.size(-1)] + list(data.shape[:-1])
data_transpose = torch.empty( data_transpose = torch.empty(
inner_dim, transpose_shape,
data.numel() // inner_dim,
dtype=torch.uint8, dtype=torch.uint8,
device=device, device=device,
) )
...@@ -186,6 +185,12 @@ class Float8Quantizer(Quantizer): ...@@ -186,6 +185,12 @@ class Float8Quantizer(Quantizer):
def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return DelayedScaling return DelayedScaling
def supports_only_rowwise_all_gather(self) -> bool:
"""
Float8Quantizer supports only rowwise all-gather
"""
return True
class Float8CurrentScalingQuantizer(Quantizer): class Float8CurrentScalingQuantizer(Quantizer):
"""Builder class for FP8 tensors with per-tensor current scaling """Builder class for FP8 tensors with per-tensor current scaling
...@@ -231,7 +236,7 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -231,7 +236,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax_epsilon: float = 0.0, amax_epsilon: float = 0.0,
) -> None: ) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise) super().__init__(rowwise=rowwise, columnwise=columnwise)
self.scale = torch.ones(1, dtype=torch.float32, device=device) self.scale = torch.empty(1, dtype=torch.float32, device=device)
self.amax = torch.empty(1, dtype=torch.float32, device=device) self.amax = torch.empty(1, dtype=torch.float32, device=device)
self.dtype = tex.DType.kInt8 if int8_simulation_fp8_tensorwise else fp8_dtype self.dtype = tex.DType.kInt8 if int8_simulation_fp8_tensorwise else fp8_dtype
self.with_amax_reduction = with_amax_reduction self.with_amax_reduction = with_amax_reduction
...@@ -363,6 +368,12 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -363,6 +368,12 @@ class Float8CurrentScalingQuantizer(Quantizer):
def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return Float8CurrentScaling return Float8CurrentScaling
def supports_only_rowwise_all_gather(self) -> bool:
"""
Float8CurrentScalingQuantizer supports only rowwise all-gather
"""
return True
class Float8Tensor(Float8TensorBase, QuantizedTensor): class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data """Experimental tensor class with FP8 data
...@@ -691,7 +702,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -691,7 +702,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
# Float8Tensor attributes # Float8Tensor attributes
self._data = tensor._data self._data = tensor._data
self._quantizer = tensor._quantizer self._quantizer = tensor._quantizer.copy()
self._fp8_dtype = tensor._fp8_dtype self._fp8_dtype = tensor._fp8_dtype
self._scale_inv = tensor._scale_inv self._scale_inv = tensor._scale_inv
self._transpose = tensor._transpose self._transpose = tensor._transpose
......
...@@ -100,7 +100,7 @@ class MXFP8Quantizer(Quantizer): ...@@ -100,7 +100,7 @@ class MXFP8Quantizer(Quantizer):
# Allocate FP8 data # Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device) data = torch.empty(shape, dtype=torch.uint8, device=device)
scale_inv = torch.zeros( scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
dtype=torch.uint8, dtype=torch.uint8,
...@@ -112,7 +112,7 @@ class MXFP8Quantizer(Quantizer): ...@@ -112,7 +112,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv = None columnwise_scale_inv = None
if self.columnwise_usage: if self.columnwise_usage:
columnwise_data = torch.empty_like(data) columnwise_data = torch.empty_like(data)
columnwise_scale_inv = torch.zeros( columnwise_scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128), round_up_to_nearest_multiple(shape[-1], 128),
dtype=torch.uint8, dtype=torch.uint8,
...@@ -433,7 +433,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -433,7 +433,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor)
self._rowwise_data = tensor._rowwise_data self._rowwise_data = tensor._rowwise_data
self._columnwise_data = tensor._columnwise_data self._columnwise_data = tensor._columnwise_data
self._quantizer = tensor._quantizer self._quantizer = tensor._quantizer.copy()
self._fp8_dtype = tensor._fp8_dtype self._fp8_dtype = tensor._fp8_dtype
self._rowwise_scale_inv = tensor._rowwise_scale_inv self._rowwise_scale_inv = tensor._rowwise_scale_inv
self._columnwise_scale_inv = tensor._columnwise_scale_inv self._columnwise_scale_inv = tensor._columnwise_scale_inv
......
...@@ -260,6 +260,10 @@ class Quantizer(abc.ABC): ...@@ -260,6 +260,10 @@ class Quantizer(abc.ABC):
def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Returns recipe class that is compatible with this quantizer""" """Returns recipe class that is compatible with this quantizer"""
def supports_only_rowwise_all_gather(self) -> bool:
"""Returns True if the quantizer supports only rowwise all-gather"""
return False
class _QuantizeFunc(torch.autograd.Function): class _QuantizeFunc(torch.autograd.Function):
"""Cast to FP8 from other dtype""" """Cast to FP8 from other dtype"""
......
...@@ -236,14 +236,23 @@ class TransformerLayer(torch.nn.Module): ...@@ -236,14 +236,23 @@ class TransformerLayer(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`. `fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False' qk_norm_type: Optional[str], default = None
if set to `True`, L2 normalization is applied to query and key tensors type of normalization to apply to query and key tensors.
after RoPE (if applicable) but before attention computation. Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied.
This follows the Llama4 approach for QK normalization to improve When 'L2Normalization', L2 normalization is applied to query and key tensors.
training stability and model performance. When 'RMSNorm', RMS normalization is applied to query and key tensors.
When 'LayerNorm', layer normalization is applied to query and key tensors.
Normalization is applied after RoPE (if applicable) but before attention computation
when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach for
QK normalization to improve training stability and model performance.
qk_norm_eps: float, default = 1e-6 qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors. epsilon value for normalization of query and key tensors.
Only used when `use_qk_norm` is True. Only used when `qk_norm_type` is not None.
qk_norm_before_rope: bool, default = `False`
if set to `True`, query and key normalization is applied before rotary position
embedding. When `False` (default), normalization is applied after RoPE.
This parameter allows supporting different architectural variants that apply
QK normalization at different points.
""" """
def __init__( def __init__(
...@@ -293,8 +302,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -293,8 +302,9 @@ class TransformerLayer(torch.nn.Module):
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd", attn_input_format: str = "sbhd",
name: str = None, name: str = None,
use_qk_norm: bool = False, qk_norm_type: Optional[str] = None,
qk_norm_eps: float = 1e-6, qk_norm_eps: float = 1e-6,
qk_norm_before_rope: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -397,8 +407,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -397,8 +407,9 @@ class TransformerLayer(torch.nn.Module):
return_bias=not self.parallel_attention_mlp, return_bias=not self.parallel_attention_mlp,
normalization=normalization, normalization=normalization,
device=device, device=device,
use_qk_norm=use_qk_norm, qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps, qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope,
name=name + ".self_attention" if name is not None else None, name=name + ".self_attention" if name is not None else None,
) )
...@@ -413,8 +424,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -413,8 +424,9 @@ class TransformerLayer(torch.nn.Module):
return_bias=True, return_bias=True,
normalization=normalization, normalization=normalization,
device=device, device=device,
use_qk_norm=use_qk_norm, qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps, qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope,
name=name + ".inter_attention" if name is not None else None, name=name + ".inter_attention" if name is not None else None,
) )
......
...@@ -341,13 +341,17 @@ def cross_entropy_forward( ...@@ -341,13 +341,17 @@ def cross_entropy_forward(
return loss, _input return loss, _input
def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): def cross_entropy_backward(
_input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False
):
"""Backward implementation of cross entropy loss kernel""" """Backward implementation of cross entropy loss kernel"""
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): # Only check torch.equal when not in CUDA graph capturable mode
if not is_cg_capturable and torch.equal(
grad_output, torch.tensor(1.0, device=grad_output.device)
):
pass pass
else: else:
B, SQ, V = _input.shape B, SQ, V = _input.shape
n_rows = B * SQ n_rows = B * SQ
......
...@@ -359,7 +359,7 @@ def _permute_kernel( ...@@ -359,7 +359,7 @@ def _permute_kernel(
if prob == 0.0: if prob == 0.0:
# for routing_map padding # for routing_map padding
# dst_row != -1 and prob == 0.0 means that this slot is padded # dst_row != -1 and prob == 0.0 means that this slot is padded
tl.store(output_ptr + output_off, 0, mask=mask) tl.store(output_ptr + output_off, 0.0, mask=mask)
else: else:
tl.store(output_ptr + output_off, inp, mask=mask) tl.store(output_ptr + output_off, inp, mask=mask)
else: else:
......
...@@ -45,10 +45,10 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: ...@@ -45,10 +45,10 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
for t in tensors: for t in tensors:
if t is not None: if t is not None:
# Workaround for double buffering in cpu offload # Workaround for double buffering in cpu offload
if hasattr(t, "do_not_clear"): if hasattr(t, "_do_not_clear"):
continue continue
if hasattr(t, "get_data_tensors"): if hasattr(t, "get_data_tensors"):
if any(hasattr(tensor, "do_not_clear") for tensor in t.get_data_tensors()): if any(hasattr(tensor, "_do_not_clear") for tensor in t.get_data_tensors()):
continue continue
if hasattr(t, "clear"): if hasattr(t, "clear"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment