"transformer_engine/pytorch/csrc/quantizer.cpp" did not exist on "e5369541eface67d5a76e99bfec861636c28985a"
Unverified Commit 26c8fcc9 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Add FP8 support to CP implementation with KV P2P (#1114)



* add window_size to AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add seq_offsets_qkvo for cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add seq_offsets_qkvo to AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets calculation of cudnn thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove a thd assert
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix bias for thd test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add thd test for cudnn FA with CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* skip GQA/MQA test for cuDNN THD
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make sure seq_offsets are computed with qkv_group of hd_hd_hd while CP>1
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix seq_offsets inputs
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove two comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn mask type for cudnn thd with cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attn_mask_type for cudnn fa with thd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix a typo
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix out dout in bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert cudnn+thd does not support attn bias
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* check if attn_mask_type has padding
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change cp test batch size to 2
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix code format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix two assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert comments
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert swa+CP cannot work with thd format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a new CP function for swa
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a missing dgrads
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft fwd function for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* enable flash attention for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove an assert of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* call SWAFuncWithCP for swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* use 2hd layout
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* change qkv_format check
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a code comment
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* tensor shape bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tensor shape fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add function to compute cu_seqlens of a cp rank
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cu_seqlens and cu_seqlens_padded to context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix FlashAttention output sequence length
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix cu_seqlens_kv_per_step calculation
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV for ending padded tokens
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero dQKV tensors of FlashAttention
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix softmax_lse correction
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove padded tokens of KV to save comounication
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not need to zero dkv for FlashAttention any mroe
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* zero out tensors
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kv shape of cp test with thd format
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* update cp unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add simple code framework
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* try not to have a separate CP function for SWA
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* backup some code change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* back up code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* clean up fwd implementation of SWAFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* reduce kv chunk concat overheads
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* make AttnFuncWithCP and SWAFuncWithCP have same API
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a docstring
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* preliminary implementation of SWAFuncWithCP forward seems working
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix output shape of SWAFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code refactoring for FlashAttention and add a code placeholder for bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* use gather_along_first_dim
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* finish the preliminary implementation of bwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix assert condition
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add draft implementation of SWA+CP with FusedAttention
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix attention mask type of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add qkv_layout
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add missing window_size argument
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix kv shape of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug and typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix dout shape
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add multi stream in fwd of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* save chunk_ids_to_kv_ag in fwd
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add multi stream in bwd of swa+cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix to cp stream sync
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* rename AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* check if window size is None
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix docstring of AttnFuncWithCP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add env var for users to choose KV ag or KV p2p
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* update cp tests
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix window size in cp unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix pytest skip messages
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add cp_comm_type into API
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* code cleaning
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add deterministic konb in cuDNN fused attn backend
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* pass fp8 and fp8_meta to attn_func_with_cp
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert only Fused Attn can support FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove redundant assert
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add a fwd draft implementation of FP8 + CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* save fp8 and fp8_meta
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert sequence length divisible requirements
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove a redundant qkv_layout compute
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* if condition change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* some typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add support table of context parallelism
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* typo and code format fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* do not print multiple disabling messages
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix aux_ctx_tensors of FP8
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix device in torch.arange and adjust code for the PR of MLA
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* commit code change for FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* commit more code change for FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* commit more fp8 code for FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fixes
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* cast merged CP results from FP32 to BF16
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor change
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix softmax_lse
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix some bugs of FP8 dkv exchange
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add FP8 unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix typos and clean asserts
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix get_p2p_comm_info
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* fix dkv p2p exchange
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change FP8 dkv P2P to A2A
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add FP8+CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* typo fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* assert amax reduction is needed for FP8+CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove duplicated code
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* destroy process group in CP unit test
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* remove interval from fp8_recipe because it has been deprecated
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* try to fix the failed CP test with the latest CI pipeline
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove redundant f before string
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* change META_O_CP
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXiaowei Ren <xren@cs-cw-dfw-login-01.cm.cluster>
parent 525de6cc
......@@ -2,15 +2,18 @@
#
# See LICENSE for license information.
import os, sys
import os, sys, logging
from contextlib import nullcontext
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank
import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.common.recipe import DelayedScaling
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
def run_dpa_with_cp(
......@@ -57,6 +60,9 @@ def run_dpa_with_cp(
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True)
# instantiate core attn module
core_attn = DotProductAttention(
config.num_heads,
......@@ -171,18 +177,27 @@ def run_dpa_with_cp(
# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
out = core_attn(
q,
k,
v,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1],
)
out.backward(dout)
if dtype == "fp8":
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
with fp8_context:
out = core_attn(
q,
k,
v,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
)
out.backward(dout)
# run core_attn wit CP
q_, k_, v_, dout_, *rest = [
......@@ -226,31 +241,34 @@ def run_dpa_with_cp(
core_attn.set_context_parallel_group(
cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type
)
out_ = core_attn(
q_,
k_,
v_,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1],
)
out_.backward(dout_)
if dtype == "fp8":
core_attn.reset_fp8_meta_tensors()
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
with fp8_context:
out_ = core_attn(
q_,
k_,
v_,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
)
out_.backward(dout_)
for x in [out_, q_.grad, k_.grad, v_.grad]:
assert torch.all(~torch.isnan(x))
assert torch.all(~torch.isinf(x))
# compare results with and without CP
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [
x.view(
......@@ -309,32 +327,55 @@ def run_dpa_with_cp(
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
elif dtype == "fp16":
tols = dict(atol=5e-3, rtol=5e-3)
elif dtype == "fp8":
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
else:
assert False, f"{dtype} is an unsupported dtype!"
def _rmse(a, b):
return torch.sqrt((a - b).square().mean()).item()
def _error(a, b):
if dtype != "fp8":
torch.testing.assert_close(a, b, **tols)
else:
try:
torch.testing.assert_close(a, b, **tols)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert (
rmse < rmse_tol * rmse_range
), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
if qkv_format == "bshd":
torch.testing.assert_close(out_[:, 0], out[:, 0], **tols)
torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols)
torch.testing.assert_close(dk_[:, 0], dk[:, 0], **tols)
torch.testing.assert_close(dv_[:, 0], dv[:, 0], **tols)
torch.testing.assert_close(out_[:, 1], out[:, 1], **tols)
torch.testing.assert_close(dq_[:, 1], dq[:, 1], **tols)
torch.testing.assert_close(dk_[:, 1], dk[:, 1], **tols)
torch.testing.assert_close(dv_[:, 1], dv[:, 1], **tols)
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a[:, 0], b[:, 0])
_error(a[:, 1], b[:, 1])
elif qkv_format == "sbhd":
torch.testing.assert_close(out_[0], out[0], **tols)
torch.testing.assert_close(dq_[0], dq[0], **tols)
torch.testing.assert_close(dk_[0], dk[0], **tols)
torch.testing.assert_close(dv_[0], dv[0], **tols)
torch.testing.assert_close(out_[1], out[1], **tols)
torch.testing.assert_close(dq_[1], dq[1], **tols)
torch.testing.assert_close(dk_[1], dk[1], **tols)
torch.testing.assert_close(dv_[1], dv[1], **tols)
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a[0], b[0])
_error(a[1], b[1])
elif qkv_format == "thd":
torch.testing.assert_close(out_, out, **tols)
torch.testing.assert_close(dq_, dq, **tols)
torch.testing.assert_close(dk_, dk, **tols)
torch.testing.assert_close(dv_, dv, **tols)
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a, b)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
dist.destroy_process_group()
def main(**kwargs):
run_dpa_with_cp(**kwargs)
......
......@@ -90,7 +90,7 @@ model_configs_fused_attn = {
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather"])
......@@ -121,8 +121,16 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
)
if config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip(
f"Fused attention does not support sliding window attention + context parallelism yet!"
"Fused attention does not support sliding window attention + context parallelism yet!"
)
if cp_comm_type == "all_gather" and dtype == "fp8":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
)
if dtype == "fp8" and qkv_format == "thd":
pytest.skip("FP8 attention cannot work with THD format yet!")
if dtype == "fp8" and config.attn_bias_type != "no_bias":
pytest.skip("FP8 attention cannot work with bias yet!")
subprocess.run(
get_bash_arguments(
......
......@@ -95,6 +95,9 @@ META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
# repurpose some unused amax history buffers for partial results of CP fwd and bwd
META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT
META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1
# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
......@@ -654,18 +657,6 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention as no backend supports the provided input")
use_fused_attention = False
fused_attention_backend = None
if (
use_fused_attention
and context_parallel
and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"]
):
logger.debug(
"Disabling FusedAttention as only sub-backend %s does not support "
"context parallellism",
int(fused_attention_backend),
)
use_fused_attention = False
fused_attention_backend = None
if (
use_fused_attention
and window_size is not None
......@@ -1322,6 +1313,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
attn_bias,
deterministic,
use_fused_attention,
fp8,
fp8_meta,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -1407,6 +1400,43 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# synchronize fwd results correction across steps
fwd_results_correction_done = torch.cuda.Event()
if fp8:
if use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_backend = FusedAttnBackend["FP8"]
if fp8_meta["recipe"].fp8_mha:
assert (
isinstance(q, Float8Tensor)
and isinstance(k, Float8Tensor)
and isinstance(v, Float8Tensor)
), "q/k/v must be Float8Tensors for FP8 MHA!"
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv
q_fp8, k_fp8, v_fp8 = q, k, v
q, k, v = q_fp8._data, k_fp8._data, v_fp8._data
else:
q_f16, k_f16, v_f16 = q, k, v
q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
if int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
k, v = [
cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward)
for x in [k_f16, v_f16]
]
fp8_meta_kwargs = {}
fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv[META_QKV]
fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv[META_S]
fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale[META_S]
fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale[META_O_CP]
amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
q_f16 = q
if use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_qkv_dtype = TE_DType[q.dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
p2p_comm_buffers = [None for _ in range(cp_size)]
if use_fused_attention and qkv_format in ["bshd", "sbhd"]:
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
......@@ -1433,7 +1463,23 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
batch_p2p_comm,
)
kv_inputs[i % 2] = p2p_comm_buffers[i]
if (
not fp8
or fp8_meta["recipe"].fp8_mha
or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
):
kv_inputs[i % 2] = p2p_comm_buffers[i]
else:
# KV exchange is in BF16/FP16, cast received KV in each step
kv_inputs[i % 2] = cast_to_fp8(
p2p_comm_buffers[i],
fp8_meta["scaling_fwd"],
META_QKV,
fp8_dtype_forward,
)
if fp8 and use_fused_attention:
fp8_meta_kwargs["amax_s"] = amax_per_step[0][i]
fp8_meta_kwargs["amax_o"] = amax_per_step[1][i]
if causal:
if i == 0:
if pad_between_seqs_q:
......@@ -1474,38 +1520,40 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
),
dim=-1,
).contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
q_inputs[i % 2],
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale,
dropout=dropout_p,
qkv_layout=qkv_layout,
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2],
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
q_inputs[i % 2],
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
fused_attn_qkv_dtype,
fused_attn_backend,
attn_scale=softmax_scale,
dropout=dropout_p,
qkv_layout=qkv_layout,
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2],
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
**fp8_meta_kwargs,
)
if len(rest) > 0:
attn_biases[i] = rest[0]
if fp8:
softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
else:
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
......@@ -1572,42 +1620,44 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if attn_bias is not None:
idx = (rank - i) % cp_size
attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv // 2,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
q_inputs[i % 2],
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale,
dropout=dropout_p,
qkv_layout=qkv_layout,
attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2],
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=(
None
if cu_seqlens_kv_padded is None
else cu_seqlens_kv_padded // 2
),
)
out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv // 2,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
q_inputs[i % 2],
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
fused_attn_qkv_dtype,
fused_attn_backend,
attn_scale=softmax_scale,
dropout=dropout_p,
qkv_layout=qkv_layout,
attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2],
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=(
None
if cu_seqlens_kv_padded is None
else cu_seqlens_kv_padded // 2
),
**fp8_meta_kwargs,
)
if len(rest) > 0:
attn_biases[i] = rest[0]
if fp8:
softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
else:
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
......@@ -1693,42 +1743,44 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
),
dim=-1,
).contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
fused_attn_fwd(
is_training,
max_seqlen_q // 2,
max_seqlen_kv,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
q_inputs[i % 2],
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale,
dropout=dropout_p,
qkv_layout=qkv_layout,
attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2],
cu_seqlens_q_padded=(
None
if cu_seqlens_q_padded is None
else cu_seqlens_q_padded // 2
),
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q // 2,
max_seqlen_kv,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
q_inputs[i % 2],
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
fused_attn_qkv_dtype,
fused_attn_backend,
attn_scale=softmax_scale,
dropout=dropout_p,
qkv_layout=qkv_layout,
attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2],
cu_seqlens_q_padded=(
None
if cu_seqlens_q_padded is None
else cu_seqlens_q_padded // 2
),
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
**fp8_meta_kwargs,
)
if len(rest) > 0:
attn_biases[i] = rest[0]
if fp8:
softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
else:
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
if qkv_format == "thd":
# [t, np, hn] -> [t/2, np, hn]
......@@ -1795,38 +1847,40 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
),
dim=-1,
).contiguous()
out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = (
fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
q,
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale,
dropout=dropout_p,
qkv_layout=qkv_layout,
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2],
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
out_per_step[i], aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
q,
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
fused_attn_qkv_dtype,
fused_attn_backend,
attn_scale=softmax_scale,
dropout=dropout_p,
qkv_layout=qkv_layout,
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2],
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
**fp8_meta_kwargs,
)
if len(rest) > 0:
attn_biases[i] = rest[0]
if fp8:
softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors
else:
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
# [b, sq, np, hn] -> [b*sq, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-2:])
......@@ -1866,8 +1920,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i - 1].squeeze_(-1)
with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]):
if fp8:
out_per_step[i - 1] = cast_from_fp8(
out_per_step[i - 1],
fp8_meta["scaling_fwd"],
META_O_CP,
fp8_dtype_forward,
TE_DType[torch.float32],
)
if i == 1:
out = torch.zeros_like(q)
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape)
softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double)
if causal and qkv_format != "thd":
# [b, np, sq] -> [b, np, 2, sq//2]
......@@ -1951,13 +2013,55 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
out = out.view(-1, *out.shape[-2:])
if fp8 and use_fused_attention:
amax_cp_fwd = amax_per_step.amax(dim=1)
fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0]
fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1]
out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype)
if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))):
out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward)
if fp8 and fp8_meta["recipe"].fp8_mha:
out_ret = Float8Tensor(
data=out_fp8,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=q_fp8.dtype,
)
else:
out_ret = out_f16
if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
q_save, kv_save, out_save = q, kv, out_fp8
fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone()
fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone()
elif fp8 and fp8_meta["recipe"].fp8_mha:
kv_fp8 = Float8Tensor(
data=kv,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_QKV,
fp8_dtype=fp8_dtype_forward,
dtype=k_fp8.dtype,
)
q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16
fp8_fwd_scales, fp8_fwd_scale_invs = None, None
else:
q_save, kv_save, out_save = q_f16, kv, out_f16
fp8_fwd_scales, fp8_fwd_scale_invs = None, None
ctx.save_for_backward(
q,
kv,
out,
q_save,
kv_save,
out_save,
softmax_lse,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
fp8_fwd_scales,
fp8_fwd_scale_invs,
*cu_seqlens_q_per_step,
*cu_seqlens_kv_per_step,
*rng_states,
......@@ -1976,7 +2080,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
return out
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
ctx.fp8_meta = fp8_meta
return out_ret
@staticmethod
def backward(ctx, dout):
......@@ -1987,10 +2093,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
(q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6]
cu_seqlens_q_per_step = ctx.saved_tensors[6 : 6 + cp_size]
cu_seqlens_kv_per_step = ctx.saved_tensors[6 + cp_size : 6 + cp_size * 2]
rng_states = ctx.saved_tensors[6 + cp_size * 2 : 6 + cp_size * 3]
attn_biases = ctx.saved_tensors[6 + cp_size * 3 : 6 + cp_size * 4]
(fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8]
cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size]
cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2]
rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3]
attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4]
causal = "causal" in ctx.attn_mask_type
padding = "padding" in ctx.attn_mask_type
......@@ -2025,22 +2132,60 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.use_fused_attention:
# [b, np, sq//2] -> [b, np, sq//2, 1]
softmax_lse_.unsqueeze_(-1)
if ctx.use_fused_attention:
# [b, np, sq] -> [b, np, sq, 1]
softmax_lse.unsqueeze_(-1)
if ctx.fp8:
if ctx.use_fused_attention:
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
fused_attn_qkv_dtype = fp8_dtype_backward
fused_attn_dqkv_dtype = fp8_dtype_backward
fused_attn_backend = FusedAttnBackend["FP8"]
dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device)
dkv_fp8_ = torch.empty_like(dkv_fp8)
dout_dtype = dout.dtype
if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!"
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv
dout = dout._data
else:
dout = cast_to_fp8(
dout, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
)
p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]]
fp8_meta_kwargs = {}
fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV]
fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S]
fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O]
fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO]
fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP]
fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S]
fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP]
fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV_CP]
amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device)
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha:
q, kv, dout = [x.from_float8(x.dtype) for x in [q, kv, dout]]
dq = torch.empty_like(q)
if ctx.qkv_format == "thd" and causal:
dq[cu_seqlens_q_padded[-1] :].fill_(0)
p2p_comm_buffers = [
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
]
p2p_comm_buffers[0][0].copy_(kv)
if ctx.use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_qkv_dtype = TE_DType[q.dtype]
fused_attn_dqkv_dtype = TE_DType[q.dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
out = out.view(*q.shape)
dout = dout.view(*q.shape)
# Flash Attn outputs
dq = torch.empty_like(q)
if ctx.qkv_format == "thd" and causal:
dq[cu_seqlens_q_padded[-1] :].fill_(0)
p2p_comm_buffers = [
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device),
]
p2p_comm_buffers[0][0].copy_(kv)
send_recv_reqs = []
fa_optional_backward_kwargs = {}
......@@ -2056,18 +2201,40 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
send_tensor = p2p_comm_buffers[i % 2]
recv_tensor = p2p_comm_buffers[(i + 1) % 2]
if i == 0:
send_tensor = send_tensor[0]
recv_tensor = recv_tensor[0]
if i == (cp_size - 1):
send_tensor = send_tensor[1]
recv_tensor = recv_tensor[1]
send_recv_reqs = flash_attn_p2p_communicate(
rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
)
if ctx.fp8:
if i < cp_size - 1:
send_recv_reqs = flash_attn_p2p_communicate(
rank,
send_tensor[0],
send_dst,
recv_tensor[0],
recv_src,
ctx.cp_group,
batch_p2p_comm,
)
else:
dkv_a2a_req = torch.distributed.all_to_all_single(
dkv_fp8,
dkv_fp8_,
group=ctx.cp_group,
async_op=True,
)
send_recv_reqs = [dkv_a2a_req]
else:
if i == 0:
send_tensor = send_tensor[0]
recv_tensor = recv_tensor[0]
if i == (cp_size - 1):
send_tensor = send_tensor[1]
recv_tensor = recv_tensor[1]
send_recv_reqs = flash_attn_p2p_communicate(
rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
)
kv = p2p_comm_buffers[i % 2][0]
if ctx.fp8 and ctx.use_fused_attention:
fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i]
fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i]
# In reversed order of fwd
if causal:
if i == (cp_size - 1):
......@@ -2090,7 +2257,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dout_ = dout.view(-1, *dout.shape[-3:])
elif ctx.qkv_format == "thd":
q_, kv_, out_, dout_ = q, kv, out, dout
aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
if ctx.fp8:
aux_ctx_tensors = [
softmax_lse,
softmax_lse,
rng_states[cp_size - i - 1],
]
else:
aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
......@@ -2103,10 +2277,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_,
dout_,
TE_DType[q.dtype],
TE_DType[kv.dtype],
fused_attn_qkv_dtype,
fused_attn_dqkv_dtype,
aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
fused_attn_backend,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
attn_scale=ctx.softmax_scale,
......@@ -2114,6 +2288,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
qkv_layout=qkv_layout,
attn_mask_type=ctx.attn_mask_type,
attn_bias_type=ctx.attn_bias_type,
deterministic=ctx.deterministic,
**fp8_meta_kwargs,
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
......@@ -2169,7 +2345,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_, out_, dout_ = q, out, dout
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
if ctx.fp8:
aux_ctx_tensors = [
softmax_lse,
softmax_lse,
rng_states[cp_size - i - 1],
]
else:
aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
......@@ -2182,10 +2365,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_,
dout_,
TE_DType[q.dtype],
TE_DType[kv.dtype],
fused_attn_qkv_dtype,
fused_attn_dqkv_dtype,
aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
fused_attn_backend,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
......@@ -2195,6 +2378,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
qkv_layout=qkv_layout,
attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=ctx.attn_bias_type,
deterministic=ctx.deterministic,
**fp8_meta_kwargs,
)
else:
# [b, 2, sq//2, np, hn] -> [b*sq, np, hn]
......@@ -2256,7 +2441,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1)
dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1)
kv_ = kv
aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
if ctx.fp8:
aux_ctx_tensors = [
softmax_lse_,
softmax_lse_,
rng_states[cp_size - i - 1],
]
else:
aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]]
if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
......@@ -2269,10 +2461,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_,
dout_,
TE_DType[q.dtype],
TE_DType[kv.dtype],
fused_attn_qkv_dtype,
fused_attn_dqkv_dtype,
aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
fused_attn_backend,
cu_seqlens_q_padded=(
None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
),
......@@ -2282,6 +2474,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
qkv_layout=qkv_layout,
attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=ctx.attn_bias_type,
deterministic=ctx.deterministic,
**fp8_meta_kwargs,
)
else:
if ctx.qkv_format == "thd":
......@@ -2325,7 +2519,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
else:
if ctx.use_fused_attention:
aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
if ctx.fp8:
aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]]
else:
aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]]
if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
dq_, dk_, dv_, dbias_ = fused_attn_bwd(
......@@ -2338,10 +2535,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
out,
dout,
TE_DType[q.dtype],
TE_DType[kv.dtype],
fused_attn_qkv_dtype,
fused_attn_dqkv_dtype,
aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
fused_attn_backend,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
attn_scale=ctx.softmax_scale,
......@@ -2349,6 +2546,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
qkv_layout=qkv_layout,
attn_mask_type=ctx.attn_mask_type,
attn_bias_type=ctx.attn_bias_type,
deterministic=ctx.deterministic,
**fp8_meta_kwargs,
)
else:
# [b, sq, np, hn] -> [b*sq, np, hn]
......@@ -2383,6 +2582,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
**fa_optional_backward_kwargs,
)
if ctx.fp8:
dq = dq_fp8[(rank + i + 1) % cp_size]
if i >= (cp_size - rank - 1) or not causal:
# [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal
# [b*sq, np, hn] -> [b, sq, np, hn] if not causal
......@@ -2395,7 +2596,17 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [b*sq//2, np, hn] -> [sq//2, b, np, hn]
dq_ = dq_.view(-1, *dq.shape[-3:])
if causal:
if ctx.fp8:
if i >= (cp_size - rank - 1) or not causal:
dq.copy_(dq_)
else:
if ctx.qkv_format == "bshd":
dq[:, 0, ...].fill_(0)
dq[:, 1, ...].copy_(dq_)
elif ctx.qkv_format == "sbhd":
dq[0].fill_(0)
dq[1].copy_(dq_)
elif causal:
if i > (cp_size - rank - 1):
dq.add_(dq_)
elif i == (cp_size - rank - 1):
......@@ -2450,7 +2661,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
for req in send_recv_reqs:
req.wait()
dkv = p2p_comm_buffers[(i + 1) % 2][1]
if ctx.fp8:
if i < cp_size - 1:
dkv = dkv_fp8_[(rank + i + 1) % cp_size]
else:
dkv = dkv_fp8[(rank + i + 1) % cp_size]
else:
dkv = p2p_comm_buffers[(i + 1) % 2][1]
if ctx.use_fused_attention:
dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0)
if ctx.qkv_format in ["bshd", "sbhd"]:
......@@ -2469,7 +2686,17 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal
dkv_ = dkv_.view(*dkv.shape)
if causal:
if ctx.fp8:
if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].copy_(dkv_)
dkv[:, :, 1, ...].fill_(0)
elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].copy_(dkv_)
dkv[:, 1, ...].fill_(0)
else:
dkv.copy_(dkv_)
elif causal:
if i == (cp_size - 1):
if rank == 0:
if ctx.qkv_format == "bshd":
......@@ -2507,6 +2734,26 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
dkv.add_(dkv_)
if ctx.fp8 and ctx.use_fused_attention:
amax_cp_bwd = amax_per_step.amax(dim=1)
ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] = amax_cp_bwd[0]
ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV_CP] = amax_cp_bwd[1]
if ctx.qkv_format in ["bshd", "sbhd"]:
# [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
# [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
dq, dkv = [
cast_from_fp8(
x,
ctx.fp8_meta["scaling_bwd"],
META_DQKV_CP,
fp8_dtype_backward,
TE_DType[torch.float32],
)
for x in [dq_fp8, dkv_fp8]
]
dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]]
if causal:
if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
......@@ -2527,6 +2774,25 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0)
dkv = dkv_
if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha:
dq, dkv = [
cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward)
for x in [dq, dkv]
]
dq, dk, dv = [
Float8Tensor(
data=x,
fp8_meta=ctx.fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=META_DQKV,
fp8_dtype=fp8_dtype_backward,
dtype=dout_dtype,
)
for x in [dq, dkv[0], dkv[1]]
]
else:
dk, dv = dkv[0], dkv[1]
if attn_dbias is not None:
# [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk]
attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1)
......@@ -2534,8 +2800,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
return (
None,
dq,
dkv[0],
dkv[1],
dk,
dv,
None,
None,
None,
......@@ -2553,12 +2819,14 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
attn_dbias,
None,
None,
None,
None,
)
@jit_fuser
@torch.compile
def get_seq_chunk_ids_to_all_gathered_kv(
local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left
local_chunk_id, cp_size, max_seqlen_q, max_seqlen_kv, window_size_left, device
):
"""Compute sequence chunk ids to the all-gathered KV."""
seq_end_idx = (local_chunk_id + 1) * max_seqlen_kv
......@@ -2569,7 +2837,7 @@ def get_seq_chunk_ids_to_all_gathered_kv(
local_chunk_id - num_chunks + 1,
local_chunk_id + 1,
dtype=torch.int32,
device="cuda",
device=device,
)
chunk_ids_to_all_gathered_kv = torch.where(
chunk_ids < cp_size, 2 * chunk_ids, 2 * (2 * cp_size - chunk_ids) - 1
......@@ -2683,6 +2951,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
if (window_size is None or window_size[0] == -1)
else window_size[0]
),
k.device,
)
chunk_ids_to_kv_ag_per_step[i] = chunk_ids_to_kv_ag
num_kv_chunks = chunk_ids_to_kv_ag.numel()
......@@ -3029,6 +3298,8 @@ def attn_forward_func_with_cp(
deterministic=False,
use_fused_attention=False,
window_size=None,
fp8=False,
fp8_meta=None,
) -> torch.Tensor:
"""
Attention implementation with context parallelism.
......@@ -3109,6 +3380,8 @@ def attn_forward_func_with_cp(
attn_bias,
deterministic,
use_fused_attention,
fp8,
fp8_meta,
)
else:
raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
......@@ -5638,9 +5911,21 @@ class FusedAttention(torch.nn.Module):
and (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen)
)
if fp8:
assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
" is required for FP8 attention!"
)
assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!"
assert not context_parallel or fp8_meta["recipe"].reduce_amax, (
"Amax reduction across TP+CP group is necessary when using context parallelism with"
" FP8!"
)
if context_parallel:
assert (
fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
fp8
or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
), f"{fused_attention_backend} does not work with context parallelism!"
assert core_attention_bias_type not in [
"alibi"
......@@ -5670,19 +5955,14 @@ class FusedAttention(torch.nn.Module):
attn_mask_type=attn_mask_type,
attn_bias_type=core_attention_bias_type,
attn_bias=core_attention_bias,
deterministic=self.deterministic,
use_fused_attention=True,
window_size=window_size,
fp8=fp8,
fp8_meta=fp8_meta,
)
else:
with self.attention_dropout_ctx():
if fp8:
assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, (
f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}"
" is required for FP8 attention!"
)
assert (
fp8_meta is not None
), "FP8 metadata fp8_meta is required for FP8 attention!"
output = FusedAttnFunc.apply(
self.training,
max_seqlen_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment