Unverified Commit 901e5d2b authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add support for flash-attn 3 (#1019)



* WIP: add fa3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: add benchmarks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* differentiate func/varlen_func
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix parsing keyword for FA3 and remove bshd->thd conversion for flash_attn_func
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add FP8 fwd support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add FA3 FP8 fwd code and test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix assert for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix FA3 FP8 logic and add tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA2 to <=2.6.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak unit tests for base/mask
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* set constraints for FA3 for sm90 and causal_bottom_right
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert debug changes in benchmark script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2215fa5c
Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b Subproject commit 98ca4e1941fe3263f128f74f10063a3ea35c7019
...@@ -156,7 +156,7 @@ def parse_results(per_cudnn, per_flash, model): ...@@ -156,7 +156,7 @@ def parse_results(per_cudnn, per_flash, model):
df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6 df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6
if per_flash > 0: if per_flash > 0:
t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy() t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash) t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0) t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6 df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
......
...@@ -92,7 +92,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -92,7 +92,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"]) install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"])
test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
if "jax" in frameworks: if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"]) install_reqs.extend(["jax", "flax>=0.7.1"])
......
...@@ -420,6 +420,10 @@ model_configs_mask = { ...@@ -420,6 +420,10 @@ model_configs_mask = {
"mask_8_1": ModelConfig( "mask_8_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" 1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
), ),
"mask_9_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"),
"mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"),
"mask_10_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
} }
...@@ -1301,6 +1305,7 @@ model_configs_fp8_vs_f16 = { ...@@ -1301,6 +1305,7 @@ model_configs_fp8_vs_f16 = {
"fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), "fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
"fp8_15": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
} }
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
...@@ -1312,6 +1317,27 @@ def _rmse(a, b): ...@@ -1312,6 +1317,27 @@ def _rmse(a, b):
return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum()) return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
try:
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert rmse < rmse_tol * rmse_range, (
name_a
+ " vs "
+ name_b
+ " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
...@@ -1320,86 +1346,74 @@ def _rmse(a, b): ...@@ -1320,86 +1346,74 @@ def _rmse(a, b):
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False]) @pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): @pytest.mark.parametrize("is_training", [True, False])
os.environ["NVTE_FLASH_ATTN"] = "0" def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, is_training):
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
_attention_backends["backend_selection_requires_update"] = True
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" global _attention_backends
if not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, is_training
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm dtype, config, True, qkv_format, input_layernorm, is_training
) )
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm dtype, config, False, qkv_format, input_layernorm, is_training
)
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
) )
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
logging.debug( if not is_training:
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format( _error(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item() flash_attn_fwd_fp8,
) fused_attn_fwd_f16,
) "flash_attn_fwd_fp8",
logging.debug( "fused_attn_fwd_f16",
"fused_attn_fwd_f16 min {:.6f} max {:.6f}".format( atol,
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item() rtol,
rmse_tol,
) )
_error(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
) )
logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
assert ( if is_training:
fwd_rmse < rmse_tol * fwd_range for i in range(len(param_names[:1])):
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( logging.debug("========== {:^25s} ==========".format(param_names[i]))
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range _error(
) fused_attn_bwd_fp8[i],
for i in range(len(param_names[:1])): fused_attn_bwd_f16[i],
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i]) f"fused_attn_bwd_fp8[{i}]",
bwd_range = max( f"fused_attn_bwd_f16[{i}]",
fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item() atol,
) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item()) rtol,
rmse_tol,
logging.debug("========== {:^25s} ==========".format(param_names[i]))
logging.debug(
"fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
)
)
logging.debug(
"fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
) )
)
logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
logging.debug(e)
assert (
bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_training):
reset_rng_states() reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
...@@ -1434,6 +1448,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1434,6 +1448,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
qkv_weight_interleaved=True, qkv_weight_interleaved=True,
qkv_format=qkv_format, qkv_format=qkv_format,
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
if not is_training:
mha = mha.eval()
seqlens_q = torch.full( seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
...@@ -1464,7 +1480,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1464,7 +1480,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda") tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
hidden_states = tensor.view(*tensor.shape[:-2], -1) hidden_states = tensor.view(*tensor.shape[:-2], -1)
hidden_states.requires_grad = True if is_training:
hidden_states.requires_grad = True
tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
out_grad = tensor.view(*tensor.shape[:-2], -1) out_grad = tensor.view(*tensor.shape[:-2], -1)
...@@ -1476,7 +1493,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1476,7 +1493,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None, is_first_microbatch=None,
) )
out.backward(out_grad) if is_training:
out.backward(out_grad)
param_names = [] param_names = []
param_names.append("hidden_states.grad") param_names.append("hidden_states.grad")
...@@ -1487,7 +1505,9 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1487,7 +1505,9 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
param_names.append(name + ".grad") param_names.append(name + ".grad")
params.append(param) params.append(param)
return out, param_names, tuple(x.grad for x in params) if is_training:
return out, param_names, tuple(x.grad for x in params)
return out, param_names, tuple(None for x in params)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
...@@ -1497,7 +1517,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1497,7 +1517,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): @pytest.mark.parametrize("is_training", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
...@@ -1505,76 +1526,69 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): ...@@ -1505,76 +1526,69 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends global _attention_backends
_attention_backends["backend_selection_requires_update"] = True if not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout) fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
)
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(dtype, config, False, qkv_layout) fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
dtype, config, False, qkv_layout, is_training
)
tols = dict(atol=5e-1, rtol=5e-2) atol = 5e-1
rtol = 5e-2
rmse_tol = 0.1 rmse_tol = 0.1
bwd_names = ["dq", "dk", "dv"] bwd_names = ["dq", "dk", "dv"]
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
)
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
logging.debug( if not is_training:
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format( _error(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item() flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
) )
_error(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
) )
logging.debug( if is_training:
"fused_attn_fwd_f16 min {:.6f} max {:.6f}".format( for i, _ in enumerate(fused_attn_bwd_f16):
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item() logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
) _error(
) fused_attn_bwd_fp8[i],
logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse)) fused_attn_bwd_f16[i],
try: f"fused_attn_bwd_fp8[{i}]",
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) f"fused_attn_bwd_f16[{i}]",
except Exception as e: atol,
logging.debug(e) rtol,
rmse_tol,
assert (
fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
)
for i, _ in enumerate(fused_attn_bwd_f16):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
bwd_range = max(
fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
logging.debug(
"fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
)
)
logging.debug(
"fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
) )
)
logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
logging.debug(e)
assert (
bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
reset_rng_states() reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
...@@ -1607,6 +1621,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): ...@@ -1607,6 +1621,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
attention_type="self", attention_type="self",
qkv_format=qkv_format, qkv_format=qkv_format,
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
if not is_training:
dpa = dpa.eval()
seqlens_q = torch.full( seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
...@@ -1680,9 +1696,12 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): ...@@ -1680,9 +1696,12 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=True, is_first_microbatch=True,
) )
out.backward(out_grad) if is_training:
out.backward(out_grad)
return out, (inp[0].grad, inp[1].grad, inp[2].grad) if is_training:
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
return out, (None, None, None)
model_configs_fp8 = { model_configs_fp8 = {
...@@ -1726,58 +1745,26 @@ def test_custom_mha_fp8_vs_f16(dtype, model): ...@@ -1726,58 +1745,26 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
tols = dict(atol=5e-1, rtol=5e-1) atol = 5e-1
rtol = 5e-1
rmse_tol = 0.1 rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16) _error(
fwd_range = max(fused_attn_fwd_fp8.max().item(), unfused_attn_fwd_f16.max().item()) - min( fused_attn_fwd_fp8,
fused_attn_fwd_fp8.min().item(), unfused_attn_fwd_f16.min().item() unfused_attn_fwd_f16,
) "fused_attn_fwd_fp8",
bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16) "unfused_attn_fwd_f16",
bwd_range = max(fused_attn_bwd_fp8.max().item(), unfused_attn_bwd_f16.max().item()) - min( atol,
fused_attn_bwd_fp8.min().item(), unfused_attn_bwd_f16.min().item() rtol,
) rmse_tol,
logging.debug(
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
)
)
logging.debug(
"unfused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item()
)
)
logging.debug("fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
logging.debug(
"fused_attn_bwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item()
)
)
logging.debug(
"unfused_attn_bwd_f16 min {:.6f} max {:.6f}".format(
unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item()
)
)
logging.debug("fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}".format(bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols)
except Exception as e:
logging.debug(e)
assert (
fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
) )
assert ( _error(
bwd_rmse < rmse_tol * bwd_range fused_attn_bwd_fp8,
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( unfused_attn_bwd_f16,
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range "fused_attn_bwd_fp8",
"unfused_attn_bwd_f16",
atol,
rtol,
rmse_tol,
) )
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import collections import collections
from contextlib import nullcontext from contextlib import nullcontext
from importlib.metadata import version as get_pkg_version from importlib.metadata import version as get_pkg_version
from importlib.metadata import PackageNotFoundError
import math import math
import os import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
...@@ -38,7 +39,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -38,7 +39,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
AttnMaskType, AttnMaskType,
FusedAttnBackend, FusedAttnBackend,
) )
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.fp8 import get_fp8_te_dtype, get_fp8_torch_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
...@@ -75,16 +76,42 @@ from transformer_engine.pytorch.graph import is_graph_capturing ...@@ -75,16 +76,42 @@ from transformer_engine.pytorch.graph import is_graph_capturing
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) _flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6") _flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.5.8") _flash_attn_max_version = PkgVersion("2.6.3")
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
_flash_attn_3_plus = False
_use_flash_attn_3 = False
try:
_flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1")
except PackageNotFoundError:
warnings.warn(
"To use flash-attn v3, please use the following commands to install: \n"
"""(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n"""
"""(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n"""
"""(3) mkdir -p $python_path/flashattn_hopper \n"""
"""(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
)
else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flashattn_hopper.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_forward as _flash_attn_forward_v3,
)
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_backward as _flash_attn_backward_v3,
)
_use_flash_attn_3 = True
if _flash_attn_version >= _flash_attn_version_required: if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
...@@ -318,6 +345,7 @@ def get_attention_backend( ...@@ -318,6 +345,7 @@ def get_attention_backend(
use_fused_attention = False use_fused_attention = False
# Filter: Compute capability # Filter: Compute capability
global _flash_attn_3_plus, _use_flash_attn_3
if device_compute_capability < (8, 0): if device_compute_capability < (8, 0):
if use_flash_attention: if use_flash_attention:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+") logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
...@@ -325,32 +353,37 @@ def get_attention_backend( ...@@ -325,32 +353,37 @@ def get_attention_backend(
if use_fused_attention: if use_fused_attention:
logger.debug("Disabling FusedAttention as it requires compute capability sm80+") logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False use_fused_attention = False
if device_compute_capability < (9, 0):
if use_flash_attention and _flash_attn_3_plus:
logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
_use_flash_attn_3 = False
# Filter: Data type # Filter: Data type
if use_flash_attention and ( if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [
qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor torch.Tensor,
): Float8Tensor,
logger.debug( ]:
"Disabling FlashAttention due to unsupported QKV data type. " if use_flash_attention:
"Supported: qkv_type = torch.Tensor, qkv_dtype = {torch.bfloat16, torch.float16}. " logger.debug(
"Found: qkv_type = %s, qkv_dtype = %s.", "Disabling FlashAttention due to unsupported QKV data type. "
qkv_type, "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
qkv_dtype, "Found: qkv_dtype = %s.",
) qkv_dtype,
use_flash_attention = False )
if use_fused_attention and (qkv_dtype not in [torch.bfloat16, torch.float16]): use_flash_attention = False
logger.debug( if use_fused_attention:
"Disabling FusedAttention due to unsupported QKV data type. " logger.debug(
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " "Disabling FusedAttention due to unsupported QKV data type. "
"Found: qkv_dtype = %s.", "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
qkv_dtype, "Found: qkv_dtype = %s.",
) qkv_dtype,
use_fused_attention = False )
use_fused_attention = False
# Filter: Execution type # Filter: Execution type
if fp8 and fp8_meta["recipe"].fp8_dpa: if fp8 and fp8_meta["recipe"].fp8_dpa:
if use_flash_attention: if use_flash_attention and is_training:
logger.debug("Disabling FlashAttention as it does not support FP8") logger.debug("Disabling FlashAttention as it does not support FP8 training")
use_flash_attention = False use_flash_attention = False
if use_unfused_attention: if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8") logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
...@@ -396,6 +429,12 @@ def get_attention_backend( ...@@ -396,6 +429,12 @@ def get_attention_backend(
) )
use_flash_attention = False use_flash_attention = False
# Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention:
if _flash_attn_3_plus and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for dropout")
_use_flash_attn_3 = False
# Filter: Context parallelism # Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends # qkv_format | attn_mask_type | attn_bias_type | supported backends
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
...@@ -414,6 +453,14 @@ def get_attention_backend( ...@@ -414,6 +453,14 @@ def get_attention_backend(
) )
use_unfused_attention = False use_unfused_attention = False
if context_parallel and use_flash_attention: if context_parallel and use_flash_attention:
if _flash_attn_3_plus and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for context parallelism")
_use_flash_attn_3 = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with FP8"
)
use_flash_attention = False
if "bottom_right" in attn_mask_type: if "bottom_right" in attn_mask_type:
logger.debug( logger.debug(
"Disabling FlashAttention as it does not support context parallelism with" "Disabling FlashAttention as it does not support context parallelism with"
...@@ -439,6 +486,7 @@ def get_attention_backend( ...@@ -439,6 +486,7 @@ def get_attention_backend(
" bias for THD format" " bias for THD format"
) )
use_flash_attention = False use_flash_attention = False
if context_parallel and use_fused_attention: if context_parallel and use_fused_attention:
if "bottom_right" in attn_mask_type: if "bottom_right" in attn_mask_type:
logger.debug( logger.debug(
...@@ -498,6 +546,18 @@ def get_attention_backend( ...@@ -498,6 +546,18 @@ def get_attention_backend(
if use_fused_attention: if use_fused_attention:
logger.debug("Disabling FusedAttention for arbitrary mask") logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False use_fused_attention = False
if (
use_flash_attention
and _flash_attn_3_plus
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
logger.warning(
"Disabling FlashAttention 3 as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
_use_flash_attn_3 = False
if ( if (
use_flash_attention use_flash_attention
and _flash_attn_2_1_plus and _flash_attn_2_1_plus
...@@ -571,6 +631,15 @@ def get_attention_backend( ...@@ -571,6 +631,15 @@ def get_attention_backend(
attn_mask_type, attn_mask_type,
) )
use_fused_attention = False use_fused_attention = False
if (
use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0])
and _flash_attn_3_plus
):
logger.debug(
"Disabling FlashAttention 3 as it does not support sliding window attention"
)
_use_flash_attn_3 = False
if ( if (
use_flash_attention use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0]) and (window_size[0] != -1 or window_size[1] not in [-1, 0])
...@@ -590,6 +659,14 @@ def get_attention_backend( ...@@ -590,6 +659,14 @@ def get_attention_backend(
# | | bottom_right (converts to a 'post_scale_bias' bias) # | | bottom_right (converts to a 'post_scale_bias' bias)
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias | # UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
if use_flash_attention and core_attention_bias_type == "alibi":
if _flash_attn_3_plus and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for ALiBi")
_use_flash_attn_3 = False
if not _flash_attn_2_4_plus:
logger.debug("Disabling FlashAttention for ALiBi")
use_flash_attention = False
if use_flash_attention and ( if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"] core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias_shape is not None or core_attention_bias_shape is not None
...@@ -1071,7 +1148,7 @@ def _get_full_cu_seqlens( ...@@ -1071,7 +1148,7 @@ def _get_full_cu_seqlens(
return _cu_seqlens_cache[(batch_size, max_seqlen)] return _cu_seqlens_cache[(batch_size, max_seqlen)]
@jit_fuser @torch.compile
def pack_tensor( def pack_tensor(
indices: torch.Tensor, indices: torch.Tensor,
tensor: torch.Tensor, tensor: torch.Tensor,
...@@ -1082,14 +1159,19 @@ def pack_tensor( ...@@ -1082,14 +1159,19 @@ def pack_tensor(
padding_indice = torch.zeros( padding_indice = torch.zeros(
1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
) )
tensor = torch.cat((tensor, padding_indice), dim=0)
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
packed = torch.gather(tensor, 0, indices) if isinstance(tensor, Float8Tensor):
tensor_data = torch.cat((tensor._data, padding_indice), dim=0)
packed = Float8Tensor.make_like(tensor, data=torch.gather(tensor_data, 0, indices))
else:
tensor = torch.cat((tensor, padding_indice), dim=0)
packed = torch.gather(tensor, 0, indices)
return packed return packed
@jit_fuser @torch.compile
def pack_2_tensors( def pack_2_tensors(
indices: torch.Tensor, indices: torch.Tensor,
t1: torch.Tensor, t1: torch.Tensor,
...@@ -1103,7 +1185,7 @@ def pack_2_tensors( ...@@ -1103,7 +1185,7 @@ def pack_2_tensors(
return t1_packed, t2_packed return t1_packed, t2_packed
@jit_fuser @torch.compile
def pack_3_tensors( def pack_3_tensors(
indices: torch.Tensor, indices: torch.Tensor,
t1: torch.Tensor, t1: torch.Tensor,
...@@ -1119,7 +1201,7 @@ def pack_3_tensors( ...@@ -1119,7 +1201,7 @@ def pack_3_tensors(
return t1_packed, t2_packed, t3_packed return t1_packed, t2_packed, t3_packed
@jit_fuser @torch.compile
def unpack_tensor( def unpack_tensor(
indices: torch.Tensor, indices: torch.Tensor,
dim0: int, dim0: int,
...@@ -1132,12 +1214,16 @@ def unpack_tensor( ...@@ -1132,12 +1214,16 @@ def unpack_tensor(
unpacked = torch.zeros( unpacked = torch.zeros(
dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
) )
unpacked.scatter_(0, indices, tensor) if isinstance(tensor, Float8Tensor):
unpacked = unpacked[0:-1, :, :] unpacked.scatter_(0, indices, tensor._data)
unpacked = Float8Tensor.make_like(tensor, data=unpacked[0:-1, :, :])
else:
unpacked.scatter_(0, indices, tensor)
unpacked = unpacked[0:-1, :, :]
return unpacked return unpacked
@jit_fuser @torch.compile
def unpack_2_tensors( def unpack_2_tensors(
indices: torch.Tensor, indices: torch.Tensor,
dim0: int, dim0: int,
...@@ -1152,7 +1238,7 @@ def unpack_2_tensors( ...@@ -1152,7 +1238,7 @@ def unpack_2_tensors(
return t1_unpacked, t2_unpacked return t1_unpacked, t2_unpacked
@jit_fuser @torch.compile
def unpack_3_tensors( def unpack_3_tensors(
indices: torch.Tensor, indices: torch.Tensor,
dim0: int, dim0: int,
...@@ -4212,14 +4298,15 @@ class FlashAttention(torch.nn.Module): ...@@ -4212,14 +4298,15 @@ class FlashAttention(torch.nn.Module):
cp_global_ranks: List[int] = None, cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""flash-attn fprop""" """flash-attn fprop"""
assert ( assert all(
query_layer.dtype in [torch.float16, torch.bfloat16] x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
and key_layer.dtype in [torch.float16, torch.bfloat16] for x in [query_layer, key_layer, value_layer]
and value_layer.dtype in [torch.float16, torch.bfloat16] ), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors."
), "FlashAttention currently only supports FP16 and BF16."
assert ( assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), "FlashAttention currently only supports CUDA tensors." ), "FlashAttention currently only supports CUDA tensors."
...@@ -4232,24 +4319,36 @@ class FlashAttention(torch.nn.Module): ...@@ -4232,24 +4319,36 @@ class FlashAttention(torch.nn.Module):
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "sbhd": if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
# For now just 128, will make it more general in the future if qkv_format == "sbhd":
if ( # For now just 128, will make it more general in the future
query_layer.shape[-1] == 128 if (
and query_layer.shape[0] * query_layer.shape[1] >= 512 query_layer.shape[-1] == 128
and qkv_layout == "sbh3d" and query_layer.shape[0] * query_layer.shape[1] >= 512
): and qkv_layout == "sbh3d"
query_layer, key_layer, value_layer = _PrepareQKVForFA.apply( ):
query_layer, key_layer, value_layer query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
) query_layer, key_layer, value_layer
else: )
else:
query_layer, key_layer, value_layer = [
x.transpose(0, 1).contiguous()
for x in (query_layer, key_layer, value_layer)
]
elif qkv_format in ["bshd", "thd"]:
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
x.transpose(0, 1).contiguous() for x in (query_layer, key_layer, value_layer) x.contiguous() for x in (query_layer, key_layer, value_layer)
]
else:
if qkv_format == "sbhd":
query_layer._data, key_layer._data, value_layer._data = [
x.transpose(0, 1).contiguous()
for x in (query_layer._data, key_layer._data, value_layer._data)
]
elif qkv_format in ["bshd", "thd"]:
query_layer._data, key_layer._data, value_layer._data = [
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
] ]
elif qkv_format in ["bshd", "thd"]:
query_layer, key_layer, value_layer = [
x.contiguous() for x in (query_layer, key_layer, value_layer)
]
batch_size = query_layer.shape[0] batch_size = query_layer.shape[0]
...@@ -4257,16 +4356,15 @@ class FlashAttention(torch.nn.Module): ...@@ -4257,16 +4356,15 @@ class FlashAttention(torch.nn.Module):
max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1] max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
max_seqlen_q *= cp_size max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size max_seqlen_kv *= cp_size
if not context_parallel:
if "padding" in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism!"
# [b * s, h, d] # [b * s, h, d]
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:]) x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [query_layer, key_layer, value_layer] for x in [query_layer, key_layer, value_layer]
] ]
if "padding" in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism!"
if self.attention_type == "self": if self.attention_type == "self":
assert ( assert (
max_seqlen_q == max_seqlen_kv max_seqlen_q == max_seqlen_kv
...@@ -4319,7 +4417,9 @@ class FlashAttention(torch.nn.Module): ...@@ -4319,7 +4417,9 @@ class FlashAttention(torch.nn.Module):
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = seqlens_kv.max().item() max_seqlen_kv = seqlens_kv.max().item()
if context_parallel: if context_parallel and all(
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
):
assert ( assert (
alibi_slopes is None alibi_slopes is None
), "Alibi slope bias addition is not supported with context parallelism." ), "Alibi slope bias addition is not supported with context parallelism."
...@@ -4366,34 +4466,94 @@ class FlashAttention(torch.nn.Module): ...@@ -4366,34 +4466,94 @@ class FlashAttention(torch.nn.Module):
fa_optional_forward_kwargs["deterministic"] = self.deterministic fa_optional_forward_kwargs["deterministic"] = self.deterministic
if _flash_attn_2_5_7_plus: if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None fa_optional_forward_kwargs["block_table"] = None
output = flash_attn_forward_func( fa_optional_forward_args_thd = []
query_layer, if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
key_layer, func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3
value_layer, else:
cu_seqlens_q, func = (
cu_seqlens_kv, flash_attn_varlen_func
max_seqlen_q, if not _use_flash_attn_3
max_seqlen_kv, else flash_attn_varlen_func_v3
self.attention_dropout if self.training else 0.0, )
softmax_scale=self.softmax_scale, fa_optional_forward_args_thd.append(cu_seqlens_q)
causal="causal" in attn_mask_type, fa_optional_forward_args_thd.append(cu_seqlens_kv)
**fa_optional_forward_kwargs, fa_optional_forward_args_thd.append(max_seqlen_q)
) fa_optional_forward_args_thd.append(max_seqlen_kv)
if _use_flash_attn_3:
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
activation_dtype = query_layer.dtype
torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha:
assert all(
isinstance(x, Float8Tensor)
for x in [query_layer, key_layer, value_layer]
), "q/k/v must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv
query_layer, key_layer, value_layer = (
x.to(activation_dtype).to(torch_dtype)
for x in [query_layer, key_layer, value_layer]
)
else:
query_layer, key_layer, value_layer = (
x.to(torch_dtype) for x in [query_layer, key_layer, value_layer]
)
output, _ = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
deterministic=self.deterministic,
)
if fp8 and fp8_meta["recipe"].fp8_mha:
output = cast_to_fp8(
output,
fp8_meta["scaling_fwd"],
META_O,
fp8_dtype_forward,
)
output = Float8Tensor(
data=output,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=activation_dtype,
)
else:
output = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
self.attention_dropout if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs,
)
if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
if qkv_format == "sbhd": if qkv_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd) # (bs)hd -> bs(hd) -> sb(hd)
output = ( if fp8 and fp8_meta["recipe"].fp8_mha:
output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1).contiguous() output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d()
) output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
else:
output = (
output.view(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1)
.contiguous()
)
elif qkv_format == "bshd": elif qkv_format == "bshd":
# (bs)hd -> bs(hd) # (bs)hd -> bs(hd)
output = output.view(batch_size, max_seqlen_q // cp_size, -1).contiguous() output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
elif qkv_format == "thd": elif qkv_format == "thd":
# thd -> t(hd) # thd -> t(hd)
output = output.view(output.shape[0], -1).contiguous() output = output.reshape(output.shape[0], -1)
return output return output
...@@ -5897,11 +6057,10 @@ class FusedAttention(torch.nn.Module): ...@@ -5897,11 +6057,10 @@ class FusedAttention(torch.nn.Module):
assert ( assert (
fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
), "No fused attention backend supports this input combination!" ), "No fused attention backend supports this input combination!"
assert ( assert all(
(query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) for x in [query_layer, key_layer, value_layer]
and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) ), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors."
), "FusedAttention only supports FP16 and BF16 data types."
assert ( assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), "FusedAttention only supports CUDA tensors." ), "FusedAttention only supports CUDA tensors."
...@@ -6812,7 +6971,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -6812,7 +6971,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8=self.fp8, fp8=self.fp8,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
) )
global _attention_backends global _attention_backends, _flash_attn_3_plus, _use_flash_attn_3
if ( if (
_attention_backends["attention_params"] is None _attention_backends["attention_params"] is None
or attention_params != _attention_backends["attention_params"] or attention_params != _attention_backends["attention_params"]
...@@ -6820,6 +6979,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -6820,6 +6979,7 @@ class DotProductAttention(TransformerEngineBaseModule):
_attention_backends["attention_params"] = attention_params _attention_backends["attention_params"] = attention_params
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
if _attention_backends["backend_selection_requires_update"]: if _attention_backends["backend_selection_requires_update"]:
_use_flash_attn_3 = _flash_attn_3_plus
( (
use_flash_attention, use_flash_attention,
use_fused_attention, use_fused_attention,
...@@ -6828,7 +6988,10 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -6828,7 +6988,10 @@ class DotProductAttention(TransformerEngineBaseModule):
_, _,
) = get_attention_backend(attention_params) ) = get_attention_backend(attention_params)
if use_flash_attention: if use_flash_attention:
self.logger.info("Running with FlashAttention backend") self.logger.info(
"Running with FlashAttention backend (version %s)",
_flash_attn_version if not _use_flash_attn_3 else _flash_attn_v3_version,
)
elif use_fused_attention: elif use_fused_attention:
self.logger.info( self.logger.info(
"Running with FusedAttention backend (sub-backend %s)", "Running with FusedAttention backend (sub-backend %s)",
...@@ -6867,6 +7030,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -6867,6 +7030,8 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_comm_type=self.cp_comm_type, cp_comm_type=self.cp_comm_type,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
) )
if use_fused_attention: if use_fused_attention:
......
...@@ -38,6 +38,15 @@ def get_default_fp8_recipe() -> DelayedScaling: ...@@ -38,6 +38,15 @@ def get_default_fp8_recipe() -> DelayedScaling:
return DelayedScaling() return DelayedScaling()
def get_fp8_torch_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> torch.dtype:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return torch.float8_e4m3fn
return torch.float8_e5m2fn
def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType:
"""Get fp8 data type according to recipe and tensor""" """Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or ( if fp8_recipe.fp8_format == Format.E4M3 or (
......
...@@ -56,7 +56,7 @@ if __name__ == "__main__": ...@@ -56,7 +56,7 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib", description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"], install_requires=["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"],
tests_require=["numpy", "onnxruntime", "torchvision"], tests_require=["numpy", "onnxruntime", "torchvision"],
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
...@@ -23,9 +23,6 @@ _default_causal_mask = {} ...@@ -23,9 +23,6 @@ _default_causal_mask = {}
def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor: def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input""" """Return the causal upper triangular mask for softmax input"""
if sq == 1:
return torch.zeros((1, sk), dtype=torch.bool, device="cuda")
matrix_identifiers = (mask_type, sq, sk) matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask: if matrix_identifiers not in _default_causal_mask:
diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1 diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment