"git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "1a7eb7da6157541ed7867c9aff94231695f2cee9"
Unverified Commit 901e5d2b authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add support for flash-attn 3 (#1019)



* WIP: add fa3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: add benchmarks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* differentiate func/varlen_func
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix parsing keyword for FA3 and remove bshd->thd conversion for flash_attn_func
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add FP8 fwd support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add FA3 FP8 fwd code and test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix assert for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix FA3 FP8 logic and add tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA2 to <=2.6.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak unit tests for base/mask
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* set constraints for FA3 for sm90 and causal_bottom_right
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert debug changes in benchmark script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2215fa5c
Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b
Subproject commit 98ca4e1941fe3263f128f74f10063a3ea35c7019
......@@ -156,7 +156,7 @@ def parse_results(per_cudnn, per_flash, model):
df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6
if per_flash > 0:
t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy()
t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
......
......@@ -92,7 +92,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"])
install_reqs.extend(["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"])
test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"])
if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"])
......
......@@ -420,6 +420,10 @@ model_configs_mask = {
"mask_8_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_9_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"),
"mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"),
"mask_10_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
}
......@@ -1301,6 +1305,7 @@ model_configs_fp8_vs_f16 = {
"fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
"fp8_15": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
}
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
......@@ -1312,6 +1317,27 @@ def _rmse(a, b):
return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
try:
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert rmse < rmse_tol * rmse_range, (
name_a
+ " vs "
+ name_b
+ " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
......@@ -1320,86 +1346,74 @@ def _rmse(a, b):
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
@pytest.mark.parametrize("is_training", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, is_training):
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
global _attention_backends
if not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, is_training
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm
dtype, config, True, qkv_format, input_layernorm, is_training
)
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm
)
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
dtype, config, False, qkv_format, input_layernorm, is_training
)
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
logging.debug(
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
)
)
logging.debug(
"fused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()
if not is_training:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
_error(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
assert (
fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
)
for i in range(len(param_names[:1])):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
bwd_range = max(
fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
logging.debug("========== {:^25s} ==========".format(param_names[i]))
logging.debug(
"fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
)
)
logging.debug(
"fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
if is_training:
for i in range(len(param_names[:1])):
logging.debug("========== {:^25s} ==========".format(param_names[i]))
_error(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]",
atol,
rtol,
rmse_tol,
)
)
logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
logging.debug(e)
assert (
bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, is_training):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
......@@ -1434,6 +1448,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
qkv_weight_interleaved=True,
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
if not is_training:
mha = mha.eval()
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
......@@ -1464,7 +1480,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
hidden_states = tensor.view(*tensor.shape[:-2], -1)
hidden_states.requires_grad = True
if is_training:
hidden_states.requires_grad = True
tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
out_grad = tensor.view(*tensor.shape[:-2], -1)
......@@ -1476,7 +1493,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None,
)
out.backward(out_grad)
if is_training:
out.backward(out_grad)
param_names = []
param_names.append("hidden_states.grad")
......@@ -1487,7 +1505,9 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
param_names.append(name + ".grad")
params.append(param)
return out, param_names, tuple(x.grad for x in params)
if is_training:
return out, param_names, tuple(x.grad for x in params)
return out, param_names, tuple(None for x in params)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
......@@ -1497,7 +1517,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
@pytest.mark.parametrize("is_training", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
config = model_configs_fp8_vs_f16[model]
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
......@@ -1505,76 +1526,69 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
global _attention_backends
_attention_backends["backend_selection_requires_update"] = True
if not is_training:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout)
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
)
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(dtype, config, False, qkv_layout)
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
dtype, config, False, qkv_layout, is_training
)
tols = dict(atol=5e-1, rtol=5e-2)
atol = 5e-1
rtol = 5e-2
rmse_tol = 0.1
bwd_names = ["dq", "dk", "dv"]
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
)
logging.debug("========== {:^25s} ==========".format("forward output"))
logging.debug(
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
if not is_training:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
_error(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
logging.debug(
"fused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()
)
)
logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
assert (
fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
)
for i, _ in enumerate(fused_attn_bwd_f16):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
bwd_range = max(
fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
logging.debug(
"fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
)
)
logging.debug(
"fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
if is_training:
for i, _ in enumerate(fused_attn_bwd_f16):
logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
_error(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]",
atol,
rtol,
rmse_tol,
)
)
logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
logging.debug(e)
assert (
bwd_rmse < rmse_tol * bwd_range
), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
)
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
......@@ -1607,6 +1621,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
attention_type="self",
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
if not is_training:
dpa = dpa.eval()
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
......@@ -1680,9 +1696,12 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=True,
)
out.backward(out_grad)
if is_training:
out.backward(out_grad)
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
if is_training:
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
return out, (None, None, None)
model_configs_fp8 = {
......@@ -1726,58 +1745,26 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
tols = dict(atol=5e-1, rtol=5e-1)
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16)
fwd_range = max(fused_attn_fwd_fp8.max().item(), unfused_attn_fwd_f16.max().item()) - min(
fused_attn_fwd_fp8.min().item(), unfused_attn_fwd_f16.min().item()
)
bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16)
bwd_range = max(fused_attn_bwd_fp8.max().item(), unfused_attn_bwd_f16.max().item()) - min(
fused_attn_bwd_fp8.min().item(), unfused_attn_bwd_f16.min().item()
)
logging.debug(
"fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
)
)
logging.debug(
"unfused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item()
)
)
logging.debug("fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
logging.debug(
"fused_attn_bwd_fp8 min {:.6f} max {:.6f}".format(
fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item()
)
)
logging.debug(
"unfused_attn_bwd_f16 min {:.6f} max {:.6f}".format(
unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item()
)
)
logging.debug("fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}".format(bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols)
except Exception as e:
logging.debug(e)
assert (
fwd_rmse < rmse_tol * fwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
_error(
fused_attn_fwd_fp8,
unfused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"unfused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
assert (
bwd_rmse < rmse_tol * bwd_range
), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
_error(
fused_attn_bwd_fp8,
unfused_attn_bwd_f16,
"fused_attn_bwd_fp8",
"unfused_attn_bwd_f16",
atol,
rtol,
rmse_tol,
)
......
This diff is collapsed.
......@@ -38,6 +38,15 @@ def get_default_fp8_recipe() -> DelayedScaling:
return DelayedScaling()
def get_fp8_torch_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> torch.dtype:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return torch.float8_e4m3fn
return torch.float8_e5m2fn
def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
......
......@@ -56,7 +56,7 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"],
install_requires=["torch", "flash-attn>=2.0.6,<=2.6.3,!=2.0.9,!=2.1.0"],
tests_require=["numpy", "onnxruntime", "torchvision"],
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
......@@ -23,9 +23,6 @@ _default_causal_mask = {}
def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input"""
if sq == 1:
return torch.zeros((1, sk), dtype=torch.bool, device="cuda")
matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask:
diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment