Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
......@@ -13,6 +13,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Optional
import warnings
import jax
import jax.numpy as jnp
from jax.interpreters import pxla
......@@ -130,7 +131,22 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
# We want to exclude the axes that already used by shard_map and shard_map
# only sets those in the abstract_mesh, not the physical one
manual_axis_names = get_abstract_mesh().manual_axes
cleaned_axis_names = tuple(name if name not in manual_axis_names else None for name in pspec)
# Multiple mesh axes can be mapped to a single shape axis, so we need to unpack and process tuples here too
def filter_manual_axes(name_or_tuple):
if isinstance(name_or_tuple, tuple):
out = tuple(n for n in name_or_tuple if n not in manual_axis_names)
if len(out) == 0:
return None
return out
if name_or_tuple in manual_axis_names:
return None
return name_or_tuple
cleaned_axis_names = tuple(filter_manual_axes(name_or_tuple) for name_or_tuple in pspec)
if cleaned_axis_names == (None,) * len(cleaned_axis_names):
return x
cleaned_pspec = PartitionSpec(*cleaned_axis_names)
return jax.lax.with_sharding_constraint(x, cleaned_pspec)
......@@ -349,6 +365,21 @@ def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
def all_reduce_sum_along_dp_fsdp_tpsp(x: jnp.array, mesh: jax.sharding.Mesh):
"""Perform all-reduce sum operation along data parallelism and sequence parallelism axes.
Args:
x: Input tensor to reduce
mesh: JAX mesh for distributed computation
Returns:
Reduced tensor
"""
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().tpsp_resource, mesh)
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
"""Perform all-reduce max operation along all axes except pipeline parallelism.
......@@ -364,3 +395,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes
if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x
def tpsp_axis_size():
"""
Get the size of the tensor parallelism axis.
Return 1 if no TP axis is set.
"""
return get_mesh_axis_size(global_mesh_resource().tpsp_resource)
def dp_or_fsdp_axis_size():
"""
Get the size of the data parallelism or FSDP axis.
Return 1 if no DP/FSDP axis is set.
"""
dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource)
fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource)
return dp_size if dp_size > 1 else fsdp_size
......@@ -46,8 +46,18 @@ from transformer_engine.pytorch.permutation import (
moe_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs,
)
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.fp8 import fp8_model_init
from transformer_engine.pytorch.quantization import fp8_autocast
from transformer_engine.pytorch.quantization import fp8_model_init
from transformer_engine.pytorch.quantization import autocast
from transformer_engine.pytorch.quantization import quantized_model_init
from transformer_engine.pytorch.quantization import is_fp8_available
from transformer_engine.pytorch.quantization import is_mxfp8_available
from transformer_engine.pytorch.quantization import is_fp8_block_scaling_available
from transformer_engine.pytorch.quantization import is_nvfp4_available
from transformer_engine.pytorch.quantization import get_default_recipe
from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.utils import is_bf16_available
from transformer_engine.pytorch.graph import make_graphed_callables
from transformer_engine.pytorch.distributed import checkpoint
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
......@@ -56,6 +66,24 @@ from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers
from transformer_engine.pytorch.export import onnx_export
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor import Float8Quantizer
from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor import NVFP4Quantizer
from transformer_engine.pytorch.tensor import QuantizedTensorStorage
from transformer_engine.pytorch.tensor import Float8TensorStorage
from transformer_engine.pytorch.tensor import MXFP8TensorStorage
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensorStorage
from transformer_engine.pytorch.tensor import NVFP4TensorStorage
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor import Float8Tensor
from transformer_engine.pytorch.tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor import NVFP4Tensor
from transformer_engine.pytorch.tensor import prepare_for_saving
from transformer_engine.pytorch.tensor import restore_from_saved
try:
torch._dynamo.config.error_on_nested_jit_trace = False
......
......@@ -13,21 +13,24 @@ import logging
from packaging.version import Version as PkgVersion
import torch
import torch.nn.functional as F
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import (
SplitAlongDim,
get_device_compute_capability,
combine_tensors,
split_tensor_along_dim,
)
from transformer_engine.pytorch.utils import attention_mask_func
from transformer_engine.pytorch.utils import attention_mask_func, nvtx_range_push, nvtx_range_pop
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorStorage,
prepare_for_saving,
restore_from_saved,
)
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.constants import (
TE_DType,
QKVLayouts,
......@@ -40,7 +43,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_O,
META_QKV,
)
from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype
from transformer_engine.pytorch.quantization import get_fp8_torch_dtype, FP8GlobalStateManager
from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.jit import no_torch_dynamo
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
......@@ -53,6 +56,9 @@ from transformer_engine.pytorch.attention.inference import InferenceParams
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
FlashAttentionUtils as fa_utils,
combine_and_quantize,
combine_and_dequantize,
print_quantizers,
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
AttentionLogging as attn_log,
......@@ -131,6 +137,58 @@ if not IS_HIP_EXTENSION:
fa_utils.set_flash_attention_3_params()
# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1"
class FP8EmulationFunc(torch.autograd.Function):
"""
Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows:
- forward : QKV (quantize+dequantize), P (pass-through), S (quantize+dequantize), O (pass-through)
- backward: dO (quantize+dequantize), dS (pass-through), dP (quantize+dequantize), dQKV (pass-through)
"""
@staticmethod
def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout):
# pylint: disable=missing-function-docstring
if quantizer_name == "QKV_quantizer":
query_layer, key_layer, value_layer = [
x.contiguous() for x in [tensor1, tensor2, tensor3]
]
q_fp8, k_fp8, v_fp8 = combine_and_quantize(
qkv_layout, query_layer, key_layer, value_layer, quantizer
)
tensors = combine_and_dequantize(
qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype
)
elif quantizer_name in ["S_quantizer", "O_quantizer"]:
t_fp8 = quantizer(tensor1)
tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3)
else:
tensors = (tensor1, tensor2, tensor3)
ctx.quantizer = quantizer
ctx.quantizer_name = quantizer_name
ctx.qkv_layout = qkv_layout
return tensors[0], tensors[1], tensors[2]
@staticmethod
def backward(ctx, grad1, grad2, grad3):
# pylint: disable=missing-function-docstring
if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]:
dt_fp8 = ctx.quantizer(grad1)
tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3
elif ctx.quantizer_name == "dQKV_quantizer":
query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]]
dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize(
ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer
)
tensors = combine_and_dequantize(
ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype
)
else:
tensors = grad1, grad2, grad3
return tensors[0], tensors[1], tensors[2], None, None, None
class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
......@@ -144,6 +202,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -151,6 +210,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_type = attention_type
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.softmax_type = softmax_type
def mask_func(x, y):
return (
......@@ -187,6 +247,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
quantizers=None,
fp8_output: bool = False,
) -> torch.Tensor:
"""Unfused attention fprop"""
assert (
......@@ -284,6 +349,35 @@ class UnfusedDotProductAttention(torch.nn.Module):
if apply_qk_layer_scaling:
scale /= self.layer_number
if fp8:
# get quantizers from DPA; all Nones if not fp8
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
dpa_utils.get_attention_quantizers(fp8, quantizers)
)
# S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
if fp8_recipe.float8_current_scaling():
S_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=S_quantizer.dtype, device="cuda"
)
dP_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=dP_quantizer.dtype, device="cuda"
)
if "2" in qkv_layout or "3" in qkv_layout:
qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout)
qkv_layout = "_".join([qkv_format] * 3)
# quantize and dequantize QKV to emulate FP8
query_layer, key_layer, value_layer = FP8EmulationFunc.apply(
query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout
)
# quantize and dequantize dQKV to emulate FP8
query_layer, key_layer, value_layer = FP8EmulationFunc.apply(
query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout
)
# Raw attention scores. [b * np, sq, sk]
if core_attention_bias_type == "no_bias":
matmul_result = torch.baddbmm(
......@@ -328,7 +422,27 @@ class UnfusedDotProductAttention(torch.nn.Module):
dtype=query_layer.dtype
)
# attention scores and attention mask [b, np, sq, sk]
if fp8:
# quantize and dequantize dP to emulate FP8
matmul_result, *_ = FP8EmulationFunc.apply(
matmul_result, None, None, dP_quantizer, "dP_quantizer", None
)
# add attention sink to the last column: [b, np, sq, sk+1]
if self.softmax_type != "vanilla":
matmul_result = torch.cat(
[
matmul_result,
softmax_offset.to(dtype=matmul_result.dtype).expand(
matmul_result.size(0), -1, matmul_result.size(2), -1
),
],
dim=-1,
)
attention_mask = F.pad(attention_mask, (0, 1), mode="constant", value=False)
attn_mask_type = "arbitrary"
# attention scores and attention mask
softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax(
matmul_result, attention_mask, attn_mask_type, softmax_scale
......@@ -339,6 +453,10 @@ class UnfusedDotProductAttention(torch.nn.Module):
if "padding" in attn_mask_type:
attention_probs = attention_probs.masked_fill(attention_mask, 0)
# remove attention sink: [b, np, sq, sk]
if self.softmax_type != "vanilla":
attention_probs = attention_probs[..., :-1]
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx():
......@@ -359,6 +477,12 @@ class UnfusedDotProductAttention(torch.nn.Module):
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
if fp8:
# quantize and dequantize S to emulate FP8
attention_probs, *_ = FP8EmulationFunc.apply(
attention_probs, None, None, S_quantizer, "S_quantizer", None
)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
......@@ -393,6 +517,20 @@ class UnfusedDotProductAttention(torch.nn.Module):
# [tq, np, hn] --> [tq, hp]
context_layer = context_layer.view(total_tokens, -1)
if fp8:
# quantize and dequantize O to emulate FP8
context_layer, *_ = FP8EmulationFunc.apply(
context_layer, None, None, O_quantizer, "O_quantizer", None
)
# quantize and dequantize dO to emulate FP8
context_layer, *_ = FP8EmulationFunc.apply(
context_layer, None, None, dO_quantizer, "dO_quantizer", None
)
# quantize O
if fp8_output:
context_layer = O_quantizer(context_layer)
return context_layer
......@@ -491,6 +629,7 @@ class FlashAttention(torch.nn.Module):
quantizers=None,
inference_params: Optional[InferenceParams] = None,
flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
fp8_output: bool = False,
) -> torch.Tensor:
"""flash-attn fprop"""
......@@ -696,6 +835,7 @@ class FlashAttention(torch.nn.Module):
quantizers=quantizers,
pad_between_seqs=False,
use_flash_attn_3=use_flash_attn_3,
fp8_output=fp8_output,
)
else:
from transformer_engine.pytorch.cpu_offload import (
......@@ -795,8 +935,6 @@ class FlashAttention(torch.nn.Module):
)
return out
# "fp8_mha" decides outputs in fp8, while inputs are inferred from
# the real dtype
assert isinstance(key_layer, query_layer.__class__) and isinstance(
value_layer, query_layer.__class__
), "q, k, and v must have the same type."
......@@ -843,7 +981,7 @@ class FlashAttention(torch.nn.Module):
if fp8:
output = output.to(dtype=torch_orig_dtype)
if fp8 and fp8_meta["recipe"].fp8_mha:
if fp8 and fp8_output:
O_quantizer = quantizers["scaling_fwd"][META_O]
output = O_quantizer(output)
......@@ -871,7 +1009,7 @@ class FlashAttention(torch.nn.Module):
if q_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd)
if fp8 and fp8_meta["recipe"].fp8_mha:
if fp8 and fp8_output:
output_data = (
output._data.reshape(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1)
......@@ -895,7 +1033,7 @@ class FlashAttention(torch.nn.Module):
class FusedAttnFunc(torch.autograd.Function):
"""Function for FusedAttention with separate Q, K, V tensors"""
"""FusedAttention forward and backward implementation"""
@staticmethod
def forward(
......@@ -919,6 +1057,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
window_size,
rng_gen,
fused_attention_backend,
......@@ -927,55 +1066,72 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta,
quantizers,
deterministic,
softmax_offset,
fp8_output,
layer_number,
):
# pylint: disable=missing-function-docstring
# "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e4m3fn
fake_dtype = q.dtype
# add NVTX range
nvtx_label = "transformer_engine.FusedAttnFunc.forward"
nvtx_range_push(f"{nvtx_label}")
# recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"]
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
# input types are inferred from the real data while output types are controlled by fp8_output
# fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha)
assert isinstance(k, q.__class__) and isinstance(
v, q.__class__
), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor."
is_input_fp8 = isinstance(q, Float8Tensor)
is_output_fp8 = fp8_output
# whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa)
# whether bwd kernel in FP8:
is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
# get quantizers from DPA; all Nones if not fp8
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False)
dpa_utils.get_attention_quantizers(fp8, quantizers)
)
# get nominal data type for out
# FP16/BF16 attention: torch.float16 or torch.bfloat16
# FP8 attention: torch.float16 or torch.bfloat16
out_nominal_dtype = q.dtype
if fp8:
fused_attention_backend = FusedAttnBackend["FP8"]
assert isinstance(k, q.__class__) and isinstance(
v, q.__class__
), "q, k, and v must have the same type."
is_input_fp8 = isinstance(q, Float8Tensor)
q_fp8, k_fp8, v_fp8 = None, None, None
# q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16
# q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
if is_input_fp8:
q_fp8, k_fp8, v_fp8 = q, k, v
else:
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_"))
match qkv_group:
case 1:
dim = qkv_layout.find("3")
qkv = combine_tensors([q, k, v], dim)
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv_fp8 = QKV_quantizer(qkv)
q_fp8, k_fp8, v_fp8 = SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True)
case 2:
q_fp8 = QKV_quantizer(q)
dim = qkv_layout.split("_")[1].find("2")
kv = combine_tensors([k, v], dim)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_fp8 = QKV_quantizer(kv_c)
k_fp8, v_fp8 = SplitAlongDim.apply(kv_fp8, dim, [1, 1], True)
case 3:
q_fp8 = QKV_quantizer(q)
k_fp8 = QKV_quantizer(k)
v_fp8 = QKV_quantizer(v)
case _:
raise "Invalid qkv_layout " + qkv_layout
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
out_fp8, aux_ctx_tensors = fused_attn_fwd(
q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer)
# print quantizers
print_quantizers(
"FusedAttnFunc.forward >> before: ",
layer_number,
QKV_quantizer,
O_quantizer,
S_quantizer,
dQKV_quantizer,
dO_quantizer,
dP_quantizer,
)
# out_:
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_, aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
......@@ -984,7 +1140,7 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8,
k_fp8,
v_fp8,
fake_dtype,
out_nominal_dtype,
fused_attention_backend,
attn_bias,
cu_seqlens_q_padded,
......@@ -999,45 +1155,59 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
window_size,
rng_gen,
softmax_offset,
)
if is_output_fp8:
out_ret = out_fp8
# out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# out: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_fp8 = out_
out = out_
if isinstance(out_, Float8Tensor):
if not is_output_fp8 or not is_bwd_fp8:
out = out_.dequantize().view(out_.shape)
else:
out_ret = out_fp8.dequantize().view(out_fp8.shape)
# is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16
# is_output_fp8 = True: out_save.dtype = torch.float8_e4m3fn
out_save = out_ret
if is_output_fp8 or (
is_bwd_fp8
and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)
):
out_fp8 = O_quantizer(out_)
if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
# 1: qkv packed, 2: kv packed, 3: qkv separate
if is_input_fp8:
qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_"))
if qkv_group == 1:
dim = qkv_layout.find("3")
qkv = combine_tensors([q, k, v], dim)
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape)
q, k, v = SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True)
if qkv_group == 2:
q = q.dequantize()
dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2")
kv = combine_tensors([k, v], dim)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_no_fp8 = kv.dequantize()
k, v = SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True)
if qkv_group == 3:
q = q.dequantize()
k = k.dequantize()
v = v.dequantize()
if is_output_fp8:
out_save = out_fp8.dequantize()
# print quantizers
print_quantizers(
"FusedAttnFunc.forward >> after: ",
layer_number,
QKV_quantizer,
O_quantizer,
S_quantizer,
dQKV_quantizer,
dO_quantizer,
dP_quantizer,
)
# return appropriate tensors
out_ret = out_fp8 if is_output_fp8 else out
# save appropriate tensors
fp8_tensors = (None, None, None, None)
qkvo_tensors = (None, None, None, None)
if is_bwd_fp8:
if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16:
fp8_tensors = (q_fp8, k_fp8, v_fp8, None)
qkvo_tensors = (None, None, None, out)
else:
fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8)
else:
# q, k, v, out_ret: torch.float16 or torch.bfloat16
out_ret, aux_ctx_tensors = fused_attn_fwd(
if is_input_fp8:
q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8)
qkvo_tensors = (q, k, v, out)
else:
# q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_, aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
......@@ -1046,7 +1216,7 @@ class FusedAttnFunc(torch.autograd.Function):
q,
k,
v,
fake_dtype,
out_nominal_dtype,
fused_attention_backend,
attn_bias,
cu_seqlens_q_padded,
......@@ -1061,13 +1231,23 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_layout,
attn_bias_type,
attn_mask_type,
softmax_type,
window_size,
rng_gen,
softmax_offset,
)
out_save = out_ret
out = out_
out_ret = out_
fp8_tensors = (None, None, None, None)
qkvo_tensors = (q, k, v, out)
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
nvtx_range_pop(f"{nvtx_label}")
ctx.fp8_recipe = fp8_recipe
ctx.fp8 = is_bwd_fp8
# assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16
# used when some tensors are base tensors and loose the "dtype" attribute
ctx.nominal_dtype = out_nominal_dtype
from transformer_engine.pytorch.cpu_offload import (
CPUOffloadEnabled,
......@@ -1078,15 +1258,13 @@ class FusedAttnFunc(torch.autograd.Function):
if ctx.fp8:
tensor_list = fp8_tensors
else:
tensor_list = [q, k, v, out_save]
tensor_list = [q, k, v, out]
qkv_layout = "sbhd_sbhd_sbhd"
mark_activation_offload(*tensor_list)
mark_activation_offload(*aux_ctx_tensors)
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
tensors_to_save, tensor_objects = prepare_for_saving(
*fp8_tensors,
*qkvo_tensors,
......@@ -1100,11 +1278,14 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.tensor_objects = tensor_objects
ctx.fp8_meta = fp8_meta
ctx.layer_number = layer_number
ctx.QKV_quantizer = QKV_quantizer
ctx.O_quantizer = O_quantizer
ctx.dQKV_quantizer = dQKV_quantizer
ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer
ctx.S_quantizer = S_quantizer
if ctx.fp8:
if ctx.fp8 and isinstance(ctx.S_quantizer, Float8Quantizer):
ctx.S_quantizer = S_quantizer.copy()
ctx.S_quantizer.scale = S_quantizer.scale.clone()
......@@ -1113,9 +1294,34 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill
from transformer_engine.pytorch.cpu_offload import (
CPUOffloadedLayer,
)
# If interleaved tensor is offloaded, reloaded tensor will be
# non-interleaved, so we need to modify the QKV layout
# for backward
if CPUOffloadedLayer and CPUOffloadEnabled:
reload_layout = ""
split_list = qkv_layout.split("_")
for split in split_list:
temp_layout = ""
rep_count = 1
for s in split:
if s.isalpha():
temp_layout = temp_layout + s
else:
rep_count = int(s)
for _ in range(rep_count):
reload_layout = reload_layout + temp_layout + "_"
ctx.qkv_layout = reload_layout[:-1]
else:
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.softmax_type = softmax_type
ctx.window_size = window_size
ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
......@@ -1128,16 +1334,14 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod
def backward(ctx, d_out):
# pylint: disable=missing-function-docstring
if ctx.is_output_fp8:
assert isinstance(
d_out, Float8Tensor
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
# FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16
# FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e5m2
fake_dtype = d_out.dtype
# d_out is expected to be in FP8 if is_output_fp8=True,
# but in the case it's not, convert it to FP8 before any operation
if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage):
d_out = ctx.dO_quantizer(d_out)
if not ctx.use_FAv2_bwd:
d_out._data = d_out._data.contiguous()
elif not ctx.use_FAv2_bwd:
d_out = d_out.contiguous()
(
q_fp8,
......@@ -1192,16 +1396,55 @@ class FusedAttnFunc(torch.autograd.Function):
dk = dk[..., : d_out.shape[-1]]
dv = dv[..., : d_out.shape[-1]]
else:
with torch.cuda.nvtx.range("_FusedAttn"):
with torch.cuda.nvtx.range("FusedAttnFunc.backward"):
# get nominal data type of dq, dk, dv
# FP16/BF16 attention: torch.float16 or torch.bfloat16
# FP8 attention: torch.float16 or torch.bfloat16
dqkv_nominal_dtype = ctx.nominal_dtype
if ctx.fp8:
# d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16
# d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
if ctx.is_output_fp8:
d_out_fp8 = d_out
else:
d_out_fp8 = ctx.dO_quantizer(d_out)
dqkv_dtype = TE_DType[d_out_fp8._data.dtype]
# q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn
# d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2
dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd(
# print quantizers
print_quantizers(
"FusedAttnFunc.backward >> before: ",
ctx.layer_number,
ctx.QKV_quantizer,
ctx.O_quantizer,
ctx.S_quantizer,
ctx.dQKV_quantizer,
ctx.dO_quantizer,
ctx.dP_quantizer,
)
# get tex.DType for dq, dk, dv data
dqkv_te_dtype = d_out_fp8._fp8_dtype
# q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16,
# fp8_dtype = tex.DType.kFloat8E4M3
# d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
# out_:
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
#
# dq_, dk_, dv_:
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E5M2
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_ = (
out
if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16
else out_fp8
)
dq_, dk_, dv_, *rest = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
cu_seqlens_q,
......@@ -1209,10 +1452,10 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8,
k_fp8,
v_fp8,
out_fp8,
out_,
d_out_fp8,
fake_dtype,
dqkv_dtype,
dqkv_nominal_dtype,
dqkv_te_dtype,
aux_ctx_tensors,
ctx.fused_attention_backend,
cu_seqlens_q_padded,
......@@ -1226,44 +1469,45 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size,
ctx.deterministic,
)
# is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16
# is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2
if not ctx.is_input_fp8:
qkv_group = len(ctx.qkv_layout.replace("paged_kv_", "").split("_"))
if qkv_group == 1:
dim = ctx.qkv_layout.find("3")
dqkv_fp8_data = combine_tensors(
[dq_fp8._data, dk_fp8._data, dv_fp8._data], dim
)
dqkv_fp8 = dq_fp8.make_like(
tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape
)
dqkv = dqkv_fp8.dequantize()
dq, dk, dv = SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True)
if qkv_group == 2:
dq = dq_fp8.dequantize()
dim = ctx.qkv_layout.split("_")[1].find("2")
dkv_fp8 = combine_tensors([dk_fp8, dv_fp8], dim)
dkv_c_fp8 = dkv_fp8.view(
-1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]
)
dkv = dkv_c_fp8.dequantize()
dk, dv = SplitAlongDim.apply(dkv, dim, [1, 1], True)
if qkv_group == 3:
dq = dq_fp8.dequantize()
dk = dk_fp8.dequantize()
dv = dv_fp8.dequantize()
else:
dq, dk, dv = dq_fp8, dk_fp8, dv_fp8
# dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16
dq, dk, dv = dq_, dk_, dv_
is_float8tensor = isinstance(dq_, Float8Tensor)
if is_float8tensor and not ctx.is_input_fp8:
# return in F16
dq, dk, dv = combine_and_dequantize(
ctx.qkv_layout,
dq_,
dk_,
dv_,
src_nominal_dtype=dq_.dtype,
)
if not is_float8tensor and ctx.is_input_fp8:
# return in FP8
dq, dk, dv = combine_and_quantize(
ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer
)
# print quantizers
print_quantizers(
"FusedAttnFunc.backward >> after: ",
ctx.layer_number,
ctx.QKV_quantizer,
ctx.O_quantizer,
ctx.S_quantizer,
ctx.dQKV_quantizer,
ctx.dO_quantizer,
ctx.dP_quantizer,
)
else:
if isinstance(d_out, QuantizedTensor):
d_out = d_out.dequantize()
dqkv_dtype = TE_DType[d_out.dtype]
# q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16
if isinstance(d_out, QuantizedTensorStorage):
d_out = d_out.dequantize(dtype=ctx.nominal_dtype)
dqkv_te_dtype = TE_DType[d_out.dtype]
# q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
......@@ -1274,8 +1518,8 @@ class FusedAttnFunc(torch.autograd.Function):
v,
out,
d_out,
fake_dtype,
dqkv_dtype,
dqkv_nominal_dtype,
dqkv_te_dtype,
aux_ctx_tensors,
ctx.fused_attention_backend,
cu_seqlens_q_padded,
......@@ -1289,12 +1533,17 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_layout,
ctx.attn_bias_type,
ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size,
ctx.deterministic,
)
# if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]:
d_bias = None
if ctx.attn_bias_type not in ["no_bias", "alibi"]:
d_bias = rest[0]
d_softmax_offset = None
if ctx.softmax_type != "vanilla":
d_softmax_offset = rest[1]
return (
None,
None,
......@@ -1308,6 +1557,7 @@ class FusedAttnFunc(torch.autograd.Function):
dq,
dk,
dv,
d_bias,
None,
None,
None,
......@@ -1323,34 +1573,7 @@ class FusedAttnFunc(torch.autograd.Function):
None,
None,
None,
)
# else, return (dqkv, dbias)
return (
None,
None,
None,
None,
None,
None,
None,
None,
None,
dq,
dk,
dv,
rest[0],
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
d_softmax_offset,
None,
None,
)
......@@ -1392,6 +1615,7 @@ class FusedAttention(torch.nn.Module):
attention_type: str = "self",
layer_number: Optional[int] = None,
deterministic: bool = False,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -1404,6 +1628,7 @@ class FusedAttention(torch.nn.Module):
) == "1" and get_device_compute_capability() == (9, 0)
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
self.softmax_type = softmax_type
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
......@@ -1455,6 +1680,8 @@ class FusedAttention(torch.nn.Module):
quantizers=None,
pad_between_seqs: bool = False,
inference_params: Optional[InferenceParams] = None,
softmax_offset: torch.Tensor = None,
fp8_output: bool = False,
) -> torch.Tensor:
"""fused attention fprop"""
assert (
......@@ -1555,14 +1782,26 @@ class FusedAttention(torch.nn.Module):
)
if fp8:
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
" is required for FP8 attention!"
)
assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
assert not context_parallel or fp8_meta["recipe"].reduce_amax, (
"Amax reduction across TP+CP group is necessary when using context parallelism with"
" FP8!"
if fp8_recipe.delayed():
assert not context_parallel or fp8_recipe.reduce_amax, (
"Amax reduction across TP+CP group is necessary when using context parallelism"
" with FP8!"
)
if fp8_recipe.float8_current_scaling() and context_parallel:
all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers)
for q in all_quantizers:
if isinstance(q, Float8CurrentScalingQuantizer):
q.with_amax_reduction = True
q.amax_reduction_group = (
cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group
)
if context_parallel:
......@@ -1605,6 +1844,10 @@ class FusedAttention(torch.nn.Module):
fp8_meta=fp8_meta,
quantizers=quantizers,
pad_between_seqs=pad_between_seqs,
softmax_type=self.softmax_type,
softmax_offset=softmax_offset,
fp8_output=fp8_output,
layer_number=self.layer_number,
)
else:
with self.attention_dropout_ctx():
......@@ -1628,6 +1871,7 @@ class FusedAttention(torch.nn.Module):
qkv_layout,
core_attention_bias_type,
attn_mask_type,
self.softmax_type,
window_size,
None, # rng_gen
fused_attention_backend,
......@@ -1636,6 +1880,9 @@ class FusedAttention(torch.nn.Module):
fp8_meta,
quantizers,
self.deterministic,
softmax_offset,
fp8_output,
self.layer_number,
)
# ...hd -> ...(hd)
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -11,11 +11,26 @@ import warnings
import logging
import torch
from torch.nn.parameter import Parameter
import transformer_engine_torch as tex
from transformer_engine.common.recipe import (
Format,
Recipe,
DelayedScaling,
Float8CurrentScaling,
)
from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.quantization import (
get_fp8_te_dtype,
FP8GlobalStateManager,
RecipeState,
DelayedScalingRecipeState,
MXFP8BlockScalingRecipeState,
Float8CurrentScalingRecipeState,
Float8BlockScalingRecipeState,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.constants import (
......@@ -72,6 +87,67 @@ _alibi_cache = {
"_alibi_bias_require_update": False,
}
"""
This feature is **experimental** and subject to change.
Some models may use different FP8 recipes for their linear layers and attention layers. To support this,
users can either use multiple, nested autocast() contexts to assign a distinct recipe for each layer,
or use a single autocast() for the non-attention layers and configure the recipe for the attention
layers as follows.
+-------------------+-----------+-----------------------------------------------------------------------------------+
| Linear | Attention | Configuration |
+===================+===========+===================================================================================+
| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to autocast(); |
| | | export NVTE_DPA_FP8_RECIPE="F16" |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8DS | FP8DS | Pass FP8DS to autocast(); |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8CS | FP8DS | Pass FP8CS to autocast(); |
| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; |
| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| NVFP4 | FP8DS | Pass NVFP4 to autocast(); |
| | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; |
| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS |
| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8DS | FP8CS | Pass FP8DS to autocast(); |
| | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,|
| | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; |
| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| FP8CS | FP8CS | Pass FP8CS to autocast(); |
| | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe |
| | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
| NVFP4 | FP8CS | Pass NVFP4 to autocast(); |
| | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe |
| | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: |
| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS |
| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" |
| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" |
| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer |
| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 |
+-------------------+-----------+-----------------------------------------------------------------------------------+
"""
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2}
_dpa_fp8_format = formats[os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID")]
_dpa_fp8ds_amax_algo = os.getenv("NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent")
_dpa_fp8ds_amax_histlen = int(os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1"))
_dpa_fp8ds_reduce_amax = os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") == "1"
__all__ = ["DotProductAttention"]
......@@ -168,6 +244,17 @@ class DotProductAttention(TransformerEngineBaseModule):
softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters
----------------------
......@@ -223,6 +310,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -307,6 +395,20 @@ class DotProductAttention(TransformerEngineBaseModule):
self.attention_type = attention_type
self.attention_dropout = attention_dropout
self.softmax_type = softmax_type
if self.softmax_type == "vanilla":
self.softmax_offset = None
if self.softmax_type == "off-by-one":
self.softmax_offset = torch.zeros(
self.num_attention_heads // self.tp_size, device="cuda"
)
if self.softmax_type == "learnable":
self.register_parameter(
"softmax_offset",
Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")),
get_rng_state_tracker=get_rng_state_tracker,
)
attn_kwargs = {
"attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx,
......@@ -328,6 +430,7 @@ class DotProductAttention(TransformerEngineBaseModule):
layer_number=layer_number,
deterministic=self.deterministic,
**attn_kwargs,
softmax_type=self.softmax_type,
)
self.unfused_attention = UnfusedDotProductAttention(
......@@ -335,6 +438,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_type=attention_type,
**attn_kwargs,
layer_number=layer_number,
softmax_type=self.softmax_type,
)
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
......@@ -433,6 +537,234 @@ class DotProductAttention(TransformerEngineBaseModule):
self.cp_stream = cp_stream
self.cp_comm_type = cp_comm_type
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""
Override TransformerEngineBaseModule.init_fp8_metadata to allow for more flexible recipe support.
Initialize fp8 related metadata and tensors during fprop.
"""
_original_recipe = self.fp8_meta.get("recipe", None)
# global recipe set in autocast()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_recipe.custom():
return
# switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to
# a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well.
#
# fp8_recipe | NVTE_DPA_FP8_RECIPE | self.fp8_meta["recipe"] | self.quantizers
# --------------------------------------------------------------------------------------------
# DelayedScaling (DS) | unset | DS | all DS
# Float8CurrentScaling (CS) | unset | DS | CS for QKV, O, dO, dQKV; DS for S, dP
# x={DS, CS} | y | refer to row x=y | refer to row x=y
fp8_recipe_dpa = fp8_recipe
fp8_recipes = fp8_recipe
if _dpa_fp8_recipe == "F16":
# ignore the recipe from autocast, set fp8_dpa = False, fp8_mha = False
fp8_recipe.fp8_dpa = False
fp8_recipe.fp8_mha = False
elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling":
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe
fake_recipe = DelayedScaling(
fp8_format=fp8_recipe.fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
)
fp8_recipe_dpa = fake_recipe
fp8_recipes = fp8_recipe_dpa
elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "DelayedScaling":
# reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a DS recipe
fake_recipe = DelayedScaling(
fp8_format=_dpa_fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
)
fp8_recipe_dpa = fake_recipe
fp8_recipes = fp8_recipe_dpa
elif fp8_recipe.delayed() and _dpa_fp8_recipe == "Float8CurrentScaling":
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a CS+DS recipe
fake_recipes = [
Float8CurrentScaling(
fp8_format=fp8_recipe.fp8_format,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
),
fp8_recipe,
]
fp8_recipe_dpa = fake_recipes[1]
fp8_recipes = fake_recipes
elif (
fp8_recipe.float8_current_scaling()
and _dpa_fp8_recipe in ("", "Float8CurrentScaling")
and (fp8_recipe.fp8_dpa or fp8_recipe.fp8_mha)
):
# use fp8_recipe for QKV, O, dO, dQKV, and construct a DS recipe for S, dP
# reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe
fake_recipe = DelayedScaling(
fp8_format=fp8_recipe.fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
)
fp8_recipe_dpa = fake_recipe
fp8_recipes = [fp8_recipe, fp8_recipe_dpa]
elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "Float8CurrentScaling":
# reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format
# construct a CS recipe for QKV, O, dO, dQKV and a DS recipe for S, dP
fake_recipes = [
Float8CurrentScaling(
fp8_format=_dpa_fp8_format,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
),
DelayedScaling(
fp8_format=_dpa_fp8_format,
amax_history_len=_dpa_fp8ds_amax_histlen,
amax_compute_algo=_dpa_fp8ds_amax_algo,
fp8_dpa=fp8_recipe.fp8_dpa,
fp8_mha=fp8_recipe.fp8_mha,
reduce_amax=_dpa_fp8ds_reduce_amax,
),
]
fp8_recipe_dpa = fake_recipes[1]
fp8_recipes = fake_recipes
# DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False
if not fp8_recipe_dpa.float8_per_tensor_scaling():
assert not (
fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha
), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe"
# reduce over TP+CP groups; expect fp8_group to be set up so
# assume attention uses the same fp8_group as GEMMs
fp8_group = FP8GlobalStateManager.get_fp8_group()
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters or fp8_enabled:
self.fp8_meta["global_recipe"] = fp8_recipe
self.fp8_meta["local_recipes"] = (
fp8_recipes if isinstance(fp8_recipes, List) else [fp8_recipes]
)
if self.fp8_parameters or fp8_enabled:
if self.fp8_initialized and fp8_recipe_dpa == self.fp8_meta["recipe"]:
# FP8 init has already been run and recipe is the same, don't do anything.
return
self.fp8_meta["recipe"] = fp8_recipe_dpa
if fp8_recipe != fp8_recipe_dpa:
# fp8_recipe has changed, rehash the key.
autocast_key = FP8GlobalStateManager.get_unique_autocast_key(
fp8_recipe_dpa, fp8_group
)
FP8GlobalStateManager.autocast_arguments[autocast_key] = (
fp8_recipe_dpa,
fp8_group,
)
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
return
if self.fp8_parameters and not self.fp8_initialized:
self.fp8_meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors(fp8_recipes)
if fp8_enabled:
# Set FP8 and other FP8 metadata
self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = fp8_group
# Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes
self.init_fp8_meta_tensors(fp8_recipes)
self.fp8_initialized = True
self.fp8_meta["recipe"] = fp8_recipe_dpa
if fp8_recipe != fp8_recipe_dpa:
# fp8_recipe has changed, rehash the key.
autocast_key = FP8GlobalStateManager.get_unique_autocast_key(
fp8_recipe_dpa, fp8_group
)
FP8GlobalStateManager.autocast_arguments[autocast_key] = (
fp8_recipe_dpa,
fp8_group,
)
_current_recipe = self.fp8_meta["recipe"]
if _original_recipe is not None and not (
issubclass(_current_recipe.__class__, _original_recipe.__class__)
or issubclass(_original_recipe.__class__, _current_recipe.__class__)
):
warnings.warn(
f"Recipe type changed from {_original_recipe.__class__.__name__} "
f"to {_current_recipe.__class__.__name__}. "
"This may affect model behavior."
)
# Clear cached workspaces as they were created with the old recipe/quantizer type
self._fp8_workspaces.clear()
def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> None:
"""Override to allow multiple recipes. Init scales and amaxes for fwd | bwd."""
if isinstance(recipe, Recipe):
recipe = [recipe]
fp8_recipe_dpa = recipe[-1]
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
# Return early if recipe state matches recipe
if self.fp8_meta_tensors_initialized:
recipe_state = self.fp8_meta[fp8_meta_tensor_key]
if fp8_recipe_dpa.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
self.adjust_amax_history_length(fp8_recipe_dpa.amax_history_len, fwd=fwd)
return
if fp8_recipe_dpa.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
return
if fp8_recipe_dpa.float8_current_scaling() and isinstance(
recipe_state, Float8CurrentScalingRecipeState
):
return
if fp8_recipe_dpa.float8_block_scaling() and isinstance(
recipe_state, Float8BlockScalingRecipeState
):
return
# When fp8_recipe=Float8CurrentScaling, recipe=[CS, DS], and QKV/dQKV, O/dO use CS quantizers, S/dP use DS quantizers.
# See table above in init_fp8_metadata for more detail.
num_gemms = [2, 1] if len(recipe) == 2 else [3]
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors = [x * 3 if fwd else x * 2 for x in num_gemms]
# Initialize recipe state and quantizers
recipe_states = [
RecipeState.create(
recipe[i],
mode=("forward" if fwd else "backward"),
num_quantizers=num_fp8_tensors[i],
)
for i in range(len(recipe))
]
self.fp8_meta[fp8_meta_tensor_key] = (
recipe_states[-1] if len(recipe) == 2 else recipe_states[0]
)
self.quantizers[fp8_meta_tensor_key] = []
for recipe_state in recipe_states:
self.quantizers[fp8_meta_tensor_key].extend(recipe_state.make_quantizers())
@no_torch_dynamo(recursive=False)
def forward(
self,
......@@ -456,6 +788,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fast_zero_fill: bool = True,
inference_params: Optional[InferenceParams] = None,
pad_between_seqs: Optional[bool] = None,
fp8_output: Optional[bool] = False,
) -> torch.Tensor:
"""
Dot Product Attention Layer.
......@@ -628,12 +961,15 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
fp8_output: Optional[bool], default = `False`
Whether to enforce output to be in FP8 or not.
"""
with torch.cuda.device(query_layer.device), self.prepare_forward(
query_layer,
num_gemms=3,
allow_non_contiguous=True,
allow_different_data_and_param_types=self.softmax_type != "vanilla",
) as query_layer:
# checks for RNG
if self.rng_states_tracker is not None and is_graph_capturing():
......@@ -663,6 +999,8 @@ class DotProductAttention(TransformerEngineBaseModule):
tex.DType.kFloat8E4M3,
tex.DType.kFloat8E5M2,
], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
else:
fp8_output = False
# checks for q/k/v shapes
assert (
......@@ -922,6 +1260,7 @@ class DotProductAttention(TransformerEngineBaseModule):
False
), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes"
# check if there is padding between sequences when qkv_format='thd'
if pad_between_seqs is None:
if qkv_format == "thd":
pad_between_seqs = (
......@@ -957,11 +1296,13 @@ class DotProductAttention(TransformerEngineBaseModule):
pad_between_seqs=pad_between_seqs,
attention_dropout=self.attention_dropout,
context_parallel=context_parallel,
cp_comm_type=self.cp_comm_type,
deterministic=self.deterministic,
is_training=self.training,
fp8=self.fp8,
fp8_meta=self.fp8_meta,
inference_params=inference_params,
softmax_type=self.softmax_type,
)
global _attention_backends
if is_in_onnx_export_mode():
......@@ -1022,6 +1363,12 @@ class DotProductAttention(TransformerEngineBaseModule):
)
# run attention
softmax_offset = (
self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32)
if self.softmax_offset is not None
else None
)
if use_flash_attention:
if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi(
......@@ -1053,6 +1400,7 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers,
inference_params=inference_params,
flash_attention_backend=flash_attention_backend,
fp8_output=fp8_output,
)
if use_fused_attention:
......@@ -1071,7 +1419,6 @@ class DotProductAttention(TransformerEngineBaseModule):
bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
)
# checkpoint_core_attention=False
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.fused_attention,
......@@ -1101,6 +1448,8 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
softmax_offset=softmax_offset,
fp8_output=fp8_output,
)
return self.fused_attention(
query_layer,
......@@ -1129,6 +1478,8 @@ class DotProductAttention(TransformerEngineBaseModule):
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
softmax_offset=softmax_offset,
fp8_output=fp8_output,
)
from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled
......@@ -1140,6 +1491,7 @@ class DotProductAttention(TransformerEngineBaseModule):
)
if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
......@@ -1157,6 +1509,11 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
inference_params=inference_params,
softmax_offset=softmax_offset,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
fp8_output=fp8_output,
)
return self.unfused_attention(
_alibi_cache,
......@@ -1173,5 +1530,10 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
inference_params=inference_params,
softmax_offset=softmax_offset,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
fp8_output=fp8_output,
)
return None
......@@ -17,6 +17,7 @@ import numpy as np
from packaging.version import Version as PkgVersion
import torch
import torch.distributed as dist
import torch.nn.functional as F
import transformer_engine_torch as tex
import transformer_engine as te
......@@ -24,6 +25,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
QKVLayout,
AttnBiasType,
AttnMaskType,
SoftmaxType,
FusedAttnBackend,
META_QKV,
META_DQKV,
......@@ -31,18 +33,22 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_DO,
META_S,
META_DP,
META_O_CP,
META_DQKV_CP,
)
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.quantization import get_fp8_te_dtype
from transformer_engine.pytorch.constants import TE_DType
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
SplitAlongDim,
combine_tensors,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
......@@ -53,6 +59,9 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
# print quantizer info for a particular layer on a particular rank
_print_layer = int(os.getenv("NVTE_PRINT_LAYER_NUMBER", "1"))
_print_rank = int(os.getenv("NVTE_PRINT_RANK", "0"))
_cu_seqlens_cache = {}
......@@ -206,16 +215,20 @@ class AttentionParams:
Attention dropout.
context_parallel: bool, default = `False`
Whether context parallelism is used or not.
cp_comm_type: str, default = "p2p"
The communication type of context parallelism.
deterministic: bool, default = `False`
Whether to run `DotProductAttention` with determinism or not.
is_training: bool, default = `True`
Whether in training mode (`True`) or inference mode (`False`)
fp8: bool, default = `False`
Whether `DotProductAttention` is in an `fp8_autocast` region.
Whether `DotProductAttention` is in an `autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None`
The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None`
Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details.
"""
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
......@@ -237,11 +250,13 @@ class AttentionParams:
pad_between_seqs: bool = False
attention_dropout: float = 0.0
context_parallel: bool = False
cp_comm_type: str = "p2p"
deterministic: bool = False
is_training: bool = True
fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None
inference_params: Optional[InferenceParams] = None
softmax_type: str = "vanilla"
def __eq__(self, other):
"""
......@@ -308,11 +323,13 @@ def get_attention_backend(
pad_between_seqs = attention_params.pad_between_seqs
attention_dropout = attention_params.attention_dropout
context_parallel = attention_params.context_parallel
cp_comm_type = attention_params.cp_comm_type
deterministic = attention_params.deterministic
is_training = attention_params.is_training
fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta
inference_params = attention_params.inference_params
softmax_type = attention_params.softmax_type
# Run config
logger = logging.getLogger("DotProductAttention")
......@@ -341,8 +358,31 @@ def get_attention_backend(
field.name: getattr(attention_params, field.name) for field in fields(attention_params)
}
run_config.update(attention_params_dict)
# Add FP8 environment variables to config
if fp8:
# all FP8 recipes: 1: (FP8 fwd, FP8 bwd), 0: (FP8 fwd, F16 bwd)
run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
# Float8CurrentScaling: 1: use F16 O in bwd, 0: use FP8 O in bwd
run_config["NVTE_DPA_FP8CS_O_in_F16"] = int(os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1"))
# switch recipe to "F16", "DelayedScaling", or "Float8CurrentScaling"
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
run_config["NVTE_DPA_FP8_RECIPE"] = _dpa_fp8_recipe
if _dpa_fp8_recipe != "":
# config new recipe if switched
run_config["NVTE_DPA_FP8_FORMAT"] = os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID")
run_config["NVTE_DPA_FP8DS_AMAX_ALGO"] = os.getenv(
"NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent"
)
run_config["NVTE_DPA_FP8DS_AMAX_HISTLEN"] = int(
os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1")
)
run_config["NVTE_DPA_FP8DS_REDUCE_AMAX"] = int(
os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1")
)
# UnfusedDotProductAttention: 1: allow FP8 emulation, 0: do not allow
run_config["NVTE_UnfusedDPA_Emulate_FP8"] = int(
os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0")
)
logger.debug("Running with config=%s", run_config)
# The following sections check if `FlashAttention` supports the provided attention params,
......@@ -422,8 +462,20 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for FP8 training")
use_flash_attention_3 = False
if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
if not allow_emulation:
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention")
use_unfused_attention = False
fp8_recipe = fp8_meta["recipe"]
if fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
if use_fused_attention and fp8_recipe.float8_current_scaling():
if device_compute_capability < (10, 0):
logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100")
use_fused_attention = False
elif cudnn_version < (9, 14, 0):
logger.debug("Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0")
use_fused_attention = False
# TODO: rocm fused attention backends does not support fp8 yet
if IS_HIP_EXTENSION and use_fused_attention:
logger.debug("Disabling ROCm FusedAttention as it does not support FP8")
......@@ -581,6 +633,51 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for dropout")
use_flash_attention_3 = False
# Filter: Softmax type
# context_parallel | softmax_type | supported backends
# ----------------------------------------------------------------------------------------------------
# no | vanilla | All
# no | off-by-one | FusedAttention, UnfusedDotProductAttention
# no | learnable | FusedAttention, UnfusedDotProductAttention
# yes | vanilla | FusedAttention, FlashAttention
# yes | off-by-one | FusedAttention
# yes | learnable | FusedAttention
if softmax_type != "vanilla":
logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type)
use_flash_attention = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type
)
use_unfused_attention = False
if qkv_format == "thd":
logger.debug(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type
)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd",
softmax_type,
)
use_unfused_attention = False
if context_parallel:
logger.debug(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s",
softmax_type,
)
use_unfused_attention = False
if cp_comm_type != "a2a":
logger.debug(
"Disabling FusedAttention for context parallelism with softmax_type = %s and"
" cp_comm_type = %s",
softmax_type,
cp_comm_type,
)
use_fused_attention = False
# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
# ----------------------------------------------------------------------------------------------------
......@@ -822,6 +919,7 @@ def get_attention_backend(
QKVLayout[qkv_layout],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
attention_dropout,
num_heads,
num_gqa_groups,
......@@ -1836,11 +1934,10 @@ def check_set_window_size(
return window_size
def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
def get_attention_quantizers(fp8, quantizers):
"""Get the list of quantizers used in attention from the quantizers list."""
if not fp8:
num_of_nones = 8 if cp_specific_quantizers else 6
return [None] * num_of_nones
return [None] * 6
QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
QKV_quantizer.internal = True
QKV_quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -1849,6 +1946,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
S_quantizer = quantizers["scaling_fwd"][META_S]
S_quantizer.internal = True
S_quantizer.set_usage(rowwise=True, columnwise=False)
dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV]
dQKV_quantizer.interal = True
dQKV_quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -1858,22 +1956,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
dP_quantizer = quantizers["scaling_bwd"][META_DP]
dP_quantizer.set_usage(rowwise=True, columnwise=False)
dP_quantizer.interal = True
dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP]
dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False)
dQKV_CP_quantizer.internal = True
O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP]
O_CP_quantizer.set_usage(rowwise=True, columnwise=False)
if cp_specific_quantizers:
return (
return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
def print_quantizers(
label,
layer_number,
QKV_quantizer,
O_quantizer,
O_CP_quantizer,
S_quantizer,
dQKV_quantizer,
dQKV_CP_quantizer,
dO_quantizer,
dP_quantizer,
):
"""Print the type and scale/amax of attention quantizers"""
_to_print = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL == 2
if (
_to_print
and _print_layer == layer_number
and (
not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() == _print_rank)
)
):
names = [
"QKV_quantizer",
"S_quantizer",
"O_quantizer",
"dO_quantizer",
"dP_quantizer",
"dQKV_quantizer",
]
quantizers = [
QKV_quantizer,
S_quantizer,
O_quantizer,
dO_quantizer,
dP_quantizer,
dQKV_quantizer,
]
if "forward" in label:
names = names[:3]
quantizers = quantizers[:3]
if "backward" in label:
names = names[3:]
quantizers = quantizers[3:]
for i, q in enumerate(quantizers):
type_str = ""
if q is None:
type_str = "None"
elif isinstance(q, Float8Quantizer):
type_str = "DS"
elif isinstance(q, Float8CurrentScalingQuantizer):
type_str = "CS"
print(
f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x"
f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}"
)
return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer):
"""Combine q,k,v based on qkv_layout and quantize them together"""
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_layout = qkv_layout.replace("paged_kv_", "")
qkv_group = len(qkv_layout.split("_"))
src_nominal_dtype = q.dtype
match qkv_group:
case 1:
dim = qkv_layout.find("3")
qkv = combine_tensors([q, k, v], dim)
qkv_fp8 = qkv_quantizer(qkv)
q_data, k_data, v_data = SplitAlongDim.apply(qkv_fp8._data, dim, [1, 1, 1], True)
case 2:
dim = qkv_layout.split("_")[1].find("2")
kv = combine_tensors([k, v], dim)
tensors = [q, kv]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv = torch.cat([x.view(-1) for x in tensors], dim=0)
qkv_fp8 = qkv_quantizer(qkv)
q_data, kv_data = [
qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)
]
k_data, v_data = SplitAlongDim.apply(kv_data, dim, [1, 1], True)
case 3:
tensors = [q, k, v]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv = torch.cat([x.view(-1) for x in tensors], dim=0)
qkv_fp8 = qkv_quantizer(qkv)
q_data, k_data, v_data = [
qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)
]
case _:
raise RuntimeError("Invalid qkv_layout " + qkv_layout)
q_fp8, k_fp8, v_fp8 = [
Float8Tensor.make_like(qkv_fp8, data=x, dtype=src_nominal_dtype)
for x in [q_data, k_data, v_data]
]
return q_fp8, k_fp8, v_fp8
def combine_and_dequantize(
qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=None, des_nominal_dtype=None
):
"""Combine q,k,v based on qkv_layout and dequantize them together"""
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_layout = qkv_layout.replace("paged_kv_", "")
qkv_group = len(qkv_layout.split("_"))
if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]):
src_nominal_dtype = q_fp8.dtype
else:
assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!"
if des_nominal_dtype is None:
des_nominal_dtype = src_nominal_dtype
q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]]
match qkv_group:
case 1:
dim = qkv_layout.find("3")
qkv_data = combine_tensors([q_data, k_data, v_data], dim)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, k, v = SplitAlongDim.apply(qkv, dim, [1, 1, 1], True)
case 2:
dim = qkv_layout.split("_")[1].find("2")
kv_data = combine_tensors([k_data, v_data], dim)
tensors = [q_data, kv_data]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv_data = torch.cat([x.reshape(-1) for x in tensors], dim=0)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, kv = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)]
k, v = SplitAlongDim.apply(kv, dim, [1, 1], True)
case 3:
tensors = [q_data, k_data, v_data]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv_data = torch.cat([x.contiguous().reshape(-1) for x in tensors], dim=0)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, k, v = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)]
case _:
raise RuntimeError("Invalid qkv_layout " + qkv_layout)
return q, k, v
......@@ -3,13 +3,14 @@
# See LICENSE for license information.
"""Multi-head Attention."""
import os
import collections
from typing import Callable, List, Optional, Tuple, Union
import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
......@@ -31,7 +32,13 @@ from transformer_engine.pytorch.distributed import (
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
# Force DotProductAttention to use a different recipe than the fp8_recipe set in autocast().
# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling"
# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa.
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
_dpa_fp8_recipe_dpa = os.getenv("NVTE_DPA_FP8_RECIPE_DPA", "0") == "1"
_dpa_fp8_recipe_mha = os.getenv("NVTE_DPA_FP8_RECIPE_MHA", "0") == "1"
class MultiheadAttention(torch.nn.Module):
......@@ -135,6 +142,17 @@ class MultiheadAttention(torch.nn.Module):
For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None`
name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters
----------------------
......@@ -245,6 +263,7 @@ class MultiheadAttention(torch.nn.Module):
qk_norm_before_rope: bool = False,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
softmax_type: str = "vanilla",
) -> None:
super().__init__()
......@@ -262,6 +281,7 @@ class MultiheadAttention(torch.nn.Module):
self.return_bias = return_bias
self.cp_size = 1
self.cp_rank = 0
self.softmax_type = softmax_type
kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
......@@ -416,6 +436,7 @@ class MultiheadAttention(torch.nn.Module):
tp_group=tp_group,
layer_number=self.layer_number,
attention_type=self.attention_type,
softmax_type=self.softmax_type,
)
# Linear
......@@ -556,10 +577,12 @@ class MultiheadAttention(torch.nn.Module):
self.cp_size = get_distributed_world_size(cp_group)
self.cp_rank = get_distributed_rank(cp_group)
elif isinstance(cp_group, list):
assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
assert (
cp_comm_type == "a2a+p2p"
), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
assert (
len(cp_group) == 2
), "cp_comm_type = a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!"
cp_size_a2a = get_distributed_world_size(cp_group[0])
cp_rank_a2a = get_distributed_rank(cp_group[0])
cp_size_p2p = get_distributed_world_size(cp_group[1])
......@@ -716,10 +739,22 @@ class MultiheadAttention(torch.nn.Module):
# Query, Key, and Value
# ======================
fp8_mha = (
FP8GlobalStateManager.is_fp8_enabled()
and FP8GlobalStateManager.get_fp8_recipe().fp8_mha
)
fp8 = FP8GlobalStateManager.is_fp8_enabled()
if _dpa_fp8_recipe == "":
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
fp8_dpa = fp8_recipe.fp8_dpa
fp8_mha = fp8_recipe.fp8_mha
float8_current_scaling = fp8_recipe.float8_current_scaling()
else:
fp8_dpa = _dpa_fp8_recipe_dpa
fp8_mha = _dpa_fp8_recipe_mha
float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling"
# QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe
qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling
# DPA: always produce FP8 output when fp8=True to take advantage of the O amax
dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha)
# Proj Gemm: match DPA output except for Float8CurrentScaling
proj_fp8_grad = dpa_fp8_output and not float8_current_scaling
layernorm_output = None
if self.attention_type == "self":
......@@ -728,7 +763,7 @@ class MultiheadAttention(torch.nn.Module):
layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None,
fp8_output=qkv_fp8_output,
)
if self.return_layernorm_output:
mixed_x_layer, layernorm_output = layernorm_qkv_outputs
......@@ -738,7 +773,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_x_layer = self.qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None,
fp8_output=qkv_fp8_output,
)
num_queries_per_key_value = (
......@@ -792,7 +827,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_kv_layer = self.key_value(
encoder_output,
is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None,
fp8_output=qkv_fp8_output,
)
if self.qkv_weight_interleaved:
......@@ -847,7 +882,7 @@ class MultiheadAttention(torch.nn.Module):
layernorm_query_outputs = self.layernorm_query(
hidden_states,
is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None,
fp8_output=qkv_fp8_output,
)
if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs
......@@ -857,7 +892,7 @@ class MultiheadAttention(torch.nn.Module):
query_layer = self.query_layer(
hidden_states,
is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None,
fp8_output=qkv_fp8_output,
)
# [sq, b, hp] --> [sq, b, np, hn]
......@@ -958,6 +993,7 @@ class MultiheadAttention(torch.nn.Module):
fast_zero_fill=fast_zero_fill,
inference_params=inference_params,
pad_between_seqs=pad_between_seqs,
fp8_output=dpa_fp8_output,
)
# ===================
......@@ -966,7 +1002,7 @@ class MultiheadAttention(torch.nn.Module):
projection_output = self.proj(
context_layer,
is_first_microbatch=is_first_microbatch,
fp8_grad=isinstance(context_layer, QuantizedTensor),
fp8_grad=proj_fp8_grad,
)
if self.return_bias:
......
......@@ -66,6 +66,9 @@ class RotaryPositionEmbedding(torch.nn.Module):
"""
Create rotary position embedding frequencies.
This function is particularly sensitive to the use of mixed precision, so we disable the
autocast context if it is enabled.
Parameters
----------
max_seq_len: int
......@@ -73,6 +76,7 @@ class RotaryPositionEmbedding(torch.nn.Module):
offset: int, default = 0
Fixed offset for frequencies.
"""
with torch.autocast(enabled=False, device_type="cuda"):
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
......
......@@ -91,3 +91,5 @@ GemmParallelModes = ("row", "column", None)
dist_group_type = torch.distributed.ProcessGroup
MXFP8_BLOCK_SCALING_SIZE = 32
NVFP4_BLOCK_SCALING_SIZE = 16
......@@ -12,6 +12,7 @@ from transformer_engine_torch import (
NVTE_QKV_Format,
NVTE_Bias_Type,
NVTE_Mask_Type,
NVTE_Softmax_Type,
NVTE_Fused_Attn_Backend,
)
from ..tensor.quantized_tensor import Quantizer
......@@ -86,6 +87,12 @@ AttnMaskType = {
"padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
}
SoftmaxType = {
"vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX,
"off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX,
"learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX,
}
FusedAttnBackend = {
"F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen,
"F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
......@@ -102,9 +109,6 @@ META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
# repurpose some unused amax history buffers for partial results of CP fwd and bwd
META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT
META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1
def fused_attn_fwd(
......@@ -131,8 +135,10 @@ def fused_attn_fwd(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input.
......@@ -197,6 +203,8 @@ def fused_attn_fwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
......@@ -205,6 +213,9 @@ def fused_attn_fwd(
rng_gen: torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
Returns
----------
......@@ -286,6 +297,7 @@ def fused_attn_fwd(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size,
cu_seqlens_q,
cu_seqlens_kv,
......@@ -300,6 +312,7 @@ def fused_attn_fwd(
s_quantizer,
o_quantizer,
attn_bias,
softmax_offset,
rng_gen,
rng_elts_per_thread,
)
......@@ -333,6 +346,7 @@ def fused_attn_bwd(
qkv_layout: str = "sbh3d",
attn_bias_type: str = "no_bias",
attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
......@@ -398,6 +412,8 @@ def fused_attn_bwd(
type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"}
attn_mask_type: str, default = "padding"
type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"}
softmax_type: str, default = "vanilla"
type of the attention softmax; {"vanilla", "off-by-one", "learnable"}
window_size: Tuple[int, int], default = (-1, -1)
sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
......@@ -417,6 +433,9 @@ def fused_attn_bwd(
d_bias: torch.Tensor, optional
gradient tensor of Bias when attn_bias_type is "pre_scale_bias"
or "post_scale_bias"; same data type and shape as Bias
d_softmax_offset: torch.Tensor, optional
gradient tensor of softmax offset in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
"""
if attn_scale is None:
d = q.size(-1)
......@@ -454,6 +473,7 @@ def fused_attn_bwd(
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size,
deterministic,
cu_seqlens_q,
......
......@@ -19,8 +19,10 @@ from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.utils import is_experimental
from ..experimental.gemm import experimental_gemm
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ...debug.pytorch.debug_quantization import DebugQuantizer
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_quant_fp8_to_int8,
......@@ -169,6 +171,24 @@ def general_gemm(
if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.")
# If A or B are experimental tensors -> dispatch to quantizers's qgemm implementation
if is_experimental(A) or is_experimental(B):
return experimental_gemm(
A,
B,
workspace,
out_dtype,
quantization_params,
gelu,
gelu_in,
accumulate,
layout,
out,
bias,
use_split_accumulator,
grad,
)
debug_quantizer = None
if isinstance(quantization_params, DebugQuantizer):
debug_quantizer = quantization_params
......@@ -179,9 +199,9 @@ def general_gemm(
# Use bfloat16 as default bias_dtype
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase):
if isinstance(A, Float8BlockwiseQTensorStorage) or isinstance(B, Float8BlockwiseQTensorStorage):
# There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorBase GEMM
# implementation for Float8BlockwiseQTensorStorage GEMM
use_split_accumulator = True
# Check that data format is supported
......@@ -191,7 +211,7 @@ def general_gemm(
):
raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format")
if int8_simulation_fp8 and (isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase)):
if int8_simulation_fp8 and (isinstance(A, Float8BlockwiseQTensorStorage) or isinstance(B, Float8BlockwiseQTensorStorage)):
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation"
......@@ -210,7 +230,7 @@ def general_gemm(
)
return y, None, None, None
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)) and int8_simulation_fp8_tensorwise:
if int8_simulation_fp8 and (isinstance(A, Float8TensorStorage) or isinstance(B, Float8TensorStorage)) and int8_simulation_fp8_tensorwise:
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation"
......@@ -251,7 +271,7 @@ def general_gemm(
return out, bias_grad, gelu_input, extra_output
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)):
if int8_simulation_fp8 and (isinstance(A, Float8TensorStorage) or isinstance(B, Float8TensorStorage)):
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation"
......@@ -440,7 +460,7 @@ def general_grouped_gemm(
for o in out
] # this should differ with respect to single output
if int8_simulation_fp8 and (isinstance(A[0], Float8BlockwiseQTensorBase) or isinstance(B[0], Float8BlockwiseQTensorBase)):
if int8_simulation_fp8 and (isinstance(A[0], Float8BlockwiseQTensorStorage) or isinstance(B[0], Float8BlockwiseQTensorStorage)):
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm."
......@@ -502,7 +522,7 @@ def general_grouped_gemm(
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise and int8_simulation_fp8_tensorwise_batched:
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorStorage) or isinstance(B[0], Float8TensorStorage)) and int8_simulation_fp8_tensorwise and int8_simulation_fp8_tensorwise_batched:
assert len(set(m_splits)) == 1, "Need token pad as same as batchgemm for NVTE_INT8_SIM_FP8_TENSORWISE_BATCHED."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm."
......@@ -616,7 +636,7 @@ def general_grouped_gemm(
return out, bias, gelu_input
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise:
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorStorage) or isinstance(B[0], Float8TensorStorage)) and int8_simulation_fp8_tensorwise:
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
bias = tex.te_general_grouped_gemm(
......@@ -642,7 +662,7 @@ def general_grouped_gemm(
return out, bias, gelu_input
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)):
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorStorage) or isinstance(B[0], Float8TensorStorage)):
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm."
......
......@@ -10,12 +10,13 @@ from typing import Any, Dict, Optional
import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from .tensor.quantized_tensor import QuantizedTensorBase
from .tensor.quantized_tensor import QuantizedTensorStorage
from .tensor.float8_tensor import Float8Tensor
__all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False
CPUOffloadedLayer = False
def get_cpu_offloading():
global CPUOffloadEnabled
......@@ -42,7 +43,7 @@ def mark_activation_offload(*tensors):
if tensor is not None:
tensor.activation_offloading = True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorBase classes are saved in the ctx,
# It is needed, because .*TensorStorage classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor.needs_force_clear = True
......@@ -361,6 +362,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
self.h2d_stream = torch.cuda.Stream()
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
global CPUOffloadedLayer
torch_stray_tensor = isinstance(
tensor,
......@@ -370,7 +372,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
),
)
is_quantized_tensor = isinstance(tensor, QuantizedTensorBase)
is_quantized_tensor = isinstance(tensor, QuantizedTensorStorage)
if not torch_stray_tensor:
......@@ -416,6 +418,11 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
tensor.clear()
else:
self.tensor_tag_to_buf[tensor_tag] = t
# Needed to differentiate non offloaded layer's attention
# QKV layout of attention of non-offloaded layer needs
# to be modified while reloading
CPUOffloadedLayer = True
else:
tensor_tag = (-1, self.torch_tensor_count)
self.torch_tensor_count += 1
......@@ -425,6 +432,8 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop."""
global CPUOffloadedLayer
assert tensor_tag in self.tensor_tag_to_state
tensor = self.tensor_tag_to_state.pop(tensor_tag)
......@@ -488,6 +497,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
def synchronize_on_group_commit_forward(self, current_group):
"""Synchronize on group commit forward."""
global CPUOffloadedLayer
# For the first group, kickstart the offload after we have
# the first compute completion
......@@ -522,7 +532,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
if tensor_tag[0] == self.offloaded_group_count:
if hasattr(tensor_buf, "needs_force_clear"):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorBase class,
# This is the case for example with the Float8TensorStorage class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf.data = torch.Tensor()
......@@ -536,6 +546,9 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Increment the offload group count to keep track
self.offloaded_group_count += 1
if current_group == (self.num_offload_group - 1):
CPUOffloadedLayer = False
if not self.double_buffer_created:
# Creating second copy of double buffer for tensors that are offloaded
if current_group == (self.num_layers - 1):
......
......@@ -12,6 +12,20 @@
namespace transformer_engine::pytorch {
/*! convert fp4 data shape back to original shape */
std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose) {
std::vector<size_t> ret;
size_t start_idx = (transpose) ? 1 : 0;
for (size_t i = start_idx; i < shape.size() - 1; ++i) {
ret.push_back(shape[i]);
}
ret.push_back(shape.back() * 2);
if (transpose) {
ret.push_back(shape.front());
}
return ret;
}
std::vector<size_t> getTensorShape(const at::Tensor& t) {
std::vector<size_t> shape;
for (auto s : t.sizes()) {
......@@ -291,4 +305,20 @@ size_t roundup(const size_t value, const size_t multiple) {
return ((value + multiple - 1) / multiple) * multiple;
}
void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) {
NVTE_SCOPED_GIL_RELEASE({
nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val,
arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_,
at::cuda::getCurrentCUDAStream());
});
}
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread) {
at::PhiloxCudaState philox_args;
std::lock_guard<std::mutex> lock(gen->mutex_);
philox_args = gen->philox_cuda_state(elts_per_thread);
return philox_args;
}
} // namespace transformer_engine::pytorch
......@@ -35,6 +35,7 @@
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/fused_router.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/hadamard_transform.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/normalization.h>
......@@ -212,20 +213,25 @@ class Float8CurrentScalingQuantizer : public Quantizer {
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
/*! @brief Construct a high precision tensor giving it this quantizer's amax
Note: this member function also zeros out the amax, as it is meant to be used in conjunction with
a kernel computing the amax, which might expect the amax to be initialized to zero
/*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer.
*
* The amax is zeroed out. Most TE kernels that output amax expect
* amax to be initialized to zero.
*/
std::pair<TensorWrapper, py::object> create_hp_tensor_with_amax(const std::vector<size_t>& shape,
DType dtype);
std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> data = std::nullopt);
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
/*! @brief Convert to a quantized data format avoiding amax computation */
/*! @brief Quantize to FP8, skipping local amax computation
*
* The quantizer's amax pointer is assumed to already hold the local
* amax. The amax may still be reduced across the amax reduction
* group.
*/
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt);
......@@ -295,6 +301,60 @@ class MXFP8Quantizer : public Quantizer {
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
};
class NVFP4Quantizer : public Quantizer {
public:
// fp4 dtype
DType dtype;
// amax reduction for low precision FP4 AG
bool with_amax_reduction;
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
// random hadamard transform
bool with_rht;
bool with_post_rht_amax;
// 2D block scaling
bool with_2d_quantization;
bool stochastic_rounding;
int rht_matrix_random_sign_mask_t;
at::Tensor rht_matrix;
explicit NVFP4Quantizer(const py::handle& quantizer);
NVTEScalingMode get_scaling_mode() const override { return NVTE_NVFP4_1D_SCALING; }
void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
/*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer
*
* The amax is zeroed out. Most TE kernels that output amax expect
* amax to be initialized to zero.
*/
std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
TensorWrapper& quantized_tensor, DType dtype);
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
/*! @brief Quantize to NVFP4, skipping local amax computation
*
* The input tensor's amax pointer is assumed to already hold the
* local amax. The amax may still be reduced across the amax
* reduction group.
*/
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out);
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
private:
void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
};
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);
std::vector<size_t> getTensorShape(const at::Tensor& t);
......@@ -445,6 +505,15 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
size_t roundup(const size_t value, const size_t multiple);
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose);
// unpack the PhiloxCudaState into CUDA tensor
void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr);
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread);
} // namespace transformer_engine::pytorch
namespace std {
......
......@@ -73,28 +73,36 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
NVTE_Fused_Attn_Backend get_fused_attn_backend(
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype,
bool create_hp_tensor_for_cs,
std::optional<at::Tensor> data);
std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const std::optional<at::Tensor> cu_seqlens_q_padded,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const py::handle O, const py::handle dO,
const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
......@@ -233,6 +241,10 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer);
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha);
py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
float limit, float alpha);
/***************************************************************************************************
* LayerNorm
**************************************************************************************************/
......
......@@ -3,184 +3,331 @@
*
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#include "common.h"
#include "pybind.h"
namespace transformer_engine::pytorch {
namespace transformer_engine {
namespace pytorch {
namespace {
using FuncType = void(const NVTETensor, NVTETensor, cudaStream_t);
using DFuncType = void(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t);
template <void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t)>
py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) {
template <FuncType* act_func, auto act_func_with_args, typename... Args>
py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1,
Args&&... args) {
init_extension();
// Input tensor
auto input_tensor = input.contiguous();
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);
const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor);
// Construct output tensor
auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape = input_cpp.shape();
const auto input_shape = input_nvte.shape();
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
output_shape.back() /= shape_divisor;
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype);
auto [out_nvte, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype);
// Compute activation
// Choose implementation
enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 };
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation directly
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); });
impl = Impl::FULLY_FUSED;
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation in high-precision fused together with amax, then quantize.
auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); });
quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp);
impl = Impl::FUSED_ACTIVATION_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else {
// Compute activation in high-precision, then quantize
impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
}
}
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); });
quantizer_cpp->quantize(temp_cpp, out_cpp);
// Perform compute
auto stream = at::cuda::getCurrentCUDAStream();
switch (impl) {
case Impl::UNFUSED:
// Compute activation in high precision, then quantize
{
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), temp_nvte.data(), stream);
}
});
quantizer_cpp->quantize(temp_nvte, out_nvte);
}
break;
case Impl::FULLY_FUSED:
// Compute activation directly
{
NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), out_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), out_nvte.data(), stream);
}
});
}
break;
case Impl::FUSED_ACTIVATION_AMAX_FP8:
// Compute activation and amax in high precision, then quantize to FP8
{
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), temp_nvte.data(), stream);
}
});
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
}
break;
case Impl::FUSED_ACTIVATION_AMAX_NVFP4:
// Compute activation and amax in high precision, then quantize to NVFP4
{
auto nvfp4_quantizer_cpp =
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), temp_nvte.data(), stream);
}
});
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
}
break;
default:
NVTE_ERROR("Invalid activation implementation (", static_cast<int>(impl), ")");
}
return out_py;
}
template <void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)>
template <DFuncType* dact_func, auto dact_func_with_args, typename... Args>
py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input,
py::handle quantizer) {
py::handle quantizer, Args&&... args) {
init_extension();
// Grad output and input tensors
auto grad_output_tensor = grad_output.contiguous();
auto input_tensor = input.contiguous();
const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor);
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);
const TensorWrapper& grad_output_nvte = makeTransformerEngineTensor(grad_output_tensor);
const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor);
// Construct grad input tensor
auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape_te = input_cpp.shape();
const auto input_shape_te = input_nvte.shape();
const std::vector<size_t> input_shape(input_shape_te.data,
input_shape_te.data + input_shape_te.ndim);
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype);
auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype);
// Compute activation backward
// Choose implementation
enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 };
Impl impl = Impl::UNFUSED;
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation backward directly
impl = Impl::FULLY_FUSED;
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
impl = Impl::FUSED_ACTIVATION_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else {
impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
}
}
// Perform compute
auto stream = at::cuda::getCurrentCUDAStream();
switch (impl) {
case Impl::UNFUSED:
// Compute activation backward in high precision, then quantize
{
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(),
at::cuda::getCurrentCUDAStream());
if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream);
}
});
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation backward in high-precision fused together with amax, then quantize.
auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype);
quantizer_cpp->quantize(temp_nvte, grad_input_nvte);
}
break;
case Impl::FULLY_FUSED:
// Compute activation backward directly
{
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(),
at::cuda::getCurrentCUDAStream());
if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream);
}
});
quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp);
}
break;
case Impl::FUSED_ACTIVATION_AMAX_FP8:
// Compute activation and amax in high precision, then quantize to FP8
{
auto fp8_quantizer_cpp = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
// Compute activation backward in high-precision, then quantize
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream);
}
});
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
}
break;
case Impl::FUSED_ACTIVATION_AMAX_NVFP4:
// Compute activation and amax in high precision, then quantize to NVFP4
{
auto nvfp4_quantizer_cpp =
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(),
at::cuda::getCurrentCUDAStream());
if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream);
}
});
quantizer_cpp->quantize(temp_cpp, grad_input_cpp);
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
}
break;
default:
NVTE_ERROR("Invalid activation implementation (", static_cast<int>(impl), ")");
}
return grad_input_py;
}
} // namespace
/* GELU and variants*/
/* GELU and variants */
py::object gelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_gelu>(input, quantizer);
return activation_helper<nvte_gelu, nullptr>(input, quantizer);
}
py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dgelu>(grad, input, quantizer);
return dactivation_helper<nvte_dgelu, nullptr>(grad, input, quantizer);
}
py::object geglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_geglu>(input, quantizer, 2);
return activation_helper<nvte_geglu, nullptr>(input, quantizer, 2);
}
py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dgeglu>(grad, input, quantizer);
return dactivation_helper<nvte_dgeglu, nullptr>(grad, input, quantizer);
}
py::object qgelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgelu>(input, quantizer);
return activation_helper<nvte_qgelu, nullptr>(input, quantizer);
}
py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgelu>(grad, input, quantizer);
return dactivation_helper<nvte_dqgelu, nullptr>(grad, input, quantizer);
}
py::object qgeglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgeglu>(input, quantizer, 2);
return activation_helper<nvte_qgeglu, nullptr>(input, quantizer, 2);
}
py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgeglu>(grad, input, quantizer);
return dactivation_helper<nvte_dqgeglu, nullptr>(grad, input, quantizer);
}
/* ReLU and variants*/
/* ReLU and variants */
py::object relu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_relu>(input, quantizer);
return activation_helper<nvte_relu, nullptr>(input, quantizer);
}
py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_drelu>(grad, input, quantizer);
return dactivation_helper<nvte_drelu, nullptr>(grad, input, quantizer);
}
py::object reglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_reglu>(input, quantizer, 2);
return activation_helper<nvte_reglu, nullptr>(input, quantizer, 2);
}
py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dreglu>(grad, input, quantizer);
return dactivation_helper<nvte_dreglu, nullptr>(grad, input, quantizer);
}
py::object srelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_srelu>(input, quantizer);
return activation_helper<nvte_srelu, nullptr>(input, quantizer);
}
py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsrelu>(grad, input, quantizer);
return dactivation_helper<nvte_dsrelu, nullptr>(grad, input, quantizer);
}
py::object sreglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_sreglu>(input, quantizer, 2);
return activation_helper<nvte_sreglu, nullptr>(input, quantizer, 2);
}
py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsreglu>(grad, input, quantizer);
return dactivation_helper<nvte_dsreglu, nullptr>(grad, input, quantizer);
}
/* Silu and variants*/
/* Silu and variants */
py::object silu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_silu>(input, quantizer);
return activation_helper<nvte_silu, nullptr>(input, quantizer);
}
py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsilu>(grad, input, quantizer);
return dactivation_helper<nvte_dsilu, nullptr>(grad, input, quantizer);
}
py::object swiglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_swiglu>(input, quantizer, 2);
return activation_helper<nvte_swiglu, nullptr>(input, quantizer, 2);
}
py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dswiglu>(grad, input, quantizer);
return dactivation_helper<nvte_dswiglu, nullptr>(grad, input, quantizer);
}
/* clamped functions */
py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) {
return activation_helper<nullptr, nvte_clamped_swiglu>(input, quantizer, 2, limit, alpha);
}
} // namespace transformer_engine::pytorch
py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer,
float limit, float alpha) {
return dactivation_helper<nullptr, nvte_clamped_dswiglu>(grad, input, quantizer, limit, alpha);
}
} // namespace pytorch
} // namespace transformer_engine
......@@ -35,22 +35,6 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
{ nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); });
}
void unpack(at::PhiloxCudaState arg, int64_t *rng_state_ptr) {
NVTE_SCOPED_GIL_RELEASE({
nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val,
arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_,
at::cuda::getCurrentCUDAStream());
});
}
// extract PhiloxCudaState from CUDA random number generator
at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_per_thread) {
at::PhiloxCudaState philox_args;
std::lock_guard<std::mutex> lock(gen->mutex_);
philox_args = gen->philox_cuda_state(elts_per_thread);
return philox_args;
}
} // namespace
namespace transformer_engine::pytorch {
......@@ -58,73 +42,102 @@ namespace transformer_engine::pytorch {
// get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend(
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
#ifdef __HIP_PLATFORM_AMD__
return NVTE_Fused_Attn_Backend::NVTE_No_Backend;
#else
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q,
max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right);
bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend;
#endif
}
// helper function for S and dP quantizers
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype,
bool create_hp_tensor_for_cs,
std::optional<at::Tensor> data) {
std::unique_ptr<Quantizer> T_quantizer = convert_quantizer(quantizer);
TensorWrapper te_T;
py::object py_T;
if (quantizer.is_none()) {
// high precision
auto *none_quantizer = dynamic_cast<NoneQuantizer *>(T_quantizer.get());
if (data.has_value()) {
std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype, data.value());
} else {
std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype);
}
} else if (detail::IsFloat8Quantizers(quantizer.ptr())) {
// delayed scaling; this helps initialize scale_inv
auto *T_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(T_quantizer.get());
std::tie(te_T, py_T) =
T_quantizer_fp8->create_tensor(shape, dtype, data, std::nullopt, std::nullopt);
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// current scaling
auto *T_quantizer_fp8 = dynamic_cast<Float8CurrentScalingQuantizer *>(T_quantizer.get());
if (create_hp_tensor_for_cs) {
if (data.has_value()) {
std::tie(te_T, py_T) =
T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value());
} else {
std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype);
}
} else {
std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype);
NVTE_CHECK(
!data.has_value(),
"Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!");
}
}
return {std::move(te_T), std::move(py_T)};
}
// fused attention FWD with separate Q, K and V tensors
std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const std::optional<at::Tensor> cu_seqlens_q_padded,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
#ifdef __HIP_PLATFORM_AMD__
assert(false);
#else
TensorWrapper te_Q, te_K, te_V, te_O, te_S;
auto none = py::none();
std::unique_ptr<Quantizer> S_quantizer = convert_quantizer(s_quantizer);
std::unique_ptr<Quantizer> O_quantizer = convert_quantizer(o_quantizer);
// create QKV tensor wrappers
TensorWrapper te_Q, te_K, te_V;
te_Q = makeTransformerEngineTensor(Q, none);
te_K = makeTransformerEngineTensor(K, none);
te_V = makeTransformerEngineTensor(V, none);
// If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types.
const DType qkv_type = te_Q.dtype();
const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
// create S tensor
TensorWrapper te_S;
py::object py_S;
std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt);
// create O tensor
TensorWrapper te_O;
py::object py_O;
std::unique_ptr<Quantizer> O_quantizer = convert_quantizer(o_quantizer);
std::vector<size_t> q_shape = convertShape(te_Q.shape());
std::vector<size_t> k_shape = convertShape(te_K.shape());
std::vector<size_t> v_shape = convertShape(te_V.shape());
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
// create output tensor O
auto o_shape = std::vector<size_t>{q_shape.begin(), q_shape.end()};
o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1];
py::object o_python, s_python;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// Initialize FP8 tensor with scale-inverse
auto *O_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(O_quantizer.get());
auto *S_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(S_quantizer.get());
NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt,
std::nullopt, std::nullopt);
std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
} else {
std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te);
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32);
}
auto o_shape_int64 = std::vector<int64_t>{o_shape.begin(), o_shape.end()};
const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt);
// construct NVTE tensors
TensorWrapper te_Bias;
......@@ -135,12 +148,13 @@ std::vector<py::object> fused_attn_fwd(
// FP8
auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0) &&
(nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
if ((h * d) % block_size == 0) {
mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else {
te_O.zero_(at::cuda::getCurrentCUDAStream());
}
}
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
te_O.zero_(at::cuda::getCurrentCUDAStream());
......@@ -188,12 +202,23 @@ std::vector<py::object> fused_attn_fwd(
DType::kInt32, nullptr, nullptr, nullptr);
}
// softmax offset
TensorWrapper te_SoftmaxOffset;
if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) {
auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec();
std::vector<size_t> SoftmaxOffset_shape{SoftmaxOffset_sizes.begin(), SoftmaxOffset_sizes.end()};
te_SoftmaxOffset =
makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), SoftmaxOffset_shape,
DType::kFloat32, nullptr, nullptr, nullptr);
}
// extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
unpack(philox_args, static_cast<int64_t *>(rng_state.data_ptr()));
auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options);
philox_unpack(philox_args, static_cast<int64_t *>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors
......@@ -206,11 +231,11 @@ std::vector<py::object> fused_attn_fwd(
// populate tensors with appropriate shapes and dtypes
NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_fwd(
te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0],
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
});
......@@ -221,52 +246,53 @@ std::vector<py::object> fused_attn_fwd(
// output_tensors = [O, nvte_aux_tensor_pack.tensors]
std::vector<py::object> output_tensors;
output_tensors.push_back(o_python);
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
output_tensors.push_back(py_O);
auto set_tensor_param = [&](size_t i, const at::Tensor &output_tensor) {
output_tensors.push_back(py::cast(output_tensor));
NVTEBasicTensor temp_data = {output_tensor.data_ptr(),
nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]),
nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])};
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data);
};
// allocate memory for nvte_aux_tensor_pack.tensors
// f16_max512 : S [b, h, sq, skv]
// f16_arbitrary: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2]
size_t i = 0;
at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
if (i < nvte_aux_tensor_pack.size - 2) {
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor = allocateSpace(
nvte_shape_to_vector(temp_shape),
// intermediate softmax tensor, S or M
output_tensor =
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
} else if (i == nvte_aux_tensor_pack.size - 2) {
output_tensor = rng_state;
} else if (i == nvte_aux_tensor_pack.size - 1) {
output_tensor = Bias.value();
}
} else {
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
set_tensor_param(i++, output_tensor);
// fp8 has an additional softmax stats tensor, ZInv
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
output_tensor =
(i < nvte_aux_tensor_pack.size - 1)
? allocateSpace(
nvte_shape_to_vector(temp_shape),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false)
: rng_state;
}
} else {
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor = allocateSpace(
nvte_shape_to_vector(temp_shape),
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
set_tensor_param(i++, output_tensor);
}
output_tensors.push_back(py::cast(output_tensor));
NVTEBasicTensor temp_data = {output_tensor.data_ptr(),
nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]),
nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])};
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data);
// rng_state
if (i < nvte_aux_tensor_pack.size) {
set_tensor_param(i++, rng_state);
}
// bias (optional)
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
set_tensor_param(i++, Bias.value());
}
// softmax_offset (optional)
if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) {
set_tensor_param(i++, SoftmaxOffset.value());
}
// execute the kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_fwd(
te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0],
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
});
......@@ -282,9 +308,10 @@ std::vector<py::object> fused_attn_fwd(
std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const py::handle O, const py::handle dO,
const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
......@@ -293,50 +320,44 @@ std::vector<py::object> fused_attn_bwd(
assert(false);
#else
auto none = py::none();
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
// create QKV, O, dO tensor wrappers
TensorWrapper te_Q, te_K, te_V, te_O, te_dO;
te_Q = makeTransformerEngineTensor(Q, none);
te_K = makeTransformerEngineTensor(K, none);
te_V = makeTransformerEngineTensor(V, none);
te_O = makeTransformerEngineTensor(O, none);
te_dO = makeTransformerEngineTensor(dO, none);
// qkv type from the te_Q
std::unique_ptr<Quantizer> dQKV_quantizer = convert_quantizer(dqkv_quantizer);
const DType qkv_type = te_Q.dtype();
const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
py::object s_python, dp_python;
std::unique_ptr<Quantizer> S_quantizer = convert_quantizer(s_quantizer);
std::unique_ptr<Quantizer> dP_quantizer = convert_quantizer(dp_quantizer);
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
auto *S_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(S_quantizer.get());
auto *dP_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(dP_quantizer.get());
NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
} else {
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32);
std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32);
}
// create S and dP tensors
TensorWrapper te_S, te_dP;
py::object py_S, py_dP;
std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt);
std::tie(te_dP, py_dP) =
quantizer_helper(dp_quantizer, {0}, DType::kFloat32, false, std::nullopt);
// create dQ, dK, dV tensors
TensorWrapper te_dQ, te_dK, te_dV;
py::object py_dQ, py_dK, py_dV;
std::unique_ptr<Quantizer> dQKV_quantizer = convert_quantizer(dqkv_quantizer);
std::vector<size_t> q_shape = convertShape(te_Q.shape());
std::vector<size_t> k_shape = convertShape(te_K.shape());
std::vector<size_t> v_shape = convertShape(te_V.shape());
auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2];
auto d_qk = q_shape[q_shape.size() - 1];
auto d_v = v_shape[v_shape.size() - 1];
auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
std::vector<size_t> o_shape{q_shape.begin(), q_shape.end()};
o_shape[o_shape.size() - 1] = d_v;
const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
at::Tensor dQ, dK, dV, dQKV, dKV;
py::object py_dQ, py_dK, py_dV;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
std::vector<int64_t> tmp_shape;
auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) {
options = options.dtype(torch::kUInt8);
}
if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) {
options = options.dtype(fake_dtype);
}
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_3HD:
......@@ -409,29 +430,17 @@ std::vector<py::object> fused_attn_bwd(
default:
NVTE_ERROR("QKV layout not supported!");
}
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
auto *fp8_quantizer = dynamic_cast<Float8Quantizer *>(dQKV_quantizer.get());
NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_dQ, py_dQ) =
fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt);
std::tie(te_dK, py_dK) =
fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt);
std::tie(te_dV, py_dV) =
fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt);
} else {
auto *none_quantizer = dynamic_cast<NoneQuantizer *>(dQKV_quantizer.get());
NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8");
std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ);
std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK);
std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV);
}
std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ);
std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK);
std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV);
// construct NVTE tensors
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) {
// FP8
if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) &&
dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() &&
(nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) &&
dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) {
mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
......@@ -440,8 +449,8 @@ std::vector<py::object> fused_attn_bwd(
dK.fill_(0);
dV.fill_(0);
}
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
}
} else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dQ.fill_(0);
dK.fill_(0);
......@@ -510,6 +519,15 @@ std::vector<py::object> fused_attn_bwd(
}
}
// create dSoftmaxOffset in the same shape as SoftmaxOffset
at::Tensor dSoftmaxOffset;
TensorWrapper te_dSoftmaxOffset;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
options = torch::TensorOptions().dtype(at::kFloat).device(torch::kCUDA);
dSoftmaxOffset = torch::empty({1, static_cast<int64_t>(h_q), 1, 1}, options);
te_dSoftmaxOffset = makeTransformerEngineTensor(dSoftmaxOffset);
}
// create workspace
TensorWrapper workspace;
......@@ -518,10 +536,10 @@ std::vector<py::object> fused_attn_bwd(
nvte_fused_attn_bwd(
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic,
workspace.data(), at::cuda::getCurrentCUDAStream());
te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
});
// allocate memory for workspace
......@@ -534,16 +552,16 @@ std::vector<py::object> fused_attn_bwd(
nvte_fused_attn_bwd(
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic,
workspace.data(), at::cuda::getCurrentCUDAStream());
te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
});
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {py_dQ, py_dK, py_dV, py::cast(dBias)};
return {py_dQ, py_dK, py_dV, py::cast(dBias), py::cast(dSoftmaxOffset)};
#endif
}
......@@ -610,7 +628,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
// Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1
int seq_dim = tensor.dim() == 3 ? 0 : 1;
int batch = cu_seqlens.size(0) - 1;
int num_heads = tensor.size(seq_dim + 1);
int dim_per_head = tensor.size(seq_dim + 2);
int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type());
......@@ -774,8 +791,6 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
NVTE_CHECK(world_size > 0);
NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0);
int batch = cu_seqlens.size(0) - 1;
std::vector<int64_t> shape = {total_tokens / world_size};
at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int));
......@@ -813,7 +828,6 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b,
**************************************************************************************************/
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) {
int max_seq_len = tensor.size(1);
int h = tensor.size(2);
int d = tensor.size(3);
std::vector<int64_t> shape = {t, h, d};
......
......@@ -54,10 +54,25 @@ std::vector<py::object> bgrad_quantize(const at::Tensor &grad_output, py::handle
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
}
// Unfused impl if quantizer is not supported
const bool with_fused_dbias_quantize_kernel =
detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr());
if (!with_fused_dbias_quantize_kernel) {
// Check if fused kernel is supported
bool with_fused_kernel = false;
if (detail::IsFloat8Quantizers(quantizer.ptr())) {
auto prop = at::cuda::getCurrentDeviceProperties();
const size_t sm_arch = 10 * prop->major + prop->minor;
if (sm_arch >= 100) {
// Fused kernel for dbias + FP8 cast on SM arch 10.0+
with_fused_kernel = true;
} else if (quantizer_cpp->rowwise_usage && quantizer_cpp->columnwise_usage) {
// Fused kernel for dbias + FP8 cast + FP8 transpose
with_fused_kernel = true;
}
} else if (detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Fused kernel for dbias + MXFP8 quantize
with_fused_kernel = true;
}
// Apply unfused impl if fused kernel is not supported
if (!with_fused_kernel) {
at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0});
quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte);
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
......@@ -122,13 +137,27 @@ std::vector<py::object> dact_dbias(
}
// Choose implementation
enum class Impl { UNFUSED, FUSED_DACT_DBIAS_QUANTIZE, FUSED_DACT_AMAX };
enum class Impl {
UNFUSED,
FUSED_DACT_DBIAS_QUANTIZE,
FUSED_DACT_AMAX_FP8,
FUSED_DACT_AMAX_NVFP4
};
Impl impl = Impl::UNFUSED;
if (detail::IsFloat8Quantizers(quantizer_py.ptr()) ||
detail::IsMXFP8Quantizers(quantizer_py.ptr())) {
impl = Impl::FUSED_DACT_DBIAS_QUANTIZE;
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) {
impl = Impl::FUSED_DACT_AMAX;
impl = Impl::FUSED_DACT_AMAX_FP8;
} else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else {
impl = Impl::FUSED_DACT_AMAX_NVFP4;
}
}
// Perform compute
......@@ -172,20 +201,38 @@ std::vector<py::object> dact_dbias(
});
break;
}
case Impl::FUSED_DACT_AMAX:
// Fused dact-amax kernel, unfused dbias and quantize
case Impl::FUSED_DACT_AMAX_FP8:
// Fused dact-amax kernel, unfused dbias and FP8 quantize
{
auto *quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_CHECK(quantizer_cpp_cs != nullptr,
auto *fp8_quantizer_cpp =
dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_CHECK(fp8_quantizer_cpp != nullptr,
"Invalid quantizer for fused dact-amax kernel impl");
auto [temp_nvte, temp_py] =
quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, grad_output_dtype);
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, grad_output_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream);
});
const auto temp_torch = temp_py.cast<at::Tensor>();
at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0});
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
break;
}
case Impl::FUSED_DACT_AMAX_NVFP4:
// Fused dact-amax kernel, unfused dbias and NVFP4 quantize
{
auto *nvfp4_quantizer_cpp =
static_cast<NVFP4Quantizer *>(quantizer_cpp.get()); // Already checked cast is valid
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr,
"Invalid quantizer for fused dact-amax kernel impl");
auto [temp_nvte, temp_py] = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(
grad_input_nvte, grad_output_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream);
});
const auto temp_torch = temp_py.cast<at::Tensor>();
at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0});
quantizer_cpp_cs->quantize_with_amax(temp_nvte, grad_input_nvte);
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
break;
}
default:
......
......@@ -37,7 +37,18 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
// Convert input tensor to C++ object
auto input_contiguous = tensor.contiguous();
const auto input_cpp = makeTransformerEngineTensor(input_contiguous);
auto input_cpp = makeTransformerEngineTensor(input_contiguous);
// Set amax if use_existing_amax = true (only valid for CS)
bool use_existing_amax = false;
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
use_existing_amax = quantizer.attr("use_existing_amax").cast<bool>();
if (use_existing_amax) {
const at::Tensor &amax = quantizer.attr("amax").cast<at::Tensor>();
input_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax));
}
}
// Initialize output tensor
TensorWrapper output_cpp;
......@@ -57,7 +68,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
}
// Perform quantization
if (use_existing_amax) {
auto *quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
quantizer_cs->quantize_with_amax(input_cpp, output_cpp, noop_flag_cpp);
} else {
quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp);
}
return output_py;
}
......@@ -298,7 +314,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
// Construct FP8 block-wise tensors
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject *>(Float8BlockwiseQTensorBasePythonClass));
reinterpret_cast<PyObject *>(Float8BlockwiseQTensorStoragePythonClass));
for (size_t i = 0; i < num_tensors; ++i) {
// Create tensor objects with proper reference counting
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
......@@ -445,7 +461,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
}
// Construct mxfp8 tensors
py::handle MXFP8TensorClass(reinterpret_cast<PyObject *>(MXFP8TensorBasePythonClass));
py::handle MXFP8TensorClass(reinterpret_cast<PyObject *>(MXFP8TensorStoragePythonClass));
for (size_t i = 0; i < num_tensors; ++i) {
// Create tensor objects with proper reference counting
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
......
......@@ -106,6 +106,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
const bool low_precision =
detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype());
const bool fp8_block_scaling = A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D ||
A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D ||
B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D ||
B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D;
// Check tensor dimensions
const auto& A_shape = A_tensor.shape();
......@@ -215,6 +219,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
// Construct GEMM config
transformer_engine::MatmulConfigWrapper config;
if (grad) {
config.set_dbias_tensor(bias_tensor.data());
config.set_with_dgelu_epilogue(gelu);
} else {
config.set_bias_tensor(bias_tensor.data());
config.set_with_gelu_epilogue(gelu);
}
config.set_epilogue_aux_tensor(te_pre_gelu_out.data());
config.set_use_split_accumulator(use_split_accumulator);
config.set_sm_count(num_math_sms);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto main_stream = at::cuda::getCurrentCUDAStream();
......@@ -224,6 +241,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
swizzled_scale_inverses_list.emplace_back(
std::move(swizzle_scaling_factors(B_tensor, !transb)));
// Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer
// as it is not natively supported by cublasLt
if (fp8_block_scaling && transformer_engine::cuda::sm_arch() >= 100) {
// Convert tensors to mxfp8 and swizzle their scaling factors
swizzled_scale_inverses_list.emplace_back(
std::move(convert_block_scaling_to_mxfp8_tensor(A_tensor, transa)));
swizzled_scale_inverses_list.emplace_back(
std::move(convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb)));
// Use TN GEMM to avoid having to transpose data.
transa = true;
transb = false;
}
if (comm_overlap) {
// Prepare extra output tensor
TensorWrapper extra_output_tensor;
......@@ -278,10 +308,9 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
} else {
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(),
bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad,
te_workspace.data(), alpha, *beta, use_split_accumulator,
num_math_sms, main_stream);
nvte_cublas_gemm_v2(transa, transb, &alpha, A_tensor.data(), B_tensor.data(), &beta.value(),
out_tensor.data(), out_tensor.data(), te_workspace.data(), config,
main_stream);
});
}
} else {
......@@ -369,15 +398,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
std::vector<at::Tensor> bias, DType bias_type, bool single_output,
std::vector<at::Tensor> pre_gelu_out, bool grad, std::vector<at::Tensor> workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) {
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> te_A_wrappers, te_B_wrappers, wrappers;
std::vector<at::Tensor> D_vectors;
auto none = py::none();
std::vector<size_t> single_output_begins;
std::vector<size_t> single_output_ends;
if (single_output && D == std::nullopt) {
NVTE_ERROR("not implemented, D should be allocated for single output case.");
}
......@@ -387,6 +407,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
output_data_ptr = (*D)[0].data_ptr();
}
const auto none = py::none();
std::vector<TensorWrapper> te_A_wrappers, te_B_wrappers, te_D_wrappers, te_bias_wrappers,
te_pre_gelu_out_wrappers;
std::vector<at::Tensor> D_vectors;
for (size_t i = 0; i < A.size(); i++) {
auto te_A = makeTransformerEngineTensor(A[i], none);
auto te_B = makeTransformerEngineTensor(B[i], none);
......@@ -452,29 +476,72 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_pre_gelu_out =
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type);
te_A_vector.emplace_back(te_A.data());
te_B_vector.emplace_back(te_B.data());
te_D_vector.emplace_back(te_D.data());
te_bias_vector.emplace_back(te_bias.data());
te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data());
te_A_wrappers.emplace_back(std::move(te_A));
te_B_wrappers.emplace_back(std::move(te_B));
wrappers.emplace_back(std::move(te_D));
wrappers.emplace_back(std::move(te_bias));
wrappers.emplace_back(std::move(te_pre_gelu_out));
te_D_wrappers.emplace_back(std::move(te_D));
te_bias_wrappers.emplace_back(std::move(te_bias));
te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out));
}
// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
// Optionally swizzle the scaling factors
// Keep the swizzled scaling factor tensors alive during the GEMMs.
auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa);
auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb);
swizzled_scale_inverses_list.emplace_back(
multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa));
swizzled_scale_inverses_list.emplace_back(
multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb));
// Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer
// as it is not natively supported by cublasLt
if (transformer_engine::cuda::sm_arch() >= 100) {
// Check if is using FP8 block scaling
bool exists_tensor_using_fp8_block_scaling = false;
bool exists_tensor_not_using_fp8_block_scaling = false;
for (const auto& tensor_wrappers : {&te_A_wrappers, &te_B_wrappers}) {
for (const TensorWrapper& tensor : *tensor_wrappers) {
const NVTEScalingMode scaling_mode = tensor.scaling_mode();
if (scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D)
exists_tensor_using_fp8_block_scaling = true;
else
exists_tensor_not_using_fp8_block_scaling = true;
}
}
if (exists_tensor_using_fp8_block_scaling) {
NVTE_CHECK(!exists_tensor_not_using_fp8_block_scaling,
"Either all tensors or no tensor must be FP8 block scaling tensors");
// Convert tensors to mxfp8 and swizzle their scaling factors
for (TensorWrapper& A_tensor : te_A_wrappers) {
swizzled_scale_inverses_list.emplace_back(
convert_block_scaling_to_mxfp8_tensor(A_tensor, transa));
}
for (TensorWrapper& B_tensor : te_B_wrappers) {
swizzled_scale_inverses_list.emplace_back(
convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb));
}
// Use TN GEMM to avoid having to transpose data.
transa = true;
transb = false;
}
}
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector;
for (size_t i = 0; i < te_A_wrappers.size(); i++) {
te_A_vector.emplace_back(te_A_wrappers[i].data());
te_B_vector.emplace_back(te_B_wrappers[i].data());
te_D_vector.emplace_back(te_D_wrappers[i].data());
te_bias_vector.emplace_back(te_bias_wrappers[i].data());
te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out_wrappers[i].data());
}
std::vector<NVTETensor> te_workspace_vector;
std::vector<TensorWrapper> te_workspace_wrappers;
for (size_t i = 0; i < workspace.size(); i++) {
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte);
te_workspace_vector.emplace_back(wsp.data());
wrappers.emplace_back(std::move(wsp));
te_workspace_wrappers.emplace_back(std::move(wsp));
}
// For now, we only have multi-stream cublas backend.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment