Unverified Commit 26c8fcc9 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Add FP8 support to CP implementation with KV P2P (#1114)



* add window_size to AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add seq_offsets_qkvo for cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add seq_offsets_qkvo to AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets calculation of cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove a thd assert
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bias for thd test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add thd test for cudnn FA with CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* skip GQA/MQA test for cuDNN THD
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets inputs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove two comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn mask type for cudnn thd with cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type for cudnn fa with thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a typo
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix out dout in bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert cudnn+thd does not support attn bias
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* check if attn_mask_type has padding
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change cp test batch size to 2
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix code format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix two assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert swa+CP cannot work with thd format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a new CP function for swa
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a missing dgrads
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft fwd function for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* enable flash attention for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove an assert of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* call SWAFuncWithCP for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* use 2hd layout
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change qkv_format check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a code comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* tensor shape bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tensor shape fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add function to compute cu_seqlens of a cp rank
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cu_seqlens and cu_seqlens_padded to context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix FlashAttention output sequence length
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix cu_seqlens_kv_per_step calculation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV for ending padded tokens
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV tensors of FlashAttention
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix softmax_lse correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove padded tokens of KV to save comounication
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not need to zero dkv for FlashAttention any mroe
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero out tensors
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kv shape of cp test with thd format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* update cp unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add simple code framework
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* try not to have a separate CP function for SWA
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* backup some code change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* back up code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* clean up fwd implementation of SWAFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* reduce kv chunk concat overheads
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make AttnFuncWithCP and SWAFuncWithCP have same API
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a docstring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* preliminary implementation of SWAFuncWithCP forward seems working
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix output shape of SWAFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code refactoring for FlashAttention and add a code placeholder for bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* use gather_along_first_dim
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* finish the preliminary implementation of bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert condition
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft implementation of SWA+CP with FusedAttention
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attention mask type of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add qkv_layout
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add missing window_size argument
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kv shape of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug and typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix dout shape
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add multi stream in fwd of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* save chunk_ids_to_kv_ag in fwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add multi stream in bwd of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix to cp stream sync
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* rename AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* check if window size is None
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix docstring of AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add env var for users to choose KV ag or KV p2p
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* update cp tests
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix window size in cp unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix pytest skip messages
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cp_comm_type into API
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add deterministic konb in cuDNN fused attn backend
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* pass fp8 and fp8_meta to attn_func_with_cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert only Fused Attn can support FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant assert
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a fwd draft implementation of FP8 + CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* save fp8 and fp8_meta
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert sequence length divisible requirements
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove a redundant qkv_layout compute
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* if condition change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* some typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add support table of context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* typo and code format fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not print multiple disabling messages
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix aux_ctx_tensors of FP8
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix device in torch.arange and adjust code for the PR of MLA
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* commit code change for FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* commit more code change for FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* commit more fp8 code for FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fixes
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* cast merged CP results from FP32 to BF16
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix softmax_lse
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix some bugs of FP8 dkv exchange
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add FP8 unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix typos and clean asserts
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix get_p2p_comm_info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix dkv p2p exchange
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change FP8 dkv P2P to A2A
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add FP8+CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert amax reduction is needed for FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove duplicated code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* destroy process group in CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove interval from fp8_recipe because it has been deprecated
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* try to fix the failed CP test with the latest CI pipeline
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove redundant f before string
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change META_O_CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXiaowei Ren <xren@cs-cw-dfw-login-01.cm.cluster>
parent 525de6cc
...@@ -2,15 +2,18 @@ ...@@ -2,15 +2,18 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import os, sys import os, sys, logging
from contextlib import nullcontext
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank
import transformer_engine_torch as tex import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.common.recipe import DelayedScaling
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
def run_dpa_with_cp( def run_dpa_with_cp(
...@@ -57,6 +60,9 @@ def run_dpa_with_cp( ...@@ -57,6 +60,9 @@ def run_dpa_with_cp(
assert rank in cp_comm_ranks assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True)
# instantiate core attn module # instantiate core attn module
core_attn = DotProductAttention( core_attn = DotProductAttention(
config.num_heads, config.num_heads,
...@@ -171,6 +177,13 @@ def run_dpa_with_cp( ...@@ -171,6 +177,13 @@ def run_dpa_with_cp(
# run core_attn without CP # run core_attn without CP
for x in [q, k, v]: for x in [q, k, v]:
x.requires_grad = True x.requires_grad = True
if dtype == "fp8":
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
with fp8_context:
out = core_attn( out = core_attn(
q, q,
k, k,
...@@ -180,7 +193,9 @@ def run_dpa_with_cp( ...@@ -180,7 +193,9 @@ def run_dpa_with_cp(
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1], cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
) )
out.backward(dout) out.backward(dout)
...@@ -226,6 +241,14 @@ def run_dpa_with_cp( ...@@ -226,6 +241,14 @@ def run_dpa_with_cp(
core_attn.set_context_parallel_group( core_attn.set_context_parallel_group(
cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type
) )
if dtype == "fp8":
core_attn.reset_fp8_meta_tensors()
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
with fp8_context:
out_ = core_attn( out_ = core_attn(
q_, q_,
k_, k_,
...@@ -235,7 +258,9 @@ def run_dpa_with_cp( ...@@ -235,7 +258,9 @@ def run_dpa_with_cp(
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1], cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
) )
out_.backward(dout_) out_.backward(dout_)
...@@ -244,13 +269,6 @@ def run_dpa_with_cp( ...@@ -244,13 +269,6 @@ def run_dpa_with_cp(
assert torch.all(~torch.isinf(x)) assert torch.all(~torch.isinf(x))
# compare results with and without CP # compare results with and without CP
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
if qkv_format == "bshd" or qkv_format == "sbhd": if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [ dq, dk, dv, out = [
x.view( x.view(
...@@ -309,32 +327,55 @@ def run_dpa_with_cp( ...@@ -309,32 +327,55 @@ def run_dpa_with_cp(
else: else:
assert False, f"{qkv_format} is an unsupported qkv_format!" assert False, f"{qkv_format} is an unsupported qkv_format!"
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
elif dtype == "fp16":
tols = dict(atol=5e-3, rtol=5e-3)
elif dtype == "fp8":
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
else:
assert False, f"{dtype} is an unsupported dtype!"
def _rmse(a, b):
return torch.sqrt((a - b).square().mean()).item()
def _error(a, b):
if dtype != "fp8":
torch.testing.assert_close(a, b, **tols)
else:
try:
torch.testing.assert_close(a, b, **tols)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert (
rmse < rmse_tol * rmse_range
), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
if qkv_format == "bshd": if qkv_format == "bshd":
torch.testing.assert_close(out_[:, 0], out[:, 0], **tols) for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols) _error(a[:, 0], b[:, 0])
torch.testing.assert_close(dk_[:, 0], dk[:, 0], **tols) _error(a[:, 1], b[:, 1])
torch.testing.assert_close(dv_[:, 0], dv[:, 0], **tols)
torch.testing.assert_close(out_[:, 1], out[:, 1], **tols)
torch.testing.assert_close(dq_[:, 1], dq[:, 1], **tols)
torch.testing.assert_close(dk_[:, 1], dk[:, 1], **tols)
torch.testing.assert_close(dv_[:, 1], dv[:, 1], **tols)
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
torch.testing.assert_close(out_[0], out[0], **tols) for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
torch.testing.assert_close(dq_[0], dq[0], **tols) _error(a[0], b[0])
torch.testing.assert_close(dk_[0], dk[0], **tols) _error(a[1], b[1])
torch.testing.assert_close(dv_[0], dv[0], **tols)
torch.testing.assert_close(out_[1], out[1], **tols)
torch.testing.assert_close(dq_[1], dq[1], **tols)
torch.testing.assert_close(dk_[1], dk[1], **tols)
torch.testing.assert_close(dv_[1], dv[1], **tols)
elif qkv_format == "thd": elif qkv_format == "thd":
torch.testing.assert_close(out_, out, **tols) for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
torch.testing.assert_close(dq_, dq, **tols) _error(a, b)
torch.testing.assert_close(dk_, dk, **tols)
torch.testing.assert_close(dv_, dv, **tols)
else: else:
assert False, f"{qkv_format} is an unsupported qkv_format!" assert False, f"{qkv_format} is an unsupported qkv_format!"
dist.destroy_process_group()
def main(**kwargs): def main(**kwargs):
run_dpa_with_cp(**kwargs) run_dpa_with_cp(**kwargs)
......
...@@ -90,7 +90,7 @@ model_configs_fused_attn = { ...@@ -90,7 +90,7 @@ model_configs_fused_attn = {
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"]) @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
...@@ -121,8 +121,16 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -121,8 +121,16 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
) )
if config.window_size != (-1, 0) and config.window_size != (-1, -1): if config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip( pytest.skip(
f"Fused attention does not support sliding window attention + context parallelism yet!" "Fused attention does not support sliding window attention + context parallelism yet!"
)
if cp_comm_type == "all_gather" and dtype == "fp8":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
) )
if dtype == "fp8" and qkv_format == "thd":
pytest.skip("FP8 attention cannot work with THD format yet!")
if dtype == "fp8" and config.attn_bias_type != "no_bias":
pytest.skip("FP8 attention cannot work with bias yet!")
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
......
...@@ -95,6 +95,9 @@ META_O = tex.FP8FwdTensors.GEMM2_INPUT ...@@ -95,6 +95,9 @@ META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3 META_DP = tex.FP8BwdTensors.GRAD_INPUT3
# repurpose some unused amax history buffers for partial results of CP fwd and bwd
META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT
META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 # NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
...@@ -654,18 +657,6 @@ def get_attention_backend( ...@@ -654,18 +657,6 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention as no backend supports the provided input") logger.debug("Disabling FusedAttention as no backend supports the provided input")
use_fused_attention = False use_fused_attention = False
fused_attention_backend = None fused_attention_backend = None
if (
use_fused_attention
and context_parallel
and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
):
logger.debug(
"Disabling FusedAttention as only sub-backend %s does not support "
"context parallellism",
int(fused_attention_backend),
)
use_fused_attention = False
fused_attention_backend = None
if ( if (
use_fused_attention use_fused_attention
and window_size is not None and window_size is not None
...@@ -1322,6 +1313,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1322,6 +1313,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
attn_bias, attn_bias,
deterministic, deterministic,
use_fused_attention, use_fused_attention,
fp8,
fp8_meta,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -1407,6 +1400,43 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1407,6 +1400,43 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# synchronize fwd results correction across steps # synchronize fwd results correction across steps
fwd_results_correction_done = torch.cuda.Event() fwd_results_correction_done = torch.cuda.Event()
if fp8:
if use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_backend = FusedAttnBackend["FP8"]
if fp8_meta["recipe"].fp8_mha:
assert (
isinstance(q, Float8Tensor)
and isinstance(k, Float8Tensor)
and isinstance(v, Float8Tensor)
), "q/k/v must be Float8Tensors for FP8 MHA!"
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
q_fp8, k_fp8, v_fp8 = q, k, v
q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
else:
q_f16, k_f16, v_f16 = q, k, v
q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
k, v = [
cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
for x in [k_f16, v_f16]
]
fp8_meta_kwargs = {}
fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv[META_QKV]
fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv[META_S]
fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale[META_S]
fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale[META_O_CP]
amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
q_f16 = q
if use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_qkv_dtype = TE_DType[q.dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
p2p_comm_buffers = [None for _ in range(cp_size)] p2p_comm_buffers = [None for _ in range(cp_size)]
if use_fused_attention and qkv_format in ["bshd", "sbhd"]: if use_fused_attention and qkv_format in ["bshd", "sbhd"]:
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
...@@ -1433,7 +1463,23 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1433,7 +1463,23 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
batch_p2p_comm, batch_p2p_comm,
) )
if (
not fp8
or fp8_meta["recipe"].fp8_mha
or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
):
kv_inputs[i % 2] = p2p_comm_buffers[i] kv_inputs[i % 2] = p2p_comm_buffers[i]
else:
# KV exchange is in BF16/FP16, cast received KV in each step
kv_inputs[i % 2] = cast_to_fp8(
p2p_comm_buffers[i],
fp8_meta["scaling_fwd"],
META_QKV,
fp8_dtype_forward,
)
if fp8 and use_fused_attention:
fp8_meta_kwargs["amax_s"] = amax_per_step[0][i]
fp8_meta_kwargs["amax_o"] = amax_per_step[1][i]
if causal: if causal:
if i == 0: if i == 0:
if pad_between_seqs_q: if pad_between_seqs_q:
...@@ -1474,8 +1520,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1474,8 +1520,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
), ),
dim=-1, dim=-1,
).contiguous() ).contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
...@@ -1492,8 +1537,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1492,8 +1537,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1] else kv_inputs[i % 2][1]
), ),
TE_DType[q.dtype], fused_attn_qkv_dtype,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, fused_attn_backend,
attn_scale=softmax_scale, attn_scale=softmax_scale,
dropout=dropout_p, dropout=dropout_p,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -1502,10 +1547,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1502,10 +1547,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
attn_bias=attn_bias_inputs[i % 2], attn_bias=attn_bias_inputs[i % 2],
cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
**fp8_meta_kwargs,
) )
) if fp8:
if len(rest) > 0: softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
attn_biases[i] = rest[0] else:
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else: else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
...@@ -1572,8 +1620,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1572,8 +1620,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if attn_bias is not None: if attn_bias is not None:
idx = (rank - i) % cp_size idx = (rank - i) % cp_size
attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv // 2, max_seqlen_kv // 2,
...@@ -1590,8 +1637,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1590,8 +1637,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1] else kv_inputs[i % 2][1]
), ),
TE_DType[q.dtype], fused_attn_qkv_dtype,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, fused_attn_backend,
attn_scale=softmax_scale, attn_scale=softmax_scale,
dropout=dropout_p, dropout=dropout_p,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -1604,10 +1651,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1604,10 +1651,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if cu_seqlens_kv_padded is None if cu_seqlens_kv_padded is None
else cu_seqlens_kv_padded // 2 else cu_seqlens_kv_padded // 2
), ),
**fp8_meta_kwargs,
) )
) if fp8:
if len(rest) > 0: softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
attn_biases[i] = rest[0] else:
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else: else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
...@@ -1693,8 +1743,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1693,8 +1743,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
), ),
dim=-1, dim=-1,
).contiguous() ).contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
fused_attn_fwd(
is_training, is_training,
max_seqlen_q // 2, max_seqlen_q // 2,
max_seqlen_kv, max_seqlen_kv,
...@@ -1711,8 +1760,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1711,8 +1760,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1] else kv_inputs[i % 2][1]
), ),
TE_DType[q.dtype], fused_attn_qkv_dtype,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, fused_attn_backend,
attn_scale=softmax_scale, attn_scale=softmax_scale,
dropout=dropout_p, dropout=dropout_p,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -1725,10 +1774,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1725,10 +1774,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else cu_seqlens_q_padded // 2 else cu_seqlens_q_padded // 2
), ),
cu_seqlens_kv_padded=cu_seqlens_kv_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
**fp8_meta_kwargs,
) )
) if fp8:
if len(rest) > 0: softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
attn_biases[i] = rest[0] else:
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else: else:
if qkv_format == "thd": if qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn] # [t, np, hn] -> [t/2, np, hn]
...@@ -1795,8 +1847,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1795,8 +1847,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
), ),
dim=-1, dim=-1,
).contiguous() ).contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = ( out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
...@@ -1813,8 +1864,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1813,8 +1864,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1] else kv_inputs[i % 2][1]
), ),
TE_DType[q.dtype], fused_attn_qkv_dtype,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, fused_attn_backend,
attn_scale=softmax_scale, attn_scale=softmax_scale,
dropout=dropout_p, dropout=dropout_p,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -1823,10 +1874,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1823,10 +1874,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
attn_bias=attn_bias_inputs[i % 2], attn_bias=attn_bias_inputs[i % 2],
cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
**fp8_meta_kwargs,
) )
) if fp8:
if len(rest) > 0: softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
attn_biases[i] = rest[0] else:
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else: else:
# [b, sq, np, hn] -> [b*sq, np, hn] # [b, sq, np, hn] -> [b*sq, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
...@@ -1866,8 +1920,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1866,8 +1920,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i - 1].squeeze_(-1) softmax_lse_per_step[i - 1].squeeze_(-1)
with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
if fp8:
out_per_step[i - 1] = cast_from_fp8(
out_per_step[i - 1],
fp8_meta["scaling_fwd"],
META_O_CP,
fp8_dtype_forward,
TE_DType[torch.float32],
)
if i == 1: if i == 1:
out = torch.zeros_like(q) out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal and qkv_format != "thd": if causal and qkv_format != "thd":
# [b, np, sq] -> [b, np, 2, sq//2] # [b, np, sq] -> [b, np, 2, sq//2]
...@@ -1951,13 +2013,55 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1951,13 +2013,55 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
out = out.view(-1, *out.shape[-2:]) out = out.view(-1, *out.shape[-2:])
if fp8 and use_fused_attention:
amax_cp_fwd = amax_per_step.amax(dim=1)
fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0]
fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1]
out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype)
if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward)
if fp8 and fp8_meta["recipe"].fp8_mha:
out_ret = Float8Tensor(
data=out_fp8,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=q_fp8.dtype,
)
else:
out_ret = out_f16
if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_save, kv_save, out_save = q, kv, out_fp8
fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
elif fp8 and fp8_meta["recipe"].fp8_mha:
kv_fp8 = Float8Tensor(
data=kv,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_QKV,
fp8_dtype=fp8_dtype_forward,
dtype=k_fp8.dtype,
)
q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16
fp8_fwd_scales, fp8_fwd_scale_invs = None, None
else:
q_save, kv_save, out_save = q_f16, kv, out_f16
fp8_fwd_scales, fp8_fwd_scale_invs = None, None
ctx.save_for_backward( ctx.save_for_backward(
q, q_save,
kv, kv_save,
out, out_save,
softmax_lse, softmax_lse,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
fp8_fwd_scales,
fp8_fwd_scale_invs,
*cu_seqlens_q_per_step, *cu_seqlens_q_per_step,
*cu_seqlens_kv_per_step, *cu_seqlens_kv_per_step,
*rng_states, *rng_states,
...@@ -1976,7 +2080,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1976,7 +2080,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
ctx.deterministic = deterministic ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention ctx.use_fused_attention = use_fused_attention
return out ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.fp8_meta = fp8_meta
return out_ret
@staticmethod @staticmethod
def backward(ctx, dout): def backward(ctx, dout):
...@@ -1987,10 +2093,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1987,10 +2093,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
(q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6]
cu_seqlens_q_per_step = ctx.saved_tensors[6 : 6 + cp_size] (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8]
cu_seqlens_kv_per_step = ctx.saved_tensors[6 + cp_size : 6 + cp_size * 2] cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size]
rng_states = ctx.saved_tensors[6 + cp_size * 2 : 6 + cp_size * 3] cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2]
attn_biases = ctx.saved_tensors[6 + cp_size * 3 : 6 + cp_size * 4] rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3]
attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4]
causal = "causal" in ctx.attn_mask_type causal = "causal" in ctx.attn_mask_type
padding = "padding" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type
...@@ -2025,22 +2132,60 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2025,22 +2132,60 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.use_fused_attention: if ctx.use_fused_attention:
# [b, np, sq//2] -> [b, np, sq//2, 1] # [b, np, sq//2] -> [b, np, sq//2, 1]
softmax_lse_.unsqueeze_(-1) softmax_lse_.unsqueeze_(-1)
if ctx.use_fused_attention: if ctx.use_fused_attention:
# [b, np, sq] -> [b, np, sq, 1] # [b, np, sq] -> [b, np, sq, 1]
softmax_lse.unsqueeze_(-1) softmax_lse.unsqueeze_(-1)
out = out.view(*q.shape)
dout = dout.view(*q.shape) if ctx.fp8:
# Flash Attn outputs if ctx.use_fused_attention:
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
fused_attn_qkv_dtype = fp8_dtype_backward
fused_attn_dqkv_dtype = fp8_dtype_backward
fused_attn_backend = FusedAttnBackend["FP8"]
dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device)
dkv_fp8_ = torch.empty_like(dkv_fp8)
dout_dtype = dout.dtype
if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
dout = dout._data
else:
dout = cast_to_fp8(
dout, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
)
p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
fp8_meta_kwargs = {}
fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV]
fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S]
fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O]
fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO]
fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP]
fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S]
fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP]
fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV_CP]
amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
q, kv, dout = [x.from_float8(x.dtype) for x in [q, kv, dout]]
dq = torch.empty_like(q) dq = torch.empty_like(q)
if ctx.qkv_format == "thd" and causal: if ctx.qkv_format == "thd" and causal:
dq[cu_seqlens_q_padded[-1] :].fill_(0) dq[cu_seqlens_q_padded[-1] :].fill_(0)
p2p_comm_buffers = [ p2p_comm_buffers = [
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
] ]
p2p_comm_buffers[0][0].copy_(kv) p2p_comm_buffers[0][0].copy_(kv)
if ctx.use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_qkv_dtype = TE_DType[q.dtype]
fused_attn_dqkv_dtype = TE_DType[q.dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
out = out.view(*q.shape)
dout = dout.view(*q.shape)
send_recv_reqs = [] send_recv_reqs = []
fa_optional_backward_kwargs = {} fa_optional_backward_kwargs = {}
...@@ -2056,18 +2201,40 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2056,18 +2201,40 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
send_tensor = p2p_comm_buffers[i % 2] send_tensor = p2p_comm_buffers[i % 2]
recv_tensor = p2p_comm_buffers[(i + 1) % 2] recv_tensor = p2p_comm_buffers[(i + 1) % 2]
if ctx.fp8:
if i < cp_size - 1:
send_recv_reqs = flash_attn_p2p_communicate(
rank,
send_tensor[0],
send_dst,
recv_tensor[0],
recv_src,
ctx.cp_group,
batch_p2p_comm,
)
else:
dkv_a2a_req = torch.distributed.all_to_all_single(
dkv_fp8,
dkv_fp8_,
group=ctx.cp_group,
async_op=True,
)
send_recv_reqs = [dkv_a2a_req]
else:
if i == 0: if i == 0:
send_tensor = send_tensor[0] send_tensor = send_tensor[0]
recv_tensor = recv_tensor[0] recv_tensor = recv_tensor[0]
if i == (cp_size - 1): if i == (cp_size - 1):
send_tensor = send_tensor[1] send_tensor = send_tensor[1]
recv_tensor = recv_tensor[1] recv_tensor = recv_tensor[1]
send_recv_reqs = flash_attn_p2p_communicate( send_recv_reqs = flash_attn_p2p_communicate(
rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
) )
kv = p2p_comm_buffers[i % 2][0] kv = p2p_comm_buffers[i % 2][0]
if ctx.fp8 and ctx.use_fused_attention:
fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i]
fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i]
# In reversed order of fwd # In reversed order of fwd
if causal: if causal:
if i == (cp_size - 1): if i == (cp_size - 1):
...@@ -2090,6 +2257,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2090,6 +2257,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dout_ = dout.view(-1, *dout.shape[-3:]) dout_ = dout.view(-1, *dout.shape[-3:])
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
q_, kv_, out_, dout_ = q, kv, out, dout q_, kv_, out_, dout_ = q, kv, out, dout
if ctx.fp8:
aux_ctx_tensors = [
softmax_lse,
softmax_lse,
rng_states[cp_size - i - 1],
]
else:
aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
if attn_dbias is not None: if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
...@@ -2103,10 +2277,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2103,10 +2277,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_, out_,
dout_, dout_,
TE_DType[q.dtype], fused_attn_qkv_dtype,
TE_DType[kv.dtype], fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, fused_attn_backend,
cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
attn_scale=ctx.softmax_scale, attn_scale=ctx.softmax_scale,
...@@ -2114,6 +2288,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2114,6 +2288,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
attn_mask_type=ctx.attn_mask_type, attn_mask_type=ctx.attn_mask_type,
attn_bias_type=ctx.attn_bias_type, attn_bias_type=ctx.attn_bias_type,
deterministic=ctx.deterministic,
**fp8_meta_kwargs,
) )
else: else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
...@@ -2169,6 +2345,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2169,6 +2345,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_, out_, dout_ = q, out, dout q_, out_, dout_ = q, out, dout
# [2, t, np, hn] -> [2, t/2, np, hn] # [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
if ctx.fp8:
aux_ctx_tensors = [
softmax_lse,
softmax_lse,
rng_states[cp_size - i - 1],
]
else:
aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
if attn_dbias is not None: if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
...@@ -2182,10 +2365,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2182,10 +2365,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_, out_,
dout_, dout_,
TE_DType[q.dtype], fused_attn_qkv_dtype,
TE_DType[kv.dtype], fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, fused_attn_backend,
cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=( cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
...@@ -2195,6 +2378,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2195,6 +2378,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
attn_mask_type="padding" if padding else "no_mask", attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=ctx.attn_bias_type, attn_bias_type=ctx.attn_bias_type,
deterministic=ctx.deterministic,
**fp8_meta_kwargs,
) )
else: else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn] # [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
...@@ -2256,6 +2441,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2256,6 +2441,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
kv_ = kv kv_ = kv
if ctx.fp8:
aux_ctx_tensors = [
softmax_lse_,
softmax_lse_,
rng_states[cp_size - i - 1],
]
else:
aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
if attn_dbias is not None: if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
...@@ -2269,10 +2461,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2269,10 +2461,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_, out_,
dout_, dout_,
TE_DType[q.dtype], fused_attn_qkv_dtype,
TE_DType[kv.dtype], fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, fused_attn_backend,
cu_seqlens_q_padded=( cu_seqlens_q_padded=(
None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
), ),
...@@ -2282,6 +2474,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2282,6 +2474,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
attn_mask_type="padding" if padding else "no_mask", attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=ctx.attn_bias_type, attn_bias_type=ctx.attn_bias_type,
deterministic=ctx.deterministic,
**fp8_meta_kwargs,
) )
else: else:
if ctx.qkv_format == "thd": if ctx.qkv_format == "thd":
...@@ -2325,6 +2519,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2325,6 +2519,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
else: else:
if ctx.use_fused_attention: if ctx.use_fused_attention:
if ctx.fp8:
aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]]
else:
aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
if attn_dbias is not None: if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
...@@ -2338,10 +2535,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2338,10 +2535,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
out, out,
dout, dout,
TE_DType[q.dtype], fused_attn_qkv_dtype,
TE_DType[kv.dtype], fused_attn_dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, fused_attn_backend,
cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
attn_scale=ctx.softmax_scale, attn_scale=ctx.softmax_scale,
...@@ -2349,6 +2546,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2349,6 +2546,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
attn_mask_type=ctx.attn_mask_type, attn_mask_type=ctx.attn_mask_type,
attn_bias_type=ctx.attn_bias_type, attn_bias_type=ctx.attn_bias_type,
deterministic=ctx.deterministic,
**fp8_meta_kwargs,
) )
else: else:
# [b, sq, np, hn] -> [b*sq, np, hn] # [b, sq, np, hn] -> [b*sq, np, hn]
...@@ -2383,6 +2582,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2383,6 +2582,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
**fa_optional_backward_kwargs, **fa_optional_backward_kwargs,
) )
if ctx.fp8:
dq = dq_fp8[(rank + i + 1) % cp_size]
if i >= (cp_size - rank - 1) or not causal: if i >= (cp_size - rank - 1) or not causal:
# [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
# [b*sq, np, hn] -> [b, sq, np, hn] if not causal # [b*sq, np, hn] -> [b, sq, np, hn] if not causal
...@@ -2395,7 +2596,17 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2395,7 +2596,17 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [b*sq//2, np, hn] -> [sq//2, b, np, hn] # [b*sq//2, np, hn] -> [sq//2, b, np, hn]
dq_ = dq_.view(-1, *dq.shape[-3:]) dq_ = dq_.view(-1, *dq.shape[-3:])
if causal: if ctx.fp8:
if i >= (cp_size - rank - 1) or not causal:
dq.copy_(dq_)
else:
if ctx.qkv_format == "bshd":
dq[:, 0, ...].fill_(0)
dq[:, 1, ...].copy_(dq_)
elif ctx.qkv_format == "sbhd":
dq[0].fill_(0)
dq[1].copy_(dq_)
elif causal:
if i > (cp_size - rank - 1): if i > (cp_size - rank - 1):
dq.add_(dq_) dq.add_(dq_)
elif i == (cp_size - rank - 1): elif i == (cp_size - rank - 1):
...@@ -2450,6 +2661,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2450,6 +2661,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
for req in send_recv_reqs: for req in send_recv_reqs:
req.wait() req.wait()
if ctx.fp8:
if i < cp_size - 1:
dkv = dkv_fp8_[(rank + i + 1) % cp_size]
else:
dkv = dkv_fp8[(rank + i + 1) % cp_size]
else:
dkv = p2p_comm_buffers[(i + 1) % 2][1] dkv = p2p_comm_buffers[(i + 1) % 2][1]
if ctx.use_fused_attention: if ctx.use_fused_attention:
dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0) dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
...@@ -2469,7 +2686,17 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2469,7 +2686,17 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
dkv_ = dkv_.view(*dkv.shape) dkv_ = dkv_.view(*dkv.shape)
if causal: if ctx.fp8:
if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].copy_(dkv_)
dkv[:, :, 1, ...].fill_(0)
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].copy_(dkv_)
dkv[:, 1, ...].fill_(0)
else:
dkv.copy_(dkv_)
elif causal:
if i == (cp_size - 1): if i == (cp_size - 1):
if rank == 0: if rank == 0:
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
...@@ -2507,6 +2734,26 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2507,6 +2734,26 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
dkv.add_(dkv_) dkv.add_(dkv_)
if ctx.fp8 and ctx.use_fused_attention:
amax_cp_bwd = amax_per_step.amax(dim=1)
ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] = amax_cp_bwd[0]
ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV_CP] = amax_cp_bwd[1]
if ctx.qkv_format in ["bshd", "sbhd"]:
# [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
# [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
dq, dkv = [
cast_from_fp8(
x,
ctx.fp8_meta["scaling_bwd"],
META_DQKV_CP,
fp8_dtype_backward,
TE_DType[torch.float32],
)
for x in [dq_fp8, dkv_fp8]
]
dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]]
if causal: if causal:
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
...@@ -2527,6 +2774,25 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2527,6 +2774,25 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0) dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0)
dkv = dkv_ dkv = dkv_
if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha:
dq, dkv = [
cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward)
for x in [dq, dkv]
]
dq, dk, dv = [
Float8Tensor(
data=x,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=dout_dtype,
)
for x in [dq, dkv[0], dkv[1]]
]
else:
dk, dv = dkv[0], dkv[1]
if attn_dbias is not None: if attn_dbias is not None:
# [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)
...@@ -2534,8 +2800,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2534,8 +2800,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
return ( return (
None, None,
dq, dq,
dkv[0], dk,
dkv[1], dv,
None, None,
None, None,
None, None,
...@@ -2553,12 +2819,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2553,12 +2819,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
attn_dbias, attn_dbias,
None, None,
None, None,
None,
None,
) )
@jit_fuser @torch.compile
def get_seq_chunk_ids_to_all_gathered_kv( def get_seq_chunk_ids_to_all_gathered_kv(
local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left, device
): ):
"""Compute sequence chunk ids to the all-gathered KV.""" """Compute sequence chunk ids to the all-gathered KV."""
seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv
...@@ -2569,7 +2837,7 @@ def get_seq_chunk_ids_to_all_gathered_kv( ...@@ -2569,7 +2837,7 @@ def get_seq_chunk_ids_to_all_gathered_kv(
local_chunk_id - num_chunks + 1, local_chunk_id - num_chunks + 1,
local_chunk_id + 1, local_chunk_id + 1,
dtype=torch.int32, dtype=torch.int32,
device="cuda", device=device,
) )
chunk_ids_to_all_gathered_kv = torch.where( chunk_ids_to_all_gathered_kv = torch.where(
chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1 chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1
...@@ -2683,6 +2951,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -2683,6 +2951,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
if (window_size is None or window_size[0] == -1) if (window_size is None or window_size[0] == -1)
else window_size[0] else window_size[0]
), ),
k.device,
) )
chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag
num_kv_chunks = chunk_ids_to_kv_ag.numel() num_kv_chunks = chunk_ids_to_kv_ag.numel()
...@@ -3029,6 +3298,8 @@ def attn_forward_func_with_cp( ...@@ -3029,6 +3298,8 @@ def attn_forward_func_with_cp(
deterministic=False, deterministic=False,
use_fused_attention=False, use_fused_attention=False,
window_size=None, window_size=None,
fp8=False,
fp8_meta=None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Attention implementation with context parallelism. Attention implementation with context parallelism.
...@@ -3109,6 +3380,8 @@ def attn_forward_func_with_cp( ...@@ -3109,6 +3380,8 @@ def attn_forward_func_with_cp(
attn_bias, attn_bias,
deterministic, deterministic,
use_fused_attention, use_fused_attention,
fp8,
fp8_meta,
) )
else: else:
raise ValueError(f"Unsupported communication type: {cp_comm_type}!") raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
...@@ -5638,9 +5911,21 @@ class FusedAttention(torch.nn.Module): ...@@ -5638,9 +5911,21 @@ class FusedAttention(torch.nn.Module):
and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen) and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
) )
if fp8:
assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
" is required for FP8 attention!"
)
assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
assert not context_parallel or fp8_meta["recipe"].reduce_amax, (
"Amax reduction across TP+CP group is necessary when using context parallelism with"
" FP8!"
)
if context_parallel: if context_parallel:
assert ( assert (
fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen fp8
or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
), f"{fused_attention_backend} does not work with context parallelism!" ), f"{fused_attention_backend} does not work with context parallelism!"
assert core_attention_bias_type not in [ assert core_attention_bias_type not in [
"alibi" "alibi"
...@@ -5670,19 +5955,14 @@ class FusedAttention(torch.nn.Module): ...@@ -5670,19 +5955,14 @@ class FusedAttention(torch.nn.Module):
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attn_bias_type=core_attention_bias_type, attn_bias_type=core_attention_bias_type,
attn_bias=core_attention_bias, attn_bias=core_attention_bias,
deterministic=self.deterministic,
use_fused_attention=True, use_fused_attention=True,
window_size=window_size, window_size=window_size,
fp8=fp8,
fp8_meta=fp8_meta,
) )
else: else:
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
if fp8:
assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
" is required for FP8 attention!"
)
assert (
fp8_meta is not None
), "FP8 metadata fp8_meta is required for FP8 attention!"
output = FusedAttnFunc.apply( output = FusedAttnFunc.apply(
self.training, self.training,
max_seqlen_q, max_seqlen_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment