Unverified Commit 5fafeb0e authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] FP8 MHA with RoPE and Miscellaneous Improvements (#1100)



* fp8 mha with rope
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* avoid index select in cast ops
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* avoid index select in fused_attn_fwd
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* rename is_first_module_in_mha to fp8_output
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* resolve comments
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* resolve comments
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* move transpose to backward for fp8 input
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix ut
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* resolve comments
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update argument list for CP
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix for FA3
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove unnecessary copy of scale_inv
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* skip fp8 dpa/mha tests when fa3 is not available
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix a merge bug
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 247850e8
...@@ -1344,19 +1344,22 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): ...@@ -1344,19 +1344,22 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False]) @pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("RoPE", [True, False])
@pytest.mark.parametrize("is_training", [True, False]) @pytest.mark.parametrize("is_training", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, is_training): def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training):
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
if _flash_attn_3_plus and not is_training: if _flash_attn_3_plus and not is_training:
if RoPE:
pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.")
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, is_training dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
) )
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
...@@ -1364,12 +1367,12 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1364,12 +1367,12 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, is_training dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
) )
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, is_training dtype, config, False, qkv_format, input_layernorm, RoPE, is_training
) )
atol = 5e-1 atol = 5e-1
...@@ -1410,7 +1413,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1410,7 +1413,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
) )
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_training): def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training):
reset_rng_states() reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
...@@ -1429,6 +1432,10 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_ ...@@ -1429,6 +1432,10 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_
) )
with fp8_model_init(enabled=fp8_mha): with fp8_model_init(enabled=fp8_mha):
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
mha = MultiheadAttention( mha = MultiheadAttention(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
num_attention_heads=config.num_heads, num_attention_heads=config.num_heads,
...@@ -1489,6 +1496,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_ ...@@ -1489,6 +1496,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_
checkpoint_core_attention=False, checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None, is_first_microbatch=None,
rotary_pos_emb=rotary_pos_emb,
) )
if is_training: if is_training:
out.backward(out_grad) out.backward(out_grad)
...@@ -1977,12 +1985,18 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -1977,12 +1985,18 @@ class _custom_mha_fp8(torch.autograd.Function):
None, None,
None, None,
None, None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv
fp8_meta["scaling_fwd"].scale_inv[META_S], META_QKV, # d_scale_qkv_offset
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale_inv, # d_scale_s
fp8_meta["scaling_fwd"].scale[META_O], META_S, # d_scale_s_offset
fp8_meta["scaling_fwd"].amax_history[0][META_S], fp8_meta["scaling_fwd"].scale, # q_scale_s
fp8_meta["scaling_fwd"].amax_history[0][META_O], META_S, # q_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_o
META_O, # q_scale_o_offset
fp8_meta["scaling_fwd"].amax_history, # amax_s
META_S, # amax_s_offset
fp8_meta["scaling_fwd"].amax_history, # amax_o
META_O, # amax_o_offset
attn_scale=None, attn_scale=None,
dropout=p_dropout, dropout=p_dropout,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
......
...@@ -38,8 +38,20 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -38,8 +38,20 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
FusedAttnBackend, FusedAttnBackend,
META_QKV,
META_DQKV,
META_O,
META_DO,
META_S,
META_DP,
META_O_CP,
META_DQKV_CP,
)
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
get_fp8_te_dtype,
get_fp8_torch_dtype,
) )
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype, get_fp8_torch_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
...@@ -120,15 +132,6 @@ if _flash_attn_version >= _flash_attn_version_required: ...@@ -120,15 +132,6 @@ if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
# repurpose some unused amax history buffers for partial results of CP fwd and bwd
META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT
META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
...@@ -1546,10 +1549,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1546,10 +1549,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
for x in [k_f16, v_f16] for x in [k_f16, v_f16]
] ]
fp8_meta_kwargs = {} fp8_meta_kwargs = {}
fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv[META_QKV] fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv
fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv[META_S] fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV
fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale[META_S] fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv
fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale[META_O_CP] fp8_meta_kwargs["d_scale_s_offset"] = META_S
fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale
fp8_meta_kwargs["q_scale_s_offset"] = META_S
fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale
fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP
amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
else: else:
assert False, "FP8 is only supported with Fused Attention!" assert False, "FP8 is only supported with Fused Attention!"
...@@ -1601,8 +1608,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1601,8 +1608,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
if fp8 and use_fused_attention: if fp8 and use_fused_attention:
fp8_meta_kwargs["amax_s"] = amax_per_step[0][i] fp8_meta_kwargs["amax_s"] = amax_per_step
fp8_meta_kwargs["amax_o"] = amax_per_step[1][i] fp8_meta_kwargs["amax_s_offset"] = i
fp8_meta_kwargs["amax_o"] = amax_per_step
fp8_meta_kwargs["amax_o_offset"] = cp_size + i
if causal: if causal:
if i == 0: if i == 0:
if pad_between_seqs_q: if pad_between_seqs_q:
...@@ -4153,9 +4162,8 @@ def get_qkv_layout( ...@@ -4153,9 +4162,8 @@ def get_qkv_layout(
stride = q.stride() stride = q.stride()
check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
stride = k.stride() check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple(
check_strides_kv = torch.equal( sv / v.shape[-1] for sv in v.stride()[:-1]
torch.Tensor(stride[:-1]) / k.shape[-1], torch.Tensor(v.stride()[:-1]) / v.shape[-1]
) )
shape = q.shape shape = q.shape
...@@ -4635,19 +4643,20 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -4635,19 +4643,20 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_meta, fp8_meta,
deterministic, deterministic,
): ):
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8: if fp8:
if fp8_meta["recipe"].fp8_mha: is_input_fp8 = isinstance(qkv, Float8Tensor)
assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA." if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
fused_attention_backend = FusedAttnBackend["FP8"] fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
# 1: qkv packed, 2: kv packed, 3: qkv separate # 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split("_")) qkv_group = len(qkv_layout.split("_"))
assert qkv_group == 1, ( assert (
"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found" qkv_group == 1
f" {qkv_layout}." ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}."
) if is_input_fp8:
if fp8_meta["recipe"].fp8_mha:
qkv_fp8 = qkv._data qkv_fp8 = qkv._data
else: else:
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
...@@ -4663,12 +4672,18 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -4663,12 +4672,18 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fused_attention_backend, fused_attention_backend,
attn_bias, attn_bias,
cu_seqlens_padded, cu_seqlens_padded,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv
fp8_meta["scaling_fwd"].scale_inv[META_S], META_QKV, # d_scale_qkv_offset
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale_inv, # d_scale_s
fp8_meta["scaling_fwd"].scale[META_O], META_S, # d_scale_s_offset
fp8_meta["scaling_fwd"].amax_history[0][META_S], fp8_meta["scaling_fwd"].scale, # q_scale_s
fp8_meta["scaling_fwd"].amax_history[0][META_O], META_S, # q_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_o
META_O, # q_scale_o_offset
fp8_meta["scaling_fwd"].amax_history, # amax_s
META_S, # amax_s_offset
fp8_meta["scaling_fwd"].amax_history, # amax_o
META_O, # amax_o_offset
attn_scale, attn_scale,
dropout_p, dropout_p,
fast_zero_fill, fast_zero_fill,
...@@ -4678,7 +4693,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -4678,7 +4693,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
window_size, window_size,
rng_gen, rng_gen,
) )
if fp8_meta["recipe"].fp8_mha: if is_output_fp8:
out_ret = Float8Tensor( out_ret = Float8Tensor(
data=out_fp8, data=out_fp8,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
...@@ -4696,7 +4711,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -4696,7 +4711,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkv_dtype, qkv_dtype,
).view(out_fp8.shape) ).view(out_fp8.shape)
out_save = out_ret out_save = out_ret
if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
if is_input_fp8:
qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1])
qkv = cast_from_fp8( qkv = cast_from_fp8(
qkv_c._data, qkv_c._data,
...@@ -4705,6 +4721,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -4705,6 +4721,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
TE_DType[qkv.dtype], TE_DType[qkv.dtype],
).view(qkv.shape) ).view(qkv.shape)
if is_output_fp8:
out_save = cast_from_fp8( out_save = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -4728,12 +4745,18 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -4728,12 +4745,18 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fused_attention_backend, fused_attention_backend,
attn_bias, attn_bias,
cu_seqlens_padded, cu_seqlens_padded,
None, None, # d_scale_qkv
None, 0, # d_scale_qkv_offset
None, None, # d_scale_s
None, 0, # d_scale_s_offset
None, None, # q_scale_s
None, 0, # q_scale_s_offset
None, # q_scale_o
0, # q_scale_o_offset
None, # amax_s
0, # amax_s_offset
None, # amax_o
0, # amax_o_offset
attn_scale, attn_scale,
dropout_p, dropout_p,
fast_zero_fill, fast_zero_fill,
...@@ -4747,6 +4770,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -4747,6 +4770,8 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
out_save = out_ret out_save = out_ret
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
ctx.save_for_backward( ctx.save_for_backward(
*qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors
...@@ -4771,7 +4796,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -4771,7 +4796,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.is_output_fp8:
assert isinstance( assert isinstance(
d_out, Float8Tensor d_out, Float8Tensor
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
...@@ -4828,7 +4853,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -4828,7 +4853,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.is_output_fp8:
d_out_fp8 = d_out d_out_fp8 = d_out
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
else: else:
...@@ -4868,7 +4893,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -4868,7 +4893,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.window_size, ctx.window_size,
ctx.deterministic, ctx.deterministic,
) )
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.is_input_fp8:
dqkv = Float8Tensor( dqkv = Float8Tensor(
data=dqkv_fp8, data=dqkv_fp8,
fp8_meta=ctx.fp8_meta, fp8_meta=ctx.fp8_meta,
...@@ -5006,22 +5031,23 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -5006,22 +5031,23 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_meta, fp8_meta,
deterministic, deterministic,
): ):
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8: if fp8:
if fp8_meta["recipe"].fp8_mha: assert isinstance(kv, q.__class__), "q and kv must have the same type."
assert isinstance(q, Float8Tensor) and isinstance( is_input_fp8 = isinstance(q, Float8Tensor)
kv, Float8Tensor if is_input_fp8:
), "q/kv must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
fused_attention_backend = FusedAttnBackend["FP8"] fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha: if is_input_fp8:
q_fp8, kv_fp8 = q._data, kv._data q_fp8, kv_fp8 = q._data, kv._data
else: else:
# 1: qkv packed, 2: kv packed, 3: qkv separate # 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split("_")) qkv_group = len(qkv_layout.split("_"))
assert qkv_group == 2, ( assert qkv_group == 2, (
"qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, "
f" but found {qkv_layout}." f"but found {qkv_layout}."
) )
q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view( q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view(
q.shape q.shape
...@@ -5043,12 +5069,18 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -5043,12 +5069,18 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
attn_bias, attn_bias,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv
fp8_meta["scaling_fwd"].scale_inv[META_S], META_QKV, # d_scale_qkv_offset
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale_inv, # d_scale_s
fp8_meta["scaling_fwd"].scale[META_O], META_S, # d_scale_s_offset
fp8_meta["scaling_fwd"].amax_history[0][META_S], fp8_meta["scaling_fwd"].scale, # q_scale_s
fp8_meta["scaling_fwd"].amax_history[0][META_O], META_S, # q_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_o
META_O, # q_scale_o_offset
fp8_meta["scaling_fwd"].amax_history, # amax_s
META_S, # amax_s_offset
fp8_meta["scaling_fwd"].amax_history, # amax_o
META_O, # amax_o_offset
attn_scale, attn_scale,
dropout_p, dropout_p,
fast_zero_fill, fast_zero_fill,
...@@ -5058,7 +5090,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -5058,7 +5090,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
window_size, window_size,
rng_gen, rng_gen,
) )
if fp8_meta["recipe"].fp8_mha: if is_output_fp8:
out_ret = Float8Tensor( out_ret = Float8Tensor(
data=out_fp8, data=out_fp8,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
...@@ -5076,9 +5108,14 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -5076,9 +5108,14 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
qkv_dtype, qkv_dtype,
).view(out_fp8.shape) ).view(out_fp8.shape)
out_save = out_ret out_save = out_ret
if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
if is_input_fp8:
q = cast_from_fp8( q = cast_from_fp8(
q._data, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, TE_DType[q.dtype] q._data,
fp8_meta["scaling_fwd"],
META_QKV,
fp8_dtype_forward,
TE_DType[q.dtype],
).view(q.shape) ).view(q.shape)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv = cast_from_fp8( kv = cast_from_fp8(
...@@ -5088,6 +5125,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -5088,6 +5125,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
TE_DType[kv.dtype], TE_DType[kv.dtype],
).view(kv.shape) ).view(kv.shape)
if is_output_fp8:
out_save = cast_from_fp8( out_save = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -5116,12 +5154,18 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -5116,12 +5154,18 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
attn_bias, attn_bias,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
None, None, # d_scale_qkv
None, 0, # d_scale_qkv_offset
None, None, # d_scale_s
None, 0, # d_scale_s_offset
None, None, # q_scale_s
None, 0, # q_scale_s_offset
None, # q_scale_o
0, # q_scale_o_offset
None, # amax_s
0, # amax_s_offset
None, # amax_o
0, # amax_o_offset
attn_scale, attn_scale,
dropout_p, dropout_p,
fast_zero_fill, fast_zero_fill,
...@@ -5135,6 +5179,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -5135,6 +5179,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_tensors = (None, None, None, None, None) fp8_tensors = (None, None, None, None, None)
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
ctx.save_for_backward( ctx.save_for_backward(
*qkvo_tensors, *qkvo_tensors,
...@@ -5166,7 +5212,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -5166,7 +5212,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.is_output_fp8:
assert isinstance( assert isinstance(
d_out, Float8Tensor d_out, Float8Tensor
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
...@@ -5227,7 +5273,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -5227,7 +5273,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.is_output_fp8:
d_out_fp8 = d_out d_out_fp8 = d_out
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
else: else:
...@@ -5271,7 +5317,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -5271,7 +5317,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.window_size, ctx.window_size,
ctx.deterministic, ctx.deterministic,
) )
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.is_input_fp8:
dq = Float8Tensor( dq = Float8Tensor(
data=dq_fp8, data=dq_fp8,
fp8_meta=ctx.fp8_meta, fp8_meta=ctx.fp8_meta,
...@@ -5437,15 +5483,16 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -5437,15 +5483,16 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_meta, fp8_meta,
deterministic, deterministic,
): ):
is_input_fp8 = False
is_output_fp8 = fp8_meta["recipe"].fp8_mha
if fp8: if fp8:
fused_attention_backend = FusedAttnBackend["FP8"] fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha: assert isinstance(k, q.__class__) and isinstance(
assert ( v, q.__class__
isinstance(q, Float8Tensor) ), "q, k, and v must have the same type."
and isinstance(k, Float8Tensor) is_input_fp8 = isinstance(q, Float8Tensor)
and isinstance(v, Float8Tensor) if is_input_fp8:
), "q/k/v must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data
else: else:
...@@ -5496,12 +5543,18 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -5496,12 +5543,18 @@ class FusedAttnFunc(torch.autograd.Function):
attn_bias, attn_bias,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv
fp8_meta["scaling_fwd"].scale_inv[META_S], META_QKV, # d_scale_qkv_offset
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale_inv, # d_scale_s
fp8_meta["scaling_fwd"].scale[META_O], META_S, # d_scale_s_offset
fp8_meta["scaling_fwd"].amax_history[0][META_S], fp8_meta["scaling_fwd"].scale, # q_scale_s
fp8_meta["scaling_fwd"].amax_history[0][META_O], META_S, # q_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_o
META_O, # q_scale_o_offset
fp8_meta["scaling_fwd"].amax_history, # amax_s
META_S, # amax_s_offset
fp8_meta["scaling_fwd"].amax_history, # amax_o
META_O, # amax_o_offset
attn_scale, attn_scale,
dropout_p, dropout_p,
fast_zero_fill, fast_zero_fill,
...@@ -5511,7 +5564,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -5511,7 +5564,7 @@ class FusedAttnFunc(torch.autograd.Function):
window_size, window_size,
rng_gen, rng_gen,
) )
if fp8_meta["recipe"].fp8_mha: if is_output_fp8:
out_ret = Float8Tensor( out_ret = Float8Tensor(
data=out_fp8, data=out_fp8,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
...@@ -5530,8 +5583,9 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -5530,8 +5583,9 @@ class FusedAttnFunc(torch.autograd.Function):
).view(out_fp8.shape) ).view(out_fp8.shape)
out_save = out_ret out_save = out_ret
if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
# 1: qkv packed, 2: kv packed, 3: qkv separate # 1: qkv packed, 2: kv packed, 3: qkv separate
if is_input_fp8:
qkv_group = len(qkv_layout.split("_")) qkv_group = len(qkv_layout.split("_"))
if qkv_group == 1: if qkv_group == 1:
dim = qkv_layout.find("3") dim = qkv_layout.find("3")
...@@ -5588,6 +5642,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -5588,6 +5642,7 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
TE_DType[v.dtype], TE_DType[v.dtype],
).view(v.shape) ).view(v.shape)
if is_output_fp8:
out_save = cast_from_fp8( out_save = cast_from_fp8(
out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]),
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -5619,12 +5674,18 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -5619,12 +5674,18 @@ class FusedAttnFunc(torch.autograd.Function):
attn_bias, attn_bias,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
None, None, # d_scale_qkv
None, 0, # d_scale_qkv_offset
None, None, # d_scale_s
None, 0, # d_scale_s_offset
None, None, # q_scale_s
None, 0, # q_scale_s_offset
None, # q_scale_o
0, # q_scale_o_offset
None, # amax_s
0, # amax_s_offset
None, # amax_o
0, # amax_o_offset
attn_scale, attn_scale,
dropout_p, dropout_p,
fast_zero_fill, fast_zero_fill,
...@@ -5647,6 +5708,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -5647,6 +5708,8 @@ class FusedAttnFunc(torch.autograd.Function):
tensor.activation_offloading = True tensor.activation_offloading = True
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
ctx.save_for_backward( ctx.save_for_backward(
*qkvo_tensors, *qkvo_tensors,
...@@ -5678,7 +5741,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -5678,7 +5741,7 @@ class FusedAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.is_output_fp8:
assert isinstance( assert isinstance(
d_out, Float8Tensor d_out, Float8Tensor
), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA."
...@@ -5743,7 +5806,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -5743,7 +5806,7 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.is_output_fp8:
d_out_fp8 = d_out d_out_fp8 = d_out
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv
else: else:
...@@ -5789,7 +5852,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -5789,7 +5852,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.deterministic, ctx.deterministic,
) )
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.is_input_fp8:
dq = Float8Tensor( dq = Float8Tensor(
data=dq_fp8, data=dq_fp8,
fp8_meta=ctx.fp8_meta, fp8_meta=ctx.fp8_meta,
...@@ -7719,12 +7782,18 @@ class MultiheadAttention(torch.nn.Module): ...@@ -7719,12 +7782,18 @@ class MultiheadAttention(torch.nn.Module):
# Query, Key, and Value # Query, Key, and Value
# ====================== # ======================
fp8_mha = (
FP8GlobalStateManager.is_fp8_enabled()
and FP8GlobalStateManager.get_fp8_recipe().fp8_mha
)
if self.attention_type == "self": if self.attention_type == "self":
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
if self.input_layernorm: if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv( layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None,
) )
if self.return_layernorm_output: if self.return_layernorm_output:
mixed_x_layer, layernorm_output = layernorm_qkv_outputs mixed_x_layer, layernorm_output = layernorm_qkv_outputs
...@@ -7734,7 +7803,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -7734,7 +7803,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_x_layer = self.qkv( mixed_x_layer = self.qkv(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=True, # specific to FP8 MHA fp8_output=fp8_mha and rotary_pos_emb is None,
) )
num_queries_per_key_value = ( num_queries_per_key_value = (
...@@ -7795,7 +7864,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -7795,7 +7864,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_kv_layer = self.key_value( mixed_kv_layer = self.key_value(
encoder_output, encoder_output,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=True, # specific to FP8 MHA fp8_output=fp8_mha and rotary_pos_emb is None,
) )
if self.qkv_weight_interleaved: if self.qkv_weight_interleaved:
...@@ -7845,6 +7914,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -7845,6 +7914,7 @@ class MultiheadAttention(torch.nn.Module):
layernorm_query_outputs = self.layernorm_query( layernorm_query_outputs = self.layernorm_query(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None,
) )
if self.return_layernorm_output: if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs query_layer, layernorm_output = layernorm_query_outputs
...@@ -7854,7 +7924,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -7854,7 +7924,7 @@ class MultiheadAttention(torch.nn.Module):
query_layer = self.query_layer( query_layer = self.query_layer(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
is_first_module_in_mha=True, # specific to FP8 MHA fp8_output=fp8_mha and rotary_pos_emb is None,
) )
# [sq, b, hp] --> [sq, b, np, hn] # [sq, b, hp] --> [sq, b, np, hn]
......
...@@ -78,6 +78,16 @@ FusedAttnBackend = { ...@@ -78,6 +78,16 @@ FusedAttnBackend = {
BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 BACKEND_F16m512_FP8_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16 BACKEND_F16arb_ELTS_PER_THREADS = 16
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
# repurpose some unused amax history buffers for partial results of CP fwd and bwd
META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT
META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1
def fused_attn_fwd_qkvpacked( def fused_attn_fwd_qkvpacked(
is_training: bool, is_training: bool,
...@@ -89,11 +99,17 @@ def fused_attn_fwd_qkvpacked( ...@@ -89,11 +99,17 @@ def fused_attn_fwd_qkvpacked(
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
cu_seqlens_padded: torch.Tensor = None, cu_seqlens_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_qkv_offset: int = META_QKV,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_s_offset: int = META_S,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_s_offset: int = META_S,
q_scale_o: torch.Tensor = None, q_scale_o: torch.Tensor = None,
q_scale_o_offset: int = META_O,
amax_s: torch.Tensor = None, amax_s: torch.Tensor = None,
amax_s_offset: int = META_S,
amax_o: torch.Tensor = None, amax_o: torch.Tensor = None,
amax_o_offset: int = META_O,
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
...@@ -128,16 +144,28 @@ def fused_attn_fwd_qkvpacked( ...@@ -128,16 +144,28 @@ def fused_attn_fwd_qkvpacked(
cumulative sequence offsets for QKV; shape [batch_size + 1] cumulative sequence offsets for QKV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_qkv_offset: int, default = META_QKV
offset in d_scale_qkv for QKV
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
d_scale_s_offset: int, default = META_S
offset in d_scale_s for S
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_s_offset: int, default = META_S
offset in q_scale_s for S
q_scale_o: torch.Tensor, default = None q_scale_o: torch.Tensor, default = None
input tensor for the quantization of O in FP8 computations input tensor for the quantization of O in FP8 computations
q_scale_o_offset: int, default = META_O
offset in q_scale_o for O
amax_s: torch.Tensor, default = None amax_s: torch.Tensor, default = None
output tensor, amax of S, used by the next iteration in FP8 computations output tensor, amax of S, used by the next iteration in FP8 computations
amax_s_offset: int, default = META_S
offset in amax_s for S
amax_o: torch.Tensor, default = None amax_o: torch.Tensor, default = None
output tensor, amax of O, used by the next iteration in FP8 computations output tensor, amax of O, used by the next iteration in FP8 computations
amax_o_offset: int, default = META_O
offset in amax_o for O
attn_scale: float, default = None attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM; if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim_qk) as the default if None, use 1.0/sqrt(head_dim_qk) as the default
...@@ -248,11 +276,17 @@ def fused_attn_fwd_qkvpacked( ...@@ -248,11 +276,17 @@ def fused_attn_fwd_qkvpacked(
qkv_dtype, qkv_dtype,
cu_seqlens_padded, cu_seqlens_padded,
d_scale_qkv, d_scale_qkv,
d_scale_qkv_offset,
d_scale_s, d_scale_s,
d_scale_s_offset,
q_scale_s, q_scale_s,
q_scale_s_offset,
q_scale_o, q_scale_o,
q_scale_o_offset,
amax_s, amax_s,
amax_s_offset,
amax_o, amax_o,
amax_o_offset,
attn_bias, attn_bias,
rng_gen, rng_gen,
rng_elts_per_thread, rng_elts_per_thread,
...@@ -448,11 +482,17 @@ def fused_attn_fwd_kvpacked( ...@@ -448,11 +482,17 @@ def fused_attn_fwd_kvpacked(
cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_qkv_offset: int = META_QKV,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_s_offset: int = META_S,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_s_offset: int = META_S,
q_scale_o: torch.Tensor = None, q_scale_o: torch.Tensor = None,
q_scale_o_offset: int = META_O,
amax_s: torch.Tensor = None, amax_s: torch.Tensor = None,
amax_s_offset: int = META_S,
amax_o: torch.Tensor = None, amax_o: torch.Tensor = None,
amax_o_offset: int = META_O,
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
...@@ -496,16 +536,28 @@ def fused_attn_fwd_kvpacked( ...@@ -496,16 +536,28 @@ def fused_attn_fwd_kvpacked(
cumulative sequence offsets for KV; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_qkv_offset: int, default = META_QKV
offset in d_scale_qkv for QKV
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
d_scale_s_offset: int, default = META_S
offset in d_scale_s for S
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_s_offset: int, default = META_S
offset in q_scale_s for S
q_scale_o: torch.Tensor, default = None q_scale_o: torch.Tensor, default = None
input tensor for the quantization of O in FP8 computations input tensor for the quantization of O in FP8 computations
q_scale_o_offset: int, default = META_O
offset in q_scale_o for O
amax_s: torch.Tensor, default = None amax_s: torch.Tensor, default = None
output tensor, amax of S, used by the next iteration in FP8 computations output tensor, amax of S, used by the next iteration in FP8 computations
amax_s_offset: int, default = META_S
offset in amax_s for S
amax_o: torch.Tensor, default = None amax_o: torch.Tensor, default = None
output tensor, amax of O, used by the next iteration in FP8 computations output tensor, amax of O, used by the next iteration in FP8 computations
amax_o_offset: int, default = META_O
offset in amax_o for O
attn_scale: float, default = None attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM; if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim_qk) as the default if None, use 1.0/sqrt(head_dim_qk) as the default
...@@ -621,11 +673,17 @@ def fused_attn_fwd_kvpacked( ...@@ -621,11 +673,17 @@ def fused_attn_fwd_kvpacked(
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
d_scale_qkv, d_scale_qkv,
d_scale_qkv_offset,
d_scale_s, d_scale_s,
d_scale_s_offset,
q_scale_s, q_scale_s,
q_scale_s_offset,
q_scale_o, q_scale_o,
q_scale_o_offset,
amax_s, amax_s,
amax_s_offset,
amax_o, amax_o,
amax_o_offset,
attn_bias, attn_bias,
rng_gen, rng_gen,
rng_elts_per_thread, rng_elts_per_thread,
...@@ -843,11 +901,17 @@ def fused_attn_fwd( ...@@ -843,11 +901,17 @@ def fused_attn_fwd(
cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_qkv_offset: int = META_QKV,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_s_offset: int = META_S,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
q_scale_s_offset: int = META_S,
q_scale_o: torch.Tensor = None, q_scale_o: torch.Tensor = None,
q_scale_o_offset: int = META_O,
amax_s: torch.Tensor = None, amax_s: torch.Tensor = None,
amax_s_offset: int = META_S,
amax_o: torch.Tensor = None, amax_o: torch.Tensor = None,
amax_o_offset: int = META_O,
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
...@@ -894,17 +958,29 @@ def fused_attn_fwd( ...@@ -894,17 +958,29 @@ def fused_attn_fwd(
cu_seqlens_kv_padded: torch.Tensor, default = None cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_qkv_offset: int, default = META_QKV
offset in d_scale_qkv for QKV
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T)
d_scale_s_offset: int, default = META_S
offset in d_scale_s for S
q_scale_s: torch.Tensor, default = None q_scale_s: torch.Tensor, default = None
input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T)
q_scale_s_offset: int, default = META_S
offset in q_scale_s for S
q_scale_o: torch.Tensor, default = None q_scale_o: torch.Tensor, default = None
input tensor for the quantization of O in FP8 computations input tensor for the quantization of O in FP8 computations
q_scale_o_offset: int, default = META_O
offset in q_scale_o for O
amax_s: torch.Tensor, default = None amax_s: torch.Tensor, default = None
output tensor, amax of S, used by the next iteration in FP8 computations output tensor, amax of S, used by the next iteration in FP8 computations
amax_s_offset: int, default = META_S
offset in amax_s for S
amax_o: torch.Tensor, default = None amax_o: torch.Tensor, default = None
output tensor, amax of O, used by the next iteration in FP8 computations output tensor, amax of O, used by the next iteration in FP8 computations
amax_o_offset: int, default = META_O
offset in amax_o for O
attn_scale: float, default = None attn_scale: float, default = None
if not None, use attn_scale as the attention scale for Q*K.T BMM; if not None, use attn_scale as the attention scale for Q*K.T BMM;
if None, use 1.0/sqrt(head_dim_qk) as the default if None, use 1.0/sqrt(head_dim_qk) as the default
...@@ -1023,11 +1099,17 @@ def fused_attn_fwd( ...@@ -1023,11 +1099,17 @@ def fused_attn_fwd(
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
d_scale_qkv, d_scale_qkv,
d_scale_qkv_offset,
d_scale_s, d_scale_s,
d_scale_s_offset,
q_scale_s, q_scale_s,
q_scale_s_offset,
q_scale_o, q_scale_o,
q_scale_o_offset,
amax_s, amax_s,
amax_s_offset,
amax_o, amax_o,
amax_o_offset,
attn_bias, attn_bias,
rng_gen, rng_gen,
rng_elts_per_thread, rng_elts_per_thread,
......
...@@ -48,11 +48,13 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -48,11 +48,13 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, const std::vector<int64_t> window_size, const at::Tensor cu_seqlens, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_padded, const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const int descale_QKV_offset,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> descale_S, const int descale_S_offset,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> scale_S, const int scale_S_offset,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen, const c10::optional<at::Tensor> scale_O, const int scale_O_offset,
size_t rng_elts_per_thread); c10::optional<at::Tensor> amax_S, const int amax_S_offset, c10::optional<at::Tensor> amax_O,
const int amax_O_offset, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked( std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout,
...@@ -75,11 +77,13 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -75,11 +77,13 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
const at::Tensor KV, const transformer_engine::DType qkv_type, const at::Tensor KV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const int descale_QKV_offset,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> descale_S, const int descale_S_offset,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> scale_S, const int scale_S_offset,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen, const c10::optional<at::Tensor> scale_O, const int scale_O_offset,
size_t rng_elts_per_thread); c10::optional<at::Tensor> amax_S, const int amax_S_offset, c10::optional<at::Tensor> amax_O,
const int amax_O_offset, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_kvpacked( std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
...@@ -104,11 +108,13 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -104,11 +108,13 @@ std::vector<at::Tensor> fused_attn_fwd(
const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const int descale_QKV_offset,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> descale_S, const int descale_S_offset,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> scale_S, const int scale_S_offset,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen, const c10::optional<at::Tensor> scale_O, const int scale_O_offset,
size_t rng_elts_per_thread); c10::optional<at::Tensor> amax_S, const int amax_S_offset, c10::optional<at::Tensor> amax_O,
const int amax_O_offset, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd( std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
...@@ -335,13 +341,18 @@ at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, fl ...@@ -335,13 +341,18 @@ at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, fl
**************************************************************************************************/ **************************************************************************************************/
at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax,
at::Tensor scale_inv, transformer_engine::DType otype); at::Tensor scale_inv, transformer_engine::DType otype,
const int scale_offset = 0, const int amax_offset = 0,
const int scale_inv_offset = 0);
void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output,
at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype); at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype,
const int scale_offset = 0, const int amax_offset = 0,
const int scale_inv_offset = 0);
at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv,
transformer_engine::DType itype, transformer_engine::DType otype); transformer_engine::DType itype, transformer_engine::DType otype,
const int scale_inv_offset = 0);
/*************************************************************************************************** /***************************************************************************************************
* Softmax * Softmax
......
...@@ -83,11 +83,13 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -83,11 +83,13 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, const std::vector<int64_t> window_size, const at::Tensor cu_seqlens, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_padded, const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const int descale_QKV_offset,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> descale_S, const int descale_S_offset,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> scale_S, const int scale_S_offset,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen, const c10::optional<at::Tensor> scale_O, const int scale_O_offset,
size_t rng_elts_per_thread) { c10::optional<at::Tensor> amax_S, const int amax_S_offset, c10::optional<at::Tensor> amax_O,
const int amax_O_offset, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
using namespace transformer_engine; using namespace transformer_engine;
auto qkv_sizes = QKV.sizes().vec(); auto qkv_sizes = QKV.sizes().vec();
...@@ -122,11 +124,14 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -122,11 +124,14 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
} }
te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr,
descale_QKV.value().data_ptr()); getDataPtr(descale_QKV.value(), descale_QKV_offset));
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
scale_S.value().data_ptr(), descale_S.value().data_ptr()); getDataPtr(amax_S.value(), amax_S_offset),
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), getDataPtr(scale_S.value(), scale_S_offset),
scale_O.value().data_ptr(), nullptr); getDataPtr(descale_S.value(), descale_S_offset));
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type,
getDataPtr(amax_O.value(), amax_O_offset),
getDataPtr(scale_O.value(), scale_O_offset), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
O.fill_(0); O.fill_(0);
...@@ -393,11 +398,13 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -393,11 +398,13 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
const at::Tensor KV, const transformer_engine::DType qkv_type, const at::Tensor KV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const int descale_QKV_offset,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> descale_S, const int descale_S_offset,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> scale_S, const int scale_S_offset,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen, const c10::optional<at::Tensor> scale_O, const int scale_O_offset,
size_t rng_elts_per_thread) { c10::optional<at::Tensor> amax_S, const int amax_S_offset, c10::optional<at::Tensor> amax_O,
const int amax_O_offset, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
using namespace transformer_engine; using namespace transformer_engine;
auto q_sizes = Q.sizes().vec(); auto q_sizes = Q.sizes().vec();
...@@ -429,13 +436,16 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -429,13 +436,16 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
} }
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr,
descale_QKV.value().data_ptr()); getDataPtr(descale_QKV.value(), descale_QKV_offset));
te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr,
descale_QKV.value().data_ptr()); getDataPtr(descale_QKV.value(), descale_QKV_offset));
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
scale_S.value().data_ptr(), descale_S.value().data_ptr()); getDataPtr(amax_S.value(), amax_S_offset),
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), getDataPtr(scale_S.value(), scale_S_offset),
scale_O.value().data_ptr(), nullptr); getDataPtr(descale_S.value(), descale_S_offset));
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type,
getDataPtr(amax_O.value(), amax_O_offset),
getDataPtr(scale_O.value(), scale_O_offset), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
O.fill_(0); O.fill_(0);
...@@ -747,11 +757,13 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -747,11 +757,13 @@ std::vector<at::Tensor> fused_attn_fwd(
const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const int descale_QKV_offset,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> descale_S, const int descale_S_offset,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> scale_S, const int scale_S_offset,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen, const c10::optional<at::Tensor> scale_O, const int scale_O_offset,
size_t rng_elts_per_thread) { c10::optional<at::Tensor> amax_S, const int amax_S_offset, c10::optional<at::Tensor> amax_O,
const int amax_O_offset, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
using namespace transformer_engine; using namespace transformer_engine;
auto q_sizes = Q.sizes().vec(); auto q_sizes = Q.sizes().vec();
...@@ -788,15 +800,18 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -788,15 +800,18 @@ std::vector<at::Tensor> fused_attn_fwd(
NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n"));
} }
te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr,
descale_QKV.value().data_ptr()); getDataPtr(descale_QKV.value(), descale_QKV_offset));
te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr,
descale_QKV.value().data_ptr()); getDataPtr(descale_QKV.value(), descale_QKV_offset));
te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr,
descale_QKV.value().data_ptr()); getDataPtr(descale_QKV.value(), descale_QKV_offset));
te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32,
scale_S.value().data_ptr(), descale_S.value().data_ptr()); getDataPtr(amax_S.value(), amax_S_offset),
te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, amax_O.value().data_ptr(), getDataPtr(scale_S.value(), scale_S_offset),
scale_O.value().data_ptr(), nullptr); getDataPtr(descale_S.value(), descale_S_offset));
te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type,
getDataPtr(amax_O.value(), amax_O_offset),
getDataPtr(scale_O.value(), scale_O_offset), nullptr);
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
O.fill_(0); O.fill_(0);
......
...@@ -6,8 +6,9 @@ ...@@ -6,8 +6,9 @@
#include "extensions.h" #include "extensions.h"
at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, at::Tensor cast_to_fp8(const at::Tensor& input, const at::Tensor& scale, at::Tensor amax,
at::Tensor scale_inv, transformer_engine::DType otype) { at::Tensor scale_inv, transformer_engine::DType otype,
const int scale_offset, const int amax_offset, const int scale_inv_offset) {
using namespace transformer_engine; using namespace transformer_engine;
auto input_shape = input.sizes().vec(); auto input_shape = input.sizes().vec();
std::vector<size_t> shape{input_shape.begin(), input_shape.end()}; std::vector<size_t> shape{input_shape.begin(), input_shape.end()};
...@@ -16,32 +17,45 @@ at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Ten ...@@ -16,32 +17,45 @@ at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Ten
if (input.numel() == 0) return output; if (input.numel() == 0) return output;
// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
auto input_cu = makeTransformerEngineTensor(input); auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax.data_ptr(), auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax_dptr,
scale.data_ptr(), scale_inv.data_ptr()); scale_dptr, scale_inv_dptr);
nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output; return output;
} }
void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, void cast_to_fp8_noalloc(const at::Tensor& input, const at::Tensor& scale, at::Tensor output,
at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype) { at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype,
const int scale_offset, const int amax_offset,
const int scale_inv_offset) {
using namespace transformer_engine; using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0)); size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1)); size_t H = static_cast<size_t>(input.size(1));
// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
auto input_cu = makeTransformerEngineTensor(input); auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax.data_ptr(), auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax_dptr,
scale.data_ptr(), scale_inv.data_ptr()); scale_dptr, scale_inv_dptr);
nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return; return;
} }
at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, at::Tensor cast_from_fp8(const at::Tensor& input, const at::Tensor& scale_inv,
transformer_engine::DType itype, transformer_engine::DType otype) { transformer_engine::DType itype, transformer_engine::DType otype,
const int scale_inv_offset) {
using namespace transformer_engine; using namespace transformer_engine;
auto input_shape = input.sizes().vec(); auto input_shape = input.sizes().vec();
std::vector<size_t> shape{input_shape.begin(), input_shape.end()}; std::vector<size_t> shape{input_shape.begin(), input_shape.end()};
...@@ -49,7 +63,7 @@ at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, ...@@ -49,7 +63,7 @@ at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv,
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, nullptr, nullptr, auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, nullptr, nullptr,
scale_inv.data_ptr()); getDataPtr(scale_inv, scale_inv_offset));
auto output_cu = makeTransformerEngineTensor(output); auto output_cu = makeTransformerEngineTensor(output);
nvte_fp8_dequantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); nvte_fp8_dequantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
......
...@@ -93,10 +93,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -93,10 +93,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc,
"Fused Multi-tensor Cast + Transpose with allocating output tensors", "Fused Multi-tensor Cast + Transpose with allocating output tensors",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard<py::gil_scoped_release>()); m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard<py::gil_scoped_release>(),
py::arg("input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"),
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8", m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("scale"),
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard<py::gil_scoped_release>()); py::arg("output"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"),
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard<py::gil_scoped_release>(),
py::arg("input"), py::arg("scale_inv"), py::arg("itype"), py::arg("otype"),
py::arg("scale_inv_offset") = 0);
m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think
m.def("te_grouped_gemm", &te_grouped_gemm, "Grouped GEMM"); m.def("te_grouped_gemm", &te_grouped_gemm, "Grouped GEMM");
m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked,
......
...@@ -26,7 +26,7 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input, const at::Tensor &scale, at:: ...@@ -26,7 +26,7 @@ at::Tensor cast_to_fp8_ts(const at::Tensor &input, const at::Tensor &scale, at::
at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype); transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = at::Tensor output =
cast_to_fp8(input, scale[fp8_tensor], amax[0][fp8_tensor], scale_inv[fp8_tensor], otype_arg); cast_to_fp8(input, scale, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, fp8_tensor);
return output; return output;
} }
...@@ -34,8 +34,8 @@ at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input, const at::Tensor &sca ...@@ -34,8 +34,8 @@ at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input, const at::Tensor &sca
at::Tensor output, at::Tensor amax, at::Tensor scale_inv, at::Tensor output, at::Tensor amax, at::Tensor scale_inv,
int64_t fp8_tensor, int64_t otype) { int64_t fp8_tensor, int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype); transformer_engine::DType otype_arg = reverse_map_dtype(otype);
cast_to_fp8_noalloc(input, scale[fp8_tensor], output, amax[0][fp8_tensor], scale_inv[fp8_tensor], cast_to_fp8_noalloc(input, scale, output, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor,
otype_arg); fp8_tensor);
return output; return output;
} }
...@@ -43,7 +43,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input, const at::Tensor &scale_inv ...@@ -43,7 +43,7 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input, const at::Tensor &scale_inv
int64_t fp8_tensor, int64_t itype, int64_t otype) { int64_t fp8_tensor, int64_t itype, int64_t otype) {
transformer_engine::DType itype_arg = reverse_map_dtype(itype); transformer_engine::DType itype_arg = reverse_map_dtype(itype);
transformer_engine::DType otype_arg = reverse_map_dtype(otype); transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = cast_from_fp8(input, scale_inv[fp8_tensor], itype_arg, otype_arg); at::Tensor output = cast_from_fp8(input, scale_inv, itype_arg, otype_arg, fp8_tensor);
return output; return output;
} }
......
...@@ -91,6 +91,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -91,6 +91,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_rs_dgrad: bool, ub_overlap_rs_dgrad: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
ub_name: str, ub_name: str,
fp8_output: bool,
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
...@@ -220,7 +221,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -220,7 +221,7 @@ class _LayerNormLinear(torch.autograd.Function):
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
ln_out_scale_inv.fill_(ln_out_scale_inv.item()) ln_out_scale_inv.fill_(ln_out_scale_inv.item())
if fp8_meta["recipe"].fp8_mha: if fp8_output:
out_index, meta_tensor, output_te_dtype, output_dtype = ( out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8FwdTensors.GEMM1_OUTPUT, tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -765,6 +766,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -765,6 +766,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, # ub_overlap_rs_dgrad None, # ub_overlap_rs_dgrad
None, # ub_overlap_ag None, # ub_overlap_ag
None, # ub_name None, # ub_name
None, # fp8_output
None, # fsdp_group None, # fsdp_group
) )
...@@ -1117,6 +1119,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1117,6 +1119,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self, self,
inp: torch.Tensor, inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
fp8_output: Optional[bool] = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
""" """
Apply layer normalization to the input followed by a linear transformation. Apply layer normalization to the input followed by a linear transformation.
...@@ -1244,6 +1247,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1244,6 +1247,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_overlap_rs_dgrad, self.ub_overlap_rs_dgrad,
self.ub_overlap_ag, self.ub_overlap_ag,
self.ub_name, self.ub_name,
fp8_output,
self.fsdp_group, self.fsdp_group,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -82,12 +82,10 @@ class _Linear(torch.autograd.Function): ...@@ -82,12 +82,10 @@ class _Linear(torch.autograd.Function):
ub_overlap_rs: bool, ub_overlap_rs: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
ub_name: str, ub_name: str,
is_first_module_in_mha: bool, fp8_output: bool,
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
) -> torch.Tensor: ) -> torch.Tensor:
is_input_fp8 = isinstance(inp, Float8Tensor) is_input_fp8 = isinstance(inp, Float8Tensor)
if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0]
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
...@@ -110,14 +108,6 @@ class _Linear(torch.autograd.Function): ...@@ -110,14 +108,6 @@ class _Linear(torch.autograd.Function):
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if isinstance(inputmat, Float8Tensor): if isinstance(inputmat, Float8Tensor):
inputmat_scale_inv = inputmat._scale_inv inputmat_scale_inv = inputmat._scale_inv
if (
not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled
and weight.requires_grad
and not sequence_parallel
):
# FP8 input for forward, FP8 input transpose for backward wgrad
inputmat_t = inputmat.transpose_2d()
else: else:
inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device)
if ( if (
...@@ -171,7 +161,7 @@ class _Linear(torch.autograd.Function): ...@@ -171,7 +161,7 @@ class _Linear(torch.autograd.Function):
assert isinstance(weight_fp8, Float8Tensor) assert isinstance(weight_fp8, Float8Tensor)
if is_first_module_in_mha: if fp8_output:
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
tex.FP8FwdTensors.GEMM1_OUTPUT, tex.FP8FwdTensors.GEMM1_OUTPUT,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -240,7 +230,7 @@ class _Linear(torch.autograd.Function): ...@@ -240,7 +230,7 @@ class _Linear(torch.autograd.Function):
fp8_meta_tensor=meta_tensor, fp8_meta_tensor=meta_tensor,
D_dtype=proj_out_tetype, D_dtype=proj_out_tetype,
) )
if is_first_module_in_mha: if fp8_output:
out = Float8Tensor( out = Float8Tensor(
data=out, data=out,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
...@@ -639,7 +629,7 @@ class _Linear(torch.autograd.Function): ...@@ -639,7 +629,7 @@ class _Linear(torch.autograd.Function):
None, # ub_overlap_rs None, # ub_overlap_rs
None, # ub_overlap_ag None, # ub_overlap_ag
None, # ub_name None, # ub_name
None, # is_first_module_in_mha None, # fp8_output
None, # fsdp_group None, # fsdp_group
) )
...@@ -917,7 +907,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -917,7 +907,7 @@ class Linear(TransformerEngineBaseModule):
self, self,
inp: torch.Tensor, inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
is_first_module_in_mha: Optional[bool] = False, fp8_output: Optional[bool] = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
""" """
Apply the linear transformation to the input. Apply the linear transformation to the input.
...@@ -951,8 +941,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -951,8 +941,6 @@ class Linear(TransformerEngineBaseModule):
allow_non_contiguous=isinstance(inp, Float8Tensor), allow_non_contiguous=isinstance(inp, Float8Tensor),
) as inp: ) as inp:
is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, Float8Tensor) for w in unfused_weights): if any(isinstance(w, Float8Tensor) for w in unfused_weights):
...@@ -1037,7 +1025,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1037,7 +1025,7 @@ class Linear(TransformerEngineBaseModule):
self.ub_overlap_rs, self.ub_overlap_rs,
self.ub_overlap_ag, self.ub_overlap_ag,
self.ub_name, self.ub_name,
is_first_module_in_mha, fp8_output,
self.fsdp_group, self.fsdp_group,
) )
out = linear_fn(*args) out = linear_fn(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment