Unverified Commit 88c0c914 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Update docs/example and benchmarks/ scripts (#1075)



* update example/benchmark scripts
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix head_dim after MLA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update notebook
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b8d453ef
......@@ -11,9 +11,7 @@ import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_get_attention_backends,
_run_dot_product_attention,
)
......@@ -29,8 +27,6 @@ ckpt_attn = False
workspace_opt = True
# QKV memory layout
qkv_layout = "bshd_bshd_bshd"
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
......@@ -64,7 +60,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -76,7 +71,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -97,7 +91,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -115,7 +108,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -205,13 +197,15 @@ def main():
)
for model in model_configs.keys():
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
available_backends, fused_attn_backends = _get_attention_backends(
config,
dtype,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
print(
f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...'
......
......@@ -6,7 +6,6 @@ import os
import torch
from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from transformer_engine.pytorch.distributed import _set_cuda_rng_state
from transformer_engine.pytorch.attention import DotProductAttention
# Initialize RNG state
......@@ -22,7 +21,7 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
def _run_dot_product_attention(
......@@ -40,7 +39,7 @@ def _run_dot_product_attention(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
inp = torch.randn(
[config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim],
[config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk],
dtype=dtype,
device="cuda",
)
......@@ -51,7 +50,7 @@ def _run_dot_product_attention(
k.requires_grad = True
v.requires_grad = True
out_grad = torch.randn(
[config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim],
[config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim_v],
dtype=dtype,
device="cuda",
)
......@@ -80,7 +79,7 @@ def _run_dot_product_attention(
block = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
qkv_format="bshd",
attention_dropout=config.dropout_p,
......@@ -89,6 +88,8 @@ def _run_dot_product_attention(
get_rng_state_tracker=None,
tp_group=None,
layer_number=1,
attn_mask_type="no_mask",
window_size=(-1, -1),
).to(dtype=dtype, device="cuda")
# Run a forward and backward pass
......@@ -103,6 +104,7 @@ def _run_dot_product_attention(
attn_mask_type=config.attn_mask_type, # 'arbitrary'
core_attention_bias_type=config.attn_bias_type, # 'no_bias'
core_attention_bias=bias, # None
window_size=(-1, -1),
)
out.backward(out_grad)
......@@ -116,6 +118,7 @@ def _run_dot_product_attention(
attn_mask_type=config.attn_mask_type, # no_mask
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
core_attention_bias=bias, # bias
window_size=(-1, -1),
)
out.backward(out_grad)
......@@ -133,6 +136,7 @@ print("Run with post_scale_bias:")
config = model_configs["test_bias"]
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
print()
print("Run with arbitrary mask:")
config = model_configs["test_mask"]
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
......@@ -140,4 +144,6 @@ unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "
torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2)
for i in range(3):
torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2)
print()
print("Test passed!")
This diff is collapsed.
......@@ -11,9 +11,7 @@ import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_get_attention_backends,
_run_dot_product_attention,
)
......@@ -60,7 +58,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported):
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -75,7 +72,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported):
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
......@@ -94,13 +90,14 @@ def main():
models = ["test_0"]
for model in models:
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
available_backends, fused_attn_backends = _get_attention_backends(
config,
dtype,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
example_attention(model, fused_attn_supported, flash_attn_supported)
......
......@@ -8,6 +8,7 @@ import math
import os
from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union, Optional
from contextlib import contextmanager
import pytest
import torch
......@@ -108,6 +109,16 @@ class ModelConfig:
self.window_size = window_size
@contextmanager
def logging_context(highest_level=logging.WARNING):
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def _get_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
......@@ -180,6 +191,7 @@ def _get_attention_backends(
return available_backends, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment