"tests/cpp/vscode:/vscode.git/clone" did not exist on "96f9c6dedc72be13ab4374c4ab453690f4d7d072"
Unverified Commit 25252e9f authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Add FP8 attention with current scaling (#2012)



* debug existing usage
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8_dpa
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reimplement fp8_dpa
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* more clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE develop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* redesign CS; need cleanup
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up s/dP quantizers
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* return dP to DS
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* improve quantizer_helper; tweak dP DS/CS logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* debug CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FE commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up non-CP; debug dq/dk mismatches
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor success with CP; need to remove debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove debug info
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* disable fp8 output for fp8_mha + CS
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add output_tensor_type to FADescriptor
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes for CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove print
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* more fixes for non-CP and CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* enable non-determinism for blackwell
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix indent; remove print
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch from create_tensor_from_data to make_like
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* enable a2a+p2p for CS CP and require additional cp_group_global
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* condense tests; only create dist groups once
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* consolidate CP P2P per-tile calls for fwd/bwd and fused/flash
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix flash-attn from last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes for previous commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix attn_mask_type in f16 causal
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert bb6a0a59 temporarily
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reenable comparison for some tensors in CP tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix dbias for fused attn CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up prints/comments and add back NVTE_CS_dP_SCALE
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* first attempt at mixed DS/CS reduction
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix for last commit for mixed DS/CS reduction
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove prints from 69639024
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix DS recipe for dP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add NVTE_DPA_FORCE_DS to force DS for all DPA tensors, not just dP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix NVTE_DPA_FORCE_DS and add NVTE_PRINT
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* modify DS recipe for MLPerf
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reduce only over TP group; need to think about CP group later
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* streamline fake_recipe/quantizer generation; allow NVTE_DPA_Fixed_Scales or DS-update S/dP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add more print: NVTE_LAYER_NUMBER
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* split S/dP in env vars: NVTE_DPA_Fix_S_Scale and NVTE_DPA_Fix_dP_Scale
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix autocast_key for DS
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add NVTE_REPEAT_in_F16 to repeat FP8 fwd/bwd passes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add FP8 CS to UnfusedDPA; unsuccessful; does not affect other backends
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* temporary: print min/max and save tensors for debugging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* emulate q/dq+bf16 with NVTE_Emulate_in_F16; add NVTE_DPA_FORCE_MXFP8 for MXFP8 q/dq
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add RHT to BMM1 with NVTE_RHT_BMM1 for the size
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* re-enable fused attn in dpa_fp8_vs_f16 test; changed during unfused attn implementation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add NVTE_FP8_CS_POWER_OF_2, NVTE_DPA_FORCE_BLOCKFP8, NVTE_Emulate_QDQ_QKV, NVTE_Emulate_QDQ_O, NVTE_Emulate_QDQ_dO
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add F16 O support for FP8 kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert to TE FE commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* return to FE develop
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tidy up; untested
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix for last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes and improvements for last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* more minor fixes and improvements
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* more small fixes/improvements; mostly for CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix CS/DS recipe switch in DPA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* avoid quantizing/saving of O when CS bwd uses F16 O
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* move fp8_autocast(fp8_recipe) print to utils.py
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add debug logging to unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back prints of quantizers/layer_number for debugging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* enable amax reduction for both CS and DS tensors
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix NVTE_FP8_DPA_BWD=0 for CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last commit for F16 fwd/bwd a2a+p2p
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* small fixes for float8_current_scaling(), nominal types, and unruly d_out types
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8_output in MHA and some CP tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes to CP tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes for CP A2A
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clamp input data in tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove rmse and tighten atol/rtol for tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* restructure fp8_recipes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix linter
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "remove rmse and tighten atol/rtol for tests"

This reverts commit 15dba6a59a5323d414f02cf22f099cb00d880532.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* more fixes for linter
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 recipe changes for F16 code path
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert to FE on main to help with merges
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* switch back to FE develop after merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE develop commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert to GitHub FE 1.14.1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE to its latest main
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for A2A
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix last commit for A2A DS
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove memset for BSHD/SBHD FP8
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove concat for qkv quantization in CS
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* improve/simplify the logic for last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add nominal_type for UnfusedDPA FP8 EmuFunc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update env vars for DPA recipes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix typo in last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix DS recipe creation for NVFP4 global recipe
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace python max with torch.maximum
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix linter
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix CP A2A for FA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reduce prints in print_quantizers
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add FP8 env vars to NVTE_DEBUG prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add reduce_amax to DS repr
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* separate fp8_dpa/fp8_mha in CP tests; fix A2A for them; add f16_O tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address some reciews
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make data optional in create_hp_tensor_with_amax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix for comments in bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* print cudnn version in attn tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* disable CS for Hopper
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* alternative tests to reduce CI time
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make NVTE_DPA_FP8CS_O_in_F16 default to 1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove _fp8 variables to avoid confusion
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* return to requiring two cp_groups for a2a+p2p
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace NVTE_PRINT with NVTE_DEBUG/_LEVEL for quantizer prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* provide a basic set of tests for CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix the last merge with nvfp4 PR
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* disable for Hopper
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 backend selection for Hopper
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reduce CP CI to essential tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix to CP test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix recipe logic in tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert to concat for qkv quantization
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove cudnn version in qa scripts
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2354fb8b
Subproject commit 1a7b4b78db44712fb9707d21cd2e3179f1fd88b8 Subproject commit 80a8e4af4d89d33a2c59d51fcf9fda1c9d368cd4
...@@ -12,14 +12,18 @@ from transformer_engine.pytorch.attention import DotProductAttention ...@@ -12,14 +12,18 @@ from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
get_cu_seqlens_on_cp_rank, get_cu_seqlens_on_cp_rank,
) )
from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize
import transformer_engine_torch as tex import transformer_engine_torch as tex
from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.common.recipe import DelayedScaling Float8Tensor,
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling
from utils import ModelConfig, compare_and_assert from utils import ModelConfig, compare_and_assert
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
...@@ -151,7 +155,7 @@ def get_tols(config, dtype): ...@@ -151,7 +155,7 @@ def get_tols(config, dtype):
elif dtype == "fp8": elif dtype == "fp8":
atol = 5e-1 atol = 5e-1
rtol = 5e-1 rtol = 5e-1
rmse_tol = 0.1 rmse_tol = 0.15
else: else:
assert False, f"{dtype=} is not supported!" assert False, f"{dtype=} is not supported!"
...@@ -164,14 +168,23 @@ def run_dpa_with_cp( ...@@ -164,14 +168,23 @@ def run_dpa_with_cp(
qkv_format="bshd", qkv_format="bshd",
kernel_backend="FlashAttention", kernel_backend="FlashAttention",
cp_comm_type="p2p", cp_comm_type="p2p",
fp8_mha=False, fp8_bwd="True",
fp8_dpa="False",
fp8_mha="False",
scaling_mode="delayed",
f16_O="False",
log_level=logging.WARNING, log_level=logging.WARNING,
): ):
"""Test DotProductAttention module with context parallelism""" """Test DotProductAttention module with context parallelism"""
logging.root.setLevel(log_level) logging.root.setLevel(log_level)
# set up environment variables and config # set up environment variables and config
fp8_mha = fp8_mha == "True" fp8_bwd = fp8_bwd == "True" and dtype == "fp8"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0"
fp8_dpa = fp8_dpa == "True" and dtype == "fp8"
fp8_mha = fp8_mha == "True" and dtype == "fp8"
f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True"
os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0"
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention": if kernel_backend == "FlashAttention":
...@@ -219,8 +232,12 @@ def run_dpa_with_cp( ...@@ -219,8 +232,12 @@ def run_dpa_with_cp(
sub_group = dist.new_group(sub_ranks, backend="nccl") sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks: if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group) cp_comm_sub_groups.append(sub_group)
if dtype == "fp8": if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha) if scaling_mode == "delayed":
fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
if scaling_mode == "current":
fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
# instantiate attention module # instantiate attention module
core_attn = DotProductAttention( core_attn = DotProductAttention(
...@@ -247,19 +264,38 @@ def run_dpa_with_cp( ...@@ -247,19 +264,38 @@ def run_dpa_with_cp(
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda() k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda() v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
for x in [q, k, v]: dout_orig = torch.clamp(
x.requires_grad = True torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1
).cuda()
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() if scaling_mode == "delayed":
if fp8_mha: qkv_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
scale=torch.tensor([1], dtype=torch.float32).cuda(),
amax=torch.tensor([0], dtype=torch.float32).cuda(),
)
dout_quantizer = Float8Quantizer( dout_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2, fp8_dtype=tex.DType.kFloat8E5M2,
scale=torch.tensor([1], dtype=torch.float32).cuda(), scale=torch.tensor([1], dtype=torch.float32).cuda(),
amax=torch.tensor([0], dtype=torch.float32).cuda(), amax=torch.tensor([0], dtype=torch.float32).cuda(),
) )
if scaling_mode == "current":
qkv_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device="cuda",
)
dout_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
device="cuda",
)
qkv_layout = "_".join([qkv_format] * 3)
q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]]
if fp8_mha:
q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer)
for x in [q, k, v]:
x.requires_grad = True
if config.attn_bias_type not in ["no_bias", "alibi"]: if config.attn_bias_type not in ["no_bias", "alibi"]:
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
...@@ -274,6 +310,7 @@ def run_dpa_with_cp( ...@@ -274,6 +310,7 @@ def run_dpa_with_cp(
else: else:
fp8_context = nullcontext() fp8_context = nullcontext()
with fp8_context: with fp8_context:
# q, k, v, out in FP8; dout in F16
out = core_attn( out = core_attn(
q, q,
k, k,
...@@ -284,8 +321,9 @@ def run_dpa_with_cp( ...@@ -284,8 +321,9 @@ def run_dpa_with_cp(
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
) )
if fp8_mha: if fp8_bwd and fp8_mha:
dout_fp8 = dout_quantizer(dout) dout_fp8 = dout_quantizer(dout)
out.backward(dout_fp8) out.backward(dout_fp8)
else: else:
...@@ -298,24 +336,10 @@ def run_dpa_with_cp( ...@@ -298,24 +336,10 @@ def run_dpa_with_cp(
############ run with CP ############ ############ run with CP ############
logging.info(f"[Rank {rank}] Run with context parallelism") logging.info(f"[Rank {rank}] Run with context parallelism")
# set up environment
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if config.softmax_type != "vanilla":
core_attn.softmax_offset.grad.zero_()
if dtype == "fp8":
core_attn.reset_fp8_meta_tensors()
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
# set up inputs # set up inputs
q_, k_, v_, dout_, *rest = [ q_, k_, v_, dout_, *rest = [
x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias]) x.clone().detach()
for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias])
] ]
bias_ = rest[0] if len(rest) else None bias_ = rest[0] if len(rest) else None
if qkv_format == "bshd" or qkv_format == "sbhd": if qkv_format == "bshd" or qkv_format == "sbhd":
...@@ -343,6 +367,16 @@ def run_dpa_with_cp( ...@@ -343,6 +367,16 @@ def run_dpa_with_cp(
) )
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]]
if scaling_mode == "delayed":
qkv_quantizer.scale.fill_(1.0)
qkv_quantizer.amax.fill_(0.0)
dout_quantizer.scale.fill_(1.0)
dout_quantizer.amax.fill_(0.0)
if fp8_mha:
q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer)
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None: if bias_ is not None:
bias_ = bias_.view( bias_ = bias_.view(
...@@ -350,9 +384,25 @@ def run_dpa_with_cp( ...@@ -350,9 +384,25 @@ def run_dpa_with_cp(
) )
bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
# set up environment
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if config.softmax_type != "vanilla":
core_attn.softmax_offset.grad.zero_()
if dtype == "fp8":
core_attn.fp8_initialized = False
core_attn.fp8_meta_tensors_initialized = False
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
# run attention # run attention
with fp8_context: with fp8_context:
# q, k, v, out in FP8; dout in F16
out_ = core_attn( out_ = core_attn(
q_, q_,
k_, k_,
...@@ -363,27 +413,30 @@ def run_dpa_with_cp( ...@@ -363,27 +413,30 @@ def run_dpa_with_cp(
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha,
) )
if fp8_mha: if fp8_bwd and fp8_mha:
dout_fp8_ = dout_quantizer(dout_) dout_fp8_ = dout_quantizer(dout_)
out_.backward(dout_fp8_) out_.backward(dout_fp8_)
else: else:
out_.backward(dout_) out_.backward(dout_)
if fp8_mha:
assert isinstance(out, Float8Tensor)
assert isinstance(out_, Float8Tensor)
out = out.dequantize()
out_ = out_.dequantize()
# get outputs
dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad
d_softmax_offset_ = None d_softmax_offset_ = None
if config.softmax_type != "vanilla": if config.softmax_type != "vanilla":
d_softmax_offset_ = core_attn.softmax_offset.grad.clone() d_softmax_offset_ = core_attn.softmax_offset.grad.clone()
for x in [out_, dq_, dk_, dv_, d_softmax_offset_]:
if x is not None: # get outputs
assert torch.all(~torch.isnan(x)) tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_]
assert torch.all(~torch.isinf(x)) if fp8_mha:
tensors_to_deq = [out, out_] if not fp8_bwd else tensors
for i, tensor in enumerate(tensors_to_deq):
tensors_to_deq[i] = tensor.dequantize()
if not fp8_bwd:
tensors[0], tensors[4] = tensors_to_deq
for tensor in tensors:
assert torch.all(~torch.isnan(tensor))
assert torch.all(~torch.isinf(tensor))
out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors
############ compare results between CP and no-CP ############ ############ compare results between CP and no-CP ############
if qkv_format == "bshd" or qkv_format == "sbhd": if qkv_format == "bshd" or qkv_format == "sbhd":
...@@ -394,17 +447,17 @@ def run_dpa_with_cp( ...@@ -394,17 +447,17 @@ def run_dpa_with_cp(
x.shape[seq_dim] // (2 * world_size), x.shape[seq_dim] // (2 * world_size),
*x.shape[(seq_dim + 1) :], *x.shape[(seq_dim + 1) :],
) )
for x in [q.grad, k.grad, v.grad, out] for x in [dq, dk, dv, out]
] ]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [ dq_, dk_, dv_, out_ = [
x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
for x in [q_.grad, k_.grad, v_.grad, out_] for x in [dq_, dk_, dv_, out_]
] ]
elif qkv_format == "thd": elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]] dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]]
dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_] dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
......
...@@ -1693,23 +1693,44 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] ...@@ -1693,23 +1693,44 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("RoPE", [True, False])
@pytest.mark.parametrize("is_training", [True, False]) @pytest.mark.parametrize("is_training", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training): @pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_mha_fp8_vs_f16(
dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode
):
"""Test MultiHeadAttention module in FP8""" """Test MultiHeadAttention module in FP8"""
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
# Test backend availability # Test backend availability
if scaling_mode == "delayed":
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=True,
fp8_mha=True,
)
elif scaling_mode == "current":
fp8_recipe = recipe.Float8CurrentScaling(
fp8_format=recipe.Format.HYBRID,
fp8_dpa=True,
fp8_mha=True,
)
fp8_meta = {}
fp8_meta["recipe"] = fp8_recipe
available_backends, _, fused_attn_backends = get_available_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=torch.float8_e4m3fn, qkv_dtype=torch.float8_e4m3fn,
qkv_layout=qkv_format.replace("hd", "h3d"), qkv_layout=qkv_format.replace("hd", "h3d"),
fp8=True,
fp8_meta=fp8_meta,
is_training=is_training, is_training=is_training,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported if flash_attn_supported + fused_attn_supported < 1:
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("No FP8 attention backend available.")
pytest.skip("Less than two backends to compare.")
if not fp8_dpa_bwd: if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
...@@ -1727,7 +1748,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1727,7 +1748,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
) )
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
...@@ -1735,19 +1756,20 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1735,19 +1756,20 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
) )
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
) )
atol = 5e-1 atol = 5e-1
rtol = 5e-1 rtol = 5e-1
rmse_tol = 0.15 rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
if flash_attn_supported: if flash_attn_supported:
logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert( compare_and_assert(
flash_attn_fwd_fp8, flash_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
...@@ -1758,6 +1780,8 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1758,6 +1780,8 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rmse_tol, rmse_tol,
True, True,
) )
logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert( compare_and_assert(
fused_attn_fwd_fp8, fused_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
...@@ -1784,7 +1808,9 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ...@@ -1784,7 +1808,9 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
) )
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training): def _run_mha_fp8_vs_f16(
dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
):
"""Run MultiHeadAttention module in FP8""" """Run MultiHeadAttention module in FP8"""
reset_rng_states() reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
...@@ -1794,15 +1820,6 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP ...@@ -1794,15 +1820,6 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
"""Get cuda rng tracker.""" """Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER return _DUMMY_CUDA_RNG_STATE_TRACKER
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_mha,
fp8_mha=fp8_mha,
)
with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe):
rotary_pos_emb = None rotary_pos_emb = None
if RoPE: if RoPE:
...@@ -1911,7 +1928,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP ...@@ -1911,7 +1928,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("is_training", [True, False]) @pytest.mark.parametrize("is_training", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): @pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode):
"""Test DotProductAttention module in FP8""" """Test DotProductAttention module in FP8"""
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
...@@ -1927,16 +1945,33 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1927,16 +1945,33 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1"
# Test backend availability # Test backend availability
if scaling_mode == "delayed":
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=True,
)
elif scaling_mode == "current":
fp8_recipe = recipe.Float8CurrentScaling(
fp8_format=recipe.Format.HYBRID,
fp8_dpa=True,
)
fp8_meta = {}
fp8_meta["recipe"] = fp8_recipe
available_backends, _, fused_attn_backends = get_available_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=torch.float8_e4m3fn, qkv_dtype=torch.float8_e4m3fn,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
fp8=True,
fp8_meta=fp8_meta,
is_training=is_training, is_training=is_training,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if flash_attn_supported + fused_attn_supported < 1: if flash_attn_supported + fused_attn_supported < 1:
pytest.skip("No FP8 attention backend available.") pytest.skip("No FP8 attention backend available.")
if not fp8_dpa_bwd: if not fp8_dpa_bwd:
...@@ -1956,32 +1991,44 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1956,32 +1991,44 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training dtype, config, True, qkv_layout, is_training, fp8_recipe
)
if unfused_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)")
unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training, fp8_recipe
) )
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training dtype, config, True, qkv_layout, is_training, fp8_recipe
) )
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
if config.dropout_p == 0.0: if config.dropout_p == 0.0:
# test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)")
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
dtype, config, False, qkv_layout, is_training dtype, config, False, qkv_layout, is_training, fp8_recipe
) )
atol = 5e-1 atol = 5e-1
rtol = 5e-2 rtol = 5e-2
rmse_tol = 0.11 rmse_tol = 0.11
bwd_names = ["dq", "dk", "dv"] bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output"))
if flash_attn_supported: if flash_attn_supported:
logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert( compare_and_assert(
flash_attn_fwd_fp8, flash_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
...@@ -1992,12 +2039,40 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -1992,12 +2039,40 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol, rmse_tol,
True, True,
) )
if unfused_attn_supported:
logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
unfused_attn_fwd_fp8,
fused_attn_fwd_f16,
"unfused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
True,
)
if is_training:
for i, _ in enumerate(fused_attn_bwd_f16):
logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
compare_and_assert(
unfused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"unfused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]",
atol,
rtol,
rmse_tol,
True,
)
if config.dropout_p != 0.0: if config.dropout_p != 0.0:
# test cuDNN FP8 dropout # test cuDNN FP8 dropout
assert torch.all( assert torch.all(
fused_attn_fwd_fp8 == 1 fused_attn_fwd_fp8 == 1
), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
else: else:
logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert( compare_and_assert(
fused_attn_fwd_fp8, fused_attn_fwd_fp8,
fused_attn_fwd_f16, fused_attn_fwd_f16,
...@@ -2021,9 +2096,10 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): ...@@ -2021,9 +2096,10 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol, rmse_tol,
True, True,
) )
os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0"
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_recipe):
"""Run DotProductAttention module in FP8""" """Run DotProductAttention module in FP8"""
reset_rng_states() reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
...@@ -2033,14 +2109,6 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): ...@@ -2033,14 +2109,6 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
"""Get cuda rng tracker.""" """Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER return _DUMMY_CUDA_RNG_STATE_TRACKER
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_dpa,
)
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
with fp8_model_init(enabled=fp8_dpa): with fp8_model_init(enabled=fp8_dpa):
dpa = DotProductAttention( dpa = DotProductAttention(
...@@ -2147,6 +2215,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): ...@@ -2147,6 +2215,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
attn_mask_type=config.attn_mask_type, attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False, checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
fp8_output=fp8_dpa,
) )
if is_training: if is_training:
out.backward(out_grad) out.backward(out_grad)
......
...@@ -14,6 +14,10 @@ from transformer_engine.pytorch.utils import ( ...@@ -14,6 +14,10 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability, get_device_compute_capability,
get_cudnn_version, get_cudnn_version,
) )
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
...@@ -27,6 +31,8 @@ seed = 1234 ...@@ -27,6 +31,8 @@ seed = 1234
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
test_essential = True
model_configs_flash_attn = { model_configs_flash_attn = {
# test: ModelConfig(b, sq, hq, dqk) # test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
...@@ -63,12 +69,22 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): ...@@ -63,12 +69,22 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
return args return args
dtypes = ["bf16", "fp16"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"]
model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs}
dtypes = ["bf16"]
qkv_formats = ["sbhd", "thd"]
@pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) @pytest.mark.parametrize("cp_comm_type", cp_comm_types)
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count(): if num_gpus > torch.cuda.device_count():
...@@ -77,6 +93,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ...@@ -77,6 +93,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model] config = model_configs_flash_attn[model]
config.context_parallel = True config.context_parallel = True
config.cp_comm_type = cp_comm_type config.cp_comm_type = cp_comm_type
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!") pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd": if cp_comm_type == "all_gather" and qkv_format == "thd":
...@@ -162,14 +179,30 @@ model_configs_fused_attn = { ...@@ -162,14 +179,30 @@ model_configs_fused_attn = {
} }
dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"]
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) @pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) @pytest.mark.parametrize("cp_comm_type", cp_comm_types)
@pytest.mark.parametrize("fp8_mha", [False, True]) @pytest.mark.parametrize("fp8_bwd", [True, False])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha): @pytest.mark.parametrize("fp8_mha", [True, False])
@pytest.mark.parametrize("fp8_dpa", [True, False])
@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"])
@pytest.mark.parametrize("f16_O", [True, False])
def test_cp_with_fused_attention(
dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O
):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count(): if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")
...@@ -180,10 +213,15 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -180,10 +213,15 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0): if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!") pytest.skip("FP8 attention is only supported on sm90+!")
if dtype == "fp8" and not fp8_dpa and fp8_mha:
pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!")
if dtype != "fp8" and fp8_bwd:
pytest.skip("Only fp8 works with fp8_bwd=True!")
config = model_configs_fused_attn[model] config = model_configs_fused_attn[model]
config.context_parallel = True config.context_parallel = True
config.cp_comm_type = cp_comm_type config.cp_comm_type = cp_comm_type
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!") pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather": if qkv_format == "thd" and cp_comm_type == "all_gather":
...@@ -211,8 +249,22 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -211,8 +249,22 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
) )
if dtype != "fp8" and fp8_mha: if dtype != "fp8" and (fp8_mha or fp8_dpa):
pytest.skip("Only fp8 works with fp8_mha=True!") pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!")
if dtype == "fp8" and not (fp8_mha or fp8_dpa):
pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!")
if dtype != "fp8" and scaling_mode is not None:
pytest.skip("Only fp8 works with scaling_mode != None!")
if dtype == "fp8" and scaling_mode is None:
pytest.skip("fp8 only works with scaling_mode != None!")
if (
dtype == "fp8"
and scaling_mode == "current"
and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"]
):
pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!")
if f16_O and (dtype != "fp8" or scaling_mode != "current"):
pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!")
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!") pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
...@@ -229,10 +281,25 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -229,10 +281,25 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
) )
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
fp8_meta = {}
fp8_meta["recipe"] = None
fp8_meta["local_recipes"] = []
fp8 = dtype == "fp8" and (fp8_dpa or fp8_mha)
if fp8 and scaling_mode == "delayed":
fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True)
fp8_meta["local_recipes"] = [DelayedScaling(fp8_dpa=True)]
if fp8 and scaling_mode == "current":
fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True)
fp8_meta["local_recipes"] = [
Float8CurrentScaling(fp8_dpa=True),
DelayedScaling(fp8_dpa=True),
]
available_backends, _, fused_attn_backends = get_available_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn,
qkv_layout="_".join([qkv_format] * 3), qkv_layout="_".join([qkv_format] * 3),
fp8=fp8,
fp8_meta=fp8_meta,
) )
_, fused_attn_supported, _ = available_backends _, fused_attn_supported, _ = available_backends
if not fused_attn_supported: if not fused_attn_supported:
...@@ -246,7 +313,11 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -246,7 +313,11 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
qkv_format=qkv_format, qkv_format=qkv_format,
kernel_backend="FusedAttention", kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type, cp_comm_type=cp_comm_type,
fp8_bwd=fp8_bwd,
fp8_dpa=fp8_dpa,
fp8_mha=fp8_mha, fp8_mha=fp8_mha,
scaling_mode=scaling_mode,
f16_O=f16_O,
log_level=pytest_logging_level, log_level=pytest_logging_level,
), ),
check=True, check=True,
......
...@@ -129,7 +129,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -129,7 +129,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
window_size_right, window_size_right,
true, true,
tensorType, tensorType,
tensorType}; cudnn_frontend::DataType_t::NOT_SET,
cudnn_frontend::DataType_t::NOT_SET,
cudnn_frontend::DataType_t::NOT_SET};
namespace fe = cudnn_frontend; namespace fe = cudnn_frontend;
using graph_and_tensors = using graph_and_tensors =
...@@ -585,7 +587,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -585,7 +587,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
window_size_right, window_size_right,
deterministic, deterministic,
tensorType, tensorType,
tensorType}; cudnn_frontend::DataType_t::NOT_SET,
cudnn_frontend::DataType_t::NOT_SET,
cudnn_frontend::DataType_t::NOT_SET};
namespace fe = cudnn_frontend; namespace fe = cudnn_frontend;
using graph_and_tensors = using graph_and_tensors =
......
...@@ -1658,8 +1658,9 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1658,8 +1658,9 @@ void fused_attn_fp8_fwd_impl_v1(
void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK,
void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO,
void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV,
void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type,
void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { cudnn_frontend::DataType_t o_tensor_type, void* workspace, size_t* workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
...@@ -1672,6 +1673,13 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1672,6 +1673,13 @@ void fused_attn_fp8_fwd_impl_v1(
auto bias_h = h; auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF ||
o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16);
bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 ||
o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2);
NVTE_CHECK(is_current_scaling || is_delayed_scaling,
"FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or "
"kFloat8E5M2!");
try { try {
FADescriptor_v1 descriptor{b, FADescriptor_v1 descriptor{b,
...@@ -1699,8 +1707,10 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1699,8 +1707,10 @@ void fused_attn_fp8_fwd_impl_v1(
0, 0,
0, 0,
true, true,
fwd_tensor_type, qkv_tensor_type,
fwd_tensor_type}; o_tensor_type,
cudnn_frontend::DataType_t::NOT_SET,
cudnn_frontend::DataType_t::NOT_SET};
namespace fe = cudnn_frontend; namespace fe = cudnn_frontend;
using graph_and_tensors = using graph_and_tensors =
...@@ -1739,7 +1749,7 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1739,7 +1749,7 @@ void fused_attn_fp8_fwd_impl_v1(
// otherwise, build the op_graph and the plan. Then update cache // otherwise, build the op_graph and the plan. Then update cache
auto mha_graph = std::make_shared<fe::graph::Graph>(); auto mha_graph = std::make_shared<fe::graph::Graph>();
mha_graph->set_io_data_type(fwd_tensor_type) mha_graph->set_io_data_type(qkv_tensor_type)
.set_intermediate_data_type(fe::DataType_t::FLOAT) .set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT); .set_compute_data_type(fe::DataType_t::FLOAT);
...@@ -1787,7 +1797,13 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1787,7 +1797,13 @@ void fused_attn_fp8_fwd_impl_v1(
descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); descale_v = mha_graph->tensor_like(descale_q, "Descale_V");
descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); descale_s = mha_graph->tensor_like(descale_q, "Descale_S");
scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); scale_s = mha_graph->tensor_like(descale_q, "Scale_S");
scale_o = mha_graph->tensor_like(descale_q, "Scale_O");
if (is_delayed_scaling) {
scale_o = mha_graph->tensor_like(descale_q, "Scale_O");
}
if (is_current_scaling) {
scale_o = mha_graph->tensor(1.0f);
}
fe::graph::SDPA_fp8_attributes sdpa_options; fe::graph::SDPA_fp8_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_fp8_attributes() sdpa_options = fe::graph::SDPA_fp8_attributes()
...@@ -1839,11 +1855,12 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1839,11 +1855,12 @@ void fused_attn_fp8_fwd_impl_v1(
std::vector<int64_t> o_stride(4); std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix); NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride).set_data_type(o_tensor_type);
amax_o->set_output(true) amax_o->set_output(true)
.set_dim({1, 1, 1, 1}) .set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT); .set_data_type(fe::DataType_t::FLOAT);
amax_s->set_output(true) amax_s->set_output(true)
.set_dim({1, 1, 1, 1}) .set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
...@@ -1916,13 +1933,16 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1916,13 +1933,16 @@ void fused_attn_fp8_fwd_impl_v1(
{descale_v, devPtrDescaleV}, {descale_v, devPtrDescaleV},
{descale_s, devPtrDescaleS}, {descale_s, devPtrDescaleS},
{scale_s, devPtrScaleS}, {scale_s, devPtrScaleS},
{scale_o, devPtrScaleO},
{attn_scale, &scaling_factor}, {attn_scale, &scaling_factor},
{O, devPtrO}, {O, devPtrO},
{amax_s, devPtrAmaxS}, {amax_s, devPtrAmaxS},
{amax_o, devPtrAmaxO}, {amax_o, devPtrAmaxO},
{Stats, devPtrM}}; {Stats, devPtrM}};
if (is_delayed_scaling) {
variant_pack[scale_o] = devPtrScaleO;
}
/* if (is_bias) { /* if (is_bias) {
variant_pack[bias] = devPtrBias; variant_pack[bias] = devPtrBias;
} */ } */
...@@ -1963,8 +1983,9 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -1963,8 +1983,9 @@ void fused_attn_fp8_bwd_impl_v1(
void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV,
void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV,
void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed,
void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type,
cudnn_frontend::DataType_t bwd_tensor_type, void* workspace, size_t* workspace_size, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type,
cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
...@@ -1978,6 +1999,15 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -1978,6 +1999,15 @@ void fused_attn_fp8_bwd_impl_v1(
auto bias_h = h; auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF ||
dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16);
bool is_delayed_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 ||
dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2);
NVTE_CHECK(is_current_scaling || is_delayed_scaling,
"FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or "
"kFloat8E5M2!");
bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF ||
o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16);
try { try {
FADescriptor_v1 descriptor{b, FADescriptor_v1 descriptor{b,
...@@ -2005,8 +2035,10 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2005,8 +2035,10 @@ void fused_attn_fp8_bwd_impl_v1(
0, 0,
0, 0,
false, false,
fwd_tensor_type, qkv_tensor_type,
bwd_tensor_type}; o_tensor_type,
do_tensor_type,
dqkv_tensor_type};
namespace fe = cudnn_frontend; namespace fe = cudnn_frontend;
using graph_and_tensors = using graph_and_tensors =
...@@ -2059,7 +2091,7 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2059,7 +2091,7 @@ void fused_attn_fp8_bwd_impl_v1(
// otherwise, build the op_graph and the plan. Then update cache // otherwise, build the op_graph and the plan. Then update cache
auto mha_graph = std::make_shared<fe::graph::Graph>(); auto mha_graph = std::make_shared<fe::graph::Graph>();
mha_graph->set_io_data_type(fwd_tensor_type) mha_graph->set_io_data_type(qkv_tensor_type)
.set_intermediate_data_type(fe::DataType_t::FLOAT) .set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT); .set_compute_data_type(fe::DataType_t::FLOAT);
...@@ -2099,7 +2131,8 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2099,7 +2131,8 @@ void fused_attn_fp8_bwd_impl_v1(
o = mha_graph->tensor(fe::graph::Tensor_attributes() o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O") .set_name("O")
.set_dim({b, h, s_q, d}) .set_dim({b, h, s_q, d})
.set_stride(o_stride)); .set_stride(o_stride)
.set_data_type(o_tensor_type));
dO = mha_graph->tensor(fe::graph::Tensor_attributes() dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO") .set_name("dO")
.set_dim({b, h, s_q, d}) .set_dim({b, h, s_q, d})
...@@ -2125,14 +2158,26 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2125,14 +2158,26 @@ void fused_attn_fp8_bwd_impl_v1(
descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); descale_k = mha_graph->tensor_like(descale_q, "Descale_q");
descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); descale_v = mha_graph->tensor_like(descale_q, "Descale_V");
descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); descale_s = mha_graph->tensor_like(descale_q, "Descale_S");
descale_o = mha_graph->tensor_like(descale_q, "Descale_O");
descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP");
if (is_O_in_F16) {
descale_o = mha_graph->tensor(1.0f);
} else {
descale_o = mha_graph->tensor_like(descale_q, "Descale_O");
}
descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO");
scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); scale_s = mha_graph->tensor_like(descale_q, "Scale_S");
scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP");
scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ");
scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); if (is_delayed_scaling) {
scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ");
scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK");
scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV");
}
if (is_current_scaling) {
scale_dQ = mha_graph->tensor(1.0f);
scale_dK = mha_graph->tensor(1.0f);
scale_dV = mha_graph->tensor(1.0f);
}
fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options;
sdpa_backward_options = fe::graph::SDPA_fp8_backward_attributes() sdpa_backward_options = fe::graph::SDPA_fp8_backward_attributes()
...@@ -2214,10 +2259,10 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2214,10 +2259,10 @@ void fused_attn_fp8_bwd_impl_v1(
.set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT); .set_data_type(fe::DataType_t::FLOAT);
dO->set_data_type(bwd_tensor_type); dO->set_data_type(do_tensor_type);
dQ->set_data_type(bwd_tensor_type); dQ->set_data_type(dqkv_tensor_type);
dK->set_data_type(bwd_tensor_type); dK->set_data_type(dqkv_tensor_type);
dV->set_data_type(bwd_tensor_type); dV->set_data_type(dqkv_tensor_type);
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q
std::shared_ptr<fe::graph::Tensor_attributes>, // k std::shared_ptr<fe::graph::Tensor_attributes>, // k
...@@ -2298,14 +2343,10 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2298,14 +2343,10 @@ void fused_attn_fp8_bwd_impl_v1(
{descale_q, devPtrDescaleQ}, {descale_q, devPtrDescaleQ},
{descale_k, devPtrDescaleK}, {descale_k, devPtrDescaleK},
{descale_v, devPtrDescaleV}, {descale_v, devPtrDescaleV},
{descale_o, devPtrDescaleO},
{descale_dO, devPtrDescaledO}, {descale_dO, devPtrDescaledO},
{descale_s, devPtrDescaleS}, {descale_s, devPtrDescaleS},
{descale_dP, devPtrDescaledP}, {descale_dP, devPtrDescaledP},
{scale_s, devPtrScaleS}, {scale_s, devPtrScaleS},
{scale_dQ, devPtrScaledQ},
{scale_dK, devPtrScaledK},
{scale_dV, devPtrScaledV},
{scale_dP, devPtrScaledP}, {scale_dP, devPtrScaledP},
{dQ, devPtrdQ}, {dQ, devPtrdQ},
{dK, devPtrdK}, {dK, devPtrdK},
...@@ -2316,6 +2357,15 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2316,6 +2357,15 @@ void fused_attn_fp8_bwd_impl_v1(
{amax_dP, devPtrAmaxdP}, {amax_dP, devPtrAmaxdP},
}; };
if (is_delayed_scaling) {
variant_pack[scale_dQ] = devPtrScaledQ;
variant_pack[scale_dK] = devPtrScaledK;
variant_pack[scale_dV] = devPtrScaledV;
}
if (!is_O_in_F16) {
variant_pack[descale_o] = devPtrDescaleO;
}
/* if (is_bias) { /* if (is_bias) {
variant_pack[bias] = devPtrBias; variant_pack[bias] = devPtrBias;
if ((bias_b == 1) && (bias_h == h)) { if ((bias_b == 1) && (bias_h == h)) {
...@@ -2366,6 +2416,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma ...@@ -2366,6 +2416,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
cudnnHandle_t handle) { cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const DType QKV_type = input_QKV->data.dtype; const DType QKV_type = input_QKV->data.dtype;
const DType O_type = output_O->data.dtype;
void* devPtrQKV = input_QKV->data.dptr; void* devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0; size_t stride = 0;
...@@ -2432,8 +2483,8 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma ...@@ -2432,8 +2483,8 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
&workspace_size, stream, handle); get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl( fused_attn::fused_attn_fp8_fwd_impl(
batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout,
...@@ -2467,6 +2518,7 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -2467,6 +2518,7 @@ void fused_attn_fp8_bwd_qkvpacked(
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const DType QKV_type = input_QKV->data.dtype; const DType QKV_type = input_QKV->data.dtype;
const DType dO_type = input_dO->data.dtype;
const DType dQKV_type = output_dQKV->data.dtype; const DType dQKV_type = output_dQKV->data.dtype;
void* devPtrQKV = input_QKV->data.dptr; void* devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
...@@ -2484,7 +2536,11 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -2484,7 +2536,11 @@ void fused_attn_fp8_bwd_qkvpacked(
void* devPtrDescaleV = input_QKV->scale_inv.dptr; void* devPtrDescaleV = input_QKV->scale_inv.dptr;
void* devPtrO = input_O->data.dptr; void* devPtrO = input_O->data.dptr;
void* devPtrDescaleO = input_O->scale_inv.dptr; const DType O_type = input_O->data.dtype;
void* devPtrDescaleO = nullptr;
if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) {
devPtrDescaleO = input_O->scale_inv.dptr;
}
void* devPtrdO = input_dO->data.dptr; void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr;
...@@ -2527,7 +2583,8 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -2527,7 +2583,8 @@ void fused_attn_fp8_bwd_qkvpacked(
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl( fused_attn::fused_attn_fp8_bwd_impl(
batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout,
...@@ -2565,6 +2622,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num ...@@ -2565,6 +2622,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype; const DType QKV_type = input_Q->data.dtype;
const DType O_type = output_O->data.dtype;
void* devPtrQ = input_Q->data.dptr; void* devPtrQ = input_Q->data.dptr;
void* devPtrKV = input_KV->data.dptr; void* devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
...@@ -2633,8 +2691,8 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num ...@@ -2633,8 +2691,8 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
&workspace_size, stream, handle); get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl( fused_attn::fused_attn_fp8_fwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale,
...@@ -2671,6 +2729,7 @@ void fused_attn_fp8_bwd_kvpacked( ...@@ -2671,6 +2729,7 @@ void fused_attn_fp8_bwd_kvpacked(
cudnnHandle_t handle) { cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype; const DType QKV_type = input_Q->data.dtype;
const DType dO_type = input_dO->data.dtype;
const DType dQKV_type = output_dQ->data.dtype; const DType dQKV_type = output_dQ->data.dtype;
void* devPtrQ = input_Q->data.dptr; void* devPtrQ = input_Q->data.dptr;
void* devPtrKV = input_KV->data.dptr; void* devPtrKV = input_KV->data.dptr;
...@@ -2688,7 +2747,11 @@ void fused_attn_fp8_bwd_kvpacked( ...@@ -2688,7 +2747,11 @@ void fused_attn_fp8_bwd_kvpacked(
void* devPtrDescaleV = input_KV->scale_inv.dptr; void* devPtrDescaleV = input_KV->scale_inv.dptr;
void* devPtrO = input_O->data.dptr; void* devPtrO = input_O->data.dptr;
void* devPtrDescaleO = input_O->scale_inv.dptr; const DType O_type = input_O->data.dtype;
void* devPtrDescaleO = nullptr;
if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) {
devPtrDescaleO = input_O->scale_inv.dptr;
}
void* devPtrdO = input_dO->data.dptr; void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr;
...@@ -2733,7 +2796,8 @@ void fused_attn_fp8_bwd_kvpacked( ...@@ -2733,7 +2796,8 @@ void fused_attn_fp8_bwd_kvpacked(
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl( fused_attn::fused_attn_fp8_bwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout,
...@@ -2822,6 +2886,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou ...@@ -2822,6 +2886,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_Q->data.dtype; const DType QKV_type = input_Q->data.dtype;
const DType O_type = output_O->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
...@@ -2831,8 +2896,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou ...@@ -2831,8 +2896,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
&workspace_size, stream, handle); get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl( fused_attn::fused_attn_fp8_fwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale,
...@@ -2878,7 +2943,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou ...@@ -2878,7 +2943,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
void* devPtrDescaleV = input_Q->scale_inv.dptr; void* devPtrDescaleV = input_Q->scale_inv.dptr;
void* devPtrO = input_O->data.dptr; void* devPtrO = input_O->data.dptr;
void* devPtrDescaleO = input_O->scale_inv.dptr; const DType O_type = input_O->data.dtype;
void* devPtrDescaleO = nullptr;
if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) {
devPtrDescaleO = input_O->scale_inv.dptr;
}
void* devPtrdO = input_dO->data.dptr; void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr;
...@@ -2911,6 +2980,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou ...@@ -2911,6 +2980,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); reinterpret_cast<void*>(reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType QKV_type = input_Q->data.dtype; const DType QKV_type = input_Q->data.dtype;
const DType dO_type = input_dO->data.dtype;
const DType dQKV_type = output_dQ->data.dtype; const DType dQKV_type = output_dQ->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
...@@ -2924,7 +2994,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou ...@@ -2924,7 +2994,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl( fused_attn::fused_attn_fp8_bwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout,
......
...@@ -111,21 +111,24 @@ struct FADescriptor_v1 { ...@@ -111,21 +111,24 @@ struct FADescriptor_v1 {
std::int64_t window_size_left; std::int64_t window_size_left;
std::int64_t window_size_right; std::int64_t window_size_right;
bool deterministic; bool deterministic;
cudnn_frontend::DataType_t fwd_tensor_type; cudnn_frontend::DataType_t qkv_tensor_type;
cudnn_frontend::DataType_t bwd_tensor_type; cudnn_frontend::DataType_t o_tensor_type;
cudnn_frontend::DataType_t do_tensor_type;
cudnn_frontend::DataType_t dqkv_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const { bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type, window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
bwd_tensor_type) < o_tensor_type, do_tensor_type, dqkv_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
rhs.fwd_tensor_type, rhs.bwd_tensor_type); rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
rhs.dqkv_tensor_type);
} }
}; };
......
...@@ -209,6 +209,7 @@ class DelayedScaling(Recipe): ...@@ -209,6 +209,7 @@ class DelayedScaling(Recipe):
f"margin={self.margin}, " f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}, " f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, " f"amax_history_len={self.amax_history_len}, "
f"reduce_amax={self.reduce_amax}, "
f"fp8_dpa={self.fp8_dpa}, " f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}" f"fp8_mha={self.fp8_mha}"
) )
...@@ -226,10 +227,11 @@ class Float8CurrentScaling(Recipe): ...@@ -226,10 +227,11 @@ class Float8CurrentScaling(Recipe):
pass. pass.
""" """
use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1"
fp8_format: Format = Format.HYBRID fp8_format: Format = Format.HYBRID
fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0) fp8_quant_fwd_inp = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0)
fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0) fp8_quant_fwd_weight = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0)
fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0) fp8_quant_bwd_grad = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0)
fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=False) fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=False)
fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True) fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True)
fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True)
...@@ -238,9 +240,6 @@ class Float8CurrentScaling(Recipe): ...@@ -238,9 +240,6 @@ class Float8CurrentScaling(Recipe):
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
assert (
not self.fp8_dpa and not self.fp8_mha
), "FP8 attention is not supported for Float8CurrentScaling."
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
......
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
from packaging.version import Version as PkgVersion from packaging.version import Version as PkgVersion
import torch import torch
import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
import transformer_engine_torch as tex import transformer_engine_torch as tex
import transformer_engine as te import transformer_engine as te
...@@ -32,11 +33,13 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -32,11 +33,13 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_DO, META_DO,
META_S, META_S,
META_DP, META_DP,
META_O_CP,
META_DQKV_CP,
) )
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
...@@ -44,6 +47,8 @@ from transformer_engine.pytorch.constants import TE_DType ...@@ -44,6 +47,8 @@ from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_device_compute_capability, get_device_compute_capability,
get_cudnn_version, get_cudnn_version,
SplitAlongDim,
combine_tensors,
) )
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
...@@ -54,6 +59,9 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) ...@@ -54,6 +59,9 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 # NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) _NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
# print quantizer info for a particular layer on a particular rank
_print_layer = int(os.getenv("NVTE_PRINT_LAYER_NUMBER", "1"))
_print_rank = int(os.getenv("NVTE_PRINT_RANK", "0"))
_cu_seqlens_cache = {} _cu_seqlens_cache = {}
...@@ -350,8 +358,31 @@ def get_attention_backend( ...@@ -350,8 +358,31 @@ def get_attention_backend(
field.name: getattr(attention_params, field.name) for field in fields(attention_params) field.name: getattr(attention_params, field.name) for field in fields(attention_params)
} }
run_config.update(attention_params_dict) run_config.update(attention_params_dict)
# Add FP8 environment variables to config
if fp8: if fp8:
# all FP8 recipes: 1: (FP8 fwd, FP8 bwd), 0: (FP8 fwd, F16 bwd)
run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
# Float8CurrentScaling: 1: use F16 O in bwd, 0: use FP8 O in bwd
run_config["NVTE_DPA_FP8CS_O_in_F16"] = int(os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1"))
# switch recipe to "F16", "DelayedScaling", or "Float8CurrentScaling"
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
run_config["NVTE_DPA_FP8_RECIPE"] = _dpa_fp8_recipe
if _dpa_fp8_recipe != "":
# config new recipe if switched
run_config["NVTE_DPA_FP8_FORMAT"] = os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID")
run_config["NVTE_DPA_FP8DS_AMAX_ALGO"] = os.getenv(
"NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent"
)
run_config["NVTE_DPA_FP8DS_AMAX_HISTLEN"] = int(
os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1")
)
run_config["NVTE_DPA_FP8DS_REDUCE_AMAX"] = int(
os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1")
)
# UnfusedDotProductAttention: 1: allow FP8 emulation, 0: do not allow
run_config["NVTE_UnfusedDPA_Emulate_FP8"] = int(
os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0")
)
logger.debug("Running with config=%s", run_config) logger.debug("Running with config=%s", run_config)
# The following sections check if `FlashAttention` supports the provided attention params, # The following sections check if `FlashAttention` supports the provided attention params,
...@@ -431,8 +462,20 @@ def get_attention_backend( ...@@ -431,8 +462,20 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for FP8 training") logger.debug("Disabling FlashAttention 3 for FP8 training")
use_flash_attention_3 = False use_flash_attention_3 = False
if use_unfused_attention: if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
use_unfused_attention = False if not allow_emulation:
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention")
use_unfused_attention = False
fp8_recipe = fp8_meta["recipe"]
if fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
if (
use_fused_attention
and fp8_recipe.float8_current_scaling()
and device_compute_capability < (10, 0)
):
logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100")
use_fused_attention = False
# Filter: KV cache # Filter: KV cache
# backend | precision | KV cache | architecture | qkv_format | page_size # backend | precision | KV cache | architecture | qkv_format | page_size
...@@ -1875,11 +1918,10 @@ def check_set_window_size( ...@@ -1875,11 +1918,10 @@ def check_set_window_size(
return window_size return window_size
def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): def get_attention_quantizers(fp8, quantizers):
"""Get the list of quantizers used in attention from the quantizers list.""" """Get the list of quantizers used in attention from the quantizers list."""
if not fp8: if not fp8:
num_of_nones = 8 if cp_specific_quantizers else 6 return [None] * 6
return [None] * num_of_nones
QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
QKV_quantizer.internal = True QKV_quantizer.internal = True
QKV_quantizer.set_usage(rowwise=True, columnwise=False) QKV_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -1888,6 +1930,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): ...@@ -1888,6 +1930,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
S_quantizer = quantizers["scaling_fwd"][META_S] S_quantizer = quantizers["scaling_fwd"][META_S]
S_quantizer.internal = True S_quantizer.internal = True
S_quantizer.set_usage(rowwise=True, columnwise=False) S_quantizer.set_usage(rowwise=True, columnwise=False)
dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV]
dQKV_quantizer.interal = True dQKV_quantizer.interal = True
dQKV_quantizer.set_usage(rowwise=True, columnwise=False) dQKV_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -1897,22 +1940,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): ...@@ -1897,22 +1940,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False):
dP_quantizer = quantizers["scaling_bwd"][META_DP] dP_quantizer = quantizers["scaling_bwd"][META_DP]
dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.set_usage(rowwise=True, columnwise=False)
dP_quantizer.interal = True dP_quantizer.interal = True
dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP]
dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False) return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
dQKV_CP_quantizer.internal = True
O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP]
O_CP_quantizer.set_usage(rowwise=True, columnwise=False) def print_quantizers(
label,
if cp_specific_quantizers: layer_number,
return ( QKV_quantizer,
O_quantizer,
S_quantizer,
dQKV_quantizer,
dO_quantizer,
dP_quantizer,
):
"""Print the type and scale/amax of attention quantizers"""
_to_print = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL == 2
if (
_to_print
and _print_layer == layer_number
and (
not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() == _print_rank)
)
):
names = [
"QKV_quantizer",
"S_quantizer",
"O_quantizer",
"dO_quantizer",
"dP_quantizer",
"dQKV_quantizer",
]
quantizers = [
QKV_quantizer, QKV_quantizer,
O_quantizer,
O_CP_quantizer,
S_quantizer, S_quantizer,
dQKV_quantizer, O_quantizer,
dQKV_CP_quantizer,
dO_quantizer, dO_quantizer,
dP_quantizer, dP_quantizer,
) dQKV_quantizer,
]
if "forward" in label:
names = names[:3]
quantizers = quantizers[:3]
if "backward" in label:
names = names[3:]
quantizers = quantizers[3:]
for i, q in enumerate(quantizers):
type_str = ""
if q is None:
type_str = "None"
elif isinstance(q, Float8Quantizer):
type_str = "DS"
elif isinstance(q, Float8CurrentScalingQuantizer):
type_str = "CS"
print(
f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x"
f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}"
)
return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer
def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer):
"""Combine q,k,v based on qkv_layout and quantize them together"""
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_layout = qkv_layout.replace("paged_kv_", "")
qkv_group = len(qkv_layout.split("_"))
src_nominal_dtype = q.dtype
match qkv_group:
case 1:
dim = qkv_layout.find("3")
qkv = combine_tensors([q, k, v], dim)
qkv_fp8 = qkv_quantizer(qkv)
q_data, k_data, v_data = SplitAlongDim.apply(qkv_fp8._data, dim, [1, 1, 1], True)
case 2:
dim = qkv_layout.split("_")[1].find("2")
kv = combine_tensors([k, v], dim)
tensors = [q, kv]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv = torch.cat([x.view(-1) for x in tensors], dim=0)
qkv_fp8 = qkv_quantizer(qkv)
q_data, kv_data = [
qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)
]
k_data, v_data = SplitAlongDim.apply(kv_data, dim, [1, 1], True)
case 3:
tensors = [q, k, v]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv = torch.cat([x.view(-1) for x in tensors], dim=0)
qkv_fp8 = qkv_quantizer(qkv)
q_data, k_data, v_data = [
qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)
]
case _:
raise RuntimeError("Invalid qkv_layout " + qkv_layout)
q_fp8, k_fp8, v_fp8 = [
Float8Tensor.make_like(qkv_fp8, data=x, dtype=src_nominal_dtype)
for x in [q_data, k_data, v_data]
]
return q_fp8, k_fp8, v_fp8
def combine_and_dequantize(
qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=None, des_nominal_dtype=None
):
"""Combine q,k,v based on qkv_layout and dequantize them together"""
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_layout = qkv_layout.replace("paged_kv_", "")
qkv_group = len(qkv_layout.split("_"))
if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]):
src_nominal_dtype = q_fp8.dtype
else:
assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!"
if des_nominal_dtype is None:
des_nominal_dtype = src_nominal_dtype
q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]]
match qkv_group:
case 1:
dim = qkv_layout.find("3")
qkv_data = combine_tensors([q_data, k_data, v_data], dim)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, k, v = SplitAlongDim.apply(qkv, dim, [1, 1, 1], True)
case 2:
dim = qkv_layout.split("_")[1].find("2")
kv_data = combine_tensors([k_data, v_data], dim)
tensors = [q_data, kv_data]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv_data = torch.cat([x.reshape(-1) for x in tensors], dim=0)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, kv = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)]
k, v = SplitAlongDim.apply(kv, dim, [1, 1], True)
case 3:
tensors = [q_data, k_data, v_data]
num_tensors = len(tensors)
shapes = [x.shape for x in tensors]
numels = [x.numel() for x in tensors]
numels = [sum(numels[:i]) for i in range(num_tensors + 1)]
qkv_data = torch.cat([x.contiguous().reshape(-1) for x in tensors], dim=0)
qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype)
qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype)
q, k, v = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)]
case _:
raise RuntimeError("Invalid qkv_layout " + qkv_layout)
return q, k, v
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Multi-head Attention.""" """Multi-head Attention."""
import os
import collections import collections
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
...@@ -31,7 +32,13 @@ from transformer_engine.pytorch.distributed import ( ...@@ -31,7 +32,13 @@ from transformer_engine.pytorch.distributed import (
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
# Force DotProductAttention to use a different recipe than the fp8_recipe set in fp8_autocast().
# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling"
# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa.
_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "")
_dpa_fp8_recipe_dpa = os.getenv("NVTE_DPA_FP8_RECIPE_DPA", "0") == "1"
_dpa_fp8_recipe_mha = os.getenv("NVTE_DPA_FP8_RECIPE_MHA", "0") == "1"
class MultiheadAttention(torch.nn.Module): class MultiheadAttention(torch.nn.Module):
...@@ -570,10 +577,12 @@ class MultiheadAttention(torch.nn.Module): ...@@ -570,10 +577,12 @@ class MultiheadAttention(torch.nn.Module):
self.cp_size = get_distributed_world_size(cp_group) self.cp_size = get_distributed_world_size(cp_group)
self.cp_rank = get_distributed_rank(cp_group) self.cp_rank = get_distributed_rank(cp_group)
elif isinstance(cp_group, list): elif isinstance(cp_group, list):
assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
assert ( assert (
cp_comm_type == "a2a+p2p" cp_comm_type == "a2a+p2p"
), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!" ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
assert (
len(cp_group) == 2
), "cp_comm_type = a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!"
cp_size_a2a = get_distributed_world_size(cp_group[0]) cp_size_a2a = get_distributed_world_size(cp_group[0])
cp_rank_a2a = get_distributed_rank(cp_group[0]) cp_rank_a2a = get_distributed_rank(cp_group[0])
cp_size_p2p = get_distributed_world_size(cp_group[1]) cp_size_p2p = get_distributed_world_size(cp_group[1])
...@@ -730,10 +739,22 @@ class MultiheadAttention(torch.nn.Module): ...@@ -730,10 +739,22 @@ class MultiheadAttention(torch.nn.Module):
# Query, Key, and Value # Query, Key, and Value
# ====================== # ======================
fp8_mha = ( fp8 = FP8GlobalStateManager.is_fp8_enabled()
FP8GlobalStateManager.is_fp8_enabled() if _dpa_fp8_recipe == "":
and FP8GlobalStateManager.get_fp8_recipe().fp8_mha fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
) fp8_dpa = fp8_recipe.fp8_dpa
fp8_mha = fp8_recipe.fp8_mha
float8_current_scaling = fp8_recipe.float8_current_scaling()
else:
fp8_dpa = _dpa_fp8_recipe_dpa
fp8_mha = _dpa_fp8_recipe_mha
float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling"
# QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe
qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling
# DPA: always produce FP8 output when fp8=True to take advantage of the O amax
dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha)
# Proj Gemm: match DPA output except for Float8CurrentScaling
proj_fp8_grad = dpa_fp8_output and not float8_current_scaling
layernorm_output = None layernorm_output = None
if self.attention_type == "self": if self.attention_type == "self":
...@@ -742,7 +763,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -742,7 +763,7 @@ class MultiheadAttention(torch.nn.Module):
layernorm_qkv_outputs = self.layernorm_qkv( layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
if self.return_layernorm_output: if self.return_layernorm_output:
mixed_x_layer, layernorm_output = layernorm_qkv_outputs mixed_x_layer, layernorm_output = layernorm_qkv_outputs
...@@ -752,7 +773,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -752,7 +773,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_x_layer = self.qkv( mixed_x_layer = self.qkv(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
num_queries_per_key_value = ( num_queries_per_key_value = (
...@@ -806,7 +827,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -806,7 +827,7 @@ class MultiheadAttention(torch.nn.Module):
mixed_kv_layer = self.key_value( mixed_kv_layer = self.key_value(
encoder_output, encoder_output,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
if self.qkv_weight_interleaved: if self.qkv_weight_interleaved:
...@@ -861,7 +882,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -861,7 +882,7 @@ class MultiheadAttention(torch.nn.Module):
layernorm_query_outputs = self.layernorm_query( layernorm_query_outputs = self.layernorm_query(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
if self.return_layernorm_output: if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs query_layer, layernorm_output = layernorm_query_outputs
...@@ -871,7 +892,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -871,7 +892,7 @@ class MultiheadAttention(torch.nn.Module):
query_layer = self.query_layer( query_layer = self.query_layer(
hidden_states, hidden_states,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_output=fp8_mha and rotary_pos_emb is None, fp8_output=qkv_fp8_output,
) )
# [sq, b, hp] --> [sq, b, np, hn] # [sq, b, hp] --> [sq, b, np, hn]
...@@ -972,6 +993,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -972,6 +993,7 @@ class MultiheadAttention(torch.nn.Module):
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
inference_params=inference_params, inference_params=inference_params,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
fp8_output=dpa_fp8_output,
) )
# =================== # ===================
...@@ -980,7 +1002,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -980,7 +1002,7 @@ class MultiheadAttention(torch.nn.Module):
projection_output = self.proj( projection_output = self.proj(
context_layer, context_layer,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
fp8_grad=isinstance(context_layer, QuantizedTensor), fp8_grad=proj_fp8_grad,
) )
if self.return_bias: if self.return_bias:
......
...@@ -109,9 +109,6 @@ META_O = tex.FP8FwdTensors.GEMM2_INPUT ...@@ -109,9 +109,6 @@ META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3 META_DP = tex.FP8BwdTensors.GRAD_INPUT3
# repurpose some unused amax history buffers for partial results of CP fwd and bwd
META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT
META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1
def fused_attn_fwd( def fused_attn_fwd(
......
...@@ -201,7 +201,7 @@ class Float8CurrentScalingQuantizer : public Quantizer { ...@@ -201,7 +201,7 @@ class Float8CurrentScalingQuantizer : public Quantizer {
* amax to be initialized to zero. * amax to be initialized to zero.
*/ */
std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax( std::pair<TensorWrapper, py::object> create_unquantized_tensor_with_amax(
const std::vector<size_t>& shape, DType dtype); const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> data = std::nullopt);
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override; std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
......
...@@ -78,6 +78,11 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -78,6 +78,11 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right); int64_t window_size_right);
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype,
bool create_hp_tensor_for_cs,
std::optional<at::Tensor> data);
std::vector<py::object> fused_attn_fwd( std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......
...@@ -53,6 +53,47 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -53,6 +53,47 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
return fused_attention_backend; return fused_attention_backend;
} }
// helper function for S and dP quantizers
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype,
bool create_hp_tensor_for_cs,
std::optional<at::Tensor> data) {
std::unique_ptr<Quantizer> T_quantizer = convert_quantizer(quantizer);
TensorWrapper te_T;
py::object py_T;
if (quantizer.is_none()) {
// high precision
auto *none_quantizer = dynamic_cast<NoneQuantizer *>(T_quantizer.get());
if (data.has_value()) {
std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype, data.value());
} else {
std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype);
}
} else if (detail::IsFloat8Quantizers(quantizer.ptr())) {
// delayed scaling; this helps initialize scale_inv
auto *T_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(T_quantizer.get());
std::tie(te_T, py_T) =
T_quantizer_fp8->create_tensor(shape, dtype, data, std::nullopt, std::nullopt);
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// current scaling
auto *T_quantizer_fp8 = dynamic_cast<Float8CurrentScalingQuantizer *>(T_quantizer.get());
if (create_hp_tensor_for_cs) {
if (data.has_value()) {
std::tie(te_T, py_T) =
T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value());
} else {
std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype);
}
} else {
std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype);
NVTE_CHECK(
!data.has_value(),
"Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!");
}
}
return {std::move(te_T), std::move(py_T)};
}
// fused attention FWD with separate Q, K and V tensors // fused attention FWD with separate Q, K and V tensors
std::vector<py::object> fused_attn_fwd( std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
...@@ -66,44 +107,30 @@ std::vector<py::object> fused_attn_fwd( ...@@ -66,44 +107,30 @@ std::vector<py::object> fused_attn_fwd(
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen, const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) { size_t rng_elts_per_thread) {
TensorWrapper te_Q, te_K, te_V, te_O, te_S;
auto none = py::none(); auto none = py::none();
std::unique_ptr<Quantizer> S_quantizer = convert_quantizer(s_quantizer);
std::unique_ptr<Quantizer> O_quantizer = convert_quantizer(o_quantizer);
// create QKV tensor wrappers
TensorWrapper te_Q, te_K, te_V;
te_Q = makeTransformerEngineTensor(Q, none); te_Q = makeTransformerEngineTensor(Q, none);
te_K = makeTransformerEngineTensor(K, none); te_K = makeTransformerEngineTensor(K, none);
te_V = makeTransformerEngineTensor(V, none); te_V = makeTransformerEngineTensor(V, none);
// If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types.
const DType qkv_type = te_Q.dtype(); const DType qkv_type = te_Q.dtype();
const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
// create S tensor
TensorWrapper te_S;
py::object py_S;
std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt);
// create O tensor
TensorWrapper te_O;
py::object py_O;
std::unique_ptr<Quantizer> O_quantizer = convert_quantizer(o_quantizer);
std::vector<size_t> q_shape = convertShape(te_Q.shape()); std::vector<size_t> q_shape = convertShape(te_Q.shape());
std::vector<size_t> k_shape = convertShape(te_K.shape());
std::vector<size_t> v_shape = convertShape(te_V.shape()); std::vector<size_t> v_shape = convertShape(te_V.shape());
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
// create output tensor O
auto o_shape = std::vector<size_t>{q_shape.begin(), q_shape.end()}; auto o_shape = std::vector<size_t>{q_shape.begin(), q_shape.end()};
o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1];
py::object o_python, s_python; const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt);
// Initialize FP8 tensor with scale-inverse
auto *O_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(O_quantizer.get());
auto *S_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(S_quantizer.get());
NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt,
std::nullopt, std::nullopt);
std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
} else {
std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te);
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32);
}
auto o_shape_int64 = std::vector<int64_t>{o_shape.begin(), o_shape.end()};
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_Bias; TensorWrapper te_Bias;
...@@ -114,11 +141,12 @@ std::vector<py::object> fused_attn_fwd( ...@@ -114,11 +141,12 @@ std::vector<py::object> fused_attn_fwd(
// FP8 // FP8
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1]; auto d = q_shape[q_shape.size() - 1];
if (set_zero && ((h * d) % block_size == 0) && if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
(nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { if ((h * d) % block_size == 0) {
mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
te_O.zero_(at::cuda::getCurrentCUDAStream()); te_O.zero_(at::cuda::getCurrentCUDAStream());
}
} }
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
...@@ -181,7 +209,8 @@ std::vector<py::object> fused_attn_fwd( ...@@ -181,7 +209,8 @@ std::vector<py::object> fused_attn_fwd(
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options);
philox_unpack(philox_args, static_cast<int64_t *>(rng_state.data_ptr())); philox_unpack(philox_args, static_cast<int64_t *>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state); auto te_rng_state = makeTransformerEngineTensor(rng_state);
...@@ -210,7 +239,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -210,7 +239,7 @@ std::vector<py::object> fused_attn_fwd(
// output_tensors = [O, nvte_aux_tensor_pack.tensors] // output_tensors = [O, nvte_aux_tensor_pack.tensors]
std::vector<py::object> output_tensors; std::vector<py::object> output_tensors;
output_tensors.push_back(o_python); output_tensors.push_back(py_O);
auto set_tensor_param = [&](size_t i, const at::Tensor &output_tensor) { auto set_tensor_param = [&](size_t i, const at::Tensor &output_tensor) {
output_tensors.push_back(py::cast(output_tensor)); output_tensors.push_back(py::cast(output_tensor));
NVTEBasicTensor temp_data = {output_tensor.data_ptr(), NVTEBasicTensor temp_data = {output_tensor.data_ptr(),
...@@ -280,50 +309,44 @@ std::vector<py::object> fused_attn_bwd( ...@@ -280,50 +309,44 @@ std::vector<py::object> fused_attn_bwd(
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer) { py::handle dp_quantizer, py::handle dqkv_quantizer) {
auto none = py::none(); auto none = py::none();
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
// create QKV, O, dO tensor wrappers
TensorWrapper te_Q, te_K, te_V, te_O, te_dO;
te_Q = makeTransformerEngineTensor(Q, none); te_Q = makeTransformerEngineTensor(Q, none);
te_K = makeTransformerEngineTensor(K, none); te_K = makeTransformerEngineTensor(K, none);
te_V = makeTransformerEngineTensor(V, none); te_V = makeTransformerEngineTensor(V, none);
te_O = makeTransformerEngineTensor(O, none); te_O = makeTransformerEngineTensor(O, none);
te_dO = makeTransformerEngineTensor(dO, none); te_dO = makeTransformerEngineTensor(dO, none);
// qkv type from the te_Q
std::unique_ptr<Quantizer> dQKV_quantizer = convert_quantizer(dqkv_quantizer);
const DType qkv_type = te_Q.dtype();
const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
py::object s_python, dp_python; // create S and dP tensors
std::unique_ptr<Quantizer> S_quantizer = convert_quantizer(s_quantizer); TensorWrapper te_S, te_dP;
std::unique_ptr<Quantizer> dP_quantizer = convert_quantizer(dp_quantizer); py::object py_S, py_dP;
std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt);
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { std::tie(te_dP, py_dP) =
auto *S_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(S_quantizer.get()); quantizer_helper(dp_quantizer, {0}, DType::kFloat32, false, std::nullopt);
auto *dP_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(dP_quantizer.get());
NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
} else {
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32);
std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32);
}
// create dQ, dK, dV tensors
TensorWrapper te_dQ, te_dK, te_dV;
py::object py_dQ, py_dK, py_dV;
std::unique_ptr<Quantizer> dQKV_quantizer = convert_quantizer(dqkv_quantizer);
std::vector<size_t> q_shape = convertShape(te_Q.shape()); std::vector<size_t> q_shape = convertShape(te_Q.shape());
std::vector<size_t> k_shape = convertShape(te_K.shape()); std::vector<size_t> k_shape = convertShape(te_K.shape());
std::vector<size_t> v_shape = convertShape(te_V.shape()); std::vector<size_t> v_shape = convertShape(te_V.shape());
auto h_q = q_shape[q_shape.size() - 2]; auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2];
auto d_qk = q_shape[q_shape.size() - 1]; auto d_qk = q_shape[q_shape.size() - 1];
auto d_v = v_shape[v_shape.size() - 1]; const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype);
auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
std::vector<size_t> o_shape{q_shape.begin(), q_shape.end()};
o_shape[o_shape.size() - 1] = d_v;
at::Tensor dQ, dK, dV, dQKV, dKV; at::Tensor dQ, dK, dV, dQKV, dKV;
py::object py_dQ, py_dK, py_dV;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
std::vector<int64_t> tmp_shape; std::vector<int64_t> tmp_shape;
auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA);
if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) {
options = options.dtype(torch::kUInt8);
}
if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) {
options = options.dtype(fake_dtype);
}
switch (layout_group) { switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_3HD: case NVTE_QKV_Layout_Group::NVTE_3HD:
...@@ -396,39 +419,27 @@ std::vector<py::object> fused_attn_bwd( ...@@ -396,39 +419,27 @@ std::vector<py::object> fused_attn_bwd(
default: default:
NVTE_ERROR("QKV layout not supported!"); NVTE_ERROR("QKV layout not supported!");
} }
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
auto *fp8_quantizer = dynamic_cast<Float8Quantizer *>(dQKV_quantizer.get()); std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ);
NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8"); std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK);
std::tie(te_dQ, py_dQ) = std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV);
fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt);
std::tie(te_dK, py_dK) =
fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt);
std::tie(te_dV, py_dV) =
fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt);
} else {
auto *none_quantizer = dynamic_cast<NoneQuantizer *>(dQKV_quantizer.get());
NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8");
std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ);
std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK);
std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV);
}
// construct NVTE tensors // construct NVTE tensors
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) {
dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) &&
(nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) {
mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)}));
} else { } else {
dQ.fill_(0); dQ.fill_(0);
dK.fill_(0); dK.fill_(0);
dV.fill_(0); dV.fill_(0);
}
} }
} else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) {
} else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) {
if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) {
dQ.fill_(0); dQ.fill_(0);
dK.fill_(0); dK.fill_(0);
...@@ -605,7 +616,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s ...@@ -605,7 +616,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
// Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1 // Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1
int seq_dim = tensor.dim() == 3 ? 0 : 1; int seq_dim = tensor.dim() == 3 ? 0 : 1;
int batch = cu_seqlens.size(0) - 1;
int num_heads = tensor.size(seq_dim + 1); int num_heads = tensor.size(seq_dim + 1);
int dim_per_head = tensor.size(seq_dim + 2); int dim_per_head = tensor.size(seq_dim + 2);
int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type()); int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type());
...@@ -769,8 +779,6 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t ...@@ -769,8 +779,6 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
NVTE_CHECK(world_size > 0); NVTE_CHECK(world_size > 0);
NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0); NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0);
int batch = cu_seqlens.size(0) - 1;
std::vector<int64_t> shape = {total_tokens / world_size}; std::vector<int64_t> shape = {total_tokens / world_size};
at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int)); at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int));
...@@ -808,7 +816,6 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, ...@@ -808,7 +816,6 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b,
**************************************************************************************************/ **************************************************************************************************/
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) { at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) {
int max_seq_len = tensor.size(1);
int h = tensor.size(2); int h = tensor.size(2);
int d = tensor.size(3); int d = tensor.size(3);
std::vector<int64_t> shape = {t, h, d}; std::vector<int64_t> shape = {t, h, d};
......
...@@ -37,7 +37,18 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob ...@@ -37,7 +37,18 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
// Convert input tensor to C++ object // Convert input tensor to C++ object
auto input_contiguous = tensor.contiguous(); auto input_contiguous = tensor.contiguous();
const auto input_cpp = makeTransformerEngineTensor(input_contiguous); auto input_cpp = makeTransformerEngineTensor(input_contiguous);
// Set amax if use_existing_amax = true (only valid for CS)
bool use_existing_amax = false;
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
use_existing_amax = quantizer.attr("use_existing_amax").cast<bool>();
if (use_existing_amax) {
const at::Tensor &amax = quantizer.attr("amax").cast<at::Tensor>();
input_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax));
}
}
// Initialize output tensor // Initialize output tensor
TensorWrapper output_cpp; TensorWrapper output_cpp;
...@@ -57,7 +68,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob ...@@ -57,7 +68,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
} }
// Perform quantization // Perform quantization
quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); if (use_existing_amax) {
auto *quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
quantizer_cs->quantize_with_amax(input_cpp, output_cpp, noop_flag_cpp);
} else {
quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp);
}
return output_py; return output_py;
} }
......
...@@ -390,9 +390,13 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso ...@@ -390,9 +390,13 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
std::pair<TensorWrapper, py::object> std::pair<TensorWrapper, py::object>
Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector<size_t>& shape, Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector<size_t>& shape,
DType dtype) { DType dtype,
std::optional<at::Tensor> data) {
amax.zero_(); amax.zero_();
auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value())
: NoneQuantizer(py::none()).create_tensor(shape, dtype);
TensorWrapper out_cpp = std::move(out.first);
py::object out_py = std::move(out.second);
out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax)); getTensorShape(amax));
return {std::move(out_cpp), std::move(out_py)}; return {std::move(out_cpp), std::move(out_py)};
......
...@@ -970,7 +970,9 @@ class Float8CurrentScalingRecipeState(RecipeState): ...@@ -970,7 +970,9 @@ class Float8CurrentScalingRecipeState(RecipeState):
from .tensor.float8_tensor import Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8CurrentScalingQuantizer
return [ return [
Float8CurrentScalingQuantizer(self.dtype, device=self.device) Float8CurrentScalingQuantizer(
self.dtype, device=self.device, force_pow_2_scales=self.recipe.use_power_2_scales
)
for i in range(self.num_quantizers) for i in range(self.num_quantizers)
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment