Unverified Commit 901e5d2b authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add support for flash-attn 3 (#1019)



* WIP: add fa3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: add benchmarks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* differentiate func/varlen_func
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix parsing keyword for FA3 and remove bshd->thd conversion for flash_attn_func
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add FP8 fwd support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add FA3 FP8 fwd code and test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix assert for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix FA3 FP8 logic and add tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA2 to <=2.6.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak unit tests for base/mask
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* set constraints for FA3 for sm90 and causal_bottom_right
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert debug changes in benchmark script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2215fa5c
Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b
Subproject commit 98ca4e1941fe3263f128f74f10063a3ea35c7019
......@@ -156,7 +156,7 @@ def parse_results(per_cudnn, per_flash, model):
df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6
if per_flash > 0:
t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy()
t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
......
......@@ -92,7 +92,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"])
install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"])
test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"])
......
......@@ -420,6 +420,10 @@ model_configs_mask = {
"mask_8_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_9_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"),
"mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"),
"mask_10_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
}
......@@ -1301,6 +1305,7 @@ model_configs_fp8_vs_f16 = {
"fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
"fp8_15": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
}
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
......@@ -1312,6 +1317,27 @@ def _rmse(a, b):
return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
try:
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert rmse < rmse_tol * rmse_range, (
name_a
+ " vs "
+ name_b
+ " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
......@@ -1320,86 +1346,74 @@ def _rmse(a, b):
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
@pytest.mark.parametrize("is_training", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, is_training):
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
global _attention_backends
if not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, is_training
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm
dtype, config, True, qkv_format, input_layernorm, is_training
)
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm
)
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
dtype, config, False, qkv_format, input_layernorm, is_training
)
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
logging.debug(
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
)
)
logging.debug(
"fused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()
if not is_training:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
_error(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
assert (
fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
)
for i in range(len(param_names[:1])):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
bwd_range = max(
fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
logging.debug("========== {:^25s} ==========".format(param_names[i]))
logging.debug(
"fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
)
)
logging.debug(
"fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
if is_training:
for i in range(len(param_names[:1])):
logging.debug("========== {:^25s} ==========".format(param_names[i]))
_error(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]",
atol,
rtol,
rmse_tol,
)
)
logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
logging.debug(e)
assert (
bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_training):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
......@@ -1434,6 +1448,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
qkv_weight_interleaved=True,
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
if not is_training:
mha = mha.eval()
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
......@@ -1464,7 +1480,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
hidden_states = tensor.view(*tensor.shape[:-2], -1)
hidden_states.requires_grad = True
if is_training:
hidden_states.requires_grad = True
tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
out_grad = tensor.view(*tensor.shape[:-2], -1)
......@@ -1476,7 +1493,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None,
)
out.backward(out_grad)
if is_training:
out.backward(out_grad)
param_names = []
param_names.append("hidden_states.grad")
......@@ -1487,7 +1505,9 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
param_names.append(name + ".grad")
params.append(param)
return out, param_names, tuple(x.grad for x in params)
if is_training:
return out, param_names, tuple(x.grad for x in params)
return out, param_names, tuple(None for x in params)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
......@@ -1497,7 +1517,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
@pytest.mark.parametrize("is_training", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
config = model_configs_fp8_vs_f16[model]
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
......@@ -1505,76 +1526,69 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
if not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout)
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
)
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(dtype, config, False, qkv_layout)
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
dtype, config, False, qkv_layout, is_training
)
tols = dict(atol=5e-1, rtol=5e-2)
atol = 5e-1
rtol = 5e-2
rmse_tol = 0.1
bwd_names = ["dq", "dk", "dv"]
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
)
logging.debug("========== {:^25s} ==========".format("forward output"))
logging.debug(
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
if not is_training:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
_error(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
logging.debug(
"fused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()
)
)
logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
assert (
fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
)
for i, _ in enumerate(fused_attn_bwd_f16):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
bwd_range = max(
fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
logging.debug(
"fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
)
)
logging.debug(
"fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
if is_training:
for i, _ in enumerate(fused_attn_bwd_f16):
logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
_error(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]",
atol,
rtol,
rmse_tol,
)
)
logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
logging.debug(e)
assert (
bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
......@@ -1607,6 +1621,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
attention_type="self",
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
if not is_training:
dpa = dpa.eval()
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
......@@ -1680,9 +1696,12 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=True,
)
out.backward(out_grad)
if is_training:
out.backward(out_grad)
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
if is_training:
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
return out, (None, None, None)
model_configs_fp8 = {
......@@ -1726,58 +1745,26 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
tols = dict(atol=5e-1, rtol=5e-1)
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), unfused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), unfused_attn_fwd_f16.min().item()
)
bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16)
bwd_range = max(fused_attn_bwd_fp8.max().item(), unfused_attn_bwd_f16.max().item()) - min(
fused_attn_bwd_fp8.min().item(), unfused_attn_bwd_f16.min().item()
)
logging.debug(
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
)
)
logging.debug(
"unfused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item()
)
)
logging.debug("fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
logging.debug(
"fused_attn_bwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item()
)
)
logging.debug(
"unfused_attn_bwd_f16 min {:.6f} max {:.6f}".format(
unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item()
)
)
logging.debug("fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}".format(bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols)
except Exception as e:
logging.debug(e)
assert (
fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
_error(
fused_attn_fwd_fp8,
unfused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"unfused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
assert (
bwd_rmse < rmse_tol * bwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
_error(
fused_attn_bwd_fp8,
unfused_attn_bwd_f16,
"fused_attn_bwd_fp8",
"unfused_attn_bwd_f16",
atol,
rtol,
rmse_tol,
)
......
......@@ -6,6 +6,7 @@
import collections
from contextlib import nullcontext
from importlib.metadata import version as get_pkg_version
from importlib.metadata import PackageNotFoundError
import math
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
......@@ -38,7 +39,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
AttnMaskType,
FusedAttnBackend,
)
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype, get_fp8_torch_dtype
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
......@@ -75,16 +76,42 @@ from transformer_engine.pytorch.graph import is_graph_capturing
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.5.8")
_flash_attn_max_version = PkgVersion("2.6.3")
_flash_attn_2_plus = _flash_attn_version >= PkgVersion("2")
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
_flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7")
_flash_attn_3_plus = False
_use_flash_attn_3 = False
try:
_flash_attn_v3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
_flash_attn_3_plus = _flash_attn_v3_version >= PkgVersion("2.6.1")
except PackageNotFoundError:
warnings.warn(
"To use flash-attn v3, please use the following commands to install: \n"
"""(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" \n"""
"""(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` \n"""
"""(3) mkdir -p $python_path/flashattn_hopper \n"""
"""(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py"""
)
else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flashattn_hopper.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_forward as _flash_attn_forward_v3,
)
from flashattn_hopper.flash_attn_interface import ( # pylint: disable=unused-import
_flash_attn_backward as _flash_attn_backward_v3,
)
_use_flash_attn_3 = True
if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
......@@ -318,6 +345,7 @@ def get_attention_backend(
use_fused_attention = False
# Filter: Compute capability
global _flash_attn_3_plus, _use_flash_attn_3
if device_compute_capability < (8, 0):
if use_flash_attention:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
......@@ -325,32 +353,37 @@ def get_attention_backend(
if use_fused_attention:
logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
use_fused_attention = False
if device_compute_capability < (9, 0):
if use_flash_attention and _flash_attn_3_plus:
logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
_use_flash_attn_3 = False
# Filter: Data type
if use_flash_attention and (
qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type == Float8Tensor
):
logger.debug(
"Disabling FlashAttention due to unsupported QKV data type. "
"Supported: qkv_type = torch.Tensor, qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_type = %s, qkv_dtype = %s.",
qkv_type,
qkv_dtype,
)
use_flash_attention = False
if use_fused_attention and (qkv_dtype not in [torch.bfloat16, torch.float16]):
logger.debug(
"Disabling FusedAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
qkv_dtype,
)
use_fused_attention = False
if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [
torch.Tensor,
Float8Tensor,
]:
if use_flash_attention:
logger.debug(
"Disabling FlashAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
qkv_dtype,
)
use_flash_attention = False
if use_fused_attention:
logger.debug(
"Disabling FusedAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
qkv_dtype,
)
use_fused_attention = False
# Filter: Execution type
if fp8 and fp8_meta["recipe"].fp8_dpa:
if use_flash_attention:
logger.debug("Disabling FlashAttention as it does not support FP8")
if use_flash_attention and is_training:
logger.debug("Disabling FlashAttention as it does not support FP8 training")
use_flash_attention = False
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
......@@ -396,6 +429,12 @@ def get_attention_backend(
)
use_flash_attention = False
# Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention:
if _flash_attn_3_plus and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for dropout")
_use_flash_attn_3 = False
# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
# ----------------------------------------------------------------------------------------------------
......@@ -414,6 +453,14 @@ def get_attention_backend(
)
use_unfused_attention = False
if context_parallel and use_flash_attention:
if _flash_attn_3_plus and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for context parallelism")
_use_flash_attn_3 = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with FP8"
)
use_flash_attention = False
if "bottom_right" in attn_mask_type:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
......@@ -439,6 +486,7 @@ def get_attention_backend(
" bias for THD format"
)
use_flash_attention = False
if context_parallel and use_fused_attention:
if "bottom_right" in attn_mask_type:
logger.debug(
......@@ -498,6 +546,18 @@ def get_attention_backend(
if use_fused_attention:
logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False
if (
use_flash_attention
and _flash_attn_3_plus
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
logger.warning(
"Disabling FlashAttention 3 as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
_use_flash_attn_3 = False
if (
use_flash_attention
and _flash_attn_2_1_plus
......@@ -571,6 +631,15 @@ def get_attention_backend(
attn_mask_type,
)
use_fused_attention = False
if (
use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0])
and _flash_attn_3_plus
):
logger.debug(
"Disabling FlashAttention 3 as it does not support sliding window attention"
)
_use_flash_attn_3 = False
if (
use_flash_attention
and (window_size[0] != -1 or window_size[1] not in [-1, 0])
......@@ -590,6 +659,14 @@ def get_attention_backend(
# | | bottom_right (converts to a 'post_scale_bias' bias)
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
if use_flash_attention and core_attention_bias_type == "alibi":
if _flash_attn_3_plus and _use_flash_attn_3:
logger.debug("Disabling FlashAttention 3 for ALiBi")
_use_flash_attn_3 = False
if not _flash_attn_2_4_plus:
logger.debug("Disabling FlashAttention for ALiBi")
use_flash_attention = False
if use_flash_attention and (
core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias_shape is not None
......@@ -1071,7 +1148,7 @@ def _get_full_cu_seqlens(
return _cu_seqlens_cache[(batch_size, max_seqlen)]
@jit_fuser
@torch.compile
def pack_tensor(
indices: torch.Tensor,
tensor: torch.Tensor,
......@@ -1082,14 +1159,19 @@ def pack_tensor(
padding_indice = torch.zeros(
1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
)
tensor = torch.cat((tensor, padding_indice), dim=0)
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
packed = torch.gather(tensor, 0, indices)
if isinstance(tensor, Float8Tensor):
tensor_data = torch.cat((tensor._data, padding_indice), dim=0)
packed = Float8Tensor.make_like(tensor, data=torch.gather(tensor_data, 0, indices))
else:
tensor = torch.cat((tensor, padding_indice), dim=0)
packed = torch.gather(tensor, 0, indices)
return packed
@jit_fuser
@torch.compile
def pack_2_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
......@@ -1103,7 +1185,7 @@ def pack_2_tensors(
return t1_packed, t2_packed
@jit_fuser
@torch.compile
def pack_3_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
......@@ -1119,7 +1201,7 @@ def pack_3_tensors(
return t1_packed, t2_packed, t3_packed
@jit_fuser
@torch.compile
def unpack_tensor(
indices: torch.Tensor,
dim0: int,
......@@ -1132,12 +1214,16 @@ def unpack_tensor(
unpacked = torch.zeros(
dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
)
unpacked.scatter_(0, indices, tensor)
unpacked = unpacked[0:-1, :, :]
if isinstance(tensor, Float8Tensor):
unpacked.scatter_(0, indices, tensor._data)
unpacked = Float8Tensor.make_like(tensor, data=unpacked[0:-1, :, :])
else:
unpacked.scatter_(0, indices, tensor)
unpacked = unpacked[0:-1, :, :]
return unpacked
@jit_fuser
@torch.compile
def unpack_2_tensors(
indices: torch.Tensor,
dim0: int,
......@@ -1152,7 +1238,7 @@ def unpack_2_tensors(
return t1_unpacked, t2_unpacked
@jit_fuser
@torch.compile
def unpack_3_tensors(
indices: torch.Tensor,
dim0: int,
......@@ -4212,14 +4298,15 @@ class FlashAttention(torch.nn.Module):
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
"""flash-attn fprop"""
assert (
query_layer.dtype in [torch.float16, torch.bfloat16]
and key_layer.dtype in [torch.float16, torch.bfloat16]
and value_layer.dtype in [torch.float16, torch.bfloat16]
), "FlashAttention currently only supports FP16 and BF16."
assert all(
x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
for x in [query_layer, key_layer, value_layer]
), "FlashAttention only supports FP16 and BF16 data types, or Float8Tensors."
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), "FlashAttention currently only supports CUDA tensors."
......@@ -4232,24 +4319,36 @@ class FlashAttention(torch.nn.Module):
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "sbhd":
# For now just 128, will make it more general in the future
if (
query_layer.shape[-1] == 128
and query_layer.shape[0] * query_layer.shape[1] >= 512
and qkv_layout == "sbh3d"
):
query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
query_layer, key_layer, value_layer
)
else:
if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
if qkv_format == "sbhd":
# For now just 128, will make it more general in the future
if (
query_layer.shape[-1] == 128
and query_layer.shape[0] * query_layer.shape[1] >= 512
and qkv_layout == "sbh3d"
):
query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(
query_layer, key_layer, value_layer
)
else:
query_layer, key_layer, value_layer = [
x.transpose(0, 1).contiguous()
for x in (query_layer, key_layer, value_layer)
]
elif qkv_format in ["bshd", "thd"]:
query_layer, key_layer, value_layer = [
x.transpose(0, 1).contiguous() for x in (query_layer, key_layer, value_layer)
x.contiguous() for x in (query_layer, key_layer, value_layer)
]
else:
if qkv_format == "sbhd":
query_layer._data, key_layer._data, value_layer._data = [
x.transpose(0, 1).contiguous()
for x in (query_layer._data, key_layer._data, value_layer._data)
]
elif qkv_format in ["bshd", "thd"]:
query_layer._data, key_layer._data, value_layer._data = [
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
]
elif qkv_format in ["bshd", "thd"]:
query_layer, key_layer, value_layer = [
x.contiguous() for x in (query_layer, key_layer, value_layer)
]
batch_size = query_layer.shape[0]
......@@ -4257,16 +4356,15 @@ class FlashAttention(torch.nn.Module):
max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size
if not context_parallel:
if "padding" in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism!"
# [b * s, h, d]
query_layer, key_layer, value_layer = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [query_layer, key_layer, value_layer]
]
if "padding" in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism!"
if self.attention_type == "self":
assert (
max_seqlen_q == max_seqlen_kv
......@@ -4319,7 +4417,9 @@ class FlashAttention(torch.nn.Module):
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = seqlens_kv.max().item()
if context_parallel:
if context_parallel and all(
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
):
assert (
alibi_slopes is None
), "Alibi slope bias addition is not supported with context parallelism."
......@@ -4366,34 +4466,94 @@ class FlashAttention(torch.nn.Module):
fa_optional_forward_kwargs["deterministic"] = self.deterministic
if _flash_attn_2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None
output = flash_attn_forward_func(
query_layer,
key_layer,
value_layer,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
self.attention_dropout if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs,
)
fa_optional_forward_args_thd = []
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3
else:
func = (
flash_attn_varlen_func
if not _use_flash_attn_3
else flash_attn_varlen_func_v3
)
fa_optional_forward_args_thd.append(cu_seqlens_q)
fa_optional_forward_args_thd.append(cu_seqlens_kv)
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if _use_flash_attn_3:
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
activation_dtype = query_layer.dtype
torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha:
assert all(
isinstance(x, Float8Tensor)
for x in [query_layer, key_layer, value_layer]
), "q/k/v must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv
query_layer, key_layer, value_layer = (
x.to(activation_dtype).to(torch_dtype)
for x in [query_layer, key_layer, value_layer]
)
else:
query_layer, key_layer, value_layer = (
x.to(torch_dtype) for x in [query_layer, key_layer, value_layer]
)
output, _ = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
deterministic=self.deterministic,
)
if fp8 and fp8_meta["recipe"].fp8_mha:
output = cast_to_fp8(
output,
fp8_meta["scaling_fwd"],
META_O,
fp8_dtype_forward,
)
output = Float8Tensor(
data=output,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=META_O,
fp8_dtype=fp8_dtype_forward,
dtype=activation_dtype,
)
else:
output = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
self.attention_dropout if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs,
)
if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
if qkv_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd)
output = (
output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1).contiguous()
)
if fp8 and fp8_meta["recipe"].fp8_mha:
output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d()
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
else:
output = (
output.view(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1)
.contiguous()
)
elif qkv_format == "bshd":
# (bs)hd -> bs(hd)
output = output.view(batch_size, max_seqlen_q // cp_size, -1).contiguous()
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
elif qkv_format == "thd":
# thd -> t(hd)
output = output.view(output.shape[0], -1).contiguous()
output = output.reshape(output.shape[0], -1)
return output
......@@ -5897,11 +6057,10 @@ class FusedAttention(torch.nn.Module):
assert (
fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend
), "No fused attention backend supports this input combination!"
assert (
(query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8])
), "FusedAttention only supports FP16 and BF16 data types."
assert all(
x.dtype in [torch.float16, torch.bfloat16] or isinstance(x, Float8Tensor)
for x in [query_layer, key_layer, value_layer]
), "FusedAttention only supports FP16 and BF16 data types, or Float8Tensors."
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), "FusedAttention only supports CUDA tensors."
......@@ -6812,7 +6971,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8=self.fp8,
fp8_meta=self.fp8_meta,
)
global _attention_backends
global _attention_backends, _flash_attn_3_plus, _use_flash_attn_3
if (
_attention_backends["attention_params"] is None
or attention_params != _attention_backends["attention_params"]
......@@ -6820,6 +6979,7 @@ class DotProductAttention(TransformerEngineBaseModule):
_attention_backends["attention_params"] = attention_params
_attention_backends["backend_selection_requires_update"] = True
if _attention_backends["backend_selection_requires_update"]:
_use_flash_attn_3 = _flash_attn_3_plus
(
use_flash_attention,
use_fused_attention,
......@@ -6828,7 +6988,10 @@ class DotProductAttention(TransformerEngineBaseModule):
_,
) = get_attention_backend(attention_params)
if use_flash_attention:
self.logger.info("Running with FlashAttention backend")
self.logger.info(
"Running with FlashAttention backend (version %s)",
_flash_attn_version if not _use_flash_attn_3 else _flash_attn_v3_version,
)
elif use_fused_attention:
self.logger.info(
"Running with FusedAttention backend (sub-backend %s)",
......@@ -6867,6 +7030,8 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_comm_type=self.cp_comm_type,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
)
if use_fused_attention:
......
......@@ -38,6 +38,15 @@ def get_default_fp8_recipe() -> DelayedScaling:
return DelayedScaling()
def get_fp8_torch_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> torch.dtype:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return torch.float8_e4m3fn
return torch.float8_e5m2fn
def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
......
......@@ -56,7 +56,7 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"],
install_requires=["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"],
tests_require=["numpy", "onnxruntime", "torchvision"],
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
......@@ -23,9 +23,6 @@ _default_causal_mask = {}
def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input"""
if sq == 1:
return torch.zeros((1, sk), dtype=torch.bool, device="cuda")
matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask:
diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment