Unverified Commit 901e5d2b authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add support for flash-attn 3 (#1019)



* WIP: add fa3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: add benchmarks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* differentiate func/varlen_func
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix parsing keyword for FA3 and remove bshd->thd conversion for flash_attn_func
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add FP8 fwd support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add FA3 FP8 fwd code and test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix assert for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix FA3 FP8 logic and add tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA2 to <=2.6.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak unit tests for base/mask
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* set constraints for FA3 for sm90 and causal_bottom_right
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert debug changes in benchmark script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2215fa5c
Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b Subproject commit 98ca4e1941fe3263f128f74f10063a3ea35c7019
...@@ -156,7 +156,7 @@ def parse_results(per_cudnn, per_flash, model): ...@@ -156,7 +156,7 @@ def parse_results(per_cudnn, per_flash, model):
df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6 df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6
if per_flash > 0: if per_flash > 0:
t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy() t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash) t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0) t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6 df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
......
...@@ -92,7 +92,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -92,7 +92,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"]) install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"])
test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
if "jax" in frameworks: if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"]) install_reqs.extend(["jax", "flax>=0.7.1"])
......
...@@ -420,6 +420,10 @@ model_configs_mask = { ...@@ -420,6 +420,10 @@ model_configs_mask = {
"mask_8_1": ModelConfig( "mask_8_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" 1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
), ),
"mask_9_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"),
"mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"),
"mask_10_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
} }
...@@ -1301,6 +1305,7 @@ model_configs_fp8_vs_f16 = { ...@@ -1301,6 +1305,7 @@ model_configs_fp8_vs_f16 = {
"fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), "fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
"fp8_15": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
} }
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
...@@ -1312,6 +1317,27 @@ def _rmse(a, b): ...@@ -1312,6 +1317,27 @@ def _rmse(a, b):
return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum()) return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
try:
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert rmse < rmse_tol * rmse_range, (
name_a
+ " vs "
+ name_b
+ " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
...@@ -1320,86 +1346,74 @@ def _rmse(a, b): ...@@ -1320,86 +1346,74 @@ def _rmse(a, b):
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False]) @pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): @pytest.mark.parametrize("is_training", [True, False])
os.environ["NVTE_FLASH_ATTN"] = "0" def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, is_training):
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
_attention_backends["backend_selection_requires_update"] = True
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" global _attention_backends
if not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, is_training
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm dtype, config, True, qkv_format, input_layernorm, is_training
) )
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm dtype, config, False, qkv_format, input_layernorm, is_training
)
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
) )
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
logging.debug( if not is_training:
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format( _error(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item() flash_attn_fwd_fp8,
) fused_attn_fwd_f16,
) "flash_attn_fwd_fp8",
logging.debug( "fused_attn_fwd_f16",
"fused_attn_fwd_f16 min {:.6f} max {:.6f}".format( atol,
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item() rtol,
rmse_tol,
) )
_error(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
) )
logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
assert ( if is_training:
fwd_rmse < rmse_tol * fwd_range for i in range(len(param_names[:1])):
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( logging.debug("========== {:^25s} ==========".format(param_names[i]))
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range _error(
) fused_attn_bwd_fp8[i],
for i in range(len(param_names[:1])): fused_attn_bwd_f16[i],
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i]) f"fused_attn_bwd_fp8[{i}]",
bwd_range = max( f"fused_attn_bwd_f16[{i}]",
fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item() atol,
) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item()) rtol,
rmse_tol,
logging.debug("========== {:^25s} ==========".format(param_names[i]))
logging.debug(
"fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
)
)
logging.debug(
"fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
) )
)
logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
logging.debug(e)
assert (
bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_training):
reset_rng_states() reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
...@@ -1434,6 +1448,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1434,6 +1448,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
qkv_weight_interleaved=True, qkv_weight_interleaved=True,
qkv_format=qkv_format, qkv_format=qkv_format,
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
if not is_training:
mha = mha.eval()
seqlens_q = torch.full( seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
...@@ -1464,7 +1480,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1464,7 +1480,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda") tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
hidden_states = tensor.view(*tensor.shape[:-2], -1) hidden_states = tensor.view(*tensor.shape[:-2], -1)
hidden_states.requires_grad = True if is_training:
hidden_states.requires_grad = True
tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
out_grad = tensor.view(*tensor.shape[:-2], -1) out_grad = tensor.view(*tensor.shape[:-2], -1)
...@@ -1476,7 +1493,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1476,7 +1493,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None, is_first_microbatch=None,
) )
out.backward(out_grad) if is_training:
out.backward(out_grad)
param_names = [] param_names = []
param_names.append("hidden_states.grad") param_names.append("hidden_states.grad")
...@@ -1487,7 +1505,9 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1487,7 +1505,9 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
param_names.append(name + ".grad") param_names.append(name + ".grad")
params.append(param) params.append(param)
return out, param_names, tuple(x.grad for x in params) if is_training:
return out, param_names, tuple(x.grad for x in params)
return out, param_names, tuple(None for x in params)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
...@@ -1497,7 +1517,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): ...@@ -1497,7 +1517,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): @pytest.mark.parametrize("is_training", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
config = model_configs_fp8_vs_f16[model] config = model_configs_fp8_vs_f16[model]
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
...@@ -1505,76 +1526,69 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): ...@@ -1505,76 +1526,69 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends global _attention_backends
_attention_backends["backend_selection_requires_update"] = True if not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout) fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
)
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(dtype, config, False, qkv_layout) fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
dtype, config, False, qkv_layout, is_training
)
tols = dict(atol=5e-1, rtol=5e-2) atol = 5e-1
rtol = 5e-2
rmse_tol = 0.1 rmse_tol = 0.1
bwd_names = ["dq", "dk", "dv"] bwd_names = ["dq", "dk", "dv"]
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
)
logging.debug("========== {:^25s} ==========".format("forward output")) logging.debug("========== {:^25s} ==========".format("forward output"))
logging.debug( if not is_training:
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format( _error(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item() flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
) )
_error(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
) )
logging.debug( if is_training:
"fused_attn_fwd_f16 min {:.6f} max {:.6f}".format( for i, _ in enumerate(fused_attn_bwd_f16):
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item() logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
) _error(
) fused_attn_bwd_fp8[i],
logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse)) fused_attn_bwd_f16[i],
try: f"fused_attn_bwd_fp8[{i}]",
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) f"fused_attn_bwd_f16[{i}]",
except Exception as e: atol,
logging.debug(e) rtol,
rmse_tol,
assert (
fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
)
for i, _ in enumerate(fused_attn_bwd_f16):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
bwd_range = max(
fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
logging.debug(
"fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
)
)
logging.debug(
"fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
) )
)
logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
logging.debug(e)
assert (
bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
reset_rng_states() reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
...@@ -1607,6 +1621,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): ...@@ -1607,6 +1621,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
attention_type="self", attention_type="self",
qkv_format=qkv_format, qkv_format=qkv_format,
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
if not is_training:
dpa = dpa.eval()
seqlens_q = torch.full( seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
...@@ -1680,9 +1696,12 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): ...@@ -1680,9 +1696,12 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=True, is_first_microbatch=True,
) )
out.backward(out_grad) if is_training:
out.backward(out_grad)
return out, (inp[0].grad, inp[1].grad, inp[2].grad) if is_training:
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
return out, (None, None, None)
model_configs_fp8 = { model_configs_fp8 = {
...@@ -1726,58 +1745,26 @@ def test_custom_mha_fp8_vs_f16(dtype, model): ...@@ -1726,58 +1745,26 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
tols = dict(atol=5e-1, rtol=5e-1) atol = 5e-1
rtol = 5e-1
rmse_tol = 0.1 rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16) _error(
fwd_range = max(fused_attn_fwd_fp8.max().item(), unfused_attn_fwd_f16.max().item()) - min( fused_attn_fwd_fp8,
fused_attn_fwd_fp8.min().item(), unfused_attn_fwd_f16.min().item() unfused_attn_fwd_f16,
) "fused_attn_fwd_fp8",
bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16) "unfused_attn_fwd_f16",
bwd_range = max(fused_attn_bwd_fp8.max().item(), unfused_attn_bwd_f16.max().item()) - min( atol,
fused_attn_bwd_fp8.min().item(), unfused_attn_bwd_f16.min().item() rtol,
) rmse_tol,
logging.debug(
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
)
)
logging.debug(
"unfused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item()
)
)
logging.debug("fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
logging.debug(
"fused_attn_bwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item()
)
)
logging.debug(
"unfused_attn_bwd_f16 min {:.6f} max {:.6f}".format(
unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item()
)
)
logging.debug("fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}".format(bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols)
except Exception as e:
logging.debug(e)
assert (
fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
) )
assert ( _error(
bwd_rmse < rmse_tol * bwd_range fused_attn_bwd_fp8,
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( unfused_attn_bwd_f16,
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range "fused_attn_bwd_fp8",
"unfused_attn_bwd_f16",
atol,
rtol,
rmse_tol,
) )
......
This diff is collapsed.
...@@ -38,6 +38,15 @@ def get_default_fp8_recipe() -> DelayedScaling: ...@@ -38,6 +38,15 @@ def get_default_fp8_recipe() -> DelayedScaling:
return DelayedScaling() return DelayedScaling()
def get_fp8_torch_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> torch.dtype:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return torch.float8_e4m3fn
return torch.float8_e5m2fn
def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType:
"""Get fp8 data type according to recipe and tensor""" """Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or ( if fp8_recipe.fp8_format == Format.E4M3 or (
......
...@@ -56,7 +56,7 @@ if __name__ == "__main__": ...@@ -56,7 +56,7 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib", description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"], install_requires=["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"],
tests_require=["numpy", "onnxruntime", "torchvision"], tests_require=["numpy", "onnxruntime", "torchvision"],
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
...@@ -23,9 +23,6 @@ _default_causal_mask = {} ...@@ -23,9 +23,6 @@ _default_causal_mask = {}
def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor: def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input""" """Return the causal upper triangular mask for softmax input"""
if sq == 1:
return torch.zeros((1, sk), dtype=torch.bool, device="cuda")
matrix_identifiers = (mask_type, sq, sk) matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask: if matrix_identifiers not in _default_causal_mask:
diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1 diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment