Unverified Commit 43569381 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add documentation for dot product attention (#889)



* add attention docs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attention doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attention doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attention doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attn doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attn doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attn doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attention doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* first draft
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak to first draft
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up pictures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* first draft for review
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add logging info/debug
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix of an SWA message
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use subprocess instaed of os.sys
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up benchmark script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add example script and update notebook
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweaks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix Jax/Paddle related comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rerun H100 benchmark
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* restrict fp8 tests to sm90+
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move get_cudnn_version from common to pytorch utils
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 905d94f4
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os, sys, time
import subprocess
import pandas as pd
import numpy as np
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_run_dot_product_attention
)
pd.set_option("display.precision", 4)
# data type
dtype = torch.bfloat16
# number of iterations after 3 warmup iterations
num_iters = 3
# checkpointing
ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
qkv_layout = 'bshd_bshd_bshd'
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
is_training = True
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}
def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model]
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=5e-3, rtol=5e-3)
cudnn_times = []
flash_times = []
warmup_iters = 3
for i in range(warmup_iters):
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
torch.cuda.cudart().cudaProfilerStart()
torch.cuda.synchronize()
fused_attn_start = time.time()
if fused_attn_supported:
for i in range(num_iters):
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
torch.cuda.synchronize()
fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0
torch.cuda.synchronize()
flash_attn_start = time.time()
if flash_attn_supported:
for i in range(num_iters):
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
torch.cuda.synchronize()
flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0
df = pd.read_csv('times.csv')
df = pd.concat([
df,
pd.DataFrame(
[[fused_attn_time*1e3/num_iters, 0, 0, 0,
flash_attn_time*1e3/num_iters, 0, 0, 0, 0]], columns=df.columns)],
ignore_index=True
)
df.to_csv('times.csv',index=False)
torch.cuda.cudart().cudaProfilerStop()
def parse_results(per_cudnn, per_flash, model):
filename = f'prof_{model}_cuda_gpu_trace.csv'
df = pd.read_csv(os.path.join('./',filename))
df_times = pd.read_csv('times.csv')
row = len(df_times.index)-1
if per_cudnn > 0:
t_cudnn_all = df[df['Name'].str.contains('cudnn')]['Duration (ns)'].to_numpy()
t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn)
t_cudnn_avg = np.average(t_cudnn_all, axis=0)
df_times.loc[row, 'FusedAttention Kernels (fwd)'] = t_cudnn_avg[0]/1e6
df_times.loc[row, 'FusedAttention Kernels (bwd)'] = t_cudnn_avg[1:4].sum()/1e6
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)'] = t_cudnn_avg.sum()/1e6
if per_flash > 0:
t_flash_all = df[df['Name'].str.contains('void flash')]['Duration (ns)'].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, 'FlashAttention Kernels (fwd)'] = t_flash_avg[0]/1e6
df_times.loc[row, 'FlashAttention Kernels (bwd)'] = t_flash_avg[1:4].sum()/1e6
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] = t_flash_avg.sum()/1e6
if per_cudnn > 0 and per_flash > 0:
df_times.loc[row, 'Fused vs Flash Kernels Speedup (fwd+bwd)'] = \
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] / \
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)']
df_times.to_csv('times.csv',index=False)
def main():
times = pd.DataFrame(
columns=[
'FusedAttention Module',
'FusedAttention Kernels (fwd)',
'FusedAttention Kernels (bwd)',
'FusedAttention Kernels (fwd+bwd)',
'FlashAttention Module',
'FlashAttention Kernels (fwd)',
'FlashAttention Kernels (bwd)',
'FlashAttention Kernels (fwd+bwd)',
'Fused vs Flash Kernels Speedup (fwd+bwd)',
])
times.to_csv('times.csv',index=False)
device_id = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(device_id)
print(f"Device {device_id}: "
f"{device_properties.name} GPU, "
f"sm{device_properties.major}{device_properties.minor} compute capability, "
f"{device_properties.total_memory/1024**3:.1f}GB memory")
for model in model_configs.keys():
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
print(f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...')
prof_cmd = [
"nsys",
"profile",
"--capture-range=cudaProfilerApi",
"--capture-range-end=stop-shutdown",
"--force-overwrite=true",
f"--output=prof_{model}",
"python",
"-c",
f""" "import benchmark_attention;""",
f"""benchmark_attention.benchmark_dot_product_attention("""
f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """,
]
prof_cmd = ' '.join(prof_cmd)
subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
stats_cmd = [
"nsys",
"stats",
"-q",
"-r",
"cuda_gpu_trace",
"--format",
"csv,column",
"--force-overwrite=true",
"--force-export=true",
f"--output=prof_{model}",
f"prof_{model}.nsys-rep",
]
if fused_attn_supported:
num_kernels_cudnn = 4
if config.attn_bias_type == 'post_scale_bias':
num_kernels_cudnn = num_kernels_cudnn+1
if config.num_heads != config.num_gqa_groups:
num_kernels_cudnn = num_kernels_cudnn+2
else:
num_kernels_cudnn = 0
num_kernels_flash = 4 if flash_attn_supported else 0
stats_cmd = ' '.join(stats_cmd)
subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
parse_cmd = [
"python",
"-c",
f""" "import benchmark_attention;""",
f"""benchmark_attention.parse_results("""
f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """,
]
parse_cmd = ' '.join(parse_cmd)
subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
df_times = pd.read_csv('times.csv')
df_times.index = list(model_configs.keys())
a=df_times[['FusedAttention Kernels (fwd+bwd)',
'FlashAttention Kernels (fwd+bwd)',
'Fused vs Flash Kernels Speedup (fwd+bwd)']]
a.columns = ['cuDNN fwd+bwd (ms)', 'flash-attn fwd+bwd (ms)', 'cuDNN vs flash speedup']
print()
print(a)
if __name__ == "__main__":
main()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import torch
from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from transformer_engine.pytorch.distributed import _set_cuda_rng_state
from transformer_engine.pytorch.attention import DotProductAttention
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
def _run_dot_product_attention(
dtype: torch.dtype,
config: ModelConfig,
qkv_layout: str,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
reset_rng_states()
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
dtype=torch.int32, device="cuda")
inp = torch.randn([config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim],
dtype=dtype, device="cuda")
q = inp[:,:,0,:,:]
k = inp[:,:,1,:,:]
v = inp[:,:,2,:,:]
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
out_grad = torch.randn([config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim],
dtype=dtype, device="cuda")
# Create attention mask / bias
attention_mask = None
bias = None
if config.attn_mask_type == "arbitrary":
attention_mask = torch.randint(-10,10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to(
dtype=torch.bool, device="cuda")
if config.attn_bias_type == "post_scale_bias":
# convert mask to bias
attention_mask = torch.randint(-10,10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to(
dtype=torch.bool, device="cuda")
bias = attention_mask.clone()
neginf = -2**50 if dtype == torch.bfloat16 else -2**15
bias = torch.where(bias==0, 0, neginf).to(dtype=dtype, device='cuda')
bias.requires_grad = False
attention_mask = None
block = (
DotProductAttention(
config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
qkv_format='bshd',
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=None,
tp_group=None,
layer_number=1,
).to(dtype=dtype, device="cuda")
)
# Run a forward and backward pass
out = None
if config.attn_mask_type == "arbitrary":
out = block(q, k, v,
attention_mask=attention_mask, # attention_mask
qkv_format='bshd',
attn_mask_type=config.attn_mask_type, # 'arbitrary'
core_attention_bias_type=config.attn_bias_type, # 'no_bias'
core_attention_bias=bias, # None
)
out.backward(out_grad)
if config.attn_bias_type == "post_scale_bias":
out = block(q, k, v,
attention_mask=attention_mask, # None
qkv_format='bshd',
attn_mask_type=config.attn_mask_type, # no_mask
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
core_attention_bias=bias, # bias
)
out.backward(out_grad)
return out, (q.grad, k.grad, v.grad)
dtype = torch.bfloat16
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_mask": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "arbitrary", "no_bias"),
"test_bias": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
}
print('Run with post_scale_bias:')
config = model_configs["test_bias"]
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd')
print('Run with arbitrary mask:')
config = model_configs["test_mask"]
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd')
torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2)
for i in range(3):
torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2)
print('Test passed!')
This diff is collapsed.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os, sys, time
import subprocess
import pandas as pd
import numpy as np
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_run_dot_product_attention
)
# data type
dtype = torch.bfloat16
# number of iterations after 3 warmup iterations
num_iters = 3
# checkpointing
ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
qkv_layout = 'bshd_bshd_bshd'
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
is_training = True
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}
def example_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model]
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=5e-3, rtol=5e-3)
if fused_attn_supported:
print()
print('Run cuDNN attention...')
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if flash_attn_supported:
print()
print('Run flash-attention...')
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
print()
print('Test passed.')
def main():
models = ['test_0']
for model in models:
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
example_attention(model, fused_attn_supported, flash_attn_supported)
if __name__ == "__main__":
main()
This diff is collapsed.
This diff is collapsed.
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
"""Utility functions for Transformer Engine modules""" """Utility functions for Transformer Engine modules"""
import math import math
import functools
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
import torch import torch
import transformer_engine.pytorch.cpp_extensions as ext
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
...@@ -241,8 +243,19 @@ def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None: ...@@ -241,8 +243,19 @@ def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None:
f"but got tensor with dims={list(tensor.size())}" f"but got tensor with dims={list(tensor.size())}"
) )
def is_bf16_compatible() -> None: def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit """Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher. check on device compute capability to enforce sm_80 or higher.
""" """
return torch.cuda.get_device_capability()[0] >= 8 return torch.cuda.get_device_capability()[0] >= 8
@functools.cache
def get_cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
encoded_version = ext.get_cudnn_version()
major_version_magnitude = 1000 if encoded_version < 90000 else 10000
major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment