Unverified Commit 88c0c914 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Update docs/example and benchmarks/ scripts (#1075)



* update example/benchmark scripts
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix head_dim after MLA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update notebook
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b8d453ef
...@@ -11,9 +11,7 @@ import nvtx ...@@ -11,9 +11,7 @@ import nvtx
import transformer_engine import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import ( from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig, ModelConfig,
_is_flash_attention_supported, _get_attention_backends,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_run_dot_product_attention, _run_dot_product_attention,
) )
...@@ -29,8 +27,6 @@ ckpt_attn = False ...@@ -29,8 +27,6 @@ ckpt_attn = False
workspace_opt = True workspace_opt = True
# QKV memory layout # QKV memory layout
qkv_layout = "bshd_bshd_bshd" qkv_layout = "bshd_bshd_bshd"
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd # padding between sequences for qkv_format=thd
pad_between_seqs = False pad_between_seqs = False
# training mode # training mode
...@@ -64,7 +60,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ...@@ -64,7 +60,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -76,7 +71,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ...@@ -76,7 +71,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -97,7 +91,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ...@@ -97,7 +91,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -115,7 +108,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp ...@@ -115,7 +108,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -205,13 +197,15 @@ def main(): ...@@ -205,13 +197,15 @@ def main():
) )
for model in model_configs.keys(): for model in model_configs.keys():
config = model_configs[model] config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( available_backends, fused_attn_backends = _get_attention_backends(
config, config,
dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
) )
fused_attn_supported = fused_attn_supported and not swa flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
flash_attn_supported = _is_flash_attention_supported(config)
print( print(
f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}' f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...' f'{" and flash-attention" if flash_attn_supported else ""}...'
......
...@@ -6,7 +6,6 @@ import os ...@@ -6,7 +6,6 @@ import os
import torch import torch
from typing import Tuple from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from transformer_engine.pytorch.distributed import _set_cuda_rng_state
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
# Initialize RNG state # Initialize RNG state
...@@ -22,7 +21,7 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) ...@@ -22,7 +21,7 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None: def reset_rng_states() -> None:
"""Revert back to initial RNG state""" """Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state) torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state) torch.cuda.set_rng_state(_cuda_rng_state)
def _run_dot_product_attention( def _run_dot_product_attention(
...@@ -40,7 +39,7 @@ def _run_dot_product_attention( ...@@ -40,7 +39,7 @@ def _run_dot_product_attention(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
) )
inp = torch.randn( inp = torch.randn(
[config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim], [config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk],
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
) )
...@@ -51,7 +50,7 @@ def _run_dot_product_attention( ...@@ -51,7 +50,7 @@ def _run_dot_product_attention(
k.requires_grad = True k.requires_grad = True
v.requires_grad = True v.requires_grad = True
out_grad = torch.randn( out_grad = torch.randn(
[config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim], [config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim_v],
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
) )
...@@ -80,7 +79,7 @@ def _run_dot_product_attention( ...@@ -80,7 +79,7 @@ def _run_dot_product_attention(
block = DotProductAttention( block = DotProductAttention(
config.num_heads, config.num_heads,
config.head_dim, config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups, num_gqa_groups=config.num_gqa_groups,
qkv_format="bshd", qkv_format="bshd",
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
...@@ -89,6 +88,8 @@ def _run_dot_product_attention( ...@@ -89,6 +88,8 @@ def _run_dot_product_attention(
get_rng_state_tracker=None, get_rng_state_tracker=None,
tp_group=None, tp_group=None,
layer_number=1, layer_number=1,
attn_mask_type="no_mask",
window_size=(-1, -1),
).to(dtype=dtype, device="cuda") ).to(dtype=dtype, device="cuda")
# Run a forward and backward pass # Run a forward and backward pass
...@@ -103,6 +104,7 @@ def _run_dot_product_attention( ...@@ -103,6 +104,7 @@ def _run_dot_product_attention(
attn_mask_type=config.attn_mask_type, # 'arbitrary' attn_mask_type=config.attn_mask_type, # 'arbitrary'
core_attention_bias_type=config.attn_bias_type, # 'no_bias' core_attention_bias_type=config.attn_bias_type, # 'no_bias'
core_attention_bias=bias, # None core_attention_bias=bias, # None
window_size=(-1, -1),
) )
out.backward(out_grad) out.backward(out_grad)
...@@ -116,6 +118,7 @@ def _run_dot_product_attention( ...@@ -116,6 +118,7 @@ def _run_dot_product_attention(
attn_mask_type=config.attn_mask_type, # no_mask attn_mask_type=config.attn_mask_type, # no_mask
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias' core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
core_attention_bias=bias, # bias core_attention_bias=bias, # bias
window_size=(-1, -1),
) )
out.backward(out_grad) out.backward(out_grad)
...@@ -133,6 +136,7 @@ print("Run with post_scale_bias:") ...@@ -133,6 +136,7 @@ print("Run with post_scale_bias:")
config = model_configs["test_bias"] config = model_configs["test_bias"]
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd") fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
print()
print("Run with arbitrary mask:") print("Run with arbitrary mask:")
config = model_configs["test_mask"] config = model_configs["test_mask"]
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd") unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
...@@ -140,4 +144,6 @@ unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, " ...@@ -140,4 +144,6 @@ unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "
torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2) torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2)
for i in range(3): for i in range(3):
torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2) torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2)
print()
print("Test passed!") print("Test passed!")
This diff is collapsed.
...@@ -11,9 +11,7 @@ import nvtx ...@@ -11,9 +11,7 @@ import nvtx
import transformer_engine import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import ( from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig, ModelConfig,
_is_flash_attention_supported, _get_attention_backends,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_run_dot_product_attention, _run_dot_product_attention,
) )
...@@ -60,7 +58,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported): ...@@ -60,7 +58,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported):
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -75,7 +72,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported): ...@@ -75,7 +72,6 @@ def example_attention(model, fused_attn_supported, flash_attn_supported):
ckpt_attn, ckpt_attn,
qkv_layout, qkv_layout,
workspace_opt, workspace_opt,
swa,
pad_between_seqs, pad_between_seqs,
is_training, is_training,
) )
...@@ -94,13 +90,14 @@ def main(): ...@@ -94,13 +90,14 @@ def main():
models = ["test_0"] models = ["test_0"]
for model in models: for model in models:
config = model_configs[model] config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( available_backends, fused_attn_backends = _get_attention_backends(
config, config,
dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
) )
fused_attn_supported = fused_attn_supported and not swa flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
flash_attn_supported = _is_flash_attention_supported(config)
example_attention(model, fused_attn_supported, flash_attn_supported) example_attention(model, fused_attn_supported, flash_attn_supported)
......
...@@ -8,6 +8,7 @@ import math ...@@ -8,6 +8,7 @@ import math
import os import os
from importlib.metadata import version from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union, Optional from typing import Any, Dict, List, Tuple, Union, Optional
from contextlib import contextmanager
import pytest import pytest
import torch import torch
...@@ -108,6 +109,16 @@ class ModelConfig: ...@@ -108,6 +109,16 @@ class ModelConfig:
self.window_size = window_size self.window_size = window_size
@contextmanager
def logging_context(highest_level=logging.WARNING):
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def _get_attention_backends( def _get_attention_backends(
config: ModelConfig, config: ModelConfig,
qkv_dtype: torch.dtype, qkv_dtype: torch.dtype,
...@@ -180,6 +191,7 @@ def _get_attention_backends( ...@@ -180,6 +191,7 @@ def _get_attention_backends(
return available_backends, fused_attention_backend return available_backends, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3): for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment