Unverified Commit 43569381 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add documentation for dot product attention (#889)



* add attention docs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attention doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attention doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attention doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attn doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attn doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attn doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: update attention doc
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* first draft
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak to first draft
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up pictures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* first draft for review
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add logging info/debug
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix of an SWA message
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use subprocess instaed of os.sys
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up benchmark script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add example script and update notebook
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweak
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor tweaks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix Jax/Paddle related comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rerun H100 benchmark
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* restrict fp8 tests to sm90+
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move get_cudnn_version from common to pytorch utils
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 905d94f4
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os, sys, time
import subprocess
import pandas as pd
import numpy as np
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_run_dot_product_attention
)
pd.set_option("display.precision", 4)
# data type
dtype = torch.bfloat16
# number of iterations after 3 warmup iterations
num_iters = 3
# checkpointing
ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
qkv_layout = 'bshd_bshd_bshd'
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
is_training = True
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}
def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model]
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=5e-3, rtol=5e-3)
cudnn_times = []
flash_times = []
warmup_iters = 3
for i in range(warmup_iters):
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
torch.cuda.cudart().cudaProfilerStart()
torch.cuda.synchronize()
fused_attn_start = time.time()
if fused_attn_supported:
for i in range(num_iters):
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
torch.cuda.synchronize()
fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0
torch.cuda.synchronize()
flash_attn_start = time.time()
if flash_attn_supported:
for i in range(num_iters):
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
torch.cuda.synchronize()
flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0
df = pd.read_csv('times.csv')
df = pd.concat([
df,
pd.DataFrame(
[[fused_attn_time*1e3/num_iters, 0, 0, 0,
flash_attn_time*1e3/num_iters, 0, 0, 0, 0]], columns=df.columns)],
ignore_index=True
)
df.to_csv('times.csv',index=False)
torch.cuda.cudart().cudaProfilerStop()
def parse_results(per_cudnn, per_flash, model):
filename = f'prof_{model}_cuda_gpu_trace.csv'
df = pd.read_csv(os.path.join('./',filename))
df_times = pd.read_csv('times.csv')
row = len(df_times.index)-1
if per_cudnn > 0:
t_cudnn_all = df[df['Name'].str.contains('cudnn')]['Duration (ns)'].to_numpy()
t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn)
t_cudnn_avg = np.average(t_cudnn_all, axis=0)
df_times.loc[row, 'FusedAttention Kernels (fwd)'] = t_cudnn_avg[0]/1e6
df_times.loc[row, 'FusedAttention Kernels (bwd)'] = t_cudnn_avg[1:4].sum()/1e6
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)'] = t_cudnn_avg.sum()/1e6
if per_flash > 0:
t_flash_all = df[df['Name'].str.contains('void flash')]['Duration (ns)'].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, 'FlashAttention Kernels (fwd)'] = t_flash_avg[0]/1e6
df_times.loc[row, 'FlashAttention Kernels (bwd)'] = t_flash_avg[1:4].sum()/1e6
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] = t_flash_avg.sum()/1e6
if per_cudnn > 0 and per_flash > 0:
df_times.loc[row, 'Fused vs Flash Kernels Speedup (fwd+bwd)'] = \
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] / \
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)']
df_times.to_csv('times.csv',index=False)
def main():
times = pd.DataFrame(
columns=[
'FusedAttention Module',
'FusedAttention Kernels (fwd)',
'FusedAttention Kernels (bwd)',
'FusedAttention Kernels (fwd+bwd)',
'FlashAttention Module',
'FlashAttention Kernels (fwd)',
'FlashAttention Kernels (bwd)',
'FlashAttention Kernels (fwd+bwd)',
'Fused vs Flash Kernels Speedup (fwd+bwd)',
])
times.to_csv('times.csv',index=False)
device_id = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(device_id)
print(f"Device {device_id}: "
f"{device_properties.name} GPU, "
f"sm{device_properties.major}{device_properties.minor} compute capability, "
f"{device_properties.total_memory/1024**3:.1f}GB memory")
for model in model_configs.keys():
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
print(f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...')
prof_cmd = [
"nsys",
"profile",
"--capture-range=cudaProfilerApi",
"--capture-range-end=stop-shutdown",
"--force-overwrite=true",
f"--output=prof_{model}",
"python",
"-c",
f""" "import benchmark_attention;""",
f"""benchmark_attention.benchmark_dot_product_attention("""
f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """,
]
prof_cmd = ' '.join(prof_cmd)
subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
stats_cmd = [
"nsys",
"stats",
"-q",
"-r",
"cuda_gpu_trace",
"--format",
"csv,column",
"--force-overwrite=true",
"--force-export=true",
f"--output=prof_{model}",
f"prof_{model}.nsys-rep",
]
if fused_attn_supported:
num_kernels_cudnn = 4
if config.attn_bias_type == 'post_scale_bias':
num_kernels_cudnn = num_kernels_cudnn+1
if config.num_heads != config.num_gqa_groups:
num_kernels_cudnn = num_kernels_cudnn+2
else:
num_kernels_cudnn = 0
num_kernels_flash = 4 if flash_attn_supported else 0
stats_cmd = ' '.join(stats_cmd)
subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
parse_cmd = [
"python",
"-c",
f""" "import benchmark_attention;""",
f"""benchmark_attention.parse_results("""
f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """,
]
parse_cmd = ' '.join(parse_cmd)
subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
df_times = pd.read_csv('times.csv')
df_times.index = list(model_configs.keys())
a=df_times[['FusedAttention Kernels (fwd+bwd)',
'FlashAttention Kernels (fwd+bwd)',
'Fused vs Flash Kernels Speedup (fwd+bwd)']]
a.columns = ['cuDNN fwd+bwd (ms)', 'flash-attn fwd+bwd (ms)', 'cuDNN vs flash speedup']
print()
print(a)
if __name__ == "__main__":
main()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import torch
from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from transformer_engine.pytorch.distributed import _set_cuda_rng_state
from transformer_engine.pytorch.attention import DotProductAttention
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
def _run_dot_product_attention(
dtype: torch.dtype,
config: ModelConfig,
qkv_layout: str,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
reset_rng_states()
seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
dtype=torch.int32, device="cuda")
seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
dtype=torch.int32, device="cuda")
inp = torch.randn([config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim],
dtype=dtype, device="cuda")
q = inp[:,:,0,:,:]
k = inp[:,:,1,:,:]
v = inp[:,:,2,:,:]
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
out_grad = torch.randn([config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim],
dtype=dtype, device="cuda")
# Create attention mask / bias
attention_mask = None
bias = None
if config.attn_mask_type == "arbitrary":
attention_mask = torch.randint(-10,10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to(
dtype=torch.bool, device="cuda")
if config.attn_bias_type == "post_scale_bias":
# convert mask to bias
attention_mask = torch.randint(-10,10,
[config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to(
dtype=torch.bool, device="cuda")
bias = attention_mask.clone()
neginf = -2**50 if dtype == torch.bfloat16 else -2**15
bias = torch.where(bias==0, 0, neginf).to(dtype=dtype, device='cuda')
bias.requires_grad = False
attention_mask = None
block = (
DotProductAttention(
config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
qkv_format='bshd',
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=None,
tp_group=None,
layer_number=1,
).to(dtype=dtype, device="cuda")
)
# Run a forward and backward pass
out = None
if config.attn_mask_type == "arbitrary":
out = block(q, k, v,
attention_mask=attention_mask, # attention_mask
qkv_format='bshd',
attn_mask_type=config.attn_mask_type, # 'arbitrary'
core_attention_bias_type=config.attn_bias_type, # 'no_bias'
core_attention_bias=bias, # None
)
out.backward(out_grad)
if config.attn_bias_type == "post_scale_bias":
out = block(q, k, v,
attention_mask=attention_mask, # None
qkv_format='bshd',
attn_mask_type=config.attn_mask_type, # no_mask
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
core_attention_bias=bias, # bias
)
out.backward(out_grad)
return out, (q.grad, k.grad, v.grad)
dtype = torch.bfloat16
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_mask": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "arbitrary", "no_bias"),
"test_bias": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
}
print('Run with post_scale_bias:')
config = model_configs["test_bias"]
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd')
print('Run with arbitrary mask:')
config = model_configs["test_mask"]
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd')
torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2)
for i in range(3):
torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2)
print('Test passed!')
This diff is collapsed.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os, sys, time
import subprocess
import pandas as pd
import numpy as np
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_run_dot_product_attention
)
# data type
dtype = torch.bfloat16
# number of iterations after 3 warmup iterations
num_iters = 3
# checkpointing
ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
qkv_layout = 'bshd_bshd_bshd'
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
is_training = True
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}
def example_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model]
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=5e-3, rtol=5e-3)
if fused_attn_supported:
print()
print('Run cuDNN attention...')
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if flash_attn_supported:
print()
print('Run flash-attention...')
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
print()
print('Test passed.')
def main():
models = ['test_0']
for model in models:
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
example_attention(model, fused_attn_supported, flash_attn_supported)
if __name__ == "__main__":
main()
This diff is collapsed.
This diff is collapsed.
......@@ -4,8 +4,10 @@
"""Utility functions for Transformer Engine modules"""
import math
import functools
from typing import Any, Callable, Optional, Tuple
import torch
import transformer_engine.pytorch.cpp_extensions as ext
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
......@@ -241,8 +243,19 @@ def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None:
f"but got tensor with dims={list(tensor.size())}"
)
def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
"""
return torch.cuda.get_device_capability()[0] >= 8
@functools.cache
def get_cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
encoded_version = ext.get_cudnn_version()
major_version_magnitude = 1000 if encoded_version < 90000 else 10000
major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment