Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
...@@ -13,6 +13,7 @@ from contextlib import contextmanager ...@@ -13,6 +13,7 @@ from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Optional from typing import Callable, Optional
import warnings import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.interpreters import pxla from jax.interpreters import pxla
...@@ -130,7 +131,22 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): ...@@ -130,7 +131,22 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
# We want to exclude the axes that already used by shard_map and shard_map # We want to exclude the axes that already used by shard_map and shard_map
# only sets those in the abstract_mesh, not the physical one # only sets those in the abstract_mesh, not the physical one
manual_axis_names = get_abstract_mesh().manual_axes manual_axis_names = get_abstract_mesh().manual_axes
cleaned_axis_names = tuple(name if name not in manual_axis_names else None for name in pspec)
# Multiple mesh axes can be mapped to a single shape axis, so we need to unpack and process tuples here too
def filter_manual_axes(name_or_tuple):
if isinstance(name_or_tuple, tuple):
out = tuple(n for n in name_or_tuple if n not in manual_axis_names)
if len(out) == 0:
return None
return out
if name_or_tuple in manual_axis_names:
return None
return name_or_tuple
cleaned_axis_names = tuple(filter_manual_axes(name_or_tuple) for name_or_tuple in pspec)
if cleaned_axis_names == (None,) * len(cleaned_axis_names):
return x
cleaned_pspec = PartitionSpec(*cleaned_axis_names) cleaned_pspec = PartitionSpec(*cleaned_axis_names)
return jax.lax.with_sharding_constraint(x, cleaned_pspec) return jax.lax.with_sharding_constraint(x, cleaned_pspec)
...@@ -349,6 +365,21 @@ def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh): ...@@ -349,6 +365,21 @@ def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh) return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
def all_reduce_sum_along_dp_fsdp_tpsp(x: jnp.array, mesh: jax.sharding.Mesh):
"""Perform all-reduce sum operation along data parallelism and sequence parallelism axes.
Args:
x: Input tensor to reduce
mesh: JAX mesh for distributed computation
Returns:
Reduced tensor
"""
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().tpsp_resource, mesh)
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh): def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
"""Perform all-reduce max operation along all axes except pipeline parallelism. """Perform all-reduce max operation along all axes except pipeline parallelism.
...@@ -364,3 +395,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes ...@@ -364,3 +395,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes
if axis != global_mesh_resource().pp_resource: if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis, mesh) x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x return x
def tpsp_axis_size():
"""
Get the size of the tensor parallelism axis.
Return 1 if no TP axis is set.
"""
return get_mesh_axis_size(global_mesh_resource().tpsp_resource)
def dp_or_fsdp_axis_size():
"""
Get the size of the data parallelism or FSDP axis.
Return 1 if no DP/FSDP axis is set.
"""
dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource)
fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource)
return dp_size if dp_size > 1 else fsdp_size
...@@ -46,8 +46,18 @@ from transformer_engine.pytorch.permutation import ( ...@@ -46,8 +46,18 @@ from transformer_engine.pytorch.permutation import (
moe_sort_chunks_by_index, moe_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs, moe_sort_chunks_by_index_with_probs,
) )
from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.quantization import fp8_autocast
from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.pytorch.quantization import fp8_model_init
from transformer_engine.pytorch.quantization import autocast
from transformer_engine.pytorch.quantization import quantized_model_init
from transformer_engine.pytorch.quantization import is_fp8_available
from transformer_engine.pytorch.quantization import is_mxfp8_available
from transformer_engine.pytorch.quantization import is_fp8_block_scaling_available
from transformer_engine.pytorch.quantization import is_nvfp4_available
from transformer_engine.pytorch.quantization import get_default_recipe
from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.utils import is_bf16_available
from transformer_engine.pytorch.graph import make_graphed_callables from transformer_engine.pytorch.graph import make_graphed_callables
from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import checkpoint
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
...@@ -56,6 +66,24 @@ from transformer_engine.pytorch import ops ...@@ -56,6 +66,24 @@ from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers from transformer_engine.pytorch import optimizers
from transformer_engine.pytorch.export import onnx_export from transformer_engine.pytorch.export import onnx_export
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor import Float8Quantizer
from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor import NVFP4Quantizer
from transformer_engine.pytorch.tensor import QuantizedTensorStorage
from transformer_engine.pytorch.tensor import Float8TensorStorage
from transformer_engine.pytorch.tensor import MXFP8TensorStorage
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensorStorage
from transformer_engine.pytorch.tensor import NVFP4TensorStorage
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor import Float8Tensor
from transformer_engine.pytorch.tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor import NVFP4Tensor
from transformer_engine.pytorch.tensor import prepare_for_saving
from transformer_engine.pytorch.tensor import restore_from_saved
try: try:
torch._dynamo.config.error_on_nested_jit_trace = False torch._dynamo.config.error_on_nested_jit_trace = False
......
...@@ -13,21 +13,24 @@ import logging ...@@ -13,21 +13,24 @@ import logging
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
import torch import torch
import torch.nn.functional as F
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
SplitAlongDim,
get_device_compute_capability, get_device_compute_capability,
combine_tensors,
split_tensor_along_dim, split_tensor_along_dim,
) )
from transformer_engine.pytorch.utils import attention_mask_func from transformer_engine.pytorch.utils import attention_mask_func, nvtx_range_push, nvtx_range_pop
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.quantized_tensor import ( from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensorStorage,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.constants import ( from transformer_engine.pytorch.constants import (
TE_DType, TE_DType,
QKVLayouts, QKVLayouts,
...@@ -40,7 +43,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -40,7 +43,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_O, META_O,
META_QKV, META_QKV,
) )
from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype from transformer_engine.pytorch.quantization import get_fp8_torch_dtype, FP8GlobalStateManager
from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.jit import no_torch_dynamo from transformer_engine.pytorch.jit import no_torch_dynamo
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
...@@ -53,6 +56,9 @@ from transformer_engine.pytorch.attention.inference import InferenceParams ...@@ -53,6 +56,9 @@ from transformer_engine.pytorch.attention.inference import InferenceParams
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.attention.dot_product_attention.utils import ( from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils as fa_utils, FlashAttentionUtils as fa_utils,
combine_and_quantize,
combine_and_dequantize,
print_quantizers,
) )
from transformer_engine.pytorch.attention.dot_product_attention.utils import ( from transformer_engine.pytorch.attention.dot_product_attention.utils import (
AttentionLogging as attn_log, AttentionLogging as attn_log,
...@@ -131,6 +137,58 @@ if not IS_HIP_EXTENSION: ...@@ -131,6 +137,58 @@ if not IS_HIP_EXTENSION:
fa_utils.set_flash_attention_3_params() fa_utils.set_flash_attention_3_params()
# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1"
class FP8EmulationFunc(torch.autograd.Function):
"""
Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows:
- forward : QKV (quantize+dequantize), P (pass-through), S (quantize+dequantize), O (pass-through)
- backward: dO (quantize+dequantize), dS (pass-through), dP (quantize+dequantize), dQKV (pass-through)
"""
@staticmethod
def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout):
# pylint: disable=missing-function-docstring
if quantizer_name == "QKV_quantizer":
query_layer, key_layer, value_layer = [
x.contiguous() for x in [tensor1, tensor2, tensor3]
]
q_fp8, k_fp8, v_fp8 = combine_and_quantize(
qkv_layout, query_layer, key_layer, value_layer, quantizer
)
tensors = combine_and_dequantize(
qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype
)
elif quantizer_name in ["S_quantizer", "O_quantizer"]:
t_fp8 = quantizer(tensor1)
tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3)
else:
tensors = (tensor1, tensor2, tensor3)
ctx.quantizer = quantizer
ctx.quantizer_name = quantizer_name
ctx.qkv_layout = qkv_layout
return tensors[0], tensors[1], tensors[2]
@staticmethod
def backward(ctx, grad1, grad2, grad3):
# pylint: disable=missing-function-docstring
if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]:
dt_fp8 = ctx.quantizer(grad1)
tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3
elif ctx.quantizer_name == "dQKV_quantizer":
query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]]
dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize(
ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer
)
tensors = combine_and_dequantize(
ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype
)
else:
tensors = grad1, grad2, grad3
return tensors[0], tensors[1], tensors[2], None, None, None
class UnfusedDotProductAttention(torch.nn.Module): class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms """Parallel attention w/o QKV and Proj Gemms
...@@ -144,6 +202,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -144,6 +202,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -151,6 +210,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -151,6 +210,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_type = attention_type self.attention_type = attention_type
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number self.layer_number = layer_number
self.softmax_type = softmax_type
def mask_func(x, y): def mask_func(x, y):
return ( return (
...@@ -187,6 +247,11 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -187,6 +247,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
quantizers=None,
fp8_output: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""Unfused attention fprop""" """Unfused attention fprop"""
assert ( assert (
...@@ -284,6 +349,35 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -284,6 +349,35 @@ class UnfusedDotProductAttention(torch.nn.Module):
if apply_qk_layer_scaling: if apply_qk_layer_scaling:
scale /= self.layer_number scale /= self.layer_number
if fp8:
# get quantizers from DPA; all Nones if not fp8
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
dpa_utils.get_attention_quantizers(fp8, quantizers)
)
# S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
if fp8_recipe.float8_current_scaling():
S_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=S_quantizer.dtype, device="cuda"
)
dP_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=dP_quantizer.dtype, device="cuda"
)
if "2" in qkv_layout or "3" in qkv_layout:
qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout)
qkv_layout = "_".join([qkv_format] * 3)
# quantize and dequantize QKV to emulate FP8
query_layer, key_layer, value_layer = FP8EmulationFunc.apply(
query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout
)
# quantize and dequantize dQKV to emulate FP8
query_layer, key_layer, value_layer = FP8EmulationFunc.apply(
query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout
)
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
if core_attention_bias_type == "no_bias": if core_attention_bias_type == "no_bias":
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
...@@ -328,7 +422,27 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -328,7 +422,27 @@ class UnfusedDotProductAttention(torch.nn.Module):
dtype=query_layer.dtype dtype=query_layer.dtype
) )
# attention scores and attention mask [b, np, sq, sk] if fp8:
# quantize and dequantize dP to emulate FP8
matmul_result, *_ = FP8EmulationFunc.apply(
matmul_result, None, None, dP_quantizer, "dP_quantizer", None
)
# add attention sink to the last column: [b, np, sq, sk+1]
if self.softmax_type != "vanilla":
matmul_result = torch.cat(
[
matmul_result,
softmax_offset.to(dtype=matmul_result.dtype).expand(
matmul_result.size(0), -1, matmul_result.size(2), -1
),
],
dim=-1,
)
attention_mask = F.pad(attention_mask, (0, 1), mode="constant", value=False)
attn_mask_type = "arbitrary"
# attention scores and attention mask
softmax_scale = self.layer_number if apply_qk_layer_scaling else None softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax( attention_probs = self.scale_mask_softmax(
matmul_result, attention_mask, attn_mask_type, softmax_scale matmul_result, attention_mask, attn_mask_type, softmax_scale
...@@ -339,6 +453,10 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -339,6 +453,10 @@ class UnfusedDotProductAttention(torch.nn.Module):
if "padding" in attn_mask_type: if "padding" in attn_mask_type:
attention_probs = attention_probs.masked_fill(attention_mask, 0) attention_probs = attention_probs.masked_fill(attention_mask, 0)
# remove attention sink: [b, np, sq, sk]
if self.softmax_type != "vanilla":
attention_probs = attention_probs[..., :-1]
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
...@@ -359,6 +477,12 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -359,6 +477,12 @@ class UnfusedDotProductAttention(torch.nn.Module):
# change view [b * np, sq, sk] # change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
if fp8:
# quantize and dequantize S to emulate FP8
attention_probs, *_ = FP8EmulationFunc.apply(
attention_probs, None, None, S_quantizer, "S_quantizer", None
)
# matmul: [b * np, sq, hn] # matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
...@@ -393,6 +517,20 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -393,6 +517,20 @@ class UnfusedDotProductAttention(torch.nn.Module):
# [tq, np, hn] --> [tq, hp] # [tq, np, hn] --> [tq, hp]
context_layer = context_layer.view(total_tokens, -1) context_layer = context_layer.view(total_tokens, -1)
if fp8:
# quantize and dequantize O to emulate FP8
context_layer, *_ = FP8EmulationFunc.apply(
context_layer, None, None, O_quantizer, "O_quantizer", None
)
# quantize and dequantize dO to emulate FP8
context_layer, *_ = FP8EmulationFunc.apply(
context_layer, None, None, dO_quantizer, "dO_quantizer", None
)
# quantize O
if fp8_output:
context_layer = O_quantizer(context_layer)
return context_layer return context_layer
...@@ -491,6 +629,7 @@ class FlashAttention(torch.nn.Module): ...@@ -491,6 +629,7 @@ class FlashAttention(torch.nn.Module):
quantizers=None, quantizers=None,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
fp8_output: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""flash-attn fprop""" """flash-attn fprop"""
...@@ -696,6 +835,7 @@ class FlashAttention(torch.nn.Module): ...@@ -696,6 +835,7 @@ class FlashAttention(torch.nn.Module):
quantizers=quantizers, quantizers=quantizers,
pad_between_seqs=False, pad_between_seqs=False,
use_flash_attn_3=use_flash_attn_3, use_flash_attn_3=use_flash_attn_3,
fp8_output=fp8_output,
) )
else: else:
from transformer_engine.pytorch.cpu_offload import ( from transformer_engine.pytorch.cpu_offload import (
...@@ -795,8 +935,6 @@ class FlashAttention(torch.nn.Module): ...@@ -795,8 +935,6 @@ class FlashAttention(torch.nn.Module):
) )
return out return out
# "fp8_mha" decides outputs in fp8, while inputs are inferred from
# the real dtype
assert isinstance(key_layer, query_layer.__class__) and isinstance( assert isinstance(key_layer, query_layer.__class__) and isinstance(
value_layer, query_layer.__class__ value_layer, query_layer.__class__
), "q, k, and v must have the same type." ), "q, k, and v must have the same type."
...@@ -843,7 +981,7 @@ class FlashAttention(torch.nn.Module): ...@@ -843,7 +981,7 @@ class FlashAttention(torch.nn.Module):
if fp8: if fp8:
output = output.to(dtype=torch_orig_dtype) output = output.to(dtype=torch_orig_dtype)
if fp8 and fp8_meta["recipe"].fp8_mha: if fp8 and fp8_output:
O_quantizer = quantizers["scaling_fwd"][META_O] O_quantizer = quantizers["scaling_fwd"][META_O]
output = O_quantizer(output) output = O_quantizer(output)
...@@ -871,7 +1009,7 @@ class FlashAttention(torch.nn.Module): ...@@ -871,7 +1009,7 @@ class FlashAttention(torch.nn.Module):
if q_format == "sbhd": if q_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd) # (bs)hd -> bs(hd) -> sb(hd)
if fp8 and fp8_meta["recipe"].fp8_mha: if fp8 and fp8_output:
output_data = ( output_data = (
output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1) .transpose(0, 1)
...@@ -895,7 +1033,7 @@ class FlashAttention(torch.nn.Module): ...@@ -895,7 +1033,7 @@ class FlashAttention(torch.nn.Module):
class FusedAttnFunc(torch.autograd.Function): class FusedAttnFunc(torch.autograd.Function):
"""Function for FusedAttention with separate Q, K, V tensors""" """FusedAttention forward and backward implementation"""
@staticmethod @staticmethod
def forward( def forward(
...@@ -919,6 +1057,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -919,6 +1057,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
window_size, window_size,
rng_gen, rng_gen,
fused_attention_backend, fused_attention_backend,
...@@ -927,55 +1066,72 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -927,55 +1066,72 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta, fp8_meta,
quantizers, quantizers,
deterministic, deterministic,
softmax_offset,
fp8_output,
layer_number,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e4m3fn
fake_dtype = q.dtype
# add NVTX range
nvtx_label = "transformer_engine.FusedAttnFunc.forward"
nvtx_range_push(f"{nvtx_label}")
# recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"]
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
# input types are inferred from the real data while output types are controlled by fp8_output
# fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha)
assert isinstance(k, q.__class__) and isinstance(
v, q.__class__
), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor."
is_input_fp8 = isinstance(q, Float8Tensor)
is_output_fp8 = fp8_output
# whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa)
# whether bwd kernel in FP8:
is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
# get quantizers from DPA; all Nones if not fp8
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) dpa_utils.get_attention_quantizers(fp8, quantizers)
) )
# get nominal data type for out
# FP16/BF16 attention: torch.float16 or torch.bfloat16
# FP8 attention: torch.float16 or torch.bfloat16
out_nominal_dtype = q.dtype
if fp8: if fp8:
fused_attention_backend = FusedAttnBackend["FP8"] fused_attention_backend = FusedAttnBackend["FP8"]
assert isinstance(k, q.__class__) and isinstance(
v, q.__class__
), "q, k, and v must have the same type."
is_input_fp8 = isinstance(q, Float8Tensor) # q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16
q_fp8, k_fp8, v_fp8 = None, None, None # q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
if is_input_fp8: if is_input_fp8:
q_fp8, k_fp8, v_fp8 = q, k, v q_fp8, k_fp8, v_fp8 = q, k, v
else: else:
# 1: qkv packed, 2: kv packed, 3: qkv separate q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer)
qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_"))
match qkv_group: # print quantizers
case 1: print_quantizers(
dim = qkv_layout.find("3") "FusedAttnFunc.forward >> before: ",
qkv = combine_tensors([q, k, v], dim) layer_number,
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) QKV_quantizer,
qkv_fp8 = QKV_quantizer(qkv) O_quantizer,
q_fp8, k_fp8, v_fp8 = SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True) S_quantizer,
case 2: dQKV_quantizer,
q_fp8 = QKV_quantizer(q) dO_quantizer,
dim = qkv_layout.split("_")[1].find("2") dP_quantizer,
kv = combine_tensors([k, v], dim) )
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_fp8 = QKV_quantizer(kv_c) # out_:
k_fp8, v_fp8 = SplitAlongDim.apply(kv_fp8, dim, [1, 1], True) # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
case 3: # fp8_dtype = tex.DType.kFloat8E4M3
q_fp8 = QKV_quantizer(q) # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
k_fp8 = QKV_quantizer(k) out_, aux_ctx_tensors = fused_attn_fwd(
v_fp8 = QKV_quantizer(v)
case _:
raise "Invalid qkv_layout " + qkv_layout
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
out_fp8, aux_ctx_tensors = fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
...@@ -984,7 +1140,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -984,7 +1140,7 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8, q_fp8,
k_fp8, k_fp8,
v_fp8, v_fp8,
fake_dtype, out_nominal_dtype,
fused_attention_backend, fused_attention_backend,
attn_bias, attn_bias,
cu_seqlens_q_padded, cu_seqlens_q_padded,
...@@ -999,45 +1155,59 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -999,45 +1155,59 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
window_size, window_size,
rng_gen, rng_gen,
softmax_offset,
) )
if is_output_fp8:
out_ret = out_fp8 # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# out: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_fp8 = out_
out = out_
if isinstance(out_, Float8Tensor):
if not is_output_fp8 or not is_bwd_fp8:
out = out_.dequantize().view(out_.shape)
else: else:
out_ret = out_fp8.dequantize().view(out_fp8.shape) if is_output_fp8 or (
# is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16 is_bwd_fp8
# is_output_fp8 = True: out_save.dtype = torch.float8_e4m3fn and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)
out_save = out_ret ):
out_fp8 = O_quantizer(out_)
if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): # print quantizers
# 1: qkv packed, 2: kv packed, 3: qkv separate print_quantizers(
"FusedAttnFunc.forward >> after: ",
layer_number,
QKV_quantizer,
O_quantizer,
S_quantizer,
dQKV_quantizer,
dO_quantizer,
dP_quantizer,
)
# return appropriate tensors
out_ret = out_fp8 if is_output_fp8 else out
# save appropriate tensors
fp8_tensors = (None, None, None, None)
qkvo_tensors = (None, None, None, None)
if is_bwd_fp8:
if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16:
fp8_tensors = (q_fp8, k_fp8, v_fp8, None)
qkvo_tensors = (None, None, None, out)
else:
fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8)
else:
if is_input_fp8: if is_input_fp8:
qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8)
if qkv_group == 1: qkvo_tensors = (q, k, v, out)
dim = qkv_layout.find("3")
qkv = combine_tensors([q, k, v], dim)
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape)
q, k, v = SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True)
if qkv_group == 2:
q = q.dequantize()
dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2")
kv = combine_tensors([k, v], dim)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_no_fp8 = kv.dequantize()
k, v = SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True)
if qkv_group == 3:
q = q.dequantize()
k = k.dequantize()
v = v.dequantize()
if is_output_fp8:
out_save = out_fp8.dequantize()
fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8)
else: else:
# q, k, v, out_ret: torch.float16 or torch.bfloat16 # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_ret, aux_ctx_tensors = fused_attn_fwd( out_, aux_ctx_tensors = fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
...@@ -1046,7 +1216,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1046,7 +1216,7 @@ class FusedAttnFunc(torch.autograd.Function):
q, q,
k, k,
v, v,
fake_dtype, out_nominal_dtype,
fused_attention_backend, fused_attention_backend,
attn_bias, attn_bias,
cu_seqlens_q_padded, cu_seqlens_q_padded,
...@@ -1061,13 +1231,23 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1061,13 +1231,23 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
window_size, window_size,
rng_gen, rng_gen,
softmax_offset,
) )
out_save = out_ret out = out_
out_ret = out_
fp8_tensors = (None, None, None, None) fp8_tensors = (None, None, None, None)
qkvo_tensors = (q, k, v, out)
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) nvtx_range_pop(f"{nvtx_label}")
ctx.fp8_recipe = fp8_recipe
ctx.fp8 = is_bwd_fp8
# assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16
# used when some tensors are base tensors and loose the "dtype" attribute
ctx.nominal_dtype = out_nominal_dtype
from transformer_engine.pytorch.cpu_offload import ( from transformer_engine.pytorch.cpu_offload import (
CPUOffloadEnabled, CPUOffloadEnabled,
...@@ -1078,15 +1258,13 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1078,15 +1258,13 @@ class FusedAttnFunc(torch.autograd.Function):
if ctx.fp8: if ctx.fp8:
tensor_list = fp8_tensors tensor_list = fp8_tensors
else: else:
tensor_list = [q, k, v, out_save] tensor_list = [q, k, v, out]
qkv_layout = "sbhd_sbhd_sbhd"
mark_activation_offload(*tensor_list) mark_activation_offload(*tensor_list)
mark_activation_offload(*aux_ctx_tensors) mark_activation_offload(*aux_ctx_tensors)
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
*fp8_tensors, *fp8_tensors,
*qkvo_tensors, *qkvo_tensors,
...@@ -1100,11 +1278,14 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1100,11 +1278,14 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.layer_number = layer_number
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.dQKV_quantizer = dQKV_quantizer ctx.dQKV_quantizer = dQKV_quantizer
ctx.dO_quantizer = dO_quantizer ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer ctx.dP_quantizer = dP_quantizer
ctx.S_quantizer = S_quantizer ctx.S_quantizer = S_quantizer
if ctx.fp8: if ctx.fp8 and isinstance(ctx.S_quantizer, Float8Quantizer):
ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer = S_quantizer.copy()
ctx.S_quantizer.scale = S_quantizer.scale.clone() ctx.S_quantizer.scale = S_quantizer.scale.clone()
...@@ -1113,9 +1294,34 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1113,9 +1294,34 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_scale = attn_scale ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill ctx.fast_zero_fill = fast_zero_fill
ctx.qkv_layout = qkv_layout
from transformer_engine.pytorch.cpu_offload import (
CPUOffloadedLayer,
)
# If interleaved tensor is offloaded, reloaded tensor will be
# non-interleaved, so we need to modify the QKV layout
# for backward
if CPUOffloadedLayer and CPUOffloadEnabled:
reload_layout = ""
split_list = qkv_layout.split("_")
for split in split_list:
temp_layout = ""
rep_count = 1
for s in split:
if s.isalpha():
temp_layout = temp_layout + s
else:
rep_count = int(s)
for _ in range(rep_count):
reload_layout = reload_layout + temp_layout + "_"
ctx.qkv_layout = reload_layout[:-1]
else:
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.softmax_type = softmax_type
ctx.window_size = window_size ctx.window_size = window_size
ctx.fused_attention_backend = ( ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
...@@ -1128,17 +1334,15 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1128,17 +1334,15 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if ctx.is_output_fp8:
assert isinstance(
d_out, Float8Tensor
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e5m2
fake_dtype = d_out.dtype
d_out = d_out.contiguous() # d_out is expected to be in FP8 if is_output_fp8=True,
# but in the case it's not, convert it to FP8 before any operation
if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage):
d_out = ctx.dO_quantizer(d_out)
if not ctx.use_FAv2_bwd:
d_out._data = d_out._data.contiguous()
elif not ctx.use_FAv2_bwd:
d_out = d_out.contiguous()
( (
q_fp8, q_fp8,
k_fp8, k_fp8,
...@@ -1192,16 +1396,55 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1192,16 +1396,55 @@ class FusedAttnFunc(torch.autograd.Function):
dk = dk[..., : d_out.shape[-1]] dk = dk[..., : d_out.shape[-1]]
dv = dv[..., : d_out.shape[-1]] dv = dv[..., : d_out.shape[-1]]
else: else:
with torch.cuda.nvtx.range("_FusedAttn"): with torch.cuda.nvtx.range("FusedAttnFunc.backward"):
# get nominal data type of dq, dk, dv
# FP16/BF16 attention: torch.float16 or torch.bfloat16
# FP8 attention: torch.float16 or torch.bfloat16
dqkv_nominal_dtype = ctx.nominal_dtype
if ctx.fp8: if ctx.fp8:
# d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16
# d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
if ctx.is_output_fp8: if ctx.is_output_fp8:
d_out_fp8 = d_out d_out_fp8 = d_out
else: else:
d_out_fp8 = ctx.dO_quantizer(d_out) d_out_fp8 = ctx.dO_quantizer(d_out)
dqkv_dtype = TE_DType[d_out_fp8._data.dtype]
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn # print quantizers
# d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2 print_quantizers(
dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( "FusedAttnFunc.backward >> before: ",
ctx.layer_number,
ctx.QKV_quantizer,
ctx.O_quantizer,
ctx.S_quantizer,
ctx.dQKV_quantizer,
ctx.dO_quantizer,
ctx.dP_quantizer,
)
# get tex.DType for dq, dk, dv data
dqkv_te_dtype = d_out_fp8._fp8_dtype
# q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16,
# fp8_dtype = tex.DType.kFloat8E4M3
# d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
# out_:
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
#
# dq_, dk_, dv_:
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_ = (
out
if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16
else out_fp8
)
dq_, dk_, dv_, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
cu_seqlens_q, cu_seqlens_q,
...@@ -1209,10 +1452,10 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1209,10 +1452,10 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8, q_fp8,
k_fp8, k_fp8,
v_fp8, v_fp8,
out_fp8, out_,
d_out_fp8, d_out_fp8,
fake_dtype, dqkv_nominal_dtype,
dqkv_dtype, dqkv_te_dtype,
aux_ctx_tensors, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
cu_seqlens_q_padded, cu_seqlens_q_padded,
...@@ -1226,44 +1469,45 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1226,44 +1469,45 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size, ctx.window_size,
ctx.deterministic, ctx.deterministic,
) )
# is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16 # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16
# is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2 dq, dk, dv = dq_, dk_, dv_
if not ctx.is_input_fp8: is_float8tensor = isinstance(dq_, Float8Tensor)
qkv_group = len(ctx.qkv_layout.replace("paged_kv_", "").split("_")) if is_float8tensor and not ctx.is_input_fp8:
if qkv_group == 1: # return in F16
dim = ctx.qkv_layout.find("3") dq, dk, dv = combine_and_dequantize(
dqkv_fp8_data = combine_tensors( ctx.qkv_layout,
[dq_fp8._data, dk_fp8._data, dv_fp8._data], dim dq_,
) dk_,
dqkv_fp8 = dq_fp8.make_like( dv_,
tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape src_nominal_dtype=dq_.dtype,
) )
dqkv = dqkv_fp8.dequantize() if not is_float8tensor and ctx.is_input_fp8:
dq, dk, dv = SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True) # return in FP8
if qkv_group == 2: dq, dk, dv = combine_and_quantize(
dq = dq_fp8.dequantize() ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer
dim = ctx.qkv_layout.split("_")[1].find("2") )
dkv_fp8 = combine_tensors([dk_fp8, dv_fp8], dim)
dkv_c_fp8 = dkv_fp8.view( # print quantizers
-1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] print_quantizers(
) "FusedAttnFunc.backward >> after: ",
dkv = dkv_c_fp8.dequantize() ctx.layer_number,
dk, dv = SplitAlongDim.apply(dkv, dim, [1, 1], True) ctx.QKV_quantizer,
if qkv_group == 3: ctx.O_quantizer,
dq = dq_fp8.dequantize() ctx.S_quantizer,
dk = dk_fp8.dequantize() ctx.dQKV_quantizer,
dv = dv_fp8.dequantize() ctx.dO_quantizer,
else: ctx.dP_quantizer,
dq, dk, dv = dq_fp8, dk_fp8, dv_fp8 )
else: else:
if isinstance(d_out, QuantizedTensor): if isinstance(d_out, QuantizedTensorStorage):
d_out = d_out.dequantize() d_out = d_out.dequantize(dtype=ctx.nominal_dtype)
dqkv_dtype = TE_DType[d_out.dtype] dqkv_te_dtype = TE_DType[d_out.dtype]
# q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16 # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16
dq, dk, dv, *rest = fused_attn_bwd( dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q, ctx.max_seqlen_q,
ctx.max_seqlen_kv, ctx.max_seqlen_kv,
...@@ -1274,8 +1518,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1274,8 +1518,8 @@ class FusedAttnFunc(torch.autograd.Function):
v, v,
out, out,
d_out, d_out,
fake_dtype, dqkv_nominal_dtype,
dqkv_dtype, dqkv_te_dtype,
aux_ctx_tensors, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
cu_seqlens_q_padded, cu_seqlens_q_padded,
...@@ -1289,42 +1533,17 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1289,42 +1533,17 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout, ctx.qkv_layout,
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size, ctx.window_size,
ctx.deterministic, ctx.deterministic,
) )
# if no_bias or alibi, return dqkv d_bias = None
if ctx.attn_bias_type in ["no_bias", "alibi"]: if ctx.attn_bias_type not in ["no_bias", "alibi"]:
return ( d_bias = rest[0]
None, d_softmax_offset = None
None, if ctx.softmax_type != "vanilla":
None, d_softmax_offset = rest[1]
None,
None,
None,
None,
None,
None,
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
# else, return (dqkv, dbias)
return ( return (
None, None,
None, None,
...@@ -1338,7 +1557,9 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1338,7 +1557,9 @@ class FusedAttnFunc(torch.autograd.Function):
dq, dq,
dk, dk,
dv, dv,
rest[0], d_bias,
None,
None,
None, None,
None, None,
None, None,
...@@ -1352,6 +1573,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1352,6 +1573,8 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
d_softmax_offset,
None,
None, None,
) )
...@@ -1392,6 +1615,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1392,6 +1615,7 @@ class FusedAttention(torch.nn.Module):
attention_type: str = "self", attention_type: str = "self",
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
deterministic: bool = False, deterministic: bool = False,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1404,6 +1628,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1404,6 +1628,7 @@ class FusedAttention(torch.nn.Module):
) == "1" and get_device_compute_capability() == (9, 0) ) == "1" and get_device_compute_capability() == (9, 0)
self.layer_number = 1 if layer_number is None else layer_number self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic self.deterministic = deterministic
self.softmax_type = softmax_type
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
""" """
...@@ -1455,6 +1680,8 @@ class FusedAttention(torch.nn.Module): ...@@ -1455,6 +1680,8 @@ class FusedAttention(torch.nn.Module):
quantizers=None, quantizers=None,
pad_between_seqs: bool = False, pad_between_seqs: bool = False,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
fp8_output: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""fused attention fprop""" """fused attention fprop"""
assert ( assert (
...@@ -1555,15 +1782,27 @@ class FusedAttention(torch.nn.Module): ...@@ -1555,15 +1782,27 @@ class FusedAttention(torch.nn.Module):
) )
if fp8: if fp8:
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
" is required for FP8 attention!" " is required for FP8 attention!"
) )
assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!" assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
assert not context_parallel or fp8_meta["recipe"].reduce_amax, ( if fp8_recipe.delayed():
"Amax reduction across TP+CP group is necessary when using context parallelism with" assert not context_parallel or fp8_recipe.reduce_amax, (
" FP8!" "Amax reduction across TP+CP group is necessary when using context parallelism"
) " with FP8!"
)
if fp8_recipe.float8_current_scaling() and context_parallel:
all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers)
for q in all_quantizers:
if isinstance(q, Float8CurrentScalingQuantizer):
q.with_amax_reduction = True
q.amax_reduction_group = (
cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group
)
if context_parallel: if context_parallel:
assert ( assert (
...@@ -1605,6 +1844,10 @@ class FusedAttention(torch.nn.Module): ...@@ -1605,6 +1844,10 @@ class FusedAttention(torch.nn.Module):
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
quantizers=quantizers, quantizers=quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
softmax_type=self.softmax_type,
softmax_offset=softmax_offset,
fp8_output=fp8_output,
layer_number=self.layer_number,
) )
else: else:
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
...@@ -1628,6 +1871,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1628,6 +1871,7 @@ class FusedAttention(torch.nn.Module):
qkv_layout, qkv_layout,
core_attention_bias_type, core_attention_bias_type,
attn_mask_type, attn_mask_type,
self.softmax_type,
window_size, window_size,
None, # rng_gen None, # rng_gen
fused_attention_backend, fused_attention_backend,
...@@ -1636,6 +1880,9 @@ class FusedAttention(torch.nn.Module): ...@@ -1636,6 +1880,9 @@ class FusedAttention(torch.nn.Module):
fp8_meta, fp8_meta,
quantizers, quantizers,
self.deterministic, self.deterministic,
softmax_offset,
fp8_output,
self.layer_number,
) )
# ...hd -> ...(hd) # ...hd -> ...(hd)
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -11,11 +11,26 @@ import warnings ...@@ -11,11 +11,26 @@ import warnings
import logging import logging
import torch import torch
from torch.nn.parameter import Parameter
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import (
Format,
Recipe,
DelayedScaling,
Float8CurrentScaling,
)
from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.quantization import (
from transformer_engine.pytorch.float8_tensor import Float8Tensor get_fp8_te_dtype,
FP8GlobalStateManager,
RecipeState,
DelayedScalingRecipeState,
MXFP8BlockScalingRecipeState,
Float8CurrentScalingRecipeState,
Float8BlockScalingRecipeState,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.constants import ( from transformer_engine.pytorch.constants import (
...@@ -72,6 +87,67 @@ _alibi_cache = { ...@@ -72,6 +87,67 @@ _alibi_cache = {
"_alibi_bias_require_update": False, "_alibi_bias_require_update": False,
} }
"""
This feature is **experimental** and subject to change.
Some models may use different FP8 recipes for their linear layers and attention layers. To support this,
users can either use multiple, nested autocast() contexts to assign a distinct recipe for each layer,
or use a single autocast() for the non-attention layers and configure the recipe for the attention
layers as follows.
+-------------------+-----------+-----------------------------------------------------------------------------------+
| Linear | Attention | Configuration |
+===================+===========+===================================================================================+
| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to autocast(); |
| | | export NVTE_DPA_FP8_RECIPE="F16" |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8DS | FP8DS | Pass FP8DS to autocast(); |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8CS | FP8DS | Pass FP8CS to autocast(); |
| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; |
| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| NVFP4 | FP8DS | Pass NVFP4 to autocast(); |
| | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; |
| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS |
| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8DS | FP8CS | Pass FP8DS to autocast(); |
| | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,|
| | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; |
| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8CS | FP8CS | Pass FP8CS to autocast(); |
| | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe |
| | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| NVFP4 | FP8CS | Pass NVFP4 to autocast(); |
| | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe |
| | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: |
| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS |
| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
"""
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2}
_dpa_fp8_format = formats[os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID")]
_dpa_fp8ds_amax_algo = os.getenv("NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent")
_dpa_fp8ds_amax_histlen = int(os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1"))
_dpa_fp8ds_reduce_amax = os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") == "1"
__all__ = ["DotProductAttention"] __all__ = ["DotProductAttention"]
...@@ -168,6 +244,17 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -168,6 +244,17 @@ class DotProductAttention(TransformerEngineBaseModule):
softmax_scale: Optional[float], default = `None` softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to softmax scale for the attention scores. If `None`, defaults to
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -223,6 +310,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -223,6 +310,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -307,6 +395,20 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -307,6 +395,20 @@ class DotProductAttention(TransformerEngineBaseModule):
self.attention_type = attention_type self.attention_type = attention_type
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.softmax_type = softmax_type
if self.softmax_type == "vanilla":
self.softmax_offset = None
if self.softmax_type == "off-by-one":
self.softmax_offset = torch.zeros(
self.num_attention_heads // self.tp_size, device="cuda"
)
if self.softmax_type == "learnable":
self.register_parameter(
"softmax_offset",
Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")),
get_rng_state_tracker=get_rng_state_tracker,
)
attn_kwargs = { attn_kwargs = {
"attention_dropout": attention_dropout, "attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx, "attention_dropout_ctx": attention_dropout_ctx,
...@@ -328,6 +430,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -328,6 +430,7 @@ class DotProductAttention(TransformerEngineBaseModule):
layer_number=layer_number, layer_number=layer_number,
deterministic=self.deterministic, deterministic=self.deterministic,
**attn_kwargs, **attn_kwargs,
softmax_type=self.softmax_type,
) )
self.unfused_attention = UnfusedDotProductAttention( self.unfused_attention = UnfusedDotProductAttention(
...@@ -335,6 +438,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -335,6 +438,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_type=attention_type, attention_type=attention_type,
**attn_kwargs, **attn_kwargs,
layer_number=layer_number, layer_number=layer_number,
softmax_type=self.softmax_type,
) )
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
...@@ -433,6 +537,234 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -433,6 +537,234 @@ class DotProductAttention(TransformerEngineBaseModule):
self.cp_stream = cp_stream self.cp_stream = cp_stream
self.cp_comm_type = cp_comm_type self.cp_comm_type = cp_comm_type
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""
Override TransformerEngineBaseModule.init_fp8_metadata to allow for more flexible recipe support.
Initialize fp8 related metadata and tensors during fprop.
"""
_original_recipe = self.fp8_meta.get("recipe", None)
# global recipe set in autocast()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_recipe.custom():
return
# switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to
# a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well.
#
# fp8_recipe | NVTE_DPA_FP8_RECIPE | self.fp8_meta["recipe"] | self.quantizers
# --------------------------------------------------------------------------------------------
# DelayedScaling (DS) | unset | DS | all DS
# Float8CurrentScaling (CS) | unset | DS | CS for QKV, O, dO, dQKV; DS for S, dP
# x={DS, CS} | y | refer to row x=y | refer to row x=y
fp8_recipe_dpa = fp8_recipe
fp8_recipes = fp8_recipe
if _dpa_fp8_recipe == "F16":
# ignore the recipe from autocast, set fp8_dpa = False, fp8_mha = False
fp8_recipe.fp8_dpa = False
fp8_recipe.fp8_mha = False
elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling":
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe
fake_recipe = DelayedScaling(
fp8_format=fp8_recipe.fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
)
fp8_recipe_dpa = fake_recipe
fp8_recipes = fp8_recipe_dpa
elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "DelayedScaling":
# reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a DS recipe
fake_recipe = DelayedScaling(
fp8_format=_dpa_fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
)
fp8_recipe_dpa = fake_recipe
fp8_recipes = fp8_recipe_dpa
elif fp8_recipe.delayed() and _dpa_fp8_recipe == "Float8CurrentScaling":
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a CS+DS recipe
fake_recipes = [
Float8CurrentScaling(
fp8_format=fp8_recipe.fp8_format,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
),
fp8_recipe,
]
fp8_recipe_dpa = fake_recipes[1]
fp8_recipes = fake_recipes
elif (
fp8_recipe.float8_current_scaling()
and _dpa_fp8_recipe in ("", "Float8CurrentScaling")
and (fp8_recipe.fp8_dpa or fp8_recipe.fp8_mha)
):
# use fp8_recipe for QKV, O, dO, dQKV, and construct a DS recipe for S, dP
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe
fake_recipe = DelayedScaling(
fp8_format=fp8_recipe.fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
)
fp8_recipe_dpa = fake_recipe
fp8_recipes = [fp8_recipe, fp8_recipe_dpa]
elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "Float8CurrentScaling":
# reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format
# construct a CS recipe for QKV, O, dO, dQKV and a DS recipe for S, dP
fake_recipes = [
Float8CurrentScaling(
fp8_format=_dpa_fp8_format,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
),
DelayedScaling(
fp8_format=_dpa_fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
),
]
fp8_recipe_dpa = fake_recipes[1]
fp8_recipes = fake_recipes
# DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False
if not fp8_recipe_dpa.float8_per_tensor_scaling():
assert not (
fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha
), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe"
# reduce over TP+CP groups; expect fp8_group to be set up so
# assume attention uses the same fp8_group as GEMMs
fp8_group = FP8GlobalStateManager.get_fp8_group()
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters or fp8_enabled:
self.fp8_meta["global_recipe"] = fp8_recipe
self.fp8_meta["local_recipes"] = (
fp8_recipes if isinstance(fp8_recipes, List) else [fp8_recipes]
)
if self.fp8_parameters or fp8_enabled:
if self.fp8_initialized and fp8_recipe_dpa == self.fp8_meta["recipe"]:
# FP8 init has already been run and recipe is the same, don't do anything.
return
self.fp8_meta["recipe"] = fp8_recipe_dpa
if fp8_recipe != fp8_recipe_dpa:
# fp8_recipe has changed, rehash the key.
autocast_key = FP8GlobalStateManager.get_unique_autocast_key(
fp8_recipe_dpa, fp8_group
)
FP8GlobalStateManager.autocast_arguments[autocast_key] = (
fp8_recipe_dpa,
fp8_group,
)
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
return
if self.fp8_parameters and not self.fp8_initialized:
self.fp8_meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors(fp8_recipes)
if fp8_enabled:
# Set FP8 and other FP8 metadata
self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = fp8_group
# Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes
self.init_fp8_meta_tensors(fp8_recipes)
self.fp8_initialized = True
self.fp8_meta["recipe"] = fp8_recipe_dpa
if fp8_recipe != fp8_recipe_dpa:
# fp8_recipe has changed, rehash the key.
autocast_key = FP8GlobalStateManager.get_unique_autocast_key(
fp8_recipe_dpa, fp8_group
)
FP8GlobalStateManager.autocast_arguments[autocast_key] = (
fp8_recipe_dpa,
fp8_group,
)
_current_recipe = self.fp8_meta["recipe"]
if _original_recipe is not None and not (
issubclass(_current_recipe.__class__, _original_recipe.__class__)
or issubclass(_original_recipe.__class__, _current_recipe.__class__)
):
warnings.warn(
f"Recipe type changed from {_original_recipe.__class__.__name__} "
f"to {_current_recipe.__class__.__name__}. "
"This may affect model behavior."
)
# Clear cached workspaces as they were created with the old recipe/quantizer type
self._fp8_workspaces.clear()
def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> None:
"""Override to allow multiple recipes. Init scales and amaxes for fwd | bwd."""
if isinstance(recipe, Recipe):
recipe = [recipe]
fp8_recipe_dpa = recipe[-1]
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
# Return early if recipe state matches recipe
if self.fp8_meta_tensors_initialized:
recipe_state = self.fp8_meta[fp8_meta_tensor_key]
if fp8_recipe_dpa.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
self.adjust_amax_history_length(fp8_recipe_dpa.amax_history_len, fwd=fwd)
return
if fp8_recipe_dpa.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
return
if fp8_recipe_dpa.float8_current_scaling() and isinstance(
recipe_state, Float8CurrentScalingRecipeState
):
return
if fp8_recipe_dpa.float8_block_scaling() and isinstance(
recipe_state, Float8BlockScalingRecipeState
):
return
# When fp8_recipe=Float8CurrentScaling, recipe=[CS, DS], and QKV/dQKV, O/dO use CS quantizers, S/dP use DS quantizers.
# See table above in init_fp8_metadata for more detail.
num_gemms = [2, 1] if len(recipe) == 2 else [3]
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors = [x * 3 if fwd else x * 2 for x in num_gemms]
# Initialize recipe state and quantizers
recipe_states = [
RecipeState.create(
recipe[i],
mode=("forward" if fwd else "backward"),
num_quantizers=num_fp8_tensors[i],
)
for i in range(len(recipe))
]
self.fp8_meta[fp8_meta_tensor_key] = (
recipe_states[-1] if len(recipe) == 2 else recipe_states[0]
)
self.quantizers[fp8_meta_tensor_key] = []
for recipe_state in recipe_states:
self.quantizers[fp8_meta_tensor_key].extend(recipe_state.make_quantizers())
@no_torch_dynamo(recursive=False) @no_torch_dynamo(recursive=False)
def forward( def forward(
self, self,
...@@ -456,6 +788,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -456,6 +788,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
pad_between_seqs: Optional[bool] = None, pad_between_seqs: Optional[bool] = None,
fp8_output: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Dot Product Attention Layer. Dot Product Attention Layer.
...@@ -628,12 +961,15 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -628,12 +961,15 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs: Optional[bool], default = `None` pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch. If true, there are padding tokens between individual sequences in a packed batch.
fp8_output: Optional[bool], default = `False`
Whether to enforce output to be in FP8 or not.
""" """
with torch.cuda.device(query_layer.device), self.prepare_forward( with torch.cuda.device(query_layer.device), self.prepare_forward(
query_layer, query_layer,
num_gemms=3, num_gemms=3,
allow_non_contiguous=True, allow_non_contiguous=True,
allow_different_data_and_param_types=self.softmax_type != "vanilla",
) as query_layer: ) as query_layer:
# checks for RNG # checks for RNG
if self.rng_states_tracker is not None and is_graph_capturing(): if self.rng_states_tracker is not None and is_graph_capturing():
...@@ -663,6 +999,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -663,6 +999,8 @@ class DotProductAttention(TransformerEngineBaseModule):
tex.DType.kFloat8E4M3, tex.DType.kFloat8E4M3,
tex.DType.kFloat8E5M2, tex.DType.kFloat8E5M2,
], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types.""" ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
else:
fp8_output = False
# checks for q/k/v shapes # checks for q/k/v shapes
assert ( assert (
...@@ -922,6 +1260,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -922,6 +1260,7 @@ class DotProductAttention(TransformerEngineBaseModule):
False False
), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
# check if there is padding between sequences when qkv_format='thd'
if pad_between_seqs is None: if pad_between_seqs is None:
if qkv_format == "thd": if qkv_format == "thd":
pad_between_seqs = ( pad_between_seqs = (
...@@ -957,11 +1296,13 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -957,11 +1296,13 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
context_parallel=context_parallel, context_parallel=context_parallel,
cp_comm_type=self.cp_comm_type,
deterministic=self.deterministic, deterministic=self.deterministic,
is_training=self.training, is_training=self.training,
fp8=self.fp8, fp8=self.fp8,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
inference_params=inference_params, inference_params=inference_params,
softmax_type=self.softmax_type,
) )
global _attention_backends global _attention_backends
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
...@@ -1022,6 +1363,12 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1022,6 +1363,12 @@ class DotProductAttention(TransformerEngineBaseModule):
) )
# run attention # run attention
softmax_offset = (
self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32)
if self.softmax_offset is not None
else None
)
if use_flash_attention: if use_flash_attention:
if core_attention_bias_type == "alibi": if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi( alibi_slopes, _ = dpa_utils.get_alibi(
...@@ -1053,6 +1400,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1053,6 +1400,7 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers, quantizers=self.quantizers,
inference_params=inference_params, inference_params=inference_params,
flash_attention_backend=flash_attention_backend, flash_attention_backend=flash_attention_backend,
fp8_output=fp8_output,
) )
if use_fused_attention: if use_fused_attention:
...@@ -1071,7 +1419,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1071,7 +1419,6 @@ class DotProductAttention(TransformerEngineBaseModule):
bias_dtype=query_layer.dtype, bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
) )
# checkpoint_core_attention=False
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.fused_attention, self.fused_attention,
...@@ -1101,6 +1448,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1101,6 +1448,8 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers, quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
fp8_output=fp8_output,
) )
return self.fused_attention( return self.fused_attention(
query_layer, query_layer,
...@@ -1129,6 +1478,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1129,6 +1478,8 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers, quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
fp8_output=fp8_output,
) )
from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
...@@ -1140,6 +1491,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1140,6 +1491,7 @@ class DotProductAttention(TransformerEngineBaseModule):
) )
if use_unfused_attention: if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.unfused_attention, self.unfused_attention,
...@@ -1157,6 +1509,11 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1157,6 +1509,11 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
fp8_output=fp8_output,
) )
return self.unfused_attention( return self.unfused_attention(
_alibi_cache, _alibi_cache,
...@@ -1173,5 +1530,10 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1173,5 +1530,10 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
inference_params=inference_params, inference_params=inference_params,
softmax_offset=softmax_offset,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
fp8_output=fp8_output,
) )
return None return None
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
import torch import torch
import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
import transformer_engine_torch as tex import transformer_engine_torch as tex
import transformer_engine as te import transformer_engine as te
...@@ -24,6 +25,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -24,6 +25,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
QKVLayout, QKVLayout,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
SoftmaxType,
FusedAttnBackend, FusedAttnBackend,
META_QKV, META_QKV,
META_DQKV, META_DQKV,
...@@ -31,18 +33,22 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -31,18 +33,22 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_DO, META_DO,
META_S, META_S,
META_DP, META_DP,
META_O_CP,
META_DQKV_CP,
) )
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype Float8Tensor,
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.quantization import get_fp8_te_dtype
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_device_compute_capability, get_device_compute_capability,
get_cudnn_version, get_cudnn_version,
SplitAlongDim,
combine_tensors,
) )
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
...@@ -53,6 +59,9 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) ...@@ -53,6 +59,9 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 # NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) _NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
# print quantizer info for a particular layer on a particular rank
_print_layer = int(os.getenv("NVTE_PRINT_LAYER_NUMBER", "1"))
_print_rank = int(os.getenv("NVTE_PRINT_RANK", "0"))
_cu_seqlens_cache = {} _cu_seqlens_cache = {}
...@@ -206,16 +215,20 @@ class AttentionParams: ...@@ -206,16 +215,20 @@ class AttentionParams:
Attention dropout. Attention dropout.
context_parallel: bool, default = `False` context_parallel: bool, default = `False`
Whether context parallelism is used or not. Whether context parallelism is used or not.
cp_comm_type: str, default = "p2p"
The communication type of context parallelism.
deterministic: bool, default = `False` deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not. Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True` is_training: bool, default = `True`
Whether in training mode (`True`) or inference mode (`False`) Whether in training mode (`True`) or inference mode (`False`)
fp8: bool, default = `False` fp8: bool, default = `False`
Whether `DotProductAttention` is in an `fp8_autocast` region. Whether `DotProductAttention` is in an `autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None` fp8_meta: Optional[Dict[str Any]], default = `None`
The FP8 metadata tensor of `DotProductAttention`. The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None` inference_params: Optional[InferenceParams], default = `None`
Inference-related parameters. See InferenceParams for details. Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details.
""" """
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
...@@ -237,11 +250,13 @@ class AttentionParams: ...@@ -237,11 +250,13 @@ class AttentionParams:
pad_between_seqs: bool = False pad_between_seqs: bool = False
attention_dropout: float = 0.0 attention_dropout: float = 0.0
context_parallel: bool = False context_parallel: bool = False
cp_comm_type: str = "p2p"
deterministic: bool = False deterministic: bool = False
is_training: bool = True is_training: bool = True
fp8: bool = False fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None fp8_meta: Union[Dict[str, Any], None] = None
inference_params: Optional[InferenceParams] = None inference_params: Optional[InferenceParams] = None
softmax_type: str = "vanilla"
def __eq__(self, other): def __eq__(self, other):
""" """
...@@ -308,11 +323,13 @@ def get_attention_backend( ...@@ -308,11 +323,13 @@ def get_attention_backend(
pad_between_seqs = attention_params.pad_between_seqs pad_between_seqs = attention_params.pad_between_seqs
attention_dropout = attention_params.attention_dropout attention_dropout = attention_params.attention_dropout
context_parallel = attention_params.context_parallel context_parallel = attention_params.context_parallel
cp_comm_type = attention_params.cp_comm_type
deterministic = attention_params.deterministic deterministic = attention_params.deterministic
is_training = attention_params.is_training is_training = attention_params.is_training
fp8 = attention_params.fp8 fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta fp8_meta = attention_params.fp8_meta
inference_params = attention_params.inference_params inference_params = attention_params.inference_params
softmax_type = attention_params.softmax_type
# Run config # Run config
logger = logging.getLogger("DotProductAttention") logger = logging.getLogger("DotProductAttention")
...@@ -341,8 +358,31 @@ def get_attention_backend( ...@@ -341,8 +358,31 @@ def get_attention_backend(
field.name: getattr(attention_params, field.name) for field in fields(attention_params) field.name: getattr(attention_params, field.name) for field in fields(attention_params)
} }
run_config.update(attention_params_dict) run_config.update(attention_params_dict)
# Add FP8 environment variables to config
if fp8: if fp8:
# all FP8 recipes: 1: (FP8 fwd, FP8 bwd), 0: (FP8 fwd, F16 bwd)
run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
# Float8CurrentScaling: 1: use F16 O in bwd, 0: use FP8 O in bwd
run_config["NVTE_DPA_FP8CS_O_in_F16"] = int(os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1"))
# switch recipe to "F16", "DelayedScaling", or "Float8CurrentScaling"
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
run_config["NVTE_DPA_FP8_RECIPE"] = _dpa_fp8_recipe
if _dpa_fp8_recipe != "":
# config new recipe if switched
run_config["NVTE_DPA_FP8_FORMAT"] = os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID")
run_config["NVTE_DPA_FP8DS_AMAX_ALGO"] = os.getenv(
"NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent"
)
run_config["NVTE_DPA_FP8DS_AMAX_HISTLEN"] = int(
os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1")
)
run_config["NVTE_DPA_FP8DS_REDUCE_AMAX"] = int(
os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1")
)
# UnfusedDotProductAttention: 1: allow FP8 emulation, 0: do not allow
run_config["NVTE_UnfusedDPA_Emulate_FP8"] = int(
os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0")
)
logger.debug("Running with config=%s", run_config) logger.debug("Running with config=%s", run_config)
# The following sections check if `FlashAttention` supports the provided attention params, # The following sections check if `FlashAttention` supports the provided attention params,
...@@ -422,8 +462,20 @@ def get_attention_backend( ...@@ -422,8 +462,20 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for FP8 training") logger.debug("Disabling FlashAttention 3 for FP8 training")
use_flash_attention_3 = False use_flash_attention_3 = False
if use_unfused_attention: if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
use_unfused_attention = False if not allow_emulation:
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention")
use_unfused_attention = False
fp8_recipe = fp8_meta["recipe"]
if fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
if use_fused_attention and fp8_recipe.float8_current_scaling():
if device_compute_capability < (10, 0):
logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100")
use_fused_attention = False
elif cudnn_version < (9, 14, 0):
logger.debug("Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0")
use_fused_attention = False
# TODO: rocm fused attention backends does not support fp8 yet # TODO: rocm fused attention backends does not support fp8 yet
if IS_HIP_EXTENSION and use_fused_attention: if IS_HIP_EXTENSION and use_fused_attention:
logger.debug("Disabling ROCm FusedAttention as it does not support FP8") logger.debug("Disabling ROCm FusedAttention as it does not support FP8")
...@@ -581,6 +633,51 @@ def get_attention_backend( ...@@ -581,6 +633,51 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for dropout") logger.debug("Disabling FlashAttention 3 for dropout")
use_flash_attention_3 = False use_flash_attention_3 = False
# Filter: Softmax type
# context_parallel | softmax_type | supported backends
# ----------------------------------------------------------------------------------------------------
# no | vanilla | All
# no | off-by-one | FusedAttention, UnfusedDotProductAttention
# no | learnable | FusedAttention, UnfusedDotProductAttention
# yes | vanilla | FusedAttention, FlashAttention
# yes | off-by-one | FusedAttention
# yes | learnable | FusedAttention
if softmax_type != "vanilla":
logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type)
use_flash_attention = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type
)
use_unfused_attention = False
if qkv_format == "thd":
logger.debug(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type
)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd",
softmax_type,
)
use_unfused_attention = False
if context_parallel:
logger.debug(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s",
softmax_type,
)
use_unfused_attention = False
if cp_comm_type != "a2a":
logger.debug(
"Disabling FusedAttention for context parallelism with softmax_type = %s and"
" cp_comm_type = %s",
softmax_type,
cp_comm_type,
)
use_fused_attention = False
# Filter: Context parallelism # Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends # qkv_format | attn_mask_type | attn_bias_type | supported backends
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
...@@ -822,6 +919,7 @@ def get_attention_backend( ...@@ -822,6 +919,7 @@ def get_attention_backend(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type], AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
attention_dropout, attention_dropout,
num_heads, num_heads,
num_gqa_groups, num_gqa_groups,
...@@ -1836,11 +1934,10 @@ def check_set_window_size( ...@@ -1836,11 +1934,10 @@ def check_set_window_size(
return window_size return window_size
def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): def get_attention_quantizers(fp8, quantizers):
"""Get the list of quantizers used in attention from the quantizers list.""" """Get the list of quantizers used in attention from the quantizers list."""
if not fp8: if not fp8:
num_of_nones = 8 if cp_specific_quantizers else 6 return [None] * 6
return [None] * num_of_nones
QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
QKV_quantizer.internal = True QKV_quantizer.internal = True
QKV_quantizer.set_usage(rowwise=True, columnwise=False) QKV_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -1849,6 +1946,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): ...@@ -1849,6 +1946,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
S_quantizer = quantizers["scaling_fwd"][META_S] S_quantizer = quantizers["scaling_fwd"][META_S]
S_quantizer.internal = True S_quantizer.internal = True
S_quantizer.set_usage(rowwise=True, columnwise=False) S_quantizer.set_usage(rowwise=True, columnwise=False)
dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV]
dQKV_quantizer.interal = True dQKV_quantizer.interal = True
dQKV_quantizer.set_usage(rowwise=True, columnwise=False) dQKV_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -1858,22 +1956,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): ...@@ -1858,22 +1956,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
dP_quantizer = quantizers["scaling_bwd"][META_DP] dP_quantizer = quantizers["scaling_bwd"][META_DP]
dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.set_usage(rowwise=True, columnwise=False)
dP_quantizer.interal = True dP_quantizer.interal = True
dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP]
dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False) return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
dQKV_CP_quantizer.internal = True
O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP]
O_CP_quantizer.set_usage(rowwise=True, columnwise=False) def print_quantizers(
label,
if cp_specific_quantizers: layer_number,
return ( QKV_quantizer,
O_quantizer,
S_quantizer,
dQKV_quantizer,
dO_quantizer,
dP_quantizer,
):
"""Print the type and scale/amax of attention quantizers"""
_to_print = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL == 2
if (
_to_print
and _print_layer == layer_number
and (
not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() == _print_rank)
)
):
names = [
"QKV_quantizer",
"S_quantizer",
"O_quantizer",
"dO_quantizer",
"dP_quantizer",
"dQKV_quantizer",
]
quantizers = [
QKV_quantizer, QKV_quantizer,
O_quantizer,
O_CP_quantizer,
S_quantizer, S_quantizer,
dQKV_quantizer, O_quantizer,
dQKV_CP_quantizer,
dO_quantizer, dO_quantizer,
dP_quantizer, dP_quantizer,
) dQKV_quantizer,
]
if "forward" in label:
names = names[:3]
quantizers = quantizers[:3]
if "backward" in label:
names = names[3:]
quantizers = quantizers[3:]
for i, q in enumerate(quantizers):
type_str = ""
if q is None:
type_str = "None"
elif isinstance(q, Float8Quantizer):
type_str = "DS"
elif isinstance(q, Float8CurrentScalingQuantizer):
type_str = "CS"
print(
f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x"
f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}"
)
return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer):
"""Combine q,k,v based on qkv_layout and quantize them together"""
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_layout = qkv_layout.replace("paged_kv_", "")
qkv_group = len(qkv_layout.split("_"))
src_nominal_dtype = q.dtype
match qkv_group:
case 1:
dim = qkv_layout.find("3")
qkv = combine_tensors([q, k, v], dim)
qkv_fp8 = qkv_quantizer(qkv)
q_data, k_data, v_data = SplitAlongDim.apply(qkv_fp8._data, dim, [1, 1, 1], True)
case 2:
dim = qkv_layout.split("_")[1].find("2")
kv = combine_tensors([k, v], dim)
tensors = [q, kv]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv = torch.cat([x.view(-1) for x in tensors], dim=0)
qkv_fp8 = qkv_quantizer(qkv)
q_data, kv_data = [
qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)
]
k_data, v_data = SplitAlongDim.apply(kv_data, dim, [1, 1], True)
case 3:
tensors = [q, k, v]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv = torch.cat([x.view(-1) for x in tensors], dim=0)
qkv_fp8 = qkv_quantizer(qkv)
q_data, k_data, v_data = [
qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)
]
case _:
raise RuntimeError("Invalid qkv_layout " + qkv_layout)
q_fp8, k_fp8, v_fp8 = [
Float8Tensor.make_like(qkv_fp8, data=x, dtype=src_nominal_dtype)
for x in [q_data, k_data, v_data]
]
return q_fp8, k_fp8, v_fp8
def combine_and_dequantize(
qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=None, des_nominal_dtype=None
):
"""Combine q,k,v based on qkv_layout and dequantize them together"""
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_layout = qkv_layout.replace("paged_kv_", "")
qkv_group = len(qkv_layout.split("_"))
if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]):
src_nominal_dtype = q_fp8.dtype
else:
assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!"
if des_nominal_dtype is None:
des_nominal_dtype = src_nominal_dtype
q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]]
match qkv_group:
case 1:
dim = qkv_layout.find("3")
qkv_data = combine_tensors([q_data, k_data, v_data], dim)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, k, v = SplitAlongDim.apply(qkv, dim, [1, 1, 1], True)
case 2:
dim = qkv_layout.split("_")[1].find("2")
kv_data = combine_tensors([k_data, v_data], dim)
tensors = [q_data, kv_data]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv_data = torch.cat([x.reshape(-1) for x in tensors], dim=0)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, kv = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)]
k, v = SplitAlongDim.apply(kv, dim, [1, 1], True)
case 3:
tensors = [q_data, k_data, v_data]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv_data = torch.cat([x.contiguous().reshape(-1) for x in tensors], dim=0)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, k, v = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)]
case _:
raise RuntimeError("Invalid qkv_layout " + qkv_layout)
return q, k, v
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Multi-head Attention.""" """Multi-head Attention."""
import os
import collections import collections
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
...@@ -31,7 +32,13 @@ from transformer_engine.pytorch.distributed import ( ...@@ -31,7 +32,13 @@ from transformer_engine.pytorch.distributed import (
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
# Force DotProductAttention to use a different recipe than the fp8_recipe set in autocast().
# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling"
# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa.
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
_dpa_fp8_recipe_dpa = os.getenv("NVTE_DPA_FP8_RECIPE_DPA", "0") == "1"
_dpa_fp8_recipe_mha = os.getenv("NVTE_DPA_FP8_RECIPE_MHA", "0") == "1"
class MultiheadAttention(torch.nn.Module): class MultiheadAttention(torch.nn.Module):
...@@ -135,6 +142,17 @@ class MultiheadAttention(torch.nn.Module): ...@@ -135,6 +142,17 @@ class MultiheadAttention(torch.nn.Module):
For that, please use `get_qkv_layout` to gain the layout information. For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None` name: str, default = `None`
name of the module, currently used for debugging purposes. name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -245,6 +263,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -245,6 +263,7 @@ class MultiheadAttention(torch.nn.Module):
qk_norm_before_rope: bool = False, qk_norm_before_rope: bool = False,
seq_length: Optional[int] = None, seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -262,6 +281,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -262,6 +281,7 @@ class MultiheadAttention(torch.nn.Module):
self.return_bias = return_bias self.return_bias = return_bias
self.cp_size = 1 self.cp_size = 1
self.cp_rank = 0 self.cp_rank = 0
self.softmax_type = softmax_type
kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
...@@ -416,6 +436,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -416,6 +436,7 @@ class MultiheadAttention(torch.nn.Module):
tp_group=tp_group, tp_group=tp_group,
layer_number=self.layer_number, layer_number=self.layer_number,
attention_type=self.attention_type, attention_type=self.attention_type,
softmax_type=self.softmax_type,
) )
# Linear # Linear
...@@ -556,10 +577,12 @@ class MultiheadAttention(torch.nn.Module): ...@@ -556,10 +577,12 @@ class MultiheadAttention(torch.nn.Module):
self.cp_size = get_distributed_world_size(cp_group) self.cp_size = get_distributed_world_size(cp_group)
self.cp_rank = get_distributed_rank(cp_group) self.cp_rank = get_distributed_rank(cp_group)
elif isinstance(cp_group, list): elif isinstance(cp_group, list):
assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
assert ( assert (
cp_comm_type == "a2a+p2p" cp_comm_type == "a2a+p2p"
), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!" ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
assert (
len(cp_group) == 2
), "cp_comm_type = a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!"
cp_size_a2a = get_distributed_world_size(cp_group[0]) cp_size_a2a = get_distributed_world_size(cp_group[0])
cp_rank_a2a = get_distributed_rank(cp_group[0]) cp_rank_a2a = get_distributed_rank(cp_group[0])
cp_size_p2p = get_distributed_world_size(cp_group[1]) cp_size_p2p = get_distributed_world_size(cp_group[1])
...@@ -716,10 +739,22 @@ class MultiheadAttention(torch.nn.Module): ...@@ -716,10 +739,22 @@ class MultiheadAttention(torch.nn.Module):
# Query, Key, and Value # Query, Key, and Value
# ====================== # ======================
fp8_mha = ( fp8 = FP8GlobalStateManager.is_fp8_enabled()
FP8GlobalStateManager.is_fp8_enabled() if _dpa_fp8_recipe == "":
and FP8GlobalStateManager.get_fp8_recipe().fp8_mha fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
) fp8_dpa = fp8_recipe.fp8_dpa
fp8_mha = fp8_recipe.fp8_mha
float8_current_scaling = fp8_recipe.float8_current_scaling()
else:
fp8_dpa = _dpa_fp8_recipe_dpa
fp8_mha = _dpa_fp8_recipe_mha
float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling"
# QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe
qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling
# DPA: always produce FP8 output when fp8=True to take advantage of the O amax
dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha)
# Proj Gemm: match DPA output except for Float8CurrentScaling
proj_fp8_grad = dpa_fp8_output and not float8_current_scaling
layernorm_output = None layernorm_output = None
if self.attention_type == "self": if self.attention_type == "self":
...@@ -728,7 +763,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -728,7 +763,7 @@ class MultiheadAttention(torch.nn.Module):
layernorm_qkv_outputs = self.layernorm_qkv( layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
if self.return_layernorm_output: if self.return_layernorm_output:
mixed_x_layer, layernorm_output = layernorm_qkv_outputs mixed_x_layer, layernorm_output = layernorm_qkv_outputs
...@@ -738,7 +773,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -738,7 +773,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_x_layer = self.qkv( mixed_x_layer = self.qkv(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
num_queries_per_key_value = ( num_queries_per_key_value = (
...@@ -792,7 +827,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -792,7 +827,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_kv_layer = self.key_value( mixed_kv_layer = self.key_value(
encoder_output, encoder_output,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
if self.qkv_weight_interleaved: if self.qkv_weight_interleaved:
...@@ -847,7 +882,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -847,7 +882,7 @@ class MultiheadAttention(torch.nn.Module):
layernorm_query_outputs = self.layernorm_query( layernorm_query_outputs = self.layernorm_query(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
if self.return_layernorm_output: if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs query_layer, layernorm_output = layernorm_query_outputs
...@@ -857,7 +892,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -857,7 +892,7 @@ class MultiheadAttention(torch.nn.Module):
query_layer = self.query_layer( query_layer = self.query_layer(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
# [sq, b, hp] --> [sq, b, np, hn] # [sq, b, hp] --> [sq, b, np, hn]
...@@ -958,6 +993,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -958,6 +993,7 @@ class MultiheadAttention(torch.nn.Module):
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
inference_params=inference_params, inference_params=inference_params,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
fp8_output=dpa_fp8_output,
) )
# =================== # ===================
...@@ -966,7 +1002,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -966,7 +1002,7 @@ class MultiheadAttention(torch.nn.Module):
projection_output = self.proj( projection_output = self.proj(
context_layer, context_layer,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_grad=isinstance(context_layer, QuantizedTensor), fp8_grad=proj_fp8_grad,
) )
if self.return_bias: if self.return_bias:
......
...@@ -66,6 +66,9 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -66,6 +66,9 @@ class RotaryPositionEmbedding(torch.nn.Module):
""" """
Create rotary position embedding frequencies. Create rotary position embedding frequencies.
This function is particularly sensitive to the use of mixed precision, so we disable the
autocast context if it is enabled.
Parameters Parameters
---------- ----------
max_seq_len: int max_seq_len: int
...@@ -73,26 +76,27 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -73,26 +76,27 @@ class RotaryPositionEmbedding(torch.nn.Module):
offset: int, default = 0 offset: int, default = 0
Fixed offset for frequencies. Fixed offset for frequencies.
""" """
seq = ( with torch.autocast(enabled=False, device_type="cuda"):
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) seq = (
+ offset torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
) + offset
)
if (
self.pretrained_max_position_embeddings is not None
and self.seq_len_interpolation_factor is not None
):
if ( if (
max_seq_len self.pretrained_max_position_embeddings is not None
> self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor and self.seq_len_interpolation_factor is not None
): ):
# dynamic linear scaling (length > position we have learned) if (
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings) max_seq_len
else: > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor
# fixed linear scaling ):
seq *= 1 / self.seq_len_interpolation_factor # dynamic linear scaling (length > position we have learned)
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) else:
# fixed linear scaling
seq *= 1 / self.seq_len_interpolation_factor
freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
# first part even vector components, second part odd vector components, # first part even vector components, second part odd vector components,
# 2 * dim in dimension size # 2 * dim in dimension size
if not self.interleaved: if not self.interleaved:
......
...@@ -91,3 +91,5 @@ GemmParallelModes = ("row", "column", None) ...@@ -91,3 +91,5 @@ GemmParallelModes = ("row", "column", None)
dist_group_type = torch.distributed.ProcessGroup dist_group_type = torch.distributed.ProcessGroup
MXFP8_BLOCK_SCALING_SIZE = 32 MXFP8_BLOCK_SCALING_SIZE = 32
NVFP4_BLOCK_SCALING_SIZE = 16
...@@ -12,6 +12,7 @@ from transformer_engine_torch import ( ...@@ -12,6 +12,7 @@ from transformer_engine_torch import (
NVTE_QKV_Format, NVTE_QKV_Format,
NVTE_Bias_Type, NVTE_Bias_Type,
NVTE_Mask_Type, NVTE_Mask_Type,
NVTE_Softmax_Type,
NVTE_Fused_Attn_Backend, NVTE_Fused_Attn_Backend,
) )
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
...@@ -86,6 +87,12 @@ AttnMaskType = { ...@@ -86,6 +87,12 @@ AttnMaskType = {
"padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, "padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
} }
SoftmaxType = {
"vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX,
"off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX,
"learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX,
}
FusedAttnBackend = { FusedAttnBackend = {
"F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen,
"F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, "F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
...@@ -102,9 +109,6 @@ META_O = tex.FP8FwdTensors.GEMM2_INPUT ...@@ -102,9 +109,6 @@ META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3 META_DP = tex.FP8BwdTensors.GRAD_INPUT3
# repurpose some unused amax history buffers for partial results of CP fwd and bwd
META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT
META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1
def fused_attn_fwd( def fused_attn_fwd(
...@@ -131,8 +135,10 @@ def fused_attn_fwd( ...@@ -131,8 +135,10 @@ def fused_attn_fwd(
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input. """Fused Attention FWD for separate QKV input.
...@@ -197,6 +203,8 @@ def fused_attn_fwd( ...@@ -197,6 +203,8 @@ def fused_attn_fwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1) window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
...@@ -205,6 +213,9 @@ def fused_attn_fwd( ...@@ -205,6 +213,9 @@ def fused_attn_fwd(
rng_gen: torch.Generator, default = None rng_gen: torch.Generator, default = None
random number generator; random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
Returns Returns
---------- ----------
...@@ -286,6 +297,7 @@ def fused_attn_fwd( ...@@ -286,6 +297,7 @@ def fused_attn_fwd(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size, window_size,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
...@@ -300,6 +312,7 @@ def fused_attn_fwd( ...@@ -300,6 +312,7 @@ def fused_attn_fwd(
s_quantizer, s_quantizer,
o_quantizer, o_quantizer,
attn_bias, attn_bias,
softmax_offset,
rng_gen, rng_gen,
rng_elts_per_thread, rng_elts_per_thread,
) )
...@@ -333,6 +346,7 @@ def fused_attn_bwd( ...@@ -333,6 +346,7 @@ def fused_attn_bwd(
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias", attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False, deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
...@@ -398,6 +412,8 @@ def fused_attn_bwd( ...@@ -398,6 +412,8 @@ def fused_attn_bwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding" attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1) window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
...@@ -417,6 +433,9 @@ def fused_attn_bwd( ...@@ -417,6 +433,9 @@ def fused_attn_bwd(
d_bias: torch.Tensor, optional d_bias: torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias" gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias or "post_scale_bias"; same data type and shape as Bias
d_softmax_offset: torch.Tensor, optional
gradient tensor of softmax offset in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
""" """
if attn_scale is None: if attn_scale is None:
d = q.size(-1) d = q.size(-1)
...@@ -454,6 +473,7 @@ def fused_attn_bwd( ...@@ -454,6 +473,7 @@ def fused_attn_bwd(
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type], AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size, window_size,
deterministic, deterministic,
cu_seqlens_q, cu_seqlens_q,
......
...@@ -19,8 +19,10 @@ from ..utils import get_sm_count, _empty_tensor ...@@ -19,8 +19,10 @@ from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor.utils import is_experimental
from ..experimental.gemm import experimental_gemm
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_quant_fp8_to_int8, from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_quant_fp8_to_int8,
...@@ -169,6 +171,24 @@ def general_gemm( ...@@ -169,6 +171,24 @@ def general_gemm(
if not out.is_contiguous(): if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.") raise ValueError("Output tensor is not contiguous.")
# If A or B are experimental tensors -> dispatch to quantizers's qgemm implementation
if is_experimental(A) or is_experimental(B):
return experimental_gemm(
A,
B,
workspace,
out_dtype,
quantization_params,
gelu,
gelu_in,
accumulate,
layout,
out,
bias,
use_split_accumulator,
grad,
)
debug_quantizer = None debug_quantizer = None
if isinstance(quantization_params, DebugQuantizer): if isinstance(quantization_params, DebugQuantizer):
debug_quantizer = quantization_params debug_quantizer = quantization_params
...@@ -179,9 +199,9 @@ def general_gemm( ...@@ -179,9 +199,9 @@ def general_gemm(
# Use bfloat16 as default bias_dtype # Use bfloat16 as default bias_dtype
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase): if isinstance(A, Float8BlockwiseQTensorStorage) or isinstance(B, Float8BlockwiseQTensorStorage):
# There is not use_split_accumulator == False # There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorBase GEMM # implementation for Float8BlockwiseQTensorStorage GEMM
use_split_accumulator = True use_split_accumulator = True
# Check that data format is supported # Check that data format is supported
...@@ -191,7 +211,7 @@ def general_gemm( ...@@ -191,7 +211,7 @@ def general_gemm(
): ):
raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format") raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format")
if int8_simulation_fp8 and (isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase)): if int8_simulation_fp8 and (isinstance(A, Float8BlockwiseQTensorStorage) or isinstance(B, Float8BlockwiseQTensorStorage)):
assert not gelu, "GELU not supported with int8 simulation" assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation" assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation" assert bias is None, "Bias not supported with int8 simulation"
...@@ -210,7 +230,7 @@ def general_gemm( ...@@ -210,7 +230,7 @@ def general_gemm(
) )
return y, None, None, None return y, None, None, None
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)) and int8_simulation_fp8_tensorwise: if int8_simulation_fp8 and (isinstance(A, Float8TensorStorage) or isinstance(B, Float8TensorStorage)) and int8_simulation_fp8_tensorwise:
assert not gelu, "GELU not supported with int8 simulation" assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation" assert gelu_in is None, "GELU input not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation" assert ub is None, "User buffer not supported with int8 simulation"
...@@ -251,7 +271,7 @@ def general_gemm( ...@@ -251,7 +271,7 @@ def general_gemm(
return out, bias_grad, gelu_input, extra_output return out, bias_grad, gelu_input, extra_output
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)): if int8_simulation_fp8 and (isinstance(A, Float8TensorStorage) or isinstance(B, Float8TensorStorage)):
assert not gelu, "GELU not supported with int8 simulation" assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation" assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation" assert bias is None, "Bias not supported with int8 simulation"
...@@ -440,7 +460,7 @@ def general_grouped_gemm( ...@@ -440,7 +460,7 @@ def general_grouped_gemm(
for o in out for o in out
] # this should differ with respect to single output ] # this should differ with respect to single output
if int8_simulation_fp8 and (isinstance(A[0], Float8BlockwiseQTensorBase) or isinstance(B[0], Float8BlockwiseQTensorBase)): if int8_simulation_fp8 and (isinstance(A[0], Float8BlockwiseQTensorStorage) or isinstance(B[0], Float8BlockwiseQTensorStorage)):
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now." assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm." assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm." assert not use_bias, "Bias not supported with int8 simulation groupgemm."
...@@ -502,7 +522,7 @@ def general_grouped_gemm( ...@@ -502,7 +522,7 @@ def general_grouped_gemm(
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8") raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise and int8_simulation_fp8_tensorwise_batched: if int8_simulation_fp8 and (isinstance(A[0], Float8TensorStorage) or isinstance(B[0], Float8TensorStorage)) and int8_simulation_fp8_tensorwise and int8_simulation_fp8_tensorwise_batched:
assert len(set(m_splits)) == 1, "Need token pad as same as batchgemm for NVTE_INT8_SIM_FP8_TENSORWISE_BATCHED." assert len(set(m_splits)) == 1, "Need token pad as same as batchgemm for NVTE_INT8_SIM_FP8_TENSORWISE_BATCHED."
assert not gelu, "GELU not supported with int8 simulation groupgemm." assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm." assert not use_bias, "Bias not supported with int8 simulation groupgemm."
...@@ -616,7 +636,7 @@ def general_grouped_gemm( ...@@ -616,7 +636,7 @@ def general_grouped_gemm(
return out, bias, gelu_input return out, bias, gelu_input
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise: if int8_simulation_fp8 and (isinstance(A[0], Float8TensorStorage) or isinstance(B[0], Float8TensorStorage)) and int8_simulation_fp8_tensorwise:
assert not gelu, "GELU not supported with int8 simulation groupgemm." assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation" assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
bias = tex.te_general_grouped_gemm( bias = tex.te_general_grouped_gemm(
...@@ -642,7 +662,7 @@ def general_grouped_gemm( ...@@ -642,7 +662,7 @@ def general_grouped_gemm(
return out, bias, gelu_input return out, bias, gelu_input
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)): if int8_simulation_fp8 and (isinstance(A[0], Float8TensorStorage) or isinstance(B[0], Float8TensorStorage)):
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now." assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm." assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm." assert not use_bias, "Bias not supported with int8 simulation groupgemm."
......
...@@ -10,12 +10,13 @@ from typing import Any, Dict, Optional ...@@ -10,12 +10,13 @@ from typing import Any, Dict, Optional
import torch import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
from .tensor.quantized_tensor import QuantizedTensorBase from .tensor.quantized_tensor import QuantizedTensorStorage
from .tensor.float8_tensor import Float8Tensor from .tensor.float8_tensor import Float8Tensor
__all__ = ["get_cpu_offload_context"] __all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False CPUOffloadEnabled = False
CPUOffloadedLayer = False
def get_cpu_offloading(): def get_cpu_offloading():
global CPUOffloadEnabled global CPUOffloadEnabled
...@@ -42,7 +43,7 @@ def mark_activation_offload(*tensors): ...@@ -42,7 +43,7 @@ def mark_activation_offload(*tensors):
if tensor is not None: if tensor is not None:
tensor.activation_offloading = True tensor.activation_offloading = True
# This is a hack to force clear the tensor after it is offloaded. # This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorBase classes are saved in the ctx, # It is needed, because .*TensorStorage classes are saved in the ctx,
# and they contain the reference to their data tensors. # and they contain the reference to their data tensors.
tensor.needs_force_clear = True tensor.needs_force_clear = True
...@@ -361,6 +362,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -361,6 +362,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
self.h2d_stream = torch.cuda.Stream() self.h2d_stream = torch.cuda.Stream()
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
global CPUOffloadedLayer
torch_stray_tensor = isinstance( torch_stray_tensor = isinstance(
tensor, tensor,
...@@ -370,7 +372,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -370,7 +372,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
), ),
) )
is_quantized_tensor = isinstance(tensor, QuantizedTensorBase) is_quantized_tensor = isinstance(tensor, QuantizedTensorStorage)
if not torch_stray_tensor: if not torch_stray_tensor:
...@@ -416,6 +418,11 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -416,6 +418,11 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
tensor.clear() tensor.clear()
else: else:
self.tensor_tag_to_buf[tensor_tag] = t self.tensor_tag_to_buf[tensor_tag] = t
# Needed to differentiate non offloaded layer's attention
# QKV layout of attention of non-offloaded layer needs
# to be modified while reloading
CPUOffloadedLayer = True
else: else:
tensor_tag = (-1, self.torch_tensor_count) tensor_tag = (-1, self.torch_tensor_count)
self.torch_tensor_count += 1 self.torch_tensor_count += 1
...@@ -425,6 +432,8 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -425,6 +432,8 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
def tensor_pop(self, tensor_tag, **kwargs): def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop.""" """Tensor pop."""
global CPUOffloadedLayer
assert tensor_tag in self.tensor_tag_to_state assert tensor_tag in self.tensor_tag_to_state
tensor = self.tensor_tag_to_state.pop(tensor_tag) tensor = self.tensor_tag_to_state.pop(tensor_tag)
...@@ -488,6 +497,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -488,6 +497,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
def synchronize_on_group_commit_forward(self, current_group): def synchronize_on_group_commit_forward(self, current_group):
"""Synchronize on group commit forward.""" """Synchronize on group commit forward."""
global CPUOffloadedLayer
# For the first group, kickstart the offload after we have # For the first group, kickstart the offload after we have
# the first compute completion # the first compute completion
...@@ -522,7 +532,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -522,7 +532,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
if tensor_tag[0] == self.offloaded_group_count: if tensor_tag[0] == self.offloaded_group_count:
if hasattr(tensor_buf, "needs_force_clear"): if hasattr(tensor_buf, "needs_force_clear"):
# Need to clear activation tensor - sometimes references persist in the code. # Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorBase class, # This is the case for example with the Float8TensorStorage class,
# which is saved directly inside the ctx while its internal tensors are # which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward. # saved inside save_for_backward.
tensor_buf.data = torch.Tensor() tensor_buf.data = torch.Tensor()
...@@ -536,6 +546,9 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -536,6 +546,9 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Increment the offload group count to keep track # Increment the offload group count to keep track
self.offloaded_group_count += 1 self.offloaded_group_count += 1
if current_group == (self.num_offload_group - 1):
CPUOffloadedLayer = False
if not self.double_buffer_created: if not self.double_buffer_created:
# Creating second copy of double buffer for tensors that are offloaded # Creating second copy of double buffer for tensors that are offloaded
if current_group == (self.num_layers - 1): if current_group == (self.num_layers - 1):
......
...@@ -12,6 +12,20 @@ ...@@ -12,6 +12,20 @@
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
/*! convert fp4 data shape back to original shape */
std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose) {
std::vector<size_t> ret;
size_t start_idx = (transpose) ? 1 : 0;
for (size_t i = start_idx; i < shape.size() - 1; ++i) {
ret.push_back(shape[i]);
}
ret.push_back(shape.back() * 2);
if (transpose) {
ret.push_back(shape.front());
}
return ret;
}
std::vector<size_t> getTensorShape(const at::Tensor& t) { std::vector<size_t> getTensorShape(const at::Tensor& t) {
std::vector<size_t> shape; std::vector<size_t> shape;
for (auto s : t.sizes()) { for (auto s : t.sizes()) {
...@@ -291,4 +305,20 @@ size_t roundup(const size_t value, const size_t multiple) { ...@@ -291,4 +305,20 @@ size_t roundup(const size_t value, const size_t multiple) {
return ((value + multiple - 1) / multiple) * multiple; return ((value + multiple - 1) / multiple) * multiple;
} }
void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) {
NVTE_SCOPED_GIL_RELEASE({
nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val,
arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_,
at::cuda::getCurrentCUDAStream());
});
}
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread) {
at::PhiloxCudaState philox_args;
std::lock_guard<std::mutex> lock(gen->mutex_);
philox_args = gen->philox_cuda_state(elts_per_thread);
return philox_args;
}
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <transformer_engine/fused_rope.h> #include <transformer_engine/fused_rope.h>
#include <transformer_engine/fused_router.h> #include <transformer_engine/fused_router.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/hadamard_transform.h>
#include <transformer_engine/multi_stream.h> #include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_tensor.h> #include <transformer_engine/multi_tensor.h>
#include <transformer_engine/normalization.h> #include <transformer_engine/normalization.h>
...@@ -212,20 +213,25 @@ class Float8CurrentScalingQuantizer : public Quantizer { ...@@ -212,20 +213,25 @@ class Float8CurrentScalingQuantizer : public Quantizer {
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override; DType dtype) const override;
/*! @brief Construct a high precision tensor giving it this quantizer's amax /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer.
*
Note: this member function also zeros out the amax, as it is meant to be used in conjunction with * The amax is zeroed out. Most TE kernels that output amax expect
a kernel computing the amax, which might expect the amax to be initialized to zero * amax to be initialized to zero.
*/ */
std::pair<TensorWrapper, py::object> create_hp_tensor_with_amax(const std::vector<size_t>& shape, std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
DType dtype); const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> data = std::nullopt);
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override; std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out, void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override; const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
/*! @brief Convert to a quantized data format avoiding amax computation */ /*! @brief Quantize to FP8, skipping local amax computation
*
* The quantizer's amax pointer is assumed to already hold the local
* amax. The amax may still be reduced across the amax reduction
* group.
*/
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, void quantize_with_amax(TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt); const std::optional<TensorWrapper>& noop_flag = std::nullopt);
...@@ -295,6 +301,60 @@ class MXFP8Quantizer : public Quantizer { ...@@ -295,6 +301,60 @@ class MXFP8Quantizer : public Quantizer {
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const; std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
}; };
class NVFP4Quantizer : public Quantizer {
public:
// fp4 dtype
DType dtype;
// amax reduction for low precision FP4 AG
bool with_amax_reduction;
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
// random hadamard transform
bool with_rht;
bool with_post_rht_amax;
// 2D block scaling
bool with_2d_quantization;
bool stochastic_rounding;
int rht_matrix_random_sign_mask_t;
at::Tensor rht_matrix;
explicit NVFP4Quantizer(const py::handle& quantizer);
NVTEScalingMode get_scaling_mode() const override { return NVTE_NVFP4_1D_SCALING; }
void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
/*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer
*
* The amax is zeroed out. Most TE kernels that output amax expect
* amax to be initialized to zero.
*/
std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
TensorWrapper& quantized_tensor, DType dtype);
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
/*! @brief Quantize to NVFP4, skipping local amax computation
*
* The input tensor's amax pointer is assumed to already hold the
* local amax. The amax may still be reduced across the amax
* reduction group.
*/
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out);
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
private:
void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
};
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer); std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);
std::vector<size_t> getTensorShape(const at::Tensor& t); std::vector<size_t> getTensorShape(const at::Tensor& t);
...@@ -445,6 +505,15 @@ std::vector<size_t> convertShape(const NVTEShape& shape); ...@@ -445,6 +505,15 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
size_t roundup(const size_t value, const size_t multiple); size_t roundup(const size_t value, const size_t multiple);
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose);
// unpack the PhiloxCudaState into CUDA tensor
void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr);
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread);
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
namespace std { namespace std {
......
...@@ -73,28 +73,36 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T ...@@ -73,28 +73,36 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Fused_Attn_Backend get_fused_attn_backend(
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype,
bool create_hp_tensor_for_cs,
std::optional<at::Tensor> data);
std::vector<py::object> fused_attn_fwd( std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const std::optional<at::Tensor> cu_seqlens_q_padded, const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread); const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<py::object> fused_attn_bwd( std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q, NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const py::handle K, const py::handle V, const py::handle O, const py::handle dO,
const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
...@@ -233,6 +241,10 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer); ...@@ -233,6 +241,10 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer);
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha);
py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
float limit, float alpha);
/*************************************************************************************************** /***************************************************************************************************
* LayerNorm * LayerNorm
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -3,184 +3,331 @@ ...@@ -3,184 +3,331 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../extensions.h" #include "../extensions.h"
#include "common.h" #include "common.h"
#include "pybind.h" #include "pybind.h"
namespace transformer_engine::pytorch { namespace transformer_engine {
namespace pytorch {
namespace {
using FuncType = void(const NVTETensor, NVTETensor, cudaStream_t);
using DFuncType = void(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t);
template <void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t)> template <FuncType* act_func, auto act_func_with_args, typename... Args>
py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1,
Args&&... args) {
init_extension(); init_extension();
// Input tensor // Input tensor
auto input_tensor = input.contiguous(); auto input_tensor = input.contiguous();
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor);
// Construct output tensor // Construct output tensor
auto quantizer_cpp = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape = input_cpp.shape(); const auto input_shape = input_nvte.shape();
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim); std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
output_shape.back() /= shape_divisor; output_shape.back() /= shape_divisor;
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); auto [out_nvte, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype);
// Compute activation // Choose implementation
enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 };
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) { detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation directly impl = Impl::FULLY_FUSED;
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); });
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation in high-precision fused together with amax, then quantize. impl = Impl::FUSED_ACTIVATION_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get()); auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
NVTE_SCOPED_GIL_RELEASE( if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); // Post-RHT amax is handled within NVFP4 quantizer
quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); impl = Impl::UNFUSED;
} else { } else {
// Compute activation in high-precision, then quantize impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
}
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); }
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); // Perform compute
quantizer_cpp->quantize(temp_cpp, out_cpp); auto stream = at::cuda::getCurrentCUDAStream();
switch (impl) {
case Impl::UNFUSED:
// Compute activation in high precision, then quantize
{
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), temp_nvte.data(), stream);
}
});
quantizer_cpp->quantize(temp_nvte, out_nvte);
}
break;
case Impl::FULLY_FUSED:
// Compute activation directly
{
NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), out_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), out_nvte.data(), stream);
}
});
}
break;
case Impl::FUSED_ACTIVATION_AMAX_FP8:
// Compute activation and amax in high precision, then quantize to FP8
{
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), temp_nvte.data(), stream);
}
});
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
}
break;
case Impl::FUSED_ACTIVATION_AMAX_NVFP4:
// Compute activation and amax in high precision, then quantize to NVFP4
{
auto nvfp4_quantizer_cpp =
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), temp_nvte.data(), stream);
}
});
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
}
break;
default:
NVTE_ERROR("Invalid activation implementation (", static_cast<int>(impl), ")");
} }
return out_py; return out_py;
} }
template <void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)> template <DFuncType* dact_func, auto dact_func_with_args, typename... Args>
py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input,
py::handle quantizer) { py::handle quantizer, Args&&... args) {
init_extension(); init_extension();
// Grad output and input tensors // Grad output and input tensors
auto grad_output_tensor = grad_output.contiguous(); auto grad_output_tensor = grad_output.contiguous();
auto input_tensor = input.contiguous(); auto input_tensor = input.contiguous();
const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor); const TensorWrapper& grad_output_nvte = makeTransformerEngineTensor(grad_output_tensor);
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor);
// Construct grad input tensor // Construct grad input tensor
auto quantizer_cpp = convert_quantizer(quantizer); auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape_te = input_cpp.shape(); const auto input_shape_te = input_nvte.shape();
const std::vector<size_t> input_shape(input_shape_te.data, const std::vector<size_t> input_shape(input_shape_te.data,
input_shape_te.data + input_shape_te.ndim); input_shape_te.data + input_shape_te.ndim);
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype);
// Compute activation backward // Choose implementation
enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 };
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) { detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation backward directly impl = Impl::FULLY_FUSED;
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation backward in high-precision fused together with amax, then quantize. impl = Impl::FUSED_ACTIVATION_AMAX_FP8;
auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get()); } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
NVTE_SCOPED_GIL_RELEASE({ NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
at::cuda::getCurrentCUDAStream()); // Post-RHT amax is handled within NVFP4 quantizer
}); impl = Impl::UNFUSED;
quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); } else {
} else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
// Compute activation backward in high-precision, then quantize }
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); }
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), // Perform compute
at::cuda::getCurrentCUDAStream()); auto stream = at::cuda::getCurrentCUDAStream();
}); switch (impl) {
quantizer_cpp->quantize(temp_cpp, grad_input_cpp); case Impl::UNFUSED:
// Compute activation backward in high precision, then quantize
{
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream);
}
});
quantizer_cpp->quantize(temp_nvte, grad_input_nvte);
}
break;
case Impl::FULLY_FUSED:
// Compute activation backward directly
{
NVTE_SCOPED_GIL_RELEASE({
if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream);
}
});
}
break;
case Impl::FUSED_ACTIVATION_AMAX_FP8:
// Compute activation and amax in high precision, then quantize to FP8
{
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream);
}
});
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
}
break;
case Impl::FUSED_ACTIVATION_AMAX_NVFP4:
// Compute activation and amax in high precision, then quantize to NVFP4
{
auto nvfp4_quantizer_cpp =
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream);
}
});
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
}
break;
default:
NVTE_ERROR("Invalid activation implementation (", static_cast<int>(impl), ")");
} }
return grad_input_py; return grad_input_py;
} }
} // namespace
/* GELU and variants*/ /* GELU and variants */
py::object gelu(const at::Tensor& input, py::handle quantizer) { py::object gelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_gelu>(input, quantizer); return activation_helper<nvte_gelu, nullptr>(input, quantizer);
} }
py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dgelu>(grad, input, quantizer); return dactivation_helper<nvte_dgelu, nullptr>(grad, input, quantizer);
} }
py::object geglu(const at::Tensor& input, py::handle quantizer) { py::object geglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_geglu>(input, quantizer, 2); return activation_helper<nvte_geglu, nullptr>(input, quantizer, 2);
} }
py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dgeglu>(grad, input, quantizer); return dactivation_helper<nvte_dgeglu, nullptr>(grad, input, quantizer);
} }
py::object qgelu(const at::Tensor& input, py::handle quantizer) { py::object qgelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgelu>(input, quantizer); return activation_helper<nvte_qgelu, nullptr>(input, quantizer);
} }
py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgelu>(grad, input, quantizer); return dactivation_helper<nvte_dqgelu, nullptr>(grad, input, quantizer);
} }
py::object qgeglu(const at::Tensor& input, py::handle quantizer) { py::object qgeglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgeglu>(input, quantizer, 2); return activation_helper<nvte_qgeglu, nullptr>(input, quantizer, 2);
} }
py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgeglu>(grad, input, quantizer); return dactivation_helper<nvte_dqgeglu, nullptr>(grad, input, quantizer);
} }
/* ReLU and variants*/ /* ReLU and variants */
py::object relu(const at::Tensor& input, py::handle quantizer) { py::object relu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_relu>(input, quantizer); return activation_helper<nvte_relu, nullptr>(input, quantizer);
} }
py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_drelu>(grad, input, quantizer); return dactivation_helper<nvte_drelu, nullptr>(grad, input, quantizer);
} }
py::object reglu(const at::Tensor& input, py::handle quantizer) { py::object reglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_reglu>(input, quantizer, 2); return activation_helper<nvte_reglu, nullptr>(input, quantizer, 2);
} }
py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dreglu>(grad, input, quantizer); return dactivation_helper<nvte_dreglu, nullptr>(grad, input, quantizer);
} }
py::object srelu(const at::Tensor& input, py::handle quantizer) { py::object srelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_srelu>(input, quantizer); return activation_helper<nvte_srelu, nullptr>(input, quantizer);
} }
py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsrelu>(grad, input, quantizer); return dactivation_helper<nvte_dsrelu, nullptr>(grad, input, quantizer);
} }
py::object sreglu(const at::Tensor& input, py::handle quantizer) { py::object sreglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_sreglu>(input, quantizer, 2); return activation_helper<nvte_sreglu, nullptr>(input, quantizer, 2);
} }
py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsreglu>(grad, input, quantizer); return dactivation_helper<nvte_dsreglu, nullptr>(grad, input, quantizer);
} }
/* Silu and variants */
/* Silu and variants*/
py::object silu(const at::Tensor& input, py::handle quantizer) { py::object silu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_silu>(input, quantizer); return activation_helper<nvte_silu, nullptr>(input, quantizer);
} }
py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsilu>(grad, input, quantizer); return dactivation_helper<nvte_dsilu, nullptr>(grad, input, quantizer);
} }
py::object swiglu(const at::Tensor& input, py::handle quantizer) { py::object swiglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_swiglu>(input, quantizer, 2); return activation_helper<nvte_swiglu, nullptr>(input, quantizer, 2);
} }
py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dswiglu>(grad, input, quantizer); return dactivation_helper<nvte_dswiglu, nullptr>(grad, input, quantizer);
}
/* clamped functions */
py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) {
return activation_helper<nullptr, nvte_clamped_swiglu>(input, quantizer, 2, limit, alpha);
} }
} // namespace transformer_engine::pytorch
py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer,
float limit, float alpha) {
return dactivation_helper<nullptr, nvte_clamped_dswiglu>(grad, input, quantizer, limit, alpha);
}
} // namespace pytorch
} // namespace transformer_engine
...@@ -35,22 +35,6 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s ...@@ -35,22 +35,6 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
{ nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); }); { nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); });
} }
void unpack(at::PhiloxCudaState arg, int64_t *rng_state_ptr) {
NVTE_SCOPED_GIL_RELEASE({
nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val,
arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_,
at::cuda::getCurrentCUDAStream());
});
}
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_per_thread) {
at::PhiloxCudaState philox_args;
std::lock_guard<std::mutex> lock(gen->mutex_);
philox_args = gen->philox_cuda_state(elts_per_thread);
return philox_args;
}
} // namespace } // namespace
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
...@@ -58,73 +42,102 @@ namespace transformer_engine::pytorch { ...@@ -58,73 +42,102 @@ namespace transformer_engine::pytorch {
// get the fused attention backend // get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Fused_Attn_Backend get_fused_attn_backend(
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
return NVTE_Fused_Attn_Backend::NVTE_No_Backend; return NVTE_Fused_Attn_Backend::NVTE_No_Backend;
#else #else
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups,
max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend; return fused_attention_backend;
#endif #endif
} }
// helper function for S and dP quantizers
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype,
bool create_hp_tensor_for_cs,
std::optional<at::Tensor> data) {
std::unique_ptr<Quantizer> T_quantizer = convert_quantizer(quantizer);
TensorWrapper te_T;
py::object py_T;
if (quantizer.is_none()) {
// high precision
auto *none_quantizer = dynamic_cast<NoneQuantizer *>(T_quantizer.get());
if (data.has_value()) {
std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype, data.value());
} else {
std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype);
}
} else if (detail::IsFloat8Quantizers(quantizer.ptr())) {
// delayed scaling; this helps initialize scale_inv
auto *T_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(T_quantizer.get());
std::tie(te_T, py_T) =
T_quantizer_fp8->create_tensor(shape, dtype, data, std::nullopt, std::nullopt);
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// current scaling
auto *T_quantizer_fp8 = dynamic_cast<Float8CurrentScalingQuantizer *>(T_quantizer.get());
if (create_hp_tensor_for_cs) {
if (data.has_value()) {
std::tie(te_T, py_T) =
T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value());
} else {
std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype);
}
} else {
std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype);
NVTE_CHECK(
!data.has_value(),
"Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!");
}
}
return {std::move(te_T), std::move(py_T)};
}
// fused attention FWD with separate Q, K and V tensors // fused attention FWD with separate Q, K and V tensors
std::vector<py::object> fused_attn_fwd( std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const std::optional<at::Tensor> cu_seqlens_q_padded, const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) { const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
assert(false); assert(false);
#else #else
TensorWrapper te_Q, te_K, te_V, te_O, te_S;
auto none = py::none(); auto none = py::none();
std::unique_ptr<Quantizer> S_quantizer = convert_quantizer(s_quantizer);
std::unique_ptr<Quantizer> O_quantizer = convert_quantizer(o_quantizer);
// create QKV tensor wrappers
TensorWrapper te_Q, te_K, te_V;
te_Q = makeTransformerEngineTensor(Q, none); te_Q = makeTransformerEngineTensor(Q, none);
te_K = makeTransformerEngineTensor(K, none); te_K = makeTransformerEngineTensor(K, none);
te_V = makeTransformerEngineTensor(V, none); te_V = makeTransformerEngineTensor(V, none);
// If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types.
const DType qkv_type = te_Q.dtype(); const DType qkv_type = te_Q.dtype();
const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
// create S tensor
TensorWrapper te_S;
py::object py_S;
std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt);
// create O tensor
TensorWrapper te_O;
py::object py_O;
std::unique_ptr<Quantizer> O_quantizer = convert_quantizer(o_quantizer);
std::vector<size_t> q_shape = convertShape(te_Q.shape()); std::vector<size_t> q_shape = convertShape(te_Q.shape());
std::vector<size_t> k_shape = convertShape(te_K.shape());
std::vector<size_t> v_shape = convertShape(te_V.shape()); std::vector<size_t> v_shape = convertShape(te_V.shape());
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
// create output tensor O
auto o_shape = std::vector<size_t>{q_shape.begin(), q_shape.end()}; auto o_shape = std::vector<size_t>{q_shape.begin(), q_shape.end()};
o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1];
py::object o_python, s_python; const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt);
// Initialize FP8 tensor with scale-inverse
auto *O_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(O_quantizer.get());
auto *S_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(S_quantizer.get());
NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt,
std::nullopt, std::nullopt);
std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
} else {
std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te);
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32);
}
auto o_shape_int64 = std::vector<int64_t>{o_shape.begin(), o_shape.end()};
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_Bias; TensorWrapper te_Bias;
...@@ -135,11 +148,12 @@ std::vector<py::object> fused_attn_fwd( ...@@ -135,11 +148,12 @@ std::vector<py::object> fused_attn_fwd(
// FP8 // FP8
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1]; auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0) && if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
(nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { if ((h * d) % block_size == 0) {
mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
te_O.zero_(at::cuda::getCurrentCUDAStream()); te_O.zero_(at::cuda::getCurrentCUDAStream());
}
} }
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
...@@ -188,12 +202,23 @@ std::vector<py::object> fused_attn_fwd( ...@@ -188,12 +202,23 @@ std::vector<py::object> fused_attn_fwd(
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
} }
// softmax offset
TensorWrapper te_SoftmaxOffset;
if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) {
auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec();
std::vector<size_t> SoftmaxOffset_shape{SoftmaxOffset_sizes.begin(), SoftmaxOffset_sizes.end()};
te_SoftmaxOffset =
makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), SoftmaxOffset_shape,
DType::kFloat32, nullptr, nullptr, nullptr);
}
// extract rng seed and offset // extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
unpack(philox_args, static_cast<int64_t *>(rng_state.data_ptr())); auto rng_state = torch::empty({2}, options);
philox_unpack(philox_args, static_cast<int64_t *>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state); auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors // create auxiliary output tensors
...@@ -206,11 +231,11 @@ std::vector<py::object> fused_attn_fwd( ...@@ -206,11 +231,11 @@ std::vector<py::object> fused_attn_fwd(
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_fwd( nvte_fused_attn_fwd(
te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
}); });
...@@ -221,52 +246,53 @@ std::vector<py::object> fused_attn_fwd( ...@@ -221,52 +246,53 @@ std::vector<py::object> fused_attn_fwd(
// output_tensors = [O, nvte_aux_tensor_pack.tensors] // output_tensors = [O, nvte_aux_tensor_pack.tensors]
std::vector<py::object> output_tensors; std::vector<py::object> output_tensors;
output_tensors.push_back(o_python); output_tensors.push_back(py_O);
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { auto set_tensor_param = [&](size_t i, const at::Tensor &output_tensor) {
// allocate memory for nvte_aux_tensor_pack.tensors
at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
if (i < nvte_aux_tensor_pack.size - 2) {
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor = allocateSpace(
nvte_shape_to_vector(temp_shape),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
} else if (i == nvte_aux_tensor_pack.size - 2) {
output_tensor = rng_state;
} else if (i == nvte_aux_tensor_pack.size - 1) {
output_tensor = Bias.value();
}
} else {
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor =
(i < nvte_aux_tensor_pack.size - 1)
? allocateSpace(
nvte_shape_to_vector(temp_shape),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false)
: rng_state;
}
} else {
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor = allocateSpace(
nvte_shape_to_vector(temp_shape),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
}
output_tensors.push_back(py::cast(output_tensor)); output_tensors.push_back(py::cast(output_tensor));
NVTEBasicTensor temp_data = {output_tensor.data_ptr(), NVTEBasicTensor temp_data = {output_tensor.data_ptr(),
nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]), nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]),
nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])}; nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])};
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data);
};
// allocate memory for nvte_aux_tensor_pack.tensors
// f16_max512 : S [b, h, sq, skv]
// f16_arbitrary: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2]
size_t i = 0;
at::Tensor output_tensor;
// intermediate softmax tensor, S or M
output_tensor =
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
set_tensor_param(i++, output_tensor);
// fp8 has an additional softmax stats tensor, ZInv
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
output_tensor =
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
set_tensor_param(i++, output_tensor);
}
// rng_state
if (i < nvte_aux_tensor_pack.size) {
set_tensor_param(i++, rng_state);
}
// bias (optional)
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
set_tensor_param(i++, Bias.value());
}
// softmax_offset (optional)
if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) {
set_tensor_param(i++, SoftmaxOffset.value());
} }
// execute the kernel // execute the kernel
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_fwd( nvte_fused_attn_fwd(
te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
}); });
...@@ -282,9 +308,10 @@ std::vector<py::object> fused_attn_fwd( ...@@ -282,9 +308,10 @@ std::vector<py::object> fused_attn_fwd(
std::vector<py::object> fused_attn_bwd( std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q, NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const py::handle K, const py::handle V, const py::handle O, const py::handle dO,
const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
...@@ -293,50 +320,44 @@ std::vector<py::object> fused_attn_bwd( ...@@ -293,50 +320,44 @@ std::vector<py::object> fused_attn_bwd(
assert(false); assert(false);
#else #else
auto none = py::none(); auto none = py::none();
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
// create QKV, O, dO tensor wrappers
TensorWrapper te_Q, te_K, te_V, te_O, te_dO;
te_Q = makeTransformerEngineTensor(Q, none); te_Q = makeTransformerEngineTensor(Q, none);
te_K = makeTransformerEngineTensor(K, none); te_K = makeTransformerEngineTensor(K, none);
te_V = makeTransformerEngineTensor(V, none); te_V = makeTransformerEngineTensor(V, none);
te_O = makeTransformerEngineTensor(O, none); te_O = makeTransformerEngineTensor(O, none);
te_dO = makeTransformerEngineTensor(dO, none); te_dO = makeTransformerEngineTensor(dO, none);
// qkv type from the te_Q
std::unique_ptr<Quantizer> dQKV_quantizer = convert_quantizer(dqkv_quantizer);
const DType qkv_type = te_Q.dtype();
const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
py::object s_python, dp_python;
std::unique_ptr<Quantizer> S_quantizer = convert_quantizer(s_quantizer);
std::unique_ptr<Quantizer> dP_quantizer = convert_quantizer(dp_quantizer);
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // create S and dP tensors
auto *S_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(S_quantizer.get()); TensorWrapper te_S, te_dP;
auto *dP_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(dP_quantizer.get()); py::object py_S, py_dP;
NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt);
NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); std::tie(te_dP, py_dP) =
std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, quantizer_helper(dp_quantizer, {0}, DType::kFloat32, false, std::nullopt);
std::nullopt, std::nullopt);
std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
} else {
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32);
std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32);
}
// create dQ, dK, dV tensors
TensorWrapper te_dQ, te_dK, te_dV;
py::object py_dQ, py_dK, py_dV;
std::unique_ptr<Quantizer> dQKV_quantizer = convert_quantizer(dqkv_quantizer);
std::vector<size_t> q_shape = convertShape(te_Q.shape()); std::vector<size_t> q_shape = convertShape(te_Q.shape());
std::vector<size_t> k_shape = convertShape(te_K.shape()); std::vector<size_t> k_shape = convertShape(te_K.shape());
std::vector<size_t> v_shape = convertShape(te_V.shape()); std::vector<size_t> v_shape = convertShape(te_V.shape());
auto h_q = q_shape[q_shape.size() - 2]; auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2];
auto d_qk = q_shape[q_shape.size() - 1]; auto d_qk = q_shape[q_shape.size() - 1];
auto d_v = v_shape[v_shape.size() - 1]; const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
std::vector<size_t> o_shape{q_shape.begin(), q_shape.end()};
o_shape[o_shape.size() - 1] = d_v;
at::Tensor dQ, dK, dV, dQKV, dKV; at::Tensor dQ, dK, dV, dQKV, dKV;
py::object py_dQ, py_dK, py_dV;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
std::vector<int64_t> tmp_shape; std::vector<int64_t> tmp_shape;
auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) {
options = options.dtype(torch::kUInt8);
}
if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) {
options = options.dtype(fake_dtype);
}
switch (layout_group) { switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_3HD: case NVTE_QKV_Layout_Group::NVTE_3HD:
...@@ -409,39 +430,27 @@ std::vector<py::object> fused_attn_bwd( ...@@ -409,39 +430,27 @@ std::vector<py::object> fused_attn_bwd(
default: default:
NVTE_ERROR("QKV layout not supported!"); NVTE_ERROR("QKV layout not supported!");
} }
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
auto *fp8_quantizer = dynamic_cast<Float8Quantizer *>(dQKV_quantizer.get()); std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ);
NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8"); std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK);
std::tie(te_dQ, py_dQ) = std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV);
fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt);
std::tie(te_dK, py_dK) =
fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt);
std::tie(te_dV, py_dV) =
fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt);
} else {
auto *none_quantizer = dynamic_cast<NoneQuantizer *>(dQKV_quantizer.get());
NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8");
std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ);
std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK);
std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV);
}
// construct NVTE tensors // construct NVTE tensors
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) &&
(nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) {
mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
dQ.fill_(0); dQ.fill_(0);
dK.fill_(0); dK.fill_(0);
dV.fill_(0); dV.fill_(0);
}
} }
} else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) {
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dQ.fill_(0); dQ.fill_(0);
dK.fill_(0); dK.fill_(0);
...@@ -510,6 +519,15 @@ std::vector<py::object> fused_attn_bwd( ...@@ -510,6 +519,15 @@ std::vector<py::object> fused_attn_bwd(
} }
} }
// create dSoftmaxOffset in the same shape as SoftmaxOffset
at::Tensor dSoftmaxOffset;
TensorWrapper te_dSoftmaxOffset;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
options = torch::TensorOptions().dtype(at::kFloat).device(torch::kCUDA);
dSoftmaxOffset = torch::empty({1, static_cast<int64_t>(h_q), 1, 1}, options);
te_dSoftmaxOffset = makeTransformerEngineTensor(dSoftmaxOffset);
}
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
...@@ -518,10 +536,10 @@ std::vector<py::object> fused_attn_bwd( ...@@ -518,10 +536,10 @@ std::vector<py::object> fused_attn_bwd(
nvte_fused_attn_bwd( nvte_fused_attn_bwd(
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv,
qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
}); });
// allocate memory for workspace // allocate memory for workspace
...@@ -534,16 +552,16 @@ std::vector<py::object> fused_attn_bwd( ...@@ -534,16 +552,16 @@ std::vector<py::object> fused_attn_bwd(
nvte_fused_attn_bwd( nvte_fused_attn_bwd(
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv,
qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
}); });
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {py_dQ, py_dK, py_dV, py::cast(dBias)}; return {py_dQ, py_dK, py_dV, py::cast(dBias), py::cast(dSoftmaxOffset)};
#endif #endif
} }
...@@ -610,7 +628,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s ...@@ -610,7 +628,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
// Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1 // Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1
int seq_dim = tensor.dim() == 3 ? 0 : 1; int seq_dim = tensor.dim() == 3 ? 0 : 1;
int batch = cu_seqlens.size(0) - 1;
int num_heads = tensor.size(seq_dim + 1); int num_heads = tensor.size(seq_dim + 1);
int dim_per_head = tensor.size(seq_dim + 2); int dim_per_head = tensor.size(seq_dim + 2);
int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type()); int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type());
...@@ -774,8 +791,6 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t ...@@ -774,8 +791,6 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
NVTE_CHECK(world_size > 0); NVTE_CHECK(world_size > 0);
NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0); NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0);
int batch = cu_seqlens.size(0) - 1;
std::vector<int64_t> shape = {total_tokens / world_size}; std::vector<int64_t> shape = {total_tokens / world_size};
at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int)); at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int));
...@@ -813,7 +828,6 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, ...@@ -813,7 +828,6 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b,
**************************************************************************************************/ **************************************************************************************************/
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) { at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) {
int max_seq_len = tensor.size(1);
int h = tensor.size(2); int h = tensor.size(2);
int d = tensor.size(3); int d = tensor.size(3);
std::vector<int64_t> shape = {t, h, d}; std::vector<int64_t> shape = {t, h, d};
......
...@@ -54,10 +54,25 @@ std::vector<py::object> bgrad_quantize(const at::Tensor &grad_output, py::handle ...@@ -54,10 +54,25 @@ std::vector<py::object> bgrad_quantize(const at::Tensor &grad_output, py::handle
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
} }
// Unfused impl if quantizer is not supported // Check if fused kernel is supported
const bool with_fused_dbias_quantize_kernel = bool with_fused_kernel = false;
detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr()); if (detail::IsFloat8Quantizers(quantizer.ptr())) {
if (!with_fused_dbias_quantize_kernel) { auto prop = at::cuda::getCurrentDeviceProperties();
const size_t sm_arch = 10 * prop->major + prop->minor;
if (sm_arch >= 100) {
// Fused kernel for dbias + FP8 cast on SM arch 10.0+
with_fused_kernel = true;
} else if (quantizer_cpp->rowwise_usage && quantizer_cpp->columnwise_usage) {
// Fused kernel for dbias + FP8 cast + FP8 transpose
with_fused_kernel = true;
}
} else if (detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Fused kernel for dbias + MXFP8 quantize
with_fused_kernel = true;
}
// Apply unfused impl if fused kernel is not supported
if (!with_fused_kernel) {
at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0}); at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0});
quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte); quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte);
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
...@@ -122,13 +137,27 @@ std::vector<py::object> dact_dbias( ...@@ -122,13 +137,27 @@ std::vector<py::object> dact_dbias(
} }
// Choose implementation // Choose implementation
enum class Impl { UNFUSED, FUSED_DACT_DBIAS_QUANTIZE, FUSED_DACT_AMAX }; enum class Impl {
UNFUSED,
FUSED_DACT_DBIAS_QUANTIZE,
FUSED_DACT_AMAX_FP8,
FUSED_DACT_AMAX_NVFP4
};
Impl impl = Impl::UNFUSED; Impl impl = Impl::UNFUSED;
if (detail::IsFloat8Quantizers(quantizer_py.ptr()) || if (detail::IsFloat8Quantizers(quantizer_py.ptr()) ||
detail::IsMXFP8Quantizers(quantizer_py.ptr())) { detail::IsMXFP8Quantizers(quantizer_py.ptr())) {
impl = Impl::FUSED_DACT_DBIAS_QUANTIZE; impl = Impl::FUSED_DACT_DBIAS_QUANTIZE;
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) {
impl = Impl::FUSED_DACT_AMAX; impl = Impl::FUSED_DACT_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else {
impl = Impl::FUSED_DACT_AMAX_NVFP4;
}
} }
// Perform compute // Perform compute
...@@ -172,20 +201,38 @@ std::vector<py::object> dact_dbias( ...@@ -172,20 +201,38 @@ std::vector<py::object> dact_dbias(
}); });
break; break;
} }
case Impl::FUSED_DACT_AMAX: case Impl::FUSED_DACT_AMAX_FP8:
// Fused dact-amax kernel, unfused dbias and quantize // Fused dact-amax kernel, unfused dbias and FP8 quantize
{ {
auto *quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get()); auto *fp8_quantizer_cpp =
NVTE_CHECK(quantizer_cpp_cs != nullptr, dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr,
"Invalid quantizer for fused dact-amax kernel impl"); "Invalid quantizer for fused dact-amax kernel impl");
auto [temp_nvte, temp_py] = auto [temp_nvte, temp_py] =
quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, grad_output_dtype); fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, grad_output_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream);
});
const auto temp_torch = temp_py.cast<at::Tensor>();
at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0});
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
break;
}
case Impl::FUSED_DACT_AMAX_NVFP4:
// Fused dact-amax kernel, unfused dbias and NVFP4 quantize
{
auto *nvfp4_quantizer_cpp =
static_cast<NVFP4Quantizer *>(quantizer_cpp.get()); // Already checked cast is valid
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr,
"Invalid quantizer for fused dact-amax kernel impl");
auto [temp_nvte, temp_py] = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(
grad_input_nvte, grad_output_dtype);
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream);
}); });
const auto temp_torch = temp_py.cast<at::Tensor>(); const auto temp_torch = temp_py.cast<at::Tensor>();
at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0});
quantizer_cpp_cs->quantize_with_amax(temp_nvte, grad_input_nvte); nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
break; break;
} }
default: default:
......
...@@ -37,7 +37,18 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob ...@@ -37,7 +37,18 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
// Convert input tensor to C++ object // Convert input tensor to C++ object
auto input_contiguous = tensor.contiguous(); auto input_contiguous = tensor.contiguous();
const auto input_cpp = makeTransformerEngineTensor(input_contiguous); auto input_cpp = makeTransformerEngineTensor(input_contiguous);
// Set amax if use_existing_amax = true (only valid for CS)
bool use_existing_amax = false;
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
use_existing_amax = quantizer.attr("use_existing_amax").cast<bool>();
if (use_existing_amax) {
const at::Tensor &amax = quantizer.attr("amax").cast<at::Tensor>();
input_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax));
}
}
// Initialize output tensor // Initialize output tensor
TensorWrapper output_cpp; TensorWrapper output_cpp;
...@@ -57,7 +68,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob ...@@ -57,7 +68,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
} }
// Perform quantization // Perform quantization
quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); if (use_existing_amax) {
auto *quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
quantizer_cs->quantize_with_amax(input_cpp, output_cpp, noop_flag_cpp);
} else {
quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp);
}
return output_py; return output_py;
} }
...@@ -298,7 +314,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp ...@@ -298,7 +314,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
// Construct FP8 block-wise tensors // Construct FP8 block-wise tensors
py::handle Float8BlockwiseQTensorClass( py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject *>(Float8BlockwiseQTensorBasePythonClass)); reinterpret_cast<PyObject *>(Float8BlockwiseQTensorStoragePythonClass));
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
// Create tensor objects with proper reference counting // Create tensor objects with proper reference counting
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
...@@ -445,7 +461,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx ...@@ -445,7 +461,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
} }
// Construct mxfp8 tensors // Construct mxfp8 tensors
py::handle MXFP8TensorClass(reinterpret_cast<PyObject *>(MXFP8TensorBasePythonClass)); py::handle MXFP8TensorClass(reinterpret_cast<PyObject *>(MXFP8TensorStoragePythonClass));
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
// Create tensor objects with proper reference counting // Create tensor objects with proper reference counting
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
......
...@@ -106,6 +106,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -106,6 +106,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
const bool low_precision = const bool low_precision =
detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype()); detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype());
const bool fp8_block_scaling = A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D ||
A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D ||
B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D ||
B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D;
// Check tensor dimensions // Check tensor dimensions
const auto& A_shape = A_tensor.shape(); const auto& A_shape = A_tensor.shape();
...@@ -215,6 +219,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -215,6 +219,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
const int sm_count = transformer_engine::cuda::sm_count(device_id); const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count); int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
// Construct GEMM config
transformer_engine::MatmulConfigWrapper config;
if (grad) {
config.set_dbias_tensor(bias_tensor.data());
config.set_with_dgelu_epilogue(gelu);
} else {
config.set_bias_tensor(bias_tensor.data());
config.set_with_gelu_epilogue(gelu);
}
config.set_epilogue_aux_tensor(te_pre_gelu_out.data());
config.set_use_split_accumulator(use_split_accumulator);
config.set_sm_count(num_math_sms);
// Keep the swizzled scaling factor tensors alive during the GEMM. // Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list; std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto main_stream = at::cuda::getCurrentCUDAStream(); auto main_stream = at::cuda::getCurrentCUDAStream();
...@@ -224,6 +241,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -224,6 +241,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
swizzled_scale_inverses_list.emplace_back( swizzled_scale_inverses_list.emplace_back(
std::move(swizzle_scaling_factors(B_tensor, !transb))); std::move(swizzle_scaling_factors(B_tensor, !transb)));
// Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer
// as it is not natively supported by cublasLt
if (fp8_block_scaling && transformer_engine::cuda::sm_arch() >= 100) {
// Convert tensors to mxfp8 and swizzle their scaling factors
swizzled_scale_inverses_list.emplace_back(
std::move(convert_block_scaling_to_mxfp8_tensor(A_tensor, transa)));
swizzled_scale_inverses_list.emplace_back(
std::move(convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb)));
// Use TN GEMM to avoid having to transpose data.
transa = true;
transb = false;
}
if (comm_overlap) { if (comm_overlap) {
// Prepare extra output tensor // Prepare extra output tensor
TensorWrapper extra_output_tensor; TensorWrapper extra_output_tensor;
...@@ -278,10 +308,9 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -278,10 +308,9 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
} else { } else {
// Launch GEMM // Launch GEMM
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), nvte_cublas_gemm_v2(transa, transb, &alpha, A_tensor.data(), B_tensor.data(), &beta.value(),
bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, out_tensor.data(), out_tensor.data(), te_workspace.data(), config,
te_workspace.data(), alpha, *beta, use_split_accumulator, main_stream);
num_math_sms, main_stream);
}); });
} }
} else { } else {
...@@ -369,15 +398,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -369,15 +398,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
std::vector<at::Tensor> bias, DType bias_type, bool single_output, std::vector<at::Tensor> bias, DType bias_type, bool single_output,
std::vector<at::Tensor> pre_gelu_out, bool grad, std::vector<at::Tensor> workspace, std::vector<at::Tensor> pre_gelu_out, bool grad, std::vector<at::Tensor> workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) {
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> te_A_wrappers, te_B_wrappers, wrappers;
std::vector<at::Tensor> D_vectors;
auto none = py::none();
std::vector<size_t> single_output_begins;
std::vector<size_t> single_output_ends;
if (single_output && D == std::nullopt) { if (single_output && D == std::nullopt) {
NVTE_ERROR("not implemented, D should be allocated for single output case."); NVTE_ERROR("not implemented, D should be allocated for single output case.");
} }
...@@ -387,6 +407,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -387,6 +407,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
output_data_ptr = (*D)[0].data_ptr(); output_data_ptr = (*D)[0].data_ptr();
} }
const auto none = py::none();
std::vector<TensorWrapper> te_A_wrappers, te_B_wrappers, te_D_wrappers, te_bias_wrappers,
te_pre_gelu_out_wrappers;
std::vector<at::Tensor> D_vectors;
for (size_t i = 0; i < A.size(); i++) { for (size_t i = 0; i < A.size(); i++) {
auto te_A = makeTransformerEngineTensor(A[i], none); auto te_A = makeTransformerEngineTensor(A[i], none);
auto te_B = makeTransformerEngineTensor(B[i], none); auto te_B = makeTransformerEngineTensor(B[i], none);
...@@ -452,29 +476,72 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -452,29 +476,72 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_pre_gelu_out = te_pre_gelu_out =
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type); makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type);
te_A_vector.emplace_back(te_A.data());
te_B_vector.emplace_back(te_B.data());
te_D_vector.emplace_back(te_D.data());
te_bias_vector.emplace_back(te_bias.data());
te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data());
te_A_wrappers.emplace_back(std::move(te_A)); te_A_wrappers.emplace_back(std::move(te_A));
te_B_wrappers.emplace_back(std::move(te_B)); te_B_wrappers.emplace_back(std::move(te_B));
wrappers.emplace_back(std::move(te_D)); te_D_wrappers.emplace_back(std::move(te_D));
wrappers.emplace_back(std::move(te_bias)); te_bias_wrappers.emplace_back(std::move(te_bias));
wrappers.emplace_back(std::move(te_pre_gelu_out)); te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out));
} }
// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
// Optionally swizzle the scaling factors // Optionally swizzle the scaling factors
// Keep the swizzled scaling factor tensors alive during the GEMMs. swizzled_scale_inverses_list.emplace_back(
auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa); multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa));
auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb); swizzled_scale_inverses_list.emplace_back(
multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb));
// Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer
// as it is not natively supported by cublasLt
if (transformer_engine::cuda::sm_arch() >= 100) {
// Check if is using FP8 block scaling
bool exists_tensor_using_fp8_block_scaling = false;
bool exists_tensor_not_using_fp8_block_scaling = false;
for (const auto& tensor_wrappers : {&te_A_wrappers, &te_B_wrappers}) {
for (const TensorWrapper& tensor : *tensor_wrappers) {
const NVTEScalingMode scaling_mode = tensor.scaling_mode();
if (scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D)
exists_tensor_using_fp8_block_scaling = true;
else
exists_tensor_not_using_fp8_block_scaling = true;
}
}
if (exists_tensor_using_fp8_block_scaling) {
NVTE_CHECK(!exists_tensor_not_using_fp8_block_scaling,
"Either all tensors or no tensor must be FP8 block scaling tensors");
// Convert tensors to mxfp8 and swizzle their scaling factors
for (TensorWrapper& A_tensor : te_A_wrappers) {
swizzled_scale_inverses_list.emplace_back(
convert_block_scaling_to_mxfp8_tensor(A_tensor, transa));
}
for (TensorWrapper& B_tensor : te_B_wrappers) {
swizzled_scale_inverses_list.emplace_back(
convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb));
}
// Use TN GEMM to avoid having to transpose data.
transa = true;
transb = false;
}
}
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector;
for (size_t i = 0; i < te_A_wrappers.size(); i++) {
te_A_vector.emplace_back(te_A_wrappers[i].data());
te_B_vector.emplace_back(te_B_wrappers[i].data());
te_D_vector.emplace_back(te_D_wrappers[i].data());
te_bias_vector.emplace_back(te_bias_wrappers[i].data());
te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out_wrappers[i].data());
}
std::vector<NVTETensor> te_workspace_vector;
std::vector<TensorWrapper> te_workspace_wrappers;
for (size_t i = 0; i < workspace.size(); i++) { for (size_t i = 0; i < workspace.size(); i++) {
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte); std::vector<size_t>{workspaceSize}, DType::kByte);
te_workspace_vector.emplace_back(wsp.data()); te_workspace_vector.emplace_back(wsp.data());
wrappers.emplace_back(std::move(wsp)); te_workspace_wrappers.emplace_back(std::move(wsp));
} }
// For now, we only have multi-stream cublas backend. // For now, we only have multi-stream cublas backend.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment