Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
013f0c4f
Unverified
Commit
013f0c4f
authored
Sep 20, 2024
by
Luka Govedič
Committed by
GitHub
Sep 20, 2024
Browse files
CMake build, allowing parent build (#19)
parent
344c988d
Pipeline
#2020
failed with stages
in 0 seconds
Changes
37
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
278 additions
and
6410 deletions
+278
-6410
tests/models/test_llama.py
tests/models/test_llama.py
+0
-633
tests/models/test_opt.py
tests/models/test_opt.py
+0
-237
tests/models/test_vit.py
tests/models/test_vit.py
+0
-48
tests/modules/test_block_parallel.py
tests/modules/test_block_parallel.py
+0
-273
tests/modules/test_embedding_parallel.py
tests/modules/test_embedding_parallel.py
+0
-106
tests/modules/test_mha_parallel.py
tests/modules/test_mha_parallel.py
+0
-160
tests/modules/test_mlp_parallel.py
tests/modules/test_mlp_parallel.py
+0
-143
tests/ops/test_dropout_layer_norm.py
tests/ops/test_dropout_layer_norm.py
+0
-1189
tests/ops/test_fused_dense.py
tests/ops/test_fused_dense.py
+0
-172
tests/ops/test_fused_dense_parallel.py
tests/ops/test_fused_dense_parallel.py
+0
-237
tests/ops/triton/test_layer_norm.py
tests/ops/triton/test_layer_norm.py
+0
-368
tests/pyproject.toml
tests/pyproject.toml
+0
-3
tests/test_flash_attn.py
tests/test_flash_attn.py
+0
-2543
tests/test_rotary.py
tests/test_rotary.py
+0
-291
tests/test_vllm_flash_attn.py
tests/test_vllm_flash_attn.py
+269
-0
vllm_flash_attn/__init__.py
vllm_flash_attn/__init__.py
+2
-1
vllm_flash_attn/flash_attn_interface.py
vllm_flash_attn/flash_attn_interface.py
+7
-6
No files found.
tests/models/test_llama.py
deleted
100644 → 0
View file @
344c988d
# Copyright (c) 2023, Tri Dao.
# To run the huggingface implementation of LLaMa (1), we first need to convert the weights:
# https://github.com/huggingface/transformers/pull/21955
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
# and repeat for 13B, 30B, 65B
import
os
import
time
from
pathlib
import
Path
current_dir
=
Path
(
__file__
).
parent
.
absolute
()
import
shutil
import
pytest
import
torch
from
einops
import
rearrange
from
flash_attn.models.gpt
import
GPTLMHeadModel
,
combine_state_dicts_tp
,
shard_state_dict_tp
from
flash_attn.models.llama
import
(
config_from_checkpoint
,
inv_remap_state_dict_hf_llama
,
llama_config_to_gpt2_config
,
remap_state_dict_hf_llama
,
remap_state_dict_meta_llama
,
state_dicts_from_checkpoint
,
)
from
flash_attn.utils.distributed
import
all_gather_raw
from
flash_attn.utils.generation
import
update_graph_cache
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
transformers
import
LlamaConfig
,
LlamaTokenizer
from
transformers.models.llama.modeling_llama
import
LlamaForCausalLM
from
transformers
import
AutoConfig
def
_pretrained_state_dict_from_checkpoint
(
checkpoint_path
,
model_name
,
config
,
checkpoint_format
):
if
checkpoint_format
==
"meta"
:
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dicts
=
[
remap_state_dict_meta_llama
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
else
:
pretrained_state_dict
=
state_dict_from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
)
pretrained_state_dict
=
remap_state_dict_hf_llama
(
pretrained_state_dict
,
config
)
return
pretrained_state_dict
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"7B"
])
def
test_llama_state_dict
(
model_name
):
checkpoint_path
=
(
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
))
/
"llama"
)
config
=
llama_config_to_gpt2_config
(
config_from_checkpoint
(
checkpoint_path
,
model_name
))
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dict
=
remap_state_dict_meta_llama
(
ckpt_state_dicts
[
0
],
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
"meta"
)
# Without device='meta' init is very slow
state_dict
=
model
.
state_dict
()
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
for
k
in
state_dict
.
keys
():
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
# TinyLlama-1.1B is to test MQA
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"meta-llama/Llama-2-7b-hf"
,
"PY007/TinyLlama-1.1B-step-50K-105b"
]
)
def
test_inv_remap_state_dict_hf_llama
(
model_name
):
config
=
llama_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
state_dict
=
state_dict_from_pretrained
(
model_name
)
# inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
state_dict
=
{
key
:
val
for
key
,
val
in
state_dict
.
items
()
if
"rotary_emb.inv_freq"
not
in
key
}
pretrained_state_dict
=
remap_state_dict_hf_llama
(
state_dict
,
config
)
state_dict_recover
=
inv_remap_state_dict_hf_llama
(
pretrained_state_dict
,
config
)
assert
set
(
state_dict_recover
.
keys
())
==
set
(
state_dict
.
keys
())
for
key
in
state_dict_recover
.
keys
():
torch
.
testing
.
assert_close
(
state_dict_recover
[
key
],
state_dict
[
key
])
# TinyLlama-1.1B is to test MQA
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"7B"
,
# Llama 1
"13B"
,
# Llama 1
"meta-llama/Llama-2-13b-hf"
,
"codellama/CodeLlama-7b-hf"
,
"codellama/CodeLlama-13b-hf"
,
"codellama/CodeLlama-34b-hf"
,
"PY007/TinyLlama-1.1B-step-50K-105b"
,
],
)
def
test_llama_optimized
(
model_name
):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
checkpoint_path
=
(
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
))
/
"llama"
)
dtype
=
torch
.
float16
device
=
"cuda"
if
"/"
in
model_name
:
# Download from HF
config
=
llama_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
else
:
config
=
config_from_checkpoint
(
checkpoint_path
,
model_name
,
checkpoint_format
=
"meta"
)
config
=
llama_config_to_gpt2_config
(
config
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
if
"/"
in
model_name
:
# Download from HF
pretrained_state_dict
=
remap_state_dict_hf_llama
(
state_dict_from_pretrained
(
model_name
),
config
)
else
:
pretrained_state_dict
=
_pretrained_state_dict_from_checkpoint
(
checkpoint_path
,
model_name
,
config
,
checkpoint_format
=
"meta"
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
logits
=
model
(
input_ids
).
logits
del
model
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
model_name
if
"/"
in
model_name
else
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
device_map
=
"auto"
,
)
model_ref
.
eval
()
with
torch
.
no_grad
():
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_ref
=
model_ref
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_ref
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
model_name
if
"/"
in
model_name
else
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
)
model_hf
.
eval
()
with
torch
.
no_grad
():
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
logits_hf
=
model_hf
(
input_ids
).
logits
del
model_hf
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"13B"
,
"meta-llama/Llama-2-13b-hf"
,
"codellama/CodeLlama-34b-hf"
]
)
def
test_llama_parallel
(
model_name
,
world_size
):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
from
apex.transformer
import
parallel_state
checkpoint_path
=
(
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
))
/
"llama"
)
dtype
=
torch
.
float16
if
"/"
in
model_name
:
# Download from HF
config
=
llama_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
else
:
config
=
config_from_checkpoint
(
checkpoint_path
,
model_name
,
checkpoint_format
=
"meta"
)
config
=
llama_config_to_gpt2_config
(
config
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
)
device
=
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
if
"/"
in
model_name
:
# Download from HF
pretrained_state_dict
=
remap_state_dict_hf_llama
(
state_dict_from_pretrained
(
model_name
),
config
)
else
:
pretrained_state_dict
=
_pretrained_state_dict_from_checkpoint
(
checkpoint_path
,
model_name
,
config
,
checkpoint_format
=
"meta"
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
))
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
out
,
_
=
all_gather_raw
(
out
,
process_group
=
process_group
)
out
=
rearrange
(
out
,
"(b s) d -> b s d"
,
b
=
batch_size
)
logits
=
model
(
input_ids
).
logits
logits
=
rearrange
(
logits
,
"(b s) d -> b s d"
,
b
=
batch_size
)
logits
,
_
=
all_gather_raw
(
logits
,
process_group
)
logits
=
rearrange
(
logits
,
"(n b) ... d -> b ... (n d)"
,
b
=
batch_size
)
del
model
if
rank
==
0
:
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
model_name
if
"/"
in
model_name
else
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
device_map
=
"auto"
,
)
model_ref
.
eval
()
with
torch
.
no_grad
():
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_ref
=
model_ref
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_ref
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
model_name
if
"/"
in
model_name
else
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
torch_dtype
=
dtype
,
device_map
=
"auto"
,
)
model_hf
.
eval
()
with
torch
.
no_grad
():
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_hf
=
model_hf
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_hf
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
2
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
# @pytest.mark.parametrize('model_name', ["7B", "13B"])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"7B"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint_format"
,
[
"meta"
,
"hf"
])
def
test_llama_generation
(
model_name
,
checkpoint_format
):
checkpoint_path
=
(
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
))
/
"llama"
)
dtype
=
torch
.
float16
device
=
"cuda"
config
=
config_from_checkpoint
(
checkpoint_path
,
model_name
,
checkpoint_format
)
config
=
llama_config_to_gpt2_config
(
config
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
)
eos_token_id
=
tokenizer
.
eos_token_id
torch
.
manual_seed
(
0
)
batch_size
=
1
seqlen
=
100
max_length
=
150
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
}
)
model_hf
.
eval
()
print
(
"HF fp16"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
del
model_hf
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
device_map
=
"auto"
)
model_ref
.
eval
()
with
torch
.
no_grad
():
logits_ref
=
model_ref
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
)
:
-
1
].
to
(
device
=
device
)
del
model_ref
pretrained_state_dict
=
_pretrained_state_dict_from_checkpoint
(
checkpoint_path
,
model_name
,
config
,
checkpoint_format
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
print
(
"Without CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
with
torch
.
no_grad
():
logits_parallel
=
model
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
)
:
-
1
]
logits_hf
=
torch
.
stack
(
out_hf
.
scores
,
dim
=
1
)
logits
=
torch
.
stack
(
out
.
scores
,
dim
=
1
)
logits_cg
=
torch
.
stack
(
out_cg
.
scores
,
dim
=
1
)
del
model
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
print
(
f
"HF fp16 logits max diff:
{
hf_error
}
"
)
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits CG max diff:
{
(
logits_cg
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
assert
(
logits_parallel
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
assert
torch
.
equal
(
logits_cg
,
logits
)
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"13B"
,
"meta-llama/Llama-2-13b-hf"
,
"codellama/CodeLlama-34b-hf"
]
)
def
test_llama_parallel_generation
(
model_name
,
world_size
):
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
from
apex.transformer
import
parallel_state
checkpoint_path
=
(
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
))
/
"llama"
)
dtype
=
torch
.
float16
if
"/"
in
model_name
:
# Download from HF
config
=
llama_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
else
:
config
=
config_from_checkpoint
(
checkpoint_path
,
model_name
,
checkpoint_format
=
"meta"
)
config
=
llama_config_to_gpt2_config
(
config
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
config
.
pad_vocab_size_multiple
=
8
*
world_size
config
.
sequence_parallel
=
False
# Need to set this to False for generation
os
.
environ
[
"NCCL_ASYNC_ERROR_HANDLING"
]
=
"0"
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
)
device
=
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
torch
.
manual_seed
(
0
)
batch_size
=
1
seqlen
=
100
max_length
=
150
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch
.
cuda
.
set_device
(
device
)
if
"/"
in
model_name
:
# Download from HF
pretrained_state_dict
=
remap_state_dict_hf_llama
(
state_dict_from_pretrained
(
model_name
),
config
)
else
:
pretrained_state_dict
=
_pretrained_state_dict_from_checkpoint
(
checkpoint_path
,
model_name
,
config
,
checkpoint_format
=
"meta"
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
))
model
.
eval
()
print
(
"Without CUDA graph"
)
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
)
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
cg
=
True
,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
)
del
model
parallel_state
.
destroy_model_parallel
()
if
rank
==
0
:
# Without device_map, the model is loaded on the CPU, which is very slow
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
model_name
if
"/"
in
model_name
else
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
torch_dtype
=
dtype
,
device_map
=
"auto"
,
)
model_hf
.
eval
()
print
(
"HF fp16"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
with
torch
.
inference_mode
():
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
del
model_hf
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
model_name
if
"/"
in
model_name
else
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
device_map
=
"auto"
,
)
model_ref
.
eval
()
with
torch
.
inference_mode
():
logits_ref
=
model_ref
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
)
:
-
1
]
del
model_ref
logits_hf
=
torch
.
stack
(
out_hf
.
scores
,
dim
=
1
)
logits
=
torch
.
stack
(
out
.
scores
,
dim
=
1
)
logits_cg
=
torch
.
stack
(
out_cg
.
scores
,
dim
=
1
)
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
print
(
f
"HF fp16 logits max diff:
{
hf_error
}
"
)
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
print
(
f
"Logits CG max diff:
{
(
logits_cg
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
assert
torch
.
equal
(
logits_cg
,
logits
)
@
torch
.
no_grad
()
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
def
test_llama_parallel_uneven_num_heads
(
world_size
):
from
apex.transformer
import
parallel_state
checkpoint_path
=
(
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
))
/
"llama"
)
num_attention_heads
=
world_size
+
1
model_name
=
f
"teeny-
{
num_attention_heads
}
-heads"
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
)
device
=
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
dtype
=
torch
.
float16
llama_config
=
LlamaConfig
(
hidden_size
=
256
*
num_attention_heads
,
# ParallelGatedMlp hidden_features must be divisible by 256
intermediate_size
=
256
*
num_attention_heads
*
4
,
num_hidden_layers
=
4
,
num_attention_heads
=
num_attention_heads
,
initializer_range
=
0.5
,
# Set crazy init range so we don't have near zero weights implying a vacuous test.
)
config
=
llama_config_to_gpt2_config
(
llama_config
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
# Create a shared test model.
if
rank
==
0
:
LlamaForCausalLM
(
config
=
llama_config
).
save_pretrained
(
checkpoint_path
/
f
"
{
model_name
}
-hf"
)
torch
.
distributed
.
barrier
()
# Run the standard forward pass test.
pretrained_state_dict
=
_pretrained_state_dict_from_checkpoint
(
checkpoint_path
,
model_name
,
config
,
checkpoint_format
=
"hf"
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
))
model
.
eval
()
# TODO: Avoid duplicate code. Modularize the comparison of two forward pass diffs.
out
=
model
.
transformer
(
input_ids
)
out
,
_
=
all_gather_raw
(
out
,
process_group
=
process_group
)
out
=
rearrange
(
out
,
"(b s) d -> b s d"
,
b
=
batch_size
)
logits
=
model
(
input_ids
).
logits
logits
=
rearrange
(
logits
,
"(b s) d -> b s d"
,
b
=
batch_size
)
logits
,
_
=
all_gather_raw
(
logits
,
process_group
)
logits
=
rearrange
(
logits
,
"(n b) ... d -> b ... (n d)"
,
b
=
batch_size
)
if
rank
==
0
:
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
device_map
=
{
""
:
device
}
)
model_ref
=
model_ref
.
to
(
device
=
device
)
model_ref
.
eval
()
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
logits_ref
=
model_ref
(
input_ids
).
logits
del
model_ref
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
}
)
model_hf
.
eval
()
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_hf
=
model_hf
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_hf
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
2
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
if
os
.
path
.
exists
(
checkpoint_path
/
f
"
{
model_name
}
-hf"
):
shutil
.
rmtree
(
checkpoint_path
/
f
"
{
model_name
}
-hf"
)
tests/models/test_opt.py
deleted
100644 → 0
View file @
344c988d
import
re
import
time
import
pytest
import
torch
from
einops
import
rearrange
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.opt
import
opt_config_to_gpt2_config
,
remap_state_dict_hf_opt
from
flash_attn.utils.generation
import
update_graph_cache
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
transformers
import
AutoTokenizer
,
OPTConfig
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"facebook/opt-125m"
,
"facebook/opt-350m"
,
"facebook/opt-1.3b"
]
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def
test_opt_state_dict
(
model_name
):
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
pretrained_state_dict
=
remap_state_dict_hf_opt
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
)
state_dict
=
model
.
state_dict
()
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
for
k
in
state_dict
.
keys
():
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"facebook/opt-125m"
,
"facebook/opt-350m"
,
"facebook/opt-1.3b"
]
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def
test_opt_optimized
(
model_name
):
"""Check that our implementation of OPT (without all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
device
=
"cuda"
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
# Only prenorm supports residual_in_fp32
config
.
residual_in_fp32
=
getattr
(
config
,
"prenorm"
,
True
)
config
.
pad_vocab_size_multiple
=
8
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model_ref
=
OPTForCausalLM
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_hf
=
OPTForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
).
to
(
device
=
device
)
model
.
eval
()
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
"cuda"
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
"cuda"
)
if
model_name
!=
"facebook/opt-350m"
:
# The OPT-350m projects the embeddings to dimension 512
out
=
model
.
transformer
(
input_ids
)
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
logits
=
model
(
input_ids
).
logits
logits_hf
=
model_hf
(
input_ids
).
logits
logits_ref
=
model_ref
(
input_ids
).
logits
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"facebook/opt-125m"
,
"facebook/opt-350m"
,
"facebook/opt-1.3b"
,
"facebook/opt-2.7b"
,
"facebook/opt-6.7b"
,
],
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
def
test_opt_generation
(
model_name
):
"""Check that our implementation of OPT generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
print
(
f
"
\n
MODEL:
{
model_name
}
"
)
verbose
=
False
dtype
=
torch
.
float16
device
=
"cuda"
rtol
,
atol
=
3e-3
,
3e-1
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
# Only prenorm supports residual_in_fp32
config
.
residual_in_fp32
=
getattr
(
config
,
"prenorm"
,
True
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
torch
.
manual_seed
(
0
)
# OPT tokenizer requires use_fast=False
# https://huggingface.co/docs/transformers/model_doc/opt
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
)
eos_token_id
=
tokenizer
.
eos_token_id
input_ids
=
tokenizer
(
"Hello, my dog is cute and he"
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences
=
[]
scores
=
[]
cur_input_ids
=
input_ids
with
torch
.
inference_mode
():
scores
.
append
(
model
(
cur_input_ids
).
logits
[:,
-
1
])
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
for
_
in
range
(
input_ids
.
shape
[
1
]
+
1
,
max_length
):
cur_input_ids
=
torch
.
cat
([
cur_input_ids
,
rearrange
(
sequences
[
-
1
],
"b -> b 1"
)],
dim
=-
1
)
scores
.
append
(
model
(
cur_input_ids
).
logits
[:,
-
1
])
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
if
eos_token_id
is
not
None
and
(
sequences
[
-
1
]
==
eos_token_id
).
all
():
break
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
)
scores
=
tuple
(
scores
)
print
(
"Without CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
if
verbose
:
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
if
getattr
(
config
,
"use_flash_attn"
,
False
):
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
if
verbose
:
print
(
out_cg
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out_cg
.
sequences
.
tolist
()))
del
model
model_hf
=
OPTForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
).
to
(
device
=
device
)
model_hf
.
eval
()
print
(
"HF fp16"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
del
model_hf
model_ref
=
OPTForCausalLM
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_ref
.
eval
()
print
(
"HF fp32"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
del
model_ref
print
(
tokenizer
.
batch_decode
(
out_ref
.
sequences
.
tolist
()))
if
verbose
:
print
(
f
"Scores max diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
"
)
print
(
f
"Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
"
)
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)
).
abs
().
max
().
item
()
tests/models/test_vit.py
deleted
100644 → 0
View file @
344c988d
import
re
import
pytest
import
torch
from
flash_attn.models.vit
import
vit_base_patch16_224
as
flash_vit_base_patch16_224
from
timm.models.vision_transformer
import
vit_base_patch16_224
@
pytest
.
mark
.
parametrize
(
"fused_mlp"
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_mlp', [False])
@
pytest
.
mark
.
parametrize
(
"optimized"
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [True])
def
test_vit
(
optimized
,
fused_mlp
):
"""Check that our implementation of ViT matches the timm's implementation:
the output of our forward pass in fp16 should be around the same as
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
"""
dtype
=
torch
.
float16
device
=
"cuda"
kwargs
=
{}
if
optimized
:
kwargs
=
dict
(
use_flash_attn
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
)
kwargs
[
"fused_mlp"
]
=
fused_mlp
model
=
flash_vit_base_patch16_224
(
**
kwargs
).
to
(
device
=
device
,
dtype
=
dtype
)
model_ref
=
vit_base_patch16_224
(
pretrained
=
True
).
to
(
device
=
device
)
model_timm
=
vit_base_patch16_224
(
pretrained
=
True
).
to
(
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
model_ref
.
state_dict
())
model
.
eval
()
model_ref
.
eval
()
model_timm
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
x
=
torch
.
randn
(
batch_size
,
3
,
224
,
224
,
device
=
device
,
dtype
=
dtype
)
out
=
model
(
x
)
out_timm
=
model_timm
(
x
)
out_ref
=
model_ref
(
x
.
float
())
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"timm fp16 max diff:
{
(
out_timm
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"timm fp16 mean diff:
{
(
out_timm
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
rtol
=
2
if
not
fused_mlp
else
8
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
rtol
*
(
out_timm
-
out_ref
).
abs
().
max
().
item
()
tests/modules/test_block_parallel.py
deleted
100644 → 0
View file @
344c988d
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_block_parallel.py
import
math
from
functools
import
partial
import
pytest
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
apex.transformer
import
parallel_state
,
tensor_parallel
from
einops
import
rearrange
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mlp
import
FusedMLP
,
ParallelFusedMLP
from
flash_attn.utils.distributed
import
allreduce_sequence_parallel_grad
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
"sequence_parallel"
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [True])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
1024
])
def
test_block_parallel
(
dim
,
sequence_parallel
,
world_size
,
dtype
):
head_dim
=
64
assert
dim
%
head_dim
==
0
num_heads
=
dim
//
head_dim
assert
num_heads
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
5e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
)
device
=
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
seqlen
=
1024
assert
(
batch_size
*
seqlen
)
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
dim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
residual_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
dim
,
device
=
device
,
requires_grad
=
True
)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
x_pt
)
/
32
if
sequence_parallel
:
x
=
(
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
)
.
detach
()
.
clone
()
.
requires_grad_
()
)
residual
=
(
tensor_parallel
.
scatter_to_sequence_parallel_region
(
residual_pt
)
.
detach
()
.
clone
()
.
requires_grad_
()
)
else
:
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
residual
=
residual_pt
.
detach
().
clone
().
requires_grad_
()
mixer_cls_pt
=
partial
(
MHA
,
num_heads
=
num_heads
,
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
device
=
device
,
dtype
=
dtype
,
)
mlp_cls_pt
=
partial
(
FusedMLP
,
hidden_features
=
4
*
dim
,
device
=
device
,
dtype
=
dtype
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
device
=
device
,
dtype
=
dtype
)
model_pt
=
Block
(
dim
,
mixer_cls_pt
,
mlp_cls_pt
,
norm_cls
,
fused_dropout_add_ln
=
True
)
with
torch
.
no_grad
():
nn
.
init
.
normal_
(
model_pt
.
norm1
.
weight
)
nn
.
init
.
normal_
(
model_pt
.
norm1
.
bias
)
nn
.
init
.
normal_
(
model_pt
.
norm2
.
weight
)
nn
.
init
.
normal_
(
model_pt
.
norm2
.
bias
)
mixer_cls
=
partial
(
ParallelMHA
,
num_heads
=
num_heads
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
,
)
mlp_cls
=
partial
(
ParallelFusedMLP
,
hidden_features
=
4
*
dim
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
,
)
model
=
Block
(
dim
,
mixer_cls
,
mlp_cls
,
norm_cls
,
fused_dropout_add_ln
=
True
,
sequence_parallel
=
sequence_parallel
,
mark_shared_params
=
True
,
)
partition_dim
=
dim
//
world_size
partition_hidden_dim
=
4
*
dim
//
world_size
with
torch
.
no_grad
():
model
.
mixer
.
Wqkv
.
weight
.
copy_
(
rearrange
(
rearrange
(
model_pt
.
mixer
.
Wqkv
.
weight
,
"(three o) i -> three o i"
,
three
=
3
)[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
"three o i -> (three o) i"
,
)
)
model
.
mixer
.
Wqkv
.
bias
.
copy_
(
rearrange
(
rearrange
(
model_pt
.
mixer
.
Wqkv
.
bias
,
"(three o) -> three o"
,
three
=
3
)[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
"three o -> (three o)"
,
)
)
model
.
mixer
.
out_proj
.
weight
.
copy_
(
model_pt
.
mixer
.
out_proj
.
weight
[:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
]
)
if
rank
==
0
:
model
.
mixer
.
out_proj
.
bias
.
copy_
(
model_pt
.
mixer
.
out_proj
.
bias
)
model
.
mlp
.
fc1
.
weight
.
copy_
(
model_pt
.
mlp
.
fc1
.
weight
[
rank
*
partition_hidden_dim
:
(
rank
+
1
)
*
partition_hidden_dim
]
)
model
.
mlp
.
fc1
.
bias
.
copy_
(
model_pt
.
mlp
.
fc1
.
bias
[
rank
*
partition_hidden_dim
:
(
rank
+
1
)
*
partition_hidden_dim
]
)
model
.
mlp
.
fc2
.
weight
.
copy_
(
model_pt
.
mlp
.
fc2
.
weight
[
:,
rank
*
partition_hidden_dim
:
(
rank
+
1
)
*
partition_hidden_dim
]
)
if
rank
==
0
:
model
.
mlp
.
fc2
.
bias
.
copy_
(
model_pt
.
mlp
.
fc2
.
bias
)
model
.
norm1
.
weight
.
copy_
(
model_pt
.
norm1
.
weight
)
model
.
norm1
.
bias
.
copy_
(
model_pt
.
norm1
.
bias
)
model
.
norm2
.
weight
.
copy_
(
model_pt
.
norm2
.
weight
)
model
.
norm2
.
bias
.
copy_
(
model_pt
.
norm2
.
bias
)
mixer_kwargs
=
{
"seqlen"
:
seqlen
}
out
,
out_residual
=
model
(
x
,
residual
,
mixer_kwargs
=
mixer_kwargs
)
out_pt
,
out_residual_pt
=
model_pt
(
rearrange
(
x_pt
,
"(b s) d -> b s d"
,
s
=
seqlen
),
rearrange
(
residual_pt
,
"(b s) d -> b s d"
,
s
=
seqlen
),
)
out_pt
,
out_residual_pt
=
[
rearrange
(
x
,
"b s d -> (b s) d"
)
for
x
in
[
out_pt
,
out_residual_pt
]]
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
out_pt
,
rtol
=
rtol
,
atol
=
atol
,
)
assert
torch
.
allclose
(
out_residual
,
out_residual_pt
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
out_residual_pt
,
rtol
=
rtol
,
atol
=
atol
,
)
(
out_pt
+
2
*
out_residual_pt
).
backward
(
g
)
(
out
+
2
*
out_residual
).
backward
(
g
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
g
)
allreduce_sequence_parallel_grad
(
model
,
parallel_state
.
get_tensor_model_parallel_group
())
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
/
10
,
# magnitude of x.grad is quite small
)
assert
torch
.
allclose
(
residual
.
grad
,
residual_pt
.
grad
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
residual_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
,
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
mixer
.
Wqkv
.
weight
.
grad
,
rearrange
(
rearrange
(
model_pt
.
mixer
.
Wqkv
.
weight
.
grad
,
"(three o) i -> three o i"
,
three
=
3
)[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
"three o i -> (three o) i"
,
),
rtol
=
rtol
,
atol
=
atol
*
10
,
)
assert
torch
.
allclose
(
model
.
mixer
.
Wqkv
.
bias
.
grad
,
rearrange
(
rearrange
(
model_pt
.
mixer
.
Wqkv
.
bias
.
grad
,
"(three o) -> three o"
,
three
=
3
)[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
"three o -> (three o)"
,
),
rtol
=
rtol
,
atol
=
atol
*
5
,
)
assert
torch
.
allclose
(
model
.
mixer
.
out_proj
.
weight
.
grad
,
model_pt
.
mixer
.
out_proj
.
weight
.
grad
[:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
rtol
=
rtol
,
atol
=
atol
*
10
,
)
if
rank
==
0
:
assert
torch
.
allclose
(
model
.
mixer
.
out_proj
.
bias
.
grad
,
model_pt
.
mixer
.
out_proj
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
,
)
assert
torch
.
allclose
(
model
.
mlp
.
fc1
.
weight
.
grad
,
model_pt
.
mlp
.
fc1
.
weight
.
grad
[
rank
*
partition_hidden_dim
:
(
rank
+
1
)
*
partition_hidden_dim
],
rtol
=
rtol
,
atol
=
atol
*
10
,
)
assert
torch
.
allclose
(
model
.
mlp
.
fc1
.
bias
.
grad
,
model_pt
.
mlp
.
fc1
.
bias
.
grad
[
rank
*
partition_hidden_dim
:
(
rank
+
1
)
*
partition_hidden_dim
],
rtol
=
rtol
,
atol
=
atol
*
5
,
)
assert
torch
.
allclose
(
model
.
mlp
.
fc2
.
weight
.
grad
,
model_pt
.
mlp
.
fc2
.
weight
.
grad
[
:,
rank
*
partition_hidden_dim
:
(
rank
+
1
)
*
partition_hidden_dim
],
rtol
=
rtol
,
atol
=
atol
*
10
,
)
if
rank
==
0
:
assert
torch
.
allclose
(
model
.
mlp
.
fc2
.
bias
.
grad
,
model_pt
.
mlp
.
fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
norm1
.
weight
.
grad
,
model_pt
.
norm1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
norm1
.
bias
.
grad
,
model_pt
.
norm1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
norm2
.
weight
.
grad
,
model_pt
.
norm2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
norm2
.
bias
.
grad
,
model_pt
.
norm2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
tests/modules/test_embedding_parallel.py
deleted
100644 → 0
View file @
344c988d
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py
import
pytest
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
apex.transformer
import
parallel_state
from
einops
import
rearrange
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
"sequence_parallel"
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [False])
@
pytest
.
mark
.
parametrize
(
"has_pos_emb"
,
[
True
,
False
])
# @pytest.mark.parametrize('has_pos_emb', [True])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
1024
])
def
test_embedding_parallel
(
dim
,
has_pos_emb
,
sequence_parallel
,
world_size
,
dtype
):
vocab_size
=
50264
seqlen
=
2048
assert
vocab_size
%
world_size
==
0
assert
dim
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
5e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
)
device
=
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
1024
assert
(
batch_size
*
seqlen
)
%
world_size
==
0
input_ids_pt
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seqlen
),
device
=
device
)
input_ids
=
input_ids_pt
.
detach
().
clone
()
model_pt
=
GPT2Embeddings
(
dim
,
vocab_size
,
seqlen
if
has_pos_emb
else
0
,
device
=
device
,
dtype
=
dtype
)
model
=
ParallelGPT2Embeddings
(
dim
,
vocab_size
,
seqlen
if
has_pos_emb
else
0
,
parallel_state
.
get_tensor_model_parallel_group
(),
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
,
)
partition_vocab_size
=
vocab_size
//
world_size
partition_dim
=
dim
//
world_size
with
torch
.
no_grad
():
model
.
word_embeddings
.
weight
.
copy_
(
model_pt
.
word_embeddings
.
weight
[
rank
*
partition_vocab_size
:
(
rank
+
1
)
*
partition_vocab_size
]
)
if
has_pos_emb
:
model
.
position_embeddings
.
weight
.
copy_
(
model_pt
.
position_embeddings
.
weight
[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
]
)
out
=
model
(
input_ids
,
combine_batch_seqlen_dim
=
True
)
out_pt
=
rearrange
(
model_pt
(
input_ids
),
"b s d -> (b s) d"
)
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
out_pt
,
rtol
=
rtol
,
atol
=
atol
,
)
g
=
torch
.
randn_like
(
out_pt
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
g
)
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
model
.
word_embeddings
.
weight
.
grad
,
model_pt
.
word_embeddings
.
weight
.
grad
[
rank
*
partition_vocab_size
:
(
rank
+
1
)
*
partition_vocab_size
],
rtol
=
rtol
,
atol
=
atol
,
)
if
has_pos_emb
:
assert
torch
.
allclose
(
model
.
position_embeddings
.
weight
.
grad
,
model_pt
.
position_embeddings
.
weight
.
grad
[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
rtol
=
rtol
,
atol
=
atol
,
)
tests/modules/test_mha_parallel.py
deleted
100644 → 0
View file @
344c988d
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mha_parallel.py
import
math
import
pytest
import
torch
import
torch.nn.functional
as
F
from
apex.transformer
import
parallel_state
,
tensor_parallel
from
einops
import
rearrange
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
"sequence_parallel"
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [False])
@
pytest
.
mark
.
parametrize
(
"head_dim"
,
[
64
,
128
])
# @pytest.mark.parametrize('head_dim', [64])
@
pytest
.
mark
.
parametrize
(
"embed_dim"
,
[
1024
,
4096
])
# @pytest.mark.parametrize('embed_dim', [1024])
def
test_mha_parallel
(
embed_dim
,
head_dim
,
sequence_parallel
,
world_size
,
dtype
):
assert
embed_dim
%
head_dim
==
0
num_heads
=
embed_dim
//
head_dim
assert
num_heads
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
)
device
=
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
seqlen
=
1024
assert
(
batch_size
*
seqlen
)
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
embed_dim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
x_pt
)
/
32
if
sequence_parallel
:
x
=
(
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
)
.
detach
()
.
clone
()
.
requires_grad_
()
)
else
:
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
MHA
(
embed_dim
,
num_heads
,
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
device
=
device
,
dtype
=
dtype
,
)
partition_dim
=
embed_dim
//
world_size
model
=
ParallelMHA
(
embed_dim
,
num_heads
,
parallel_state
.
get_tensor_model_parallel_group
(),
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
,
)
with
torch
.
no_grad
():
model
.
Wqkv
.
weight
.
copy_
(
rearrange
(
rearrange
(
model_pt
.
Wqkv
.
weight
,
"(three o) i -> three o i"
,
three
=
3
)[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
"three o i -> (three o) i"
,
)
)
model
.
Wqkv
.
bias
.
copy_
(
rearrange
(
rearrange
(
model_pt
.
Wqkv
.
bias
,
"(three o) -> three o"
,
three
=
3
)[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
"three o -> (three o)"
,
)
)
model
.
out_proj
.
weight
.
copy_
(
model_pt
.
out_proj
.
weight
[:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
]
)
if
rank
==
0
:
model
.
out_proj
.
bias
.
copy_
(
model_pt
.
out_proj
.
bias
)
out
=
model
(
x
,
seqlen
=
seqlen
)
out_pt
=
rearrange
(
model_pt
(
rearrange
(
x_pt
,
"(b s) d -> b s d"
,
s
=
seqlen
)),
"b s d -> (b s) d"
)
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
out_pt
,
rtol
=
rtol
,
atol
=
atol
,
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
g
)
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
/
100
,
# magnitude of x.grad is quite small
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
Wqkv
.
weight
.
grad
,
rearrange
(
rearrange
(
model_pt
.
Wqkv
.
weight
.
grad
,
"(three o) i -> three o i"
,
three
=
3
)[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
"three o i -> (three o) i"
,
),
rtol
=
rtol
,
atol
=
atol
*
10
,
)
assert
torch
.
allclose
(
model
.
Wqkv
.
bias
.
grad
,
rearrange
(
rearrange
(
model_pt
.
Wqkv
.
bias
.
grad
,
"(three o) -> three o"
,
three
=
3
)[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
"three o -> (three o)"
,
),
rtol
=
rtol
,
atol
=
atol
*
5
,
)
assert
torch
.
allclose
(
model
.
out_proj
.
weight
.
grad
,
model_pt
.
out_proj
.
weight
.
grad
[:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
rtol
=
rtol
,
atol
=
atol
*
10
,
)
if
rank
==
0
:
assert
torch
.
allclose
(
model
.
out_proj
.
bias
.
grad
,
model_pt
.
out_proj
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
tests/modules/test_mlp_parallel.py
deleted
100644 → 0
View file @
344c988d
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py
import
pytest
import
torch
import
torch.nn.functional
as
F
from
apex.transformer
import
parallel_state
,
tensor_parallel
from
einops
import
rearrange
from
flash_attn.modules.mlp
import
GatedMlp
,
ParallelGatedMlp
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
"sequence_parallel"
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [False])
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
F
.
silu
,
F
.
sigmoid
])
# @pytest.mark.parametrize('activation', [F.silu])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
1024
,
4096
])
# @pytest.mark.parametrize('dim', [1024])
def
test_mlp_parallel
(
dim
,
activation
,
sequence_parallel
,
world_size
,
dtype
):
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
)
device
=
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
seqlen
=
1024
assert
(
batch_size
*
seqlen
)
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
dim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
x_pt
)
/
32
if
sequence_parallel
:
x
=
(
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
)
.
detach
()
.
clone
()
.
requires_grad_
()
)
else
:
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
GatedMlp
(
dim
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
)
partition_dim
=
model_pt
.
fc1
.
weight
.
shape
[
0
]
//
2
//
world_size
model
=
ParallelGatedMlp
(
dim
,
parallel_state
.
get_tensor_model_parallel_group
(),
activation
=
activation
,
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
,
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
rearrange
(
rearrange
(
model_pt
.
fc1
.
weight
,
"(two o) i -> two o i"
,
two
=
2
)[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
"two o i -> (two o) i"
,
)
)
model
.
fc1
.
bias
.
copy_
(
rearrange
(
rearrange
(
model_pt
.
fc1
.
bias
,
"(two o) -> two o"
,
two
=
2
)[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
"two o -> (two o)"
,
)
)
model
.
fc2
.
weight
.
copy_
(
model_pt
.
fc2
.
weight
[:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
]
)
if
rank
==
0
:
model
.
fc2
.
bias
.
copy_
(
model_pt
.
fc2
.
bias
)
out
=
model
(
x
)
out_pt
=
model_pt
(
x_pt
)
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
out_pt
,
rtol
=
rtol
,
atol
=
atol
,
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
g
)
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
,
)
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
rearrange
(
rearrange
(
model_pt
.
fc1
.
weight
.
grad
,
"(two o) i -> two o i"
,
two
=
2
)[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
"two o i -> (two o) i"
,
),
rtol
=
rtol
,
atol
=
atol
,
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
rearrange
(
rearrange
(
model_pt
.
fc1
.
bias
.
grad
,
"(two o) -> two o"
,
two
=
2
)[
:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
"two o -> (two o)"
,
),
rtol
=
rtol
,
atol
=
atol
,
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt
.
fc2
.
weight
.
grad
[:,
rank
*
partition_dim
:
(
rank
+
1
)
*
partition_dim
],
rtol
=
rtol
,
atol
=
atol
,
)
if
rank
==
0
:
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt
.
fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
tests/ops/test_dropout_layer_norm.py
deleted
100644 → 0
View file @
344c988d
import
math
import
pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
flash_attn.ops.layer_norm
import
(
DropoutAddLayerNorm
,
dropout_add_layer_norm
,
dropout_add_layer_norm_parallel_residual
,
dropout_add_layer_norm_subset
,
)
from
flash_attn.ops.rms_norm
import
(
DropoutAddRMSNorm
,
dropout_add_rms_norm
,
dropout_add_rms_norm_parallel_residual
,
dropout_add_rms_norm_subset
,
)
try
:
from
apex.normalization
import
FusedRMSNorm
from
apex.normalization.fused_layer_norm
import
fused_rms_norm_affine
except
:
FusedRMSNorm
,
fused_rms_norm_affine
=
None
,
None
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_colscale"
,
[
True
,
False
])
# @pytest.mark.parametrize('has_colscale', [False])
@
pytest
.
mark
.
parametrize
(
"has_rowscale"
,
[
True
,
False
])
# @pytest.mark.parametrize('has_rowscale', [True])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
# @pytest.mark.parametrize('has_residual', [False])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.37
,
0.0
])
# @pytest.mark.parametrize('dropout_p', [0.0])
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
],
)
# @pytest.mark.parametrize('hidden_size', [256])
def
test_dropout_layer_norm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
,
is_rms_norm
,
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
if
is_rms_norm
and
FusedRMSNorm
is
None
:
pytest
.
skip
()
# We need Apex's FusedRMSNorm to test
layer_norm_cls
=
torch
.
nn
.
LayerNorm
if
not
is_rms_norm
else
FusedRMSNorm
our_layer_norm_cls
=
DropoutAddLayerNorm
if
not
is_rms_norm
else
DropoutAddRMSNorm
our_layer_norm_func
=
dropout_add_layer_norm
if
not
is_rms_norm
else
dropout_add_rms_norm
device
=
"cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
1e-4
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x0_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
if
has_colscale
:
colscale
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
colscale_pt
=
colscale
.
detach
().
clone
().
requires_grad_
()
colscale_ref
=
colscale
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
colscale
=
None
if
has_residual
:
res_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res_pt
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
res
=
None
if
has_rowscale
:
rowscale
=
torch
.
empty
(
batch_size
,
seqlen
,
device
=
device
,
dtype
=
input_dtype
)
survival_rate
=
0.87
rowscale
=
rowscale
.
bernoulli_
(
survival_rate
)
/
survival_rate
x0_scaled_pt
=
x0_pt
*
rearrange
(
rowscale
,
"... -> ... 1"
)
x0_scaled_ref
=
x0_ref
*
rearrange
(
rowscale
,
"... -> ... 1"
)
else
:
rowscale
=
None
x0_scaled_pt
=
x0_pt
x0_scaled_ref
=
x0_ref
if
has_colscale
:
x0_scaled_pt
=
x0_scaled_pt
*
colscale_pt
x0_scaled_ref
=
x0_scaled_ref
*
colscale_ref
model_pt
=
layer_norm_cls
(
hidden_size
).
to
(
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
if
not
is_rms_norm
:
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model_ref
=
layer_norm_cls
(
hidden_size
).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
model
=
our_layer_norm_cls
(
hidden_size
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
if
not
is_rms_norm
:
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
dmask
=
our_layer_norm_func
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
eps
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
,
)
assert
out
.
dtype
==
input_dtype
print
(
f
"Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
"
)
if
has_residual
:
residual_pt
=
(
(
x0_scaled_pt
.
float
()
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
res_pt
.
float
()
).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
res_ref
else
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask
.
float
())
/
(
1
-
dropout_p
)
out_pt
=
model_pt
(
residual_pt
.
to
(
dtype
=
weight_dtype
)).
to
(
dtype
=
input_dtype
)
out_ref
=
model_ref
(
residual_ref
)
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
g
=
torch
.
randn_like
(
out
)
/
batch_size
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
out_ref
.
backward
(
g
)
assert
(
x0
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
res
.
grad
-
res_ref
.
grad
).
abs
().
max
()
<=
4
*
(
res_pt
.
grad
-
res_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
3
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
3e-5
if
not
is_rms_norm
:
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
3e-5
if
has_colscale
:
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3072
,
4096
,
5120
])
def
test_dropout_layer_norm_eval
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
device
=
"cuda"
# rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
1e-4
)
dropout_p
=
0.37
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
32
seqlen
=
512
x0_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
res_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res_pt
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res_pt
.
detach
().
clone
().
float
().
requires_grad_
()
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
model_pt
.
eval
()
model
.
eval
()
model_ref
.
eval
()
out
=
model
(
x0
,
res
)
residual_pt
=
(
x0_pt
.
float
()
+
res_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
x0_ref
+
res_ref
out_pt
=
model_pt
(
residual_pt
.
to
(
dtype
=
weight_dtype
)).
to
(
input_dtype
)
out_ref
=
model_ref
(
residual_ref
)
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_colscale"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_rowscale"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.37
,
0.0
])
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
],
)
# @pytest.mark.parametrize('hidden_size', [256])
def
test_dropout_layer_norm_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_rowscale
,
has_colscale
,
is_rms_norm
,
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
if
is_rms_norm
and
FusedRMSNorm
is
None
:
pytest
.
skip
()
# We need Apex's FusedRMSNorm to test
layer_norm_cls
=
torch
.
nn
.
LayerNorm
if
not
is_rms_norm
else
FusedRMSNorm
our_layer_norm_cls
=
DropoutAddLayerNorm
if
not
is_rms_norm
else
DropoutAddRMSNorm
our_layer_norm_func
=
dropout_add_layer_norm
if
not
is_rms_norm
else
dropout_add_rms_norm
device
=
"cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
2e-4
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x0_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
if
has_colscale
:
colscale
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
colscale_pt
=
colscale
.
detach
().
clone
().
requires_grad_
()
colscale_ref
=
colscale
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
colscale
=
None
if
has_residual
:
res_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res_pt
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
res
=
None
if
has_rowscale
:
rowscale
=
torch
.
empty
(
batch_size
,
seqlen
,
device
=
device
,
dtype
=
input_dtype
)
survival_rate
=
0.87
rowscale
=
rowscale
.
bernoulli_
(
survival_rate
)
/
survival_rate
x0_scaled_pt
=
x0_pt
*
rearrange
(
rowscale
,
"... -> ... 1"
)
x0_scaled_ref
=
x0_ref
*
rearrange
(
rowscale
,
"... -> ... 1"
)
else
:
rowscale
=
None
x0_scaled_pt
=
x0_pt
x0_scaled_ref
=
x0_ref
if
has_colscale
:
x0_scaled_pt
=
x0_scaled_pt
*
colscale_pt
x0_scaled_ref
=
x0_scaled_ref
*
colscale_ref
model_pt
=
layer_norm_cls
(
hidden_size
).
to
(
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
if
not
is_rms_norm
:
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model_ref
=
layer_norm_cls
(
hidden_size
).
to
(
device
=
device
,
dtype
=
torch
.
float32
)
model
=
our_layer_norm_cls
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
if
not
is_rms_norm
:
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
residual
,
dmask
=
our_layer_norm_func
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
eps
,
rowscale
=
rowscale
,
layerscale
=
colscale
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
,
)
print
(
f
"Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
"
)
if
has_residual
:
residual_pt
=
(
(
x0_scaled_pt
.
float
()
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
res_pt
.
float
()
).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask
.
float
())
/
(
1
-
dropout_p
)
+
res_ref
else
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask
.
float
())
/
(
1
-
dropout_p
)
out_pt
=
model_pt
(
residual_pt
.
to
(
dtype
=
weight_dtype
)).
to
(
dtype
=
input_dtype
)
out_ref
=
model_ref
(
residual_ref
)
assert
out
.
dtype
==
input_dtype
assert
residual
.
dtype
==
residual_dtype
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
assert
(
residual
-
residual_ref
).
abs
().
max
()
<=
4
*
(
residual_pt
-
residual_ref
).
abs
().
max
()
+
1e-4
g
=
torch
.
randn_like
(
out
)
/
batch_size
(
out_pt
*
F
.
sigmoid
(
residual_pt
)).
backward
(
g
)
(
out
*
F
.
sigmoid
(
residual
)).
backward
(
g
)
(
out_ref
*
F
.
sigmoid
(
residual_ref
.
to
(
dtype
=
residual_dtype
))).
backward
(
g
)
assert
(
x0
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
res
.
grad
-
res_ref
.
grad
).
abs
().
max
()
<=
4
*
(
res_pt
.
grad
-
res_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
2e-4
if
not
is_rms_norm
:
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
if
has_colscale
:
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3072
,
4096
,
5120
])
def
test_dropout_layer_norm_prenorm_eval
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
device
=
"cuda"
# rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
1e-4
)
dropout_p
=
0.37
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
32
seqlen
=
512
x0_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
res_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res_pt
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res_pt
.
detach
().
clone
().
float
().
requires_grad_
()
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
model_pt
.
eval
()
model
.
eval
()
model_ref
.
eval
()
out
,
residual
=
model
(
x0
,
res
)
residual_pt
=
(
x0_pt
.
float
()
+
res_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
x0_ref
+
res_ref
out_pt
=
model_pt
(
residual_pt
.
to
(
dtype
=
weight_dtype
)).
to
(
input_dtype
)
out_ref
=
model_ref
(
residual_ref
)
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
assert
(
residual
-
residual_ref
).
abs
().
max
()
<=
4
*
(
residual_pt
-
residual_ref
).
abs
().
max
()
+
1e-4
@
pytest
.
mark
.
parametrize
(
"has_colscale"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.37
,
0.0
])
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
],
)
# @pytest.mark.parametrize('hidden_size', [256])
def
test_dropout_layer_norm_subset_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_colscale
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
device
=
"cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
2e-4
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
drop_path_rate
=
0.4
drop_path_scale
=
1
/
(
1
-
drop_path_rate
)
def
generate_droppath_masks
(
batch_size
,
seqlen
,
drop_path_rate
,
device
):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch
=
torch
.
rand
(
batch_size
)
<
1
-
drop_path_rate
numrows
=
(
mask_batch
).
sum
().
item
()
*
seqlen
mask_batch
=
mask_batch
.
to
(
device
=
device
,
non_blocking
=
True
)
mask_batch_seqlen
=
repeat
(
mask_batch
,
"b -> (b s)"
,
s
=
seqlen
)
subset
=
torch
.
cumsum
(
mask_batch_seqlen
,
dim
=
0
,
dtype
=
torch
.
int32
).
masked_fill_
(
~
mask_batch_seqlen
,
0
)
return
mask_batch
,
numrows
,
rearrange
(
subset
,
"(b s) -> b s"
,
b
=
batch_size
)
x0_mask_batch
,
x0_numrows
,
x0_subset
=
generate_droppath_masks
(
batch_size
,
seqlen
,
drop_path_rate
,
device
)
out_mask_batch
,
out_numrows
,
out_subset
=
generate_droppath_masks
(
batch_size
,
seqlen
,
drop_path_rate
,
device
)
x0_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
()[
x0_mask_batch
].
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
if
has_colscale
:
colscale
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
colscale_pt
=
colscale
.
detach
().
clone
().
requires_grad_
()
colscale_ref
=
colscale
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
colscale
=
None
if
has_residual
:
res_pt
=
torch
.
randn_like
(
x0_pt
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res_pt
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
res
=
None
if
has_colscale
:
x0_scaled_pt
=
x0_pt
*
colscale_pt
x0_scaled_ref
=
x0_ref
*
colscale_ref
else
:
x0_scaled_pt
=
x0_pt
x0_scaled_ref
=
x0_ref
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
prenorm
=
False
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
dmask
=
dropout_add_layer_norm_subset
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
eps
,
layerscale
=
colscale
,
x0_subset
=
x0_subset
,
out_subset
=
out_subset
,
rowscale_const
=
drop_path_scale
,
out_numrows
=
out_numrows
,
prenorm
=
False
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
,
)
print
(
f
"Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
"
)
x0_scaled_pt
=
(
x0_scaled_pt
.
masked_fill
(
repeat
(
~
x0_mask_batch
,
"b -> b s d"
,
s
=
seqlen
,
d
=
hidden_size
),
0
)
*
drop_path_scale
)
x0_scaled_ref
=
(
x0_scaled_ref
.
masked_fill
(
repeat
(
~
x0_mask_batch
,
"b -> b s d"
,
s
=
seqlen
,
d
=
hidden_size
),
0
)
*
drop_path_scale
)
dmask_expanded
=
torch
.
zeros_like
(
x0_pt
,
dtype
=
torch
.
uint8
)
dmask_expanded
[
x0_mask_batch
]
=
dmask
if
has_residual
:
residual_pt
=
(
(
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
res_pt
.
float
()
).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
res_ref
else
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
out_pt
=
model_pt
(
residual_pt
.
to
(
dtype
=
weight_dtype
)).
to
(
dtype
=
input_dtype
)[
out_mask_batch
]
out_ref
=
model_ref
(
residual_ref
)[
out_mask_batch
]
assert
out
.
dtype
==
input_dtype
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
g
=
torch
.
randn_like
(
out
)
/
batch_size
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
out_ref
.
backward
(
g
)
assert
(
x0
.
grad
-
x0_ref
.
grad
[
x0_mask_batch
]).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
)[
x0_mask_batch
].
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
res
.
grad
-
res_ref
.
grad
).
abs
().
max
()
<=
4
*
(
res_pt
.
grad
-
res_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
2e-4
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
if
has_colscale
:
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
"has_colscale"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.37
,
0.0
])
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
],
)
# @pytest.mark.parametrize('hidden_size', [256])
def
test_dropout_layer_norm_subset_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_residual
,
has_colscale
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
device
=
"cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
2e-4
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
drop_path_rate
=
0.4
drop_path_scale
=
1
/
(
1
-
drop_path_rate
)
def
generate_droppath_masks
(
batch_size
,
seqlen
,
drop_path_rate
,
device
):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch
=
torch
.
rand
(
batch_size
)
<
1
-
drop_path_rate
numrows
=
(
mask_batch
).
sum
().
item
()
*
seqlen
mask_batch
=
mask_batch
.
to
(
device
=
device
,
non_blocking
=
True
)
mask_batch_seqlen
=
repeat
(
mask_batch
,
"b -> (b s)"
,
s
=
seqlen
)
subset
=
torch
.
cumsum
(
mask_batch_seqlen
,
dim
=
0
,
dtype
=
torch
.
int32
).
masked_fill_
(
~
mask_batch_seqlen
,
0
)
return
mask_batch
,
numrows
,
rearrange
(
subset
,
"(b s) -> b s"
,
b
=
batch_size
)
x0_mask_batch
,
x0_numrows
,
x0_subset
=
generate_droppath_masks
(
batch_size
,
seqlen
,
drop_path_rate
,
device
)
out_mask_batch
,
out_numrows
,
out_subset
=
generate_droppath_masks
(
batch_size
,
seqlen
,
drop_path_rate
,
device
)
x0_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
()[
x0_mask_batch
].
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
if
has_colscale
:
colscale
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
colscale_pt
=
colscale
.
detach
().
clone
().
requires_grad_
()
colscale_ref
=
colscale
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
colscale
=
None
if
has_residual
:
res_pt
=
torch
.
randn_like
(
x0_pt
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res_pt
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
res
=
None
if
has_colscale
:
x0_scaled_pt
=
x0_pt
*
colscale_pt
x0_scaled_ref
=
x0_ref
*
colscale_ref
else
:
x0_scaled_pt
=
x0_pt
x0_scaled_ref
=
x0_ref
model_pt
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
weight
)
torch
.
nn
.
init
.
normal_
(
model_pt
.
bias
)
model_ref
=
torch
.
nn
.
LayerNorm
(
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
prenorm
=
True
,
p
=
dropout_p
,
device
=
device
,
dtype
=
weight_dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
model_ref
.
weight
.
copy_
(
model_pt
.
weight
)
model_ref
.
bias
.
copy_
(
model_pt
.
bias
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
residual
,
dmask
=
dropout_add_layer_norm_subset
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
eps
,
layerscale
=
colscale
,
x0_subset
=
x0_subset
,
out_subset
=
out_subset
,
rowscale_const
=
drop_path_scale
,
out_numrows
=
out_numrows
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
,
)
print
(
f
"Actual dropout fraction:
{
1
-
dmask
.
float
().
mean
().
item
()
}
"
)
x0_scaled_pt
=
(
x0_scaled_pt
.
masked_fill
(
repeat
(
~
x0_mask_batch
,
"b -> b s d"
,
s
=
seqlen
,
d
=
hidden_size
),
0
)
*
drop_path_scale
)
x0_scaled_ref
=
(
x0_scaled_ref
.
masked_fill
(
repeat
(
~
x0_mask_batch
,
"b -> b s d"
,
s
=
seqlen
,
d
=
hidden_size
),
0
)
*
drop_path_scale
)
dmask_expanded
=
torch
.
zeros_like
(
x0_pt
,
dtype
=
torch
.
uint8
)
dmask_expanded
[
x0_mask_batch
]
=
dmask
if
has_residual
:
residual_pt
=
(
(
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
res_pt
.
float
()
).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
+
res_ref
else
:
residual_pt
=
((
x0_scaled_pt
.
float
()
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_scaled_ref
*
dmask_expanded
.
float
())
/
(
1
-
dropout_p
)
out_pt
=
model_pt
(
residual_pt
.
to
(
dtype
=
weight_dtype
)).
to
(
dtype
=
input_dtype
)[
out_mask_batch
]
out_ref
=
model_ref
(
residual_ref
)[
out_mask_batch
]
assert
out
.
dtype
==
input_dtype
assert
residual
.
dtype
==
residual_dtype
assert
(
out
-
out_ref
).
abs
().
max
()
<=
4
*
(
out_pt
-
out_ref
).
abs
().
max
()
+
1e-4
assert
(
residual
-
residual_ref
).
abs
().
max
()
<=
4
*
(
residual_pt
-
residual_ref
).
abs
().
max
()
+
1e-4
g
=
torch
.
randn_like
(
out
)
/
batch_size
(
out_pt
*
F
.
sigmoid
(
residual_pt
[
out_mask_batch
])
+
residual_pt
.
mean
(
0
,
keepdim
=
True
)).
backward
(
g
)
(
out
*
F
.
sigmoid
(
residual
[
out_mask_batch
])
+
residual
.
mean
(
0
,
keepdim
=
True
)).
backward
(
g
)
(
out_ref
*
F
.
sigmoid
(
residual_ref
[
out_mask_batch
].
to
(
dtype
=
residual_dtype
))
+
residual_ref
.
mean
(
0
,
keepdim
=
True
)
).
backward
(
g
)
assert
(
x0
.
grad
-
x0_ref
.
grad
[
x0_mask_batch
]).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
)[
x0_mask_batch
].
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
res
.
grad
-
res_ref
.
grad
).
abs
().
max
()
<=
4
*
(
res_pt
.
grad
-
res_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
model
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
weight
.
grad
-
model_ref
.
weight
.
grad
).
abs
().
max
()
+
2e-4
assert
(
model
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
<=
2
*
(
model_pt
.
bias
.
grad
-
model_ref
.
bias
.
grad
).
abs
().
max
()
+
2e-4
if
has_colscale
:
assert
(
colscale
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
<=
2
*
(
colscale_pt
.
grad
-
colscale_ref
.
grad
).
abs
().
max
()
+
2e-4
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
# @pytest.mark.parametrize('is_rms_norm', [False])
@
pytest
.
mark
.
parametrize
(
"tied_norm"
,
[
False
,
True
])
# @pytest.mark.parametrize('tied_norm', [False])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
# @pytest.mark.parametrize('has_residual', [False])
@
pytest
.
mark
.
parametrize
(
"has_x1"
,
[
True
,
False
])
# @pytest.mark.parametrize('has_x1', [True])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.37
,
0.0
])
# @pytest.mark.parametrize('dropout_p', [0.0])
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
])
# @pytest.mark.parametrize('weight_dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
],
)
# @pytest.mark.parametrize('hidden_size', [256])
def
test_dropout_layer_norm_parallel_residual_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_x1
,
has_residual
,
tied_norm
,
is_rms_norm
,
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
if
is_rms_norm
and
fused_rms_norm_affine
is
None
:
pytest
.
skip
()
# We need Apex's FusedRMSNorm to test
our_layer_norm_func
=
(
dropout_add_layer_norm_parallel_residual
if
not
is_rms_norm
else
dropout_add_rms_norm_parallel_residual
)
device
=
"cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
1e-4
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x0_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
if
has_x1
:
x1_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
x1_ref
=
x1_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
x1
=
None
if
has_residual
:
res_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res_pt
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
res
=
None
weight0
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
bias0
=
(
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
if
not
is_rms_norm
else
None
)
weight0_pt
=
weight0
.
detach
().
clone
().
requires_grad_
()
weight0_ref
=
weight0
.
detach
().
clone
().
float
().
requires_grad_
()
bias0_pt
=
bias0
.
detach
().
clone
().
requires_grad_
()
if
bias0
is
not
None
else
None
bias0_ref
=
bias0
.
detach
().
clone
().
float
().
requires_grad_
()
if
bias0
is
not
None
else
None
if
not
tied_norm
:
weight1
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
bias1
=
(
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
if
not
is_rms_norm
else
None
)
weight1_pt
=
weight1
.
detach
().
clone
().
requires_grad_
()
weight1_ref
=
weight1
.
detach
().
clone
().
float
().
requires_grad_
()
bias1_pt
=
bias1
.
detach
().
clone
().
requires_grad_
()
if
bias1
is
not
None
else
None
bias1_ref
=
bias1
.
detach
().
clone
().
float
().
requires_grad_
()
if
bias1
is
not
None
else
None
else
:
weight1
,
bias1
=
None
,
None
epsilon
=
1e-5
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out0
,
out1
,
dmask0
,
dmask1
=
our_layer_norm_func
(
x0
,
x1
,
res
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
,
)
assert
out0
.
dtype
==
input_dtype
if
not
tied_norm
:
assert
out1
.
dtype
==
input_dtype
print
(
f
"Actual dropout fraction:
{
1
-
dmask0
.
float
().
mean
().
item
()
}
"
)
if
has_residual
:
if
has_x1
:
residual_pt
=
(
(
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_pt
.
float
()
*
dmask1
.
float
())
/
(
1
-
dropout_p
)
+
res_pt
.
float
()
).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
(
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_ref
*
dmask1
.
float
())
/
(
1
-
dropout_p
)
)
+
res_ref
else
:
residual_pt
=
((
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
res_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
res_ref
else
:
if
has_x1
:
residual_pt
=
(
(
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_pt
.
float
()
*
dmask1
.
float
())
/
(
1
-
dropout_p
)
).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_ref
*
dmask1
.
float
()
)
/
(
1
-
dropout_p
)
else
:
residual_pt
=
((
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
if
not
is_rms_norm
:
out0_pt
=
F
.
layer_norm
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
(
hidden_size
,),
weight0_pt
,
bias0_pt
,
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out0_ref
=
F
.
layer_norm
(
residual_ref
,
(
hidden_size
,),
weight0_ref
,
bias0_ref
,
eps
=
epsilon
)
if
not
tied_norm
:
out1_pt
=
F
.
layer_norm
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
(
hidden_size
,),
weight1_pt
,
bias1_pt
,
eps
=
epsilon
,
).
to
(
dtype
=
input_dtype
)
out1_ref
=
F
.
layer_norm
(
residual_ref
,
(
hidden_size
,),
weight1_ref
,
bias1_ref
,
eps
=
epsilon
)
else
:
out0_pt
=
fused_rms_norm_affine
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
weight0_pt
,
(
hidden_size
,),
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out0_ref
=
fused_rms_norm_affine
(
residual_ref
,
weight0_ref
,
(
hidden_size
,),
eps
=
epsilon
)
if
not
tied_norm
:
out1_pt
=
fused_rms_norm_affine
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
weight1_pt
,
(
hidden_size
,),
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out1_ref
=
fused_rms_norm_affine
(
residual_ref
,
weight1_ref
,
(
hidden_size
,),
eps
=
epsilon
)
assert
(
out0
-
out0_ref
).
abs
().
max
()
<=
4
*
(
out0_pt
-
out0_ref
).
abs
().
max
()
+
1e-4
if
not
tied_norm
:
assert
(
out1
-
out1_ref
).
abs
().
max
()
<=
4
*
(
out1_pt
-
out1_ref
).
abs
().
max
()
+
1e-4
g0
=
torch
.
randn_like
(
out0
)
/
batch_size
if
tied_norm
:
out0
.
backward
(
g0
)
out0_pt
.
backward
(
g0
)
out0_ref
.
backward
(
g0
)
else
:
g1
=
torch
.
randn_like
(
out1
)
/
batch_size
(
out0
*
g0
+
out1
*
g1
).
sum
().
backward
()
(
out0_pt
*
g0
+
out1_pt
*
g1
).
sum
().
backward
()
(
out0_ref
*
g0
+
out1_ref
*
g1
).
sum
().
backward
()
assert
(
x0
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
+
1e-4
if
has_x1
:
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
res
.
grad
-
res_ref
.
grad
).
abs
().
max
()
<=
4
*
(
res_pt
.
grad
-
res_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
weight0
.
grad
-
weight0_ref
.
grad
).
abs
().
max
()
<=
3
*
(
weight0_pt
.
grad
-
weight0_ref
.
grad
).
abs
().
max
()
+
3e-5
if
not
is_rms_norm
:
assert
(
bias0
.
grad
-
bias0_ref
.
grad
).
abs
().
max
()
<=
2
*
(
bias0_pt
.
grad
-
bias0_ref
.
grad
).
abs
().
max
()
+
3e-5
if
not
tied_norm
:
assert
(
weight1
.
grad
-
weight1_ref
.
grad
).
abs
().
max
()
<=
3
*
(
weight1_pt
.
grad
-
weight1_ref
.
grad
).
abs
().
max
()
+
3e-5
if
not
is_rms_norm
:
assert
(
bias1
.
grad
-
bias1_ref
.
grad
).
abs
().
max
()
<=
2
*
(
bias1_pt
.
grad
-
bias1_ref
.
grad
).
abs
().
max
()
+
3e-5
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
# @pytest.mark.parametrize('is_rms_norm', [False])
@
pytest
.
mark
.
parametrize
(
"tied_norm"
,
[
False
,
True
])
# @pytest.mark.parametrize('tied_norm', [False])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
# @pytest.mark.parametrize('has_residual', [False])
@
pytest
.
mark
.
parametrize
(
"has_x1"
,
[
True
,
False
])
# @pytest.mark.parametrize('has_x1', [True])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.37
,
0.0
])
# @pytest.mark.parametrize('dropout_p', [0.0])
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
])
# @pytest.mark.parametrize('weight_dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
256
,
384
,
768
,
1024
,
1280
,
1536
,
1600
,
2048
,
2560
,
3000
,
3072
,
4096
,
5120
,
6144
],
)
# @pytest.mark.parametrize('hidden_size', [256])
def
test_dropout_layer_norm_parallel_residual_prenorm_training
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
dropout_p
,
has_x1
,
has_residual
,
tied_norm
,
is_rms_norm
,
):
if
weight_dtype
==
torch
.
float16
and
input_dtype
==
torch
.
bfloat16
:
pytest
.
skip
()
# Not supported
if
is_rms_norm
and
fused_rms_norm_affine
is
None
:
pytest
.
skip
()
# We need Apex's FusedRMSNorm to test
our_layer_norm_func
=
(
dropout_add_layer_norm_parallel_residual
if
not
is_rms_norm
else
dropout_add_rms_norm_parallel_residual
)
device
=
"cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol
,
atol
=
(
1e-3
,
1e-4
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x0_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0
=
x0_pt
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0_pt
.
detach
().
clone
().
float
().
requires_grad_
()
if
has_x1
:
x1_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x1
=
x1_pt
.
detach
().
clone
().
requires_grad_
()
x1_ref
=
x1_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
x1
=
None
if
has_residual
:
res_pt
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res
=
res_pt
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res_pt
.
detach
().
clone
().
float
().
requires_grad_
()
else
:
res
=
None
weight0
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
bias0
=
(
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
if
not
is_rms_norm
else
None
)
weight0_pt
=
weight0
.
detach
().
clone
().
requires_grad_
()
weight0_ref
=
weight0
.
detach
().
clone
().
float
().
requires_grad_
()
bias0_pt
=
bias0
.
detach
().
clone
().
requires_grad_
()
if
bias0
is
not
None
else
None
bias0_ref
=
bias0
.
detach
().
clone
().
float
().
requires_grad_
()
if
bias0
is
not
None
else
None
if
not
tied_norm
:
weight1
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
bias1
=
(
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
if
not
is_rms_norm
else
None
)
weight1_pt
=
weight1
.
detach
().
clone
().
requires_grad_
()
weight1_ref
=
weight1
.
detach
().
clone
().
float
().
requires_grad_
()
bias1_pt
=
bias1
.
detach
().
clone
().
requires_grad_
()
if
bias1
is
not
None
else
None
bias1_ref
=
bias1
.
detach
().
clone
().
float
().
requires_grad_
()
if
bias1
is
not
None
else
None
else
:
weight1
,
bias1
=
None
,
None
epsilon
=
1e-5
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out0
,
out1
,
residual
,
dmask0
,
dmask1
=
our_layer_norm_func
(
x0
,
x1
,
res
,
weight0
,
bias0
,
weight1
,
bias1
,
dropout_p
,
epsilon
,
prenorm
=
True
,
residual_in_fp32
=
residual_in_fp32
,
return_dropout_mask
=
True
,
)
assert
out0
.
dtype
==
input_dtype
if
not
tied_norm
:
assert
out1
.
dtype
==
input_dtype
print
(
f
"Actual dropout fraction:
{
1
-
dmask0
.
float
().
mean
().
item
()
}
"
)
if
has_residual
:
if
has_x1
:
residual_pt
=
(
(
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_pt
.
float
()
*
dmask1
.
float
())
/
(
1
-
dropout_p
)
+
res_pt
.
float
()
).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
(
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_ref
*
dmask1
.
float
())
/
(
1
-
dropout_p
)
)
+
res_ref
else
:
residual_pt
=
((
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
res_pt
.
float
()).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
res_ref
else
:
if
has_x1
:
residual_pt
=
(
(
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_pt
.
float
()
*
dmask1
.
float
())
/
(
1
-
dropout_p
)
).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
+
(
x1_ref
*
dmask1
.
float
()
)
/
(
1
-
dropout_p
)
else
:
residual_pt
=
((
x0_pt
.
float
()
*
dmask0
.
float
())
/
(
1
-
dropout_p
)).
to
(
dtype
=
residual_dtype
)
residual_ref
=
(
x0_ref
*
dmask0
.
float
())
/
(
1
-
dropout_p
)
if
not
is_rms_norm
:
out0_pt
=
F
.
layer_norm
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
(
hidden_size
,),
weight0_pt
,
bias0_pt
,
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out0_ref
=
F
.
layer_norm
(
residual_ref
,
(
hidden_size
,),
weight0_ref
,
bias0_ref
,
eps
=
epsilon
)
if
not
tied_norm
:
out1_pt
=
F
.
layer_norm
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
(
hidden_size
,),
weight1_pt
,
bias1_pt
,
eps
=
epsilon
,
).
to
(
dtype
=
input_dtype
)
out1_ref
=
F
.
layer_norm
(
residual_ref
,
(
hidden_size
,),
weight1_ref
,
bias1_ref
,
eps
=
epsilon
)
else
:
out0_pt
=
fused_rms_norm_affine
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
weight0_pt
,
(
hidden_size
,),
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out0_ref
=
fused_rms_norm_affine
(
residual_ref
,
weight0_ref
,
(
hidden_size
,),
eps
=
epsilon
)
if
not
tied_norm
:
out1_pt
=
fused_rms_norm_affine
(
residual_pt
.
to
(
dtype
=
weight_dtype
),
weight1_pt
,
(
hidden_size
,),
eps
=
epsilon
).
to
(
dtype
=
input_dtype
)
out1_ref
=
fused_rms_norm_affine
(
residual_ref
,
weight1_ref
,
(
hidden_size
,),
eps
=
epsilon
)
assert
(
out0
-
out0_ref
).
abs
().
max
()
<=
4
*
(
out0_pt
-
out0_ref
).
abs
().
max
()
+
1e-4
if
not
tied_norm
:
assert
(
out1
-
out1_ref
).
abs
().
max
()
<=
4
*
(
out1_pt
-
out1_ref
).
abs
().
max
()
+
1e-4
assert
(
residual
-
residual_ref
).
abs
().
max
()
<=
4
*
(
residual_pt
-
residual_ref
).
abs
().
max
()
+
1e-4
g0
=
torch
.
randn_like
(
out0
)
/
batch_size
if
tied_norm
:
(
out0
*
F
.
sigmoid
(
residual
)).
backward
(
g0
)
(
out0_pt
*
F
.
sigmoid
(
residual_pt
)).
backward
(
g0
)
(
out0_ref
*
F
.
sigmoid
(
residual_ref
)).
backward
(
g0
)
else
:
g1
=
torch
.
randn_like
(
out1
)
/
batch_size
(
out0
*
F
.
sigmoid
(
residual
)
*
g0
+
out1
*
g1
).
sum
().
backward
()
(
out0_pt
*
F
.
sigmoid
(
residual_pt
)
*
g0
+
out1_pt
*
g1
).
sum
().
backward
()
(
out0_ref
*
F
.
sigmoid
(
residual_ref
)
*
g0
+
out1_ref
*
g1
).
sum
().
backward
()
assert
(
x0
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x0_pt
.
grad
-
x0_ref
.
grad
).
abs
().
max
()
+
1e-4
if
has_x1
:
assert
(
x1
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
<=
4
*
(
x1_pt
.
grad
-
x1_ref
.
grad
).
abs
().
max
()
+
1e-4
if
has_residual
:
assert
(
res
.
grad
-
res_ref
.
grad
).
abs
().
max
()
<=
4
*
(
res_pt
.
grad
-
res_ref
.
grad
).
abs
().
max
()
+
1e-4
assert
(
weight0
.
grad
-
weight0_ref
.
grad
).
abs
().
max
()
<=
3
*
(
weight0_pt
.
grad
-
weight0_ref
.
grad
).
abs
().
max
()
+
3e-5
if
not
is_rms_norm
:
assert
(
bias0
.
grad
-
bias0_ref
.
grad
).
abs
().
max
()
<=
2
*
(
bias0_pt
.
grad
-
bias0_ref
.
grad
).
abs
().
max
()
+
3e-5
if
not
tied_norm
:
assert
(
weight1
.
grad
-
weight1_ref
.
grad
).
abs
().
max
()
<=
3
*
(
weight1_pt
.
grad
-
weight1_ref
.
grad
).
abs
().
max
()
+
3e-5
if
not
is_rms_norm
:
assert
(
bias1
.
grad
-
bias1_ref
.
grad
).
abs
().
max
()
<=
2
*
(
bias1_pt
.
grad
-
bias1_ref
.
grad
).
abs
().
max
()
+
3e-5
def
test_dropout_layer_norm_randomness
():
hidden_size
=
256
dtype
=
torch
.
float32
dropout_p
=
0.1
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x0
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
res
=
torch
.
randn_like
(
x0
,
dtype
=
dtype
,
requires_grad
=
True
)
model
=
DropoutAddLayerNorm
(
hidden_size
,
p
=
dropout_p
,
device
=
device
,
dtype
=
dtype
)
torch
.
random
.
manual_seed
(
42
)
_
,
dmask0
=
dropout_add_layer_norm
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
eps
,
return_dropout_mask
=
True
)
# Subsequent call should have a different dropout mask
_
,
dmask1
=
dropout_add_layer_norm
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
eps
,
return_dropout_mask
=
True
)
torch
.
random
.
manual_seed
(
42
)
# Resetting the seed, should get the same dropout mask
_
,
dmask2
=
dropout_add_layer_norm
(
x0
,
res
,
model
.
weight
,
model
.
bias
,
model
.
p
,
model
.
eps
,
return_dropout_mask
=
True
)
assert
not
torch
.
equal
(
dmask0
,
dmask1
)
assert
torch
.
equal
(
dmask0
,
dmask2
)
tests/ops/test_fused_dense.py
deleted
100644 → 0
View file @
344c988d
import
math
from
functools
import
partial
import
pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
flash_attn.ops.fused_dense
import
FusedDense
,
FusedMLP
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"return_residual"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_features"
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"in_features"
,
[
1024
,
4096
])
def
test_fused_linear_bias
(
in_features
,
out_features
,
has_bias
,
return_residual
,
dtype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDense
(
in_features
,
out_features
,
bias
=
has_bias
,
return_residual
=
return_residual
,
device
=
device
,
dtype
=
dtype
,
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
if
has_bias
:
model
.
bias
.
copy_
(
model_pt
.
bias
)
out_pt
=
model_pt
(
x_pt
)
if
not
return_residual
:
out
=
model
(
x
)
else
:
out
,
x_copy
=
model
(
x
)
x_copy
=
(
x_copy
[...,
:
out_features
]
if
out_features
<
in_features
else
F
.
pad
(
x_copy
,
(
0
,
out_features
-
in_features
))
)
x_pt_copy
=
(
x_pt
[...,
:
out_features
]
if
out_features
<
in_features
else
F
.
pad
(
x_pt
,
(
0
,
out_features
-
in_features
))
)
# Just add some random function of the residual
out_pt
=
out_pt
+
F
.
gelu
(
x_pt_copy
)
out
=
out
+
F
.
gelu
(
x_copy
)
# with torch.no_grad():
# out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
if
has_bias
:
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"heuristic"
,
[
"auto"
,
-
1
])
# @pytest.mark.parametrize('heuristic', ['auto'])
@
pytest
.
mark
.
parametrize
(
"checkpoint_lvl"
,
[
0
,
1
,
2
])
# @pytest.mark.parametrize('checkpoint_lvl', [1])
@
pytest
.
mark
.
parametrize
(
"return_residual"
,
[
False
,
True
])
# @pytest.mark.parametrize('return_residual', [False])
@
pytest
.
mark
.
parametrize
(
"has_bias2"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_bias1"
,
[
True
,
False
])
# @pytest.mark.parametrize('has_bias2', [True])
# @pytest.mark.parametrize('has_bias1', [True])
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
"gelu_approx"
,
"relu"
])
# @pytest.mark.parametrize('activation', ['relu'])
@
pytest
.
mark
.
parametrize
(
"out_features"
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"in_features"
,
[
1024
,
4096
])
# @pytest.mark.parametrize('out_features', [4096])
# @pytest.mark.parametrize('in_features', [1024])
def
test_fused_mlp
(
in_features
,
out_features
,
activation
,
has_bias1
,
has_bias2
,
return_residual
,
checkpoint_lvl
,
heuristic
,
dtype
,
):
device
=
"cuda"
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias1
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedMLP
(
in_features
,
out_features
,
in_features
,
activation
=
activation
,
bias1
=
has_bias1
,
bias2
=
has_bias2
,
return_residual
=
return_residual
,
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
device
=
device
,
dtype
=
dtype
,
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
if
has_bias1
:
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
if
has_bias2
:
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
"tanh"
)
if
activation
==
"gelu_approx"
else
partial
(
F
.
relu
,
inplace
=
True
)
)
out_pt
=
model_pt_fc2
(
activation_fn
(
model_pt_fc1
(
x_pt
)))
if
not
return_residual
:
out
=
model
(
x
)
else
:
out
,
x_copy
=
model
(
x
)
# Just add some random function of the residual
out_pt
=
out_pt
+
F
.
gelu
(
x_pt
)
out
=
out
+
F
.
gelu
(
x_copy
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
# The error for relu is higher still
if
activation
==
"relu"
:
atol
=
1e-1
if
dtype
==
torch
.
bfloat16
else
5e-2
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
if
has_bias1
:
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
if
has_bias2
:
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
tests/ops/test_fused_dense_parallel.py
deleted
100644 → 0
View file @
344c988d
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/ops/test_fused_dense_parallel.py
import
math
import
pytest
import
torch
import
torch.nn.functional
as
F
from
apex.transformer
import
parallel_state
,
tensor_parallel
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
FusedDense
,
FusedMLP
,
ParallelFusedMLP
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
"sequence_parallel"
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [False])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
,
False
])
# @pytest.mark.parametrize('has_bias', [False])
@
pytest
.
mark
.
parametrize
(
"out_features"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"in_features"
,
[
4096
])
def
test_fused_linear_bias
(
in_features
,
out_features
,
has_bias
,
sequence_parallel
,
world_size
,
dtype
):
assert
out_features
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
)
device
=
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
seqlen
=
512
assert
batch_size
*
seqlen
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
sequence_parallel
:
x
=
(
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
)
.
detach
()
.
clone
()
.
requires_grad_
()
)
else
:
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
partition_out_features
=
out_features
//
world_size
model
=
ColumnParallelLinear
(
in_features
,
out_features
,
parallel_state
.
get_tensor_model_parallel_group
(),
bias
=
has_bias
,
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
,
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
[
rank
*
partition_out_features
:
(
rank
+
1
)
*
partition_out_features
]
)
if
has_bias
:
model
.
bias
.
copy_
(
model_pt
.
bias
[
rank
*
partition_out_features
:
(
rank
+
1
)
*
partition_out_features
]
)
out
=
model
(
x
)
out_pt
=
model_pt
(
x_pt
)
assert
torch
.
allclose
(
out
,
out_pt
[:,
rank
*
partition_out_features
:
(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
,
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out_pt
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
[:,
rank
*
partition_out_features
:
(
rank
+
1
)
*
partition_out_features
])
parallel_state
.
destroy_model_parallel
()
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
,
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
[
rank
*
partition_out_features
:
(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
10
,
)
if
has_bias
:
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
[
rank
*
partition_out_features
:
(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
5
,
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
"sequence_parallel"
,
[
True
,
False
])
# @pytest.mark.parametrize('sequence_parallel', [False])
@
pytest
.
mark
.
parametrize
(
"has_bias2"
,
[
True
,
False
])
# @pytest.mark.parametrize('has_bias2', [True])
@
pytest
.
mark
.
parametrize
(
"out_features"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"in_features"
,
[
1024
])
def
test_fused_mlp
(
in_features
,
out_features
,
has_bias2
,
sequence_parallel
,
world_size
,
dtype
):
assert
out_features
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
)
device
=
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
seqlen
=
512
assert
batch_size
*
seqlen
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
x_pt
)
/
32
if
sequence_parallel
:
x
=
(
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
)
.
detach
()
.
clone
()
.
requires_grad_
()
)
else
:
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
dtype
=
dtype
)
partition_out_features
=
out_features
//
world_size
partition_in_features
=
in_features
//
world_size
model
=
ParallelFusedMLP
(
in_features
,
out_features
,
in_features
,
process_group
=
parallel_state
.
get_tensor_model_parallel_group
(),
bias2
=
has_bias2
and
rank
==
0
,
sequence_parallel
=
sequence_parallel
,
device
=
device
,
dtype
=
dtype
,
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
[
rank
*
partition_out_features
:
(
rank
+
1
)
*
partition_out_features
]
)
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
[
rank
*
partition_out_features
:
(
rank
+
1
)
*
partition_out_features
]
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
[
:,
rank
*
partition_out_features
:
(
rank
+
1
)
*
partition_out_features
]
)
if
has_bias2
and
rank
==
0
:
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
out
=
model
(
x
)
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
"tanh"
))
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
out_pt
,
rtol
=
rtol
,
atol
=
atol
,
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
g
)
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:
(
rank
+
1
)
*
partition_batch_dim
]
if
sequence_parallel
else
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
,
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
[
rank
*
partition_out_features
:
(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
10
,
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
[
rank
*
partition_out_features
:
(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
5
,
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
[
:,
rank
*
partition_out_features
:
(
rank
+
1
)
*
partition_out_features
],
rtol
=
rtol
,
atol
=
atol
*
10
,
)
if
has_bias2
and
rank
==
0
:
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
tests/ops/triton/test_layer_norm.py
deleted
100644 → 0
View file @
344c988d
# Copyright (c) 2024, Tri Dao.
import
pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
flash_attn.ops.triton.layer_norm
import
(
layer_norm_fn
,
layer_norm_ref
,
rms_norm_ref
,
layer_norm_linear_fn
,
)
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
"has_weight1"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_weight1", [True])
@
pytest
.
mark
.
parametrize
(
"has_x1"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_x1", [False])
@
pytest
.
mark
.
parametrize
(
"has_rowscale"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_rowscale", [False])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.27
])
# @pytest.mark.parametrize("dropout_p", [0.0])
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
# @pytest.mark.parametrize("prenorm", [False])
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
# @pytest.mark.parametrize("is_rms_norm", [True])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
# @pytest.mark.parametrize("has_residual", [False])
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
,
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[])
)
# @pytest.mark.parametrize("weight_dtype", [torch.float32])
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
),
(
torch
.
float32
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.float16, torch.float16)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
2048
,
2560
,
3000
,
4096
])
# @pytest.mark.parametrize("hidden_size", [256])
def
test_layer_norm
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
has_residual
,
is_rms_norm
,
prenorm
,
dropout_p
,
has_rowscale
,
has_x1
,
has_weight1
,
):
if
has_rowscale
and
has_x1
:
pytest
.
skip
(
"Not supported"
)
device
=
"cuda"
if
any
(
x
==
torch
.
bfloat16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
atol
=
5e-2
elif
any
(
x
==
torch
.
float16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
atol
=
1e-2
else
:
atol
=
1e-4
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
layer_norm_ref_fn
=
layer_norm_ref
if
not
is_rms_norm
else
rms_norm_ref
allclose
=
(
# Sometimes x0_pt.grad is NaN
lambda
x
,
x_pt
,
x_ref
,
atol
=
atol
:
(
x
-
x_ref
).
abs
().
max
()
<=
2
*
(
x_pt
[
~
x_pt
.
isnan
()]
-
x_ref
[
~
x_pt
.
isnan
()]).
abs
().
max
()
+
atol
or
(
# Sometimes x_pt and x_ref are the same (e.g. bfloat16) so we want to perturb is a bit
# by multiply and divide by 0.3
(
x_pt
[
~
x_pt
.
isnan
()]
-
x_ref
[
~
x_pt
.
isnan
()]).
abs
().
max
()
==
0.0
and
(
x
-
x_ref
).
abs
().
max
()
<=
2
*
(
x_pt
[
~
x_pt
.
isnan
()]
*
0.3
/
0.3
-
x_ref
[
~
x_pt
.
isnan
()]).
abs
().
max
()
+
atol
)
)
x0
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0_pt
=
x0
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0
.
detach
().
clone
().
requires_grad_
()
if
has_residual
:
res
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res_pt
=
res
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res
.
detach
().
clone
().
requires_grad_
()
else
:
res
,
res_pt
,
res_ref
=
None
,
None
,
None
weight
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
if
not
is_rms_norm
:
bias
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
else
:
bias
=
None
weight_pt
=
weight
.
detach
().
clone
().
requires_grad_
()
weight_ref
=
weight
.
detach
().
clone
().
requires_grad_
()
bias_pt
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
bias_ref
=
bias
.
detach
().
clone
().
requires_grad_
()
if
bias
is
not
None
else
None
if
has_x1
:
x1
=
torch
.
randn_like
(
x0
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x1_pt
=
x1
.
detach
().
clone
().
requires_grad_
()
x1_ref
=
x1
.
detach
().
clone
().
requires_grad_
()
else
:
x1
,
x1_pt
,
x1_ref
=
None
,
None
,
None
if
has_weight1
:
weight1
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
weight1_pt
=
weight1
.
detach
().
clone
().
requires_grad_
()
weight1_ref
=
weight1
.
detach
().
clone
().
requires_grad_
()
if
not
is_rms_norm
:
bias1
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
else
:
bias1
=
None
bias1_pt
=
bias1
.
detach
().
clone
().
requires_grad_
()
if
bias1
is
not
None
else
None
bias1_ref
=
bias1
.
detach
().
clone
().
requires_grad_
()
if
bias1
is
not
None
else
None
else
:
weight1
,
weight1_pt
,
weight1_ref
=
None
,
None
,
None
bias1
,
bias1_pt
,
bias1_ref
=
None
,
None
,
None
rowscale
=
(
torch
.
randn
(
batch_size
,
seqlen
,
dtype
=
input_dtype
,
device
=
device
)
if
has_rowscale
else
None
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
out
,
*
rest
=
layer_norm_fn
(
x0
,
weight
,
bias
,
residual
=
res
,
x1
=
x1
,
weight1
=
weight1
,
bias1
=
bias1
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
is_rms_norm
=
is_rms_norm
,
return_dropout_mask
=
True
,
)
dropout_mask
=
rest
[
-
2
]
if
dropout_p
>
0.0
else
None
dropout_mask1
=
rest
[
-
1
]
if
dropout_p
>
0.0
and
x1
is
not
None
else
None
out_pt
=
layer_norm_ref_fn
(
x0_pt
,
weight_pt
,
bias_pt
,
residual
=
res_pt
,
x1
=
x1_pt
,
weight1
=
weight1_pt
,
bias1
=
bias1_pt
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
prenorm
=
prenorm
,
dropout_mask
=
dropout_mask
,
dropout_mask1
=
dropout_mask1
,
)
out_ref
=
layer_norm_ref_fn
(
x0_ref
,
weight_ref
,
bias_ref
,
residual
=
res_ref
,
x1
=
x1_ref
,
weight1
=
weight1_ref
,
bias1
=
bias1_ref
,
eps
=
1e-6
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
prenorm
=
prenorm
,
dropout_mask
=
dropout_mask
,
dropout_mask1
=
dropout_mask1
,
upcast
=
True
,
)
if
not
has_weight1
:
if
prenorm
:
residual
=
rest
[
0
]
out_pt
,
residual_pt
=
out_pt
out_ref
,
residual_ref
=
out_ref
out1
,
out1_pt
,
out1_ref
=
None
,
None
,
None
else
:
out1
=
rest
.
pop
(
0
)
if
prenorm
:
residual
=
rest
[
0
]
out_pt
,
out1_pt
,
residual_pt
=
out_pt
out_ref
,
out1_ref
,
residual_ref
=
out_ref
else
:
out_pt
,
out1_pt
=
out_pt
out_ref
,
out1_ref
=
out_ref
assert
out
.
dtype
==
input_dtype
if
prenorm
:
assert
residual
.
dtype
==
residual_dtype
assert
allclose
(
residual
,
residual_pt
,
residual_ref
)
assert
allclose
(
out
,
out_pt
,
out_ref
)
if
out1
is
not
None
:
assert
out1
.
dtype
==
input_dtype
assert
allclose
(
out1
,
out1_pt
,
out1_ref
)
if
dropout_mask
is
not
None
:
dropout_fraction
=
1.0
-
dropout_mask
.
float
().
mean
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<
0.01
if
dropout_mask1
is
not
None
:
dropout_fraction
=
1.0
-
dropout_mask1
.
float
().
mean
()
assert
abs
(
dropout_fraction
-
dropout_p
)
<
0.01
assert
not
torch
.
equal
(
dropout_mask
,
dropout_mask1
)
g
=
torch
.
randn_like
(
out
)
/
batch_size
if
has_weight1
:
out
=
out
*
F
.
gelu
(
out1
)
out_pt
=
out_pt
*
F
.
gelu
(
out1_pt
)
out_ref
=
out_ref
*
F
.
gelu
(
out1_ref
)
if
not
prenorm
:
out
.
backward
(
g
)
out_pt
.
backward
(
g
)
out_ref
.
backward
(
g
)
else
:
(
out
*
F
.
sigmoid
(
residual
)).
backward
(
g
)
(
out_pt
*
F
.
sigmoid
(
residual_pt
)).
backward
(
g
)
(
out_ref
*
F
.
sigmoid
(
residual_ref
.
to
(
dtype
=
residual_dtype
))).
backward
(
g
)
assert
allclose
(
x0
.
grad
,
x0_pt
.
grad
,
x0_ref
.
grad
)
if
has_residual
:
assert
allclose
(
res
.
grad
,
res_pt
.
grad
,
res_ref
.
grad
)
if
has_x1
:
assert
allclose
(
x1
.
grad
,
x1_pt
.
grad
,
x1_ref
.
grad
)
assert
allclose
(
weight
.
grad
,
weight_pt
.
grad
,
weight_ref
.
grad
)
if
bias
is
not
None
:
assert
allclose
(
bias
.
grad
,
bias_pt
.
grad
,
bias_ref
.
grad
)
if
has_weight1
:
assert
allclose
(
weight1
.
grad
,
weight1_pt
.
grad
,
weight1_ref
.
grad
)
if
bias1
is
not
None
:
assert
allclose
(
bias1
.
grad
,
bias1_pt
.
grad
,
bias1_ref
.
grad
)
@
pytest
.
mark
.
parametrize
(
"prenorm"
,
[
True
,
False
])
# @pytest.mark.parametrize("prenorm", [True])
@
pytest
.
mark
.
parametrize
(
"is_rms_norm"
,
[
False
,
True
])
# @pytest.mark.parametrize("is_rms_norm", [True])
@
pytest
.
mark
.
parametrize
(
"has_residual"
,
[
True
,
False
])
# @pytest.mark.parametrize("has_residual", [False])
@
pytest
.
mark
.
parametrize
(
"weight_dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"input_dtype,residual_dtype"
,
[(
torch
.
float16
,
torch
.
float16
),
(
torch
.
float16
,
torch
.
float32
)]
+
([(
torch
.
bfloat16
,
torch
.
bfloat16
),
(
torch
.
bfloat16
,
torch
.
float32
)]
if
is_sm8x
else
[]),
)
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
192
,
2048
,
2560
,
3000
])
# @pytest.mark.parametrize("hidden_size", [256])
def
test_layer_norm_linear
(
hidden_size
,
input_dtype
,
residual_dtype
,
weight_dtype
,
has_residual
,
is_rms_norm
,
prenorm
):
device
=
"cuda"
if
any
(
x
==
torch
.
bfloat16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
atol
=
5e-2
elif
any
(
x
==
torch
.
float16
for
x
in
[
input_dtype
,
residual_dtype
,
weight_dtype
]):
atol
=
1e-2
else
:
atol
=
1e-4
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
4
seqlen
=
512
# batch_size = 1
# seqlen = 1
layer_norm_ref_fn
=
layer_norm_ref
if
not
is_rms_norm
else
rms_norm_ref
allclose
=
(
lambda
x
,
x_pt
,
x_ref
,
atol
=
atol
:
(
x
-
x_ref
).
abs
().
max
()
<=
2
*
(
x_pt
-
x_ref
).
abs
().
max
()
+
atol
)
x0
=
torch
.
randn
(
batch_size
,
seqlen
,
hidden_size
,
device
=
device
,
dtype
=
input_dtype
,
requires_grad
=
True
)
x0_pt
=
x0
.
detach
().
clone
().
requires_grad_
()
x0_ref
=
x0
.
detach
().
clone
().
requires_grad_
()
if
has_residual
:
res
=
torch
.
randn_like
(
x0
,
dtype
=
residual_dtype
,
requires_grad
=
True
)
res_pt
=
res
.
detach
().
clone
().
requires_grad_
()
res_ref
=
res
.
detach
().
clone
().
requires_grad_
()
else
:
res
,
res_pt
,
res_ref
=
None
,
None
,
None
norm_weight
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
if
not
is_rms_norm
:
norm_bias
=
torch
.
randn
(
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
else
:
norm_bias
=
None
norm_weight_pt
=
norm_weight
.
detach
().
clone
().
requires_grad_
()
norm_weight_ref
=
norm_weight
.
detach
().
clone
().
requires_grad_
()
norm_bias_pt
=
norm_bias
.
detach
().
clone
().
requires_grad_
()
if
norm_bias
is
not
None
else
None
norm_bias_ref
=
norm_bias
.
detach
().
clone
().
requires_grad_
()
if
norm_bias
is
not
None
else
None
linear_weight
=
torch
.
empty
(
2
*
hidden_size
,
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
torch
.
nn
.
init
.
xavier_uniform_
(
linear_weight
)
if
not
is_rms_norm
:
linear_bias
=
torch
.
randn
(
2
*
hidden_size
,
device
=
device
,
dtype
=
weight_dtype
,
requires_grad
=
True
)
else
:
linear_bias
=
None
linear_weight_pt
=
linear_weight
.
detach
().
clone
().
requires_grad_
()
linear_weight_ref
=
linear_weight
.
detach
().
clone
().
requires_grad_
()
linear_bias_pt
=
(
linear_bias
.
detach
().
clone
().
requires_grad_
()
if
linear_bias
is
not
None
else
None
)
linear_bias_ref
=
(
linear_bias
.
detach
().
clone
().
requires_grad_
()
if
linear_bias
is
not
None
else
None
)
residual_in_fp32
=
(
not
has_residual
)
and
residual_dtype
==
torch
.
float32
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
input_dtype
):
out
,
*
rest
=
layer_norm_linear_fn
(
x0
,
norm_weight
,
norm_bias
,
linear_weight
,
linear_bias
,
residual
=
res
,
eps
=
1e-6
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
is_rms_norm
=
is_rms_norm
,
)
out_pt
,
*
rest_pt
=
layer_norm_ref_fn
(
x0_pt
,
norm_weight_pt
,
norm_bias_pt
,
residual
=
res_pt
,
eps
=
1e-6
,
prenorm
=
prenorm
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
input_dtype
):
out_pt
=
F
.
linear
(
out_pt
,
linear_weight_pt
,
linear_bias_pt
)
out_ref
,
*
rest_ref
=
layer_norm_ref_fn
(
x0_ref
,
norm_weight_ref
,
norm_bias_ref
,
residual
=
res_ref
,
eps
=
1e-6
,
prenorm
=
prenorm
,
upcast
=
True
,
)
out_ref
=
F
.
linear
(
out_ref
.
to
(
linear_weight_ref
.
dtype
),
linear_weight_ref
,
linear_bias_ref
)
if
prenorm
:
residual
=
rest
[
0
]
residual_pt
=
rest_pt
[
0
]
residual_ref
=
rest_ref
[
0
]
assert
out
.
dtype
==
input_dtype
if
prenorm
:
assert
residual
.
dtype
==
residual_dtype
assert
allclose
(
residual
,
residual_pt
,
residual_ref
)
assert
allclose
(
out
,
out_pt
,
out_ref
)
g
=
torch
.
randn_like
(
out
)
/
batch_size
out
.
backward
(
g
)
out_pt
.
backward
(
g
)
out_ref
.
backward
(
g
)
assert
allclose
(
x0
.
grad
,
x0_pt
.
grad
,
x0_ref
.
grad
)
if
has_residual
:
assert
allclose
(
res
.
grad
,
res_pt
.
grad
,
res_ref
.
grad
)
assert
allclose
(
norm_weight
.
grad
,
norm_weight_pt
.
grad
,
norm_weight_ref
.
grad
)
if
norm_bias
is
not
None
:
assert
allclose
(
norm_bias
.
grad
,
norm_bias_pt
.
grad
,
norm_bias_ref
.
grad
)
assert
allclose
(
linear_weight
.
grad
,
linear_weight_pt
.
grad
,
linear_weight_ref
.
grad
)
if
linear_bias
is
not
None
:
assert
allclose
(
linear_bias
.
grad
,
linear_bias_pt
.
grad
,
linear_bias_ref
.
grad
)
tests/pyproject.toml
deleted
100644 → 0
View file @
344c988d
[tool.black]
line-length
=
100
target-version
=
['py38']
\ No newline at end of file
tests/test_flash_attn.py
deleted
100644 → 0
View file @
344c988d
import
math
import
pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
flash_attn
import
(
flash_attn_func
,
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_with_kvcache
,
)
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.flash_attn_interface
import
_get_block_size_n
from
flash_attn.layers.rotary
import
apply_rotary_emb
MAX_HEADDIM_SM8x
=
192
is_sm75
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
7
,
5
)
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)[
0
]
==
8
is_sm80
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
8
,
0
)
is_sm90
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
==
(
9
,
0
)
def
attn_bias_from_alibi_slopes
(
slopes
,
seqlen_q
,
seqlen_k
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
):
batch
,
nheads
=
slopes
.
shape
device
=
slopes
.
device
slopes
=
rearrange
(
slopes
,
"b h -> b h 1 1"
)
if
causal
:
return
torch
.
arange
(
-
seqlen_k
+
1
,
1
,
device
=
device
,
dtype
=
torch
.
float32
)
*
slopes
else
:
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
sk
=
(
seqlen_k
if
key_padding_mask
is
None
else
rearrange
(
key_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
sq
=
(
seqlen_q
if
query_padding_mask
is
None
else
rearrange
(
query_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
relative_pos
=
torch
.
abs
(
row_idx
+
sk
-
sq
-
col_idx
)
return
-
slopes
*
relative_pos
.
to
(
dtype
=
slopes
.
dtype
)
def
generate_random_padding_mask
(
max_seqlen
,
batch_size
,
device
,
mode
=
"random"
):
assert
mode
in
[
"full"
,
"random"
,
"third"
]
if
mode
==
"full"
:
lengths
=
torch
.
full
((
batch_size
,
1
),
max_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
)
elif
mode
==
"random"
:
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
elif
mode
==
"third"
:
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
padding_mask
=
(
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
)
return
padding_mask
def
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
kvpacked
=
False
,
qkvpacked
=
False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert
not
(
kvpacked
and
qkvpacked
)
batch_size
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
nheads_k
,
_
=
k
.
shape
assert
k
.
shape
==
(
batch_size
,
seqlen_k
,
nheads_k
,
d
)
assert
v
.
shape
==
(
batch_size
,
seqlen_k
,
nheads_k
,
d
)
if
query_padding_mask
is
not
None
:
q_unpad
,
indices_q
,
cu_seqlens_q
,
max_seqlen_q
=
unpad_input
(
q
,
query_padding_mask
)
output_pad_fn
=
lambda
output_unpad
:
pad_input
(
output_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
q_unpad
=
rearrange
(
q
,
"b s h d -> (b s) h d"
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q_unpad
.
device
)
max_seqlen_q
=
seqlen_q
output_pad_fn
=
lambda
output_unpad
:
rearrange
(
output_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
if
key_padding_mask
is
not
None
:
k_unpad
,
indices_k
,
cu_seqlens_k
,
max_seqlen_k
=
unpad_input
(
k
,
key_padding_mask
)
v_unpad
,
_
,
_
,
_
=
unpad_input
(
v
,
key_padding_mask
)
else
:
k_unpad
=
rearrange
(
k
,
"b s h d -> (b s) h d"
)
v_unpad
=
rearrange
(
v
,
"b s h d -> (b s) h d"
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
k_unpad
.
device
)
max_seqlen_k
=
seqlen_k
if
qkvpacked
:
assert
(
query_padding_mask
==
key_padding_mask
).
all
()
assert
nheads
==
nheads_k
qkv_unpad
=
torch
.
stack
([
q_unpad
,
k_unpad
,
v_unpad
],
dim
=
1
)
qkv
=
torch
.
stack
([
q
,
k
,
v
],
dim
=
2
)
if
query_padding_mask
is
not
None
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
pad_input
(
dqkv_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
rearrange
(
dqkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
qkv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
max_seqlen_q
,
qkv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dqkv_pad_fn
,
)
elif
kvpacked
:
kv_unpad
=
torch
.
stack
([
k_unpad
,
v_unpad
],
dim
=
1
)
kv
=
torch
.
stack
([
k
,
v
],
dim
=
2
)
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
dkv_pad_fn
=
lambda
dkv_unpad
:
pad_input
(
dkv_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dkv_pad_fn
=
lambda
dkv_unpad
:
rearrange
(
dkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
kv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
kv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dkv_pad_fn
,
)
else
:
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
dk_pad_fn
=
lambda
dk_unpad
:
pad_input
(
dk_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dk_pad_fn
=
lambda
dk_unpad
:
rearrange
(
dk_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
k_unpad
.
detach
().
requires_grad_
(),
v_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
k
.
detach
().
requires_grad_
(),
v
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
def
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
query_padding_mask
=
None
,
key_padding_mask
=
None
,
device
=
None
,
):
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
sk
=
(
seqlen_k
if
key_padding_mask
is
None
else
rearrange
(
key_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
sq
=
(
seqlen_q
if
query_padding_mask
is
None
else
rearrange
(
query_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
if
window_size
[
0
]
<
0
:
return
col_idx
>
row_idx
+
sk
-
sq
+
window_size
[
1
]
else
:
sk
=
torch
.
full_like
(
col_idx
,
seqlen_k
)
if
key_padding_mask
is
None
else
sk
return
torch
.
logical_or
(
col_idx
>
torch
.
minimum
(
row_idx
+
sk
-
sq
+
window_size
[
1
],
sk
),
col_idx
<
row_idx
+
sk
-
sq
-
window_size
[
0
],
)
def
attention_ref
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
attn_bias
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
softcap
=
0.0
,
upcast
=
True
,
reorder_ops
=
False
,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
dtype_og
=
q
.
dtype
if
upcast
:
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
seqlen_q
,
seqlen_k
=
q
.
shape
[
1
],
k
.
shape
[
1
]
k
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
k
.
shape
[
2
])
v
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
q
.
shape
[
2
]
//
v
.
shape
[
2
])
d
=
q
.
shape
[
-
1
]
if
not
reorder_ops
:
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
d
),
k
)
else
:
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
/
math
.
sqrt
(
d
))
if
softcap
>
0
:
scores
/=
softcap
scores
=
scores
.
tanh
()
scores
*=
softcap
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
q
.
device
,
)
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
if
attn_bias
is
not
None
:
scores
=
scores
+
attn_bias
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
attention
=
attention
.
masked_fill
(
torch
.
all
(
local_mask
,
dim
=-
1
,
keepdim
=
True
),
0.0
)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if
query_padding_mask
is
not
None
:
attention
=
attention
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
dropout_scaling
=
1.0
/
(
1
-
dropout_p
)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if
dropout_mask
is
not
None
:
attention_drop
=
attention
.
masked_fill
(
~
dropout_mask
,
0.0
)
else
:
attention_drop
=
attention
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
*
dropout_scaling
)
if
query_padding_mask
is
not
None
:
output
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b s 1 1"
),
0.0
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
def
attention_kvpacked_ref
(
q
,
kv
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
attn_bias
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
softcap
=
0.0
,
upcast
=
True
,
reorder_ops
=
False
,
):
return
attention_ref
(
q
,
kv
[:,
:,
0
],
kv
[:,
:,
1
],
query_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
upcast
=
upcast
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
reorder_ops
=
reorder_ops
,
)
def
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
=
None
,
attn_bias
=
None
,
dropout_p
=
0.0
,
dropout_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
softcap
=
0.0
,
upcast
=
True
,
reorder_ops
=
False
,
):
return
attention_ref
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
key_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
upcast
=
upcast
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
reorder_ops
=
reorder_ops
,
)
def
generate_sparsity_mask
(
seqlen
,
sparsity
=
0.3
):
repeats
=
seqlen
//
16
//
2
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
nrow
,
ncol
=
seqlen
//
16
,
seqlen
//
256
mask
=
torch
.
rand
(
nrow
,
ncol
,
device
=
"cuda"
)
<
sparsity
return
mask
def
attention_blocksparse_ref
(
qkv
,
blockmask
,
attn_mask
,
dropout_p
,
dropout_mask
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
blockmask: (seqlen / 16, seqlen / 256)
attn_mask: (batch_size, seqlen)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen, seqlen)
Output:
output: (batch_size, seqlen, nheads, head_dim)
attention: softmax after dropout
"""
q
,
k
,
v
=
qkv
.
float
().
unbind
(
dim
=
2
)
d
=
qkv
.
shape
[
-
1
]
seqlen
=
qkv
.
shape
[
1
]
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
d
),
k
)
scores
.
masked_fill_
(
rearrange
(
~
attn_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
blockmask
=
repeat
(
blockmask
,
"s_16 s_256 -> (s_16 16) (s_256 256)"
)
blockmask
=
blockmask
[:
seqlen
,
:
seqlen
]
scores
.
masked_fill_
(
rearrange
(
~
blockmask
,
"t s -> 1 1 t s"
),
float
(
"-inf"
))
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
attention
=
attention
.
masked_fill
(
rearrange
(
~
attn_mask
,
"b s -> b 1 s 1"
),
0.0
)
attention
=
attention
.
masked_fill_
(
rearrange
(
~
blockmask
,
"t s -> 1 1 t s"
),
0.0
)
attention_drop
=
attention
.
masked_fill
(
~
dropout_mask
,
0.0
)
/
(
1
-
dropout_p
)
output
=
torch
.
einsum
(
"bhts,bshd->bthd"
,
attention_drop
,
v
)
output
.
masked_fill_
(
rearrange
(
~
attn_mask
,
"b s -> b s 1 1"
),
0
)
return
output
.
to
(
dtype
=
qkv
.
dtype
),
attention
.
to
(
dtype
=
qkv
.
dtype
)
def
convert_flash_attn_S_to_softmax
(
S
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
head_dim
,
is_dropout
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
"""FlashAttention stores the S matrix in a different way.
Arguments:
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
query_padding_mask: (batch_size, seqlen_q_rounded)
key_padding_mask: (batch_size, seqlen_k_rounded)
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
seqlen_q_rounded
,
seqlen_k_rounded
=
S
.
shape
[
-
2
:]
S_converted
=
S
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
S
.
device
,
)
local_mask
=
F
.
pad
(
local_mask
,
(
0
,
seqlen_k_rounded
-
seqlen_k
,
0
,
seqlen_q_rounded
-
seqlen_q
),
value
=
True
,
)
S_converted
=
S_converted
.
masked_fill
(
local_mask
,
0.0
)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
seqlen_q_og
=
(
query_padding_mask
.
shape
[
-
1
]
if
query_padding_mask
is
not
None
else
seqlen_q_rounded
)
if
query_padding_mask
is
not
None
:
query_padding_mask
=
F
.
pad
(
query_padding_mask
,
(
0
,
seqlen_q_rounded
-
seqlen_q_og
))
S_converted
=
S_converted
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
seqlen_k_og
=
key_padding_mask
.
shape
[
-
1
]
if
key_padding_mask
is
not
None
else
seqlen_k
if
key_padding_mask
is
not
None
:
key_padding_mask
=
F
.
pad
(
key_padding_mask
,
(
0
,
seqlen_k_rounded
-
seqlen_k_og
))
S_converted
=
S_converted
.
masked_fill
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
0.0
)
S_converted
=
F
.
pad
(
S_converted
,
(
0
,
0
,
0
,
seqlen_q_og
-
seqlen_q_rounded
))
S_converted
=
F
.
pad
(
S_converted
,
(
0
,
seqlen_k_og
-
seqlen_k_rounded
))
return
S_converted
[:,
:,
:
seqlen_q
,
:
seqlen_k
]
def
normalize_flash_attn_S
(
attn_unnorm
,
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
attn_bias
=
None
,
is_dropout
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k, v: (batch_size, seqlen_k, nheads, head_dim)
key_padding_mask: (batch_size, seqlen_q)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
Output:
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
q
,
k
,
v
=
q
.
float
(),
k
.
float
(),
v
.
float
()
_
,
seqlen_q
,
_
,
head_dim
=
q
.
shape
seqlen_k
=
k
.
shape
[
1
]
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
head_dim
),
k
)
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
q
.
device
,
)
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
if
attn_bias
is
not
None
:
scores
=
scores
+
attn_bias
.
to
(
dtype
=
scores
.
dtype
)
block_size_n
=
_get_block_size_n
(
scores
.
device
,
head_dim
,
is_dropout
,
causal
)
scores_block
=
scores
.
split
(
block_size_n
,
dim
=-
1
)
lse_block
=
torch
.
stack
([
torch
.
logsumexp
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
lse
=
torch
.
logsumexp
(
lse_block
,
dim
=-
1
)
# lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
# so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
lse
[
lse
==
float
(
"-inf"
)]
=
float
(
"inf"
)
scores_max_block
=
torch
.
stack
([
torch
.
amax
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
cummax_block
=
torch
.
cummax
(
scores_max_block
.
flip
(
-
1
),
dim
=-
1
).
values
.
flip
(
-
1
).
unbind
(
dim
=-
1
)
attn_unnorm_block
=
attn_unnorm
.
split
(
block_size_n
,
dim
=-
1
)
attn_norm
=
torch
.
cat
(
[
a
*
rearrange
(
torch
.
exp
(
m
-
lse
),
"b h s -> b h s 1"
)
for
a
,
m
in
zip
(
attn_unnorm_block
,
cummax_block
)
],
dim
=-
1
,
)
if
query_padding_mask
is
not
None
:
attn_norm
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
return
attn_norm
.
to
(
dtype
=
attn_unnorm
.
dtype
)
def
get_dropout_fraction
(
dropout_mask
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
):
"""
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
"""
if
causal
:
window_size
=
(
window_size
[
0
],
0
)
batch_size
,
nheads
,
seqlen_q
,
seqlen_k
=
dropout_mask
.
shape
dropped
=
~
dropout_mask
valid
=
torch
.
ones_like
(
dropout_mask
)
if
query_padding_mask
is
not
None
:
dropped
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
False
)
valid
.
masked_fill_
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
False
)
if
key_padding_mask
is
not
None
:
dropped
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
False
)
valid
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
False
)
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
local_mask
=
construct_local_mask
(
seqlen_q
,
seqlen_k
,
window_size
,
query_padding_mask
,
key_padding_mask
,
dropout_mask
.
device
,
)
dropped
.
masked_fill_
(
local_mask
,
False
)
valid
.
masked_fill_
(
local_mask
,
False
)
dropped_total
=
dropped
.
sum
()
return
dropped
.
sum
()
/
valid
.
sum
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
[
False
,
True
])
# @pytest.mark.parametrize("deterministic", [False])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [False])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [False])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize("d", [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
384
,
768
,
1024
,
1025
,
2048
])
# @pytest.mark.parametrize("seqlen", [512])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.0])
def
test_flash_attn_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
dtype
):
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
4
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen
,
(
2
,))
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen
,
seqlen
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
out
,
lse
,
S_dmask
=
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
)
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
seqlen
,
seqlen
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
attn
=
normalize_flash_attn_S
(
attn_unnorm
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
None
,
None
,
attn_bias
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
None
,
None
,
causal
=
causal
,
window_size
=
window_size
).
item
()
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
else
:
dropout_mask
=
None
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
None
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
)
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
qkv
,
None
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
# v = qkv[:, :, 2].float()
# qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
# if causal:
# causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)
# qk.masked_fill_(causal_mask, float('-inf'))
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# p_tmp = torch.softmax(qk / math.sqrt(d), -1)
# p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values
# qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values
# qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values
# qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values
# o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:])
# o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:])
# o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:])
# o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :])
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
if
dropout_p
>
0.0
:
print
(
f
"Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
# do_o = (g.float() * out.float()).sum(-1)
# dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
# dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
(
dqkv
,)
=
torch
.
autograd
.
grad
(
out
,
qkv
,
g
)
(
dqkv_ref
,)
=
torch
.
autograd
.
grad
(
out_ref
,
qkv
,
g
)
(
dqkv_pt
,)
=
torch
.
autograd
.
grad
(
out_pt
,
qkv
,
g
)
print
(
f
"dQ max diff:
{
(
dqkv
[:,
:,
0
]
-
dqkv_ref
[:,
:,
0
]).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
dqkv
[:,
:,
1
]
-
dqkv_ref
[:,
:,
1
]).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
dqkv
[:,
:,
2
]
-
dqkv_ref
[:,
:,
2
]).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQKV mean diff:
{
(
dqkv
-
dqkv_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
dqkv_pt
[:,
:,
0
]
-
dqkv_ref
[:,
:,
0
]).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
dqkv_pt
[:,
:,
1
]
-
dqkv_ref
[:,
:,
1
]).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
dqkv_pt
[:,
:,
2
]
-
dqkv_ref
[:,
:,
2
]).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQKV Pytorch mean diff:
{
(
dqkv_pt
-
dqkv_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
[
False
,
True
])
# @pytest.mark.parametrize("deterministic", [True])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
257
,
384
,
512
,
768
,
1025
,
2048
])
# @pytest.mark.parametrize('seqlen', [128])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_varlen_qkvpacked
(
seqlen
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
dtype
):
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
5
nheads
=
6
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen
,
(
2
,))
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen
,
batch_size
,
device
,
mode
=
"random"
)
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen
,
seqlen
,
key_padding_mask
,
key_padding_mask
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
qkv_unpad
,
cu_seqlens
,
max_seqlen
,
qkv
,
output_pad_fn
,
dqkv_pad_fn
=
generate_qkv
(
*
qkv
.
unbind
(
dim
=
2
),
key_padding_mask
,
key_padding_mask
,
qkvpacked
=
True
)
out_unpad
,
sm_lse
,
S_dmask
=
flash_attn_varlen_qkvpacked_func
(
qkv_unpad
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
)
out
=
output_pad_fn
(
out_unpad
)
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
seqlen
,
seqlen
,
key_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
attn
=
normalize_flash_attn_S
(
attn_unnorm
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
key_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
key_padding_mask
,
key_padding_mask
,
causal
=
causal
,
window_size
=
window_size
).
item
()
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
else
:
dropout_mask
=
None
out_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
)
out_pt
,
attn_pt
=
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
if
dropout_p
>
0.0
:
print
(
f
"Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
(
dqkv_unpad
,)
=
torch
.
autograd
.
grad
(
out
,
qkv_unpad
,
g
)
dqkv
=
dqkv_pad_fn
(
dqkv_unpad
)
(
dqkv_ref
,)
=
torch
.
autograd
.
grad
(
out_ref
,
qkv
,
g
)
(
dqkv_pt
,)
=
torch
.
autograd
.
grad
(
out_pt
,
qkv
,
g
)
print
(
f
"dQ max diff:
{
(
dqkv
[:,
:,
0
]
-
dqkv_ref
[:,
:,
0
]).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
dqkv
[:,
:,
1
]
-
dqkv_ref
[:,
:,
1
]).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
dqkv
[:,
:,
2
]
-
dqkv_ref
[:,
:,
2
]).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQKV mean diff:
{
(
dqkv
-
dqkv_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
dqkv_pt
[:,
:,
0
]
-
dqkv_ref
[:,
:,
0
]).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
dqkv_pt
[:,
:,
1
]
-
dqkv_ref
[:,
:,
1
]).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
dqkv_pt
[:,
:,
2
]
-
dqkv_ref
[:,
:,
2
]).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQKV Pytorch mean diff:
{
(
dqkv_pt
-
dqkv_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"kvpacked"
,
[
True
,
False
])
# @pytest.mark.parametrize("kvpacked", [False])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
[
False
,
True
])
# @pytest.mark.parametrize("deterministic", [True])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.17])
@
pytest
.
mark
.
parametrize
(
"softcap"
,
[
0.0
,
50.0
])
def
test_flash_attn_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
,
softcap
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
pytest
.
skip
()
# Reference implementation OOM
if
softcap
>
0.0
and
dropout_p
>
0.0
:
pytest
.
skip
(
"Softcap and dropout not supported together"
)
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
4
nheads
=
6
if
softcap
==
0.0
else
4
# softcap reference impl takes more memory
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
2
)
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
softcap
>
0
:
# Ensure the values of qk are at least within softcap range.
q
=
q
*
softcap
if
kvpacked
:
kv
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
else
:
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen_q
,
seqlen_k
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
if
kvpacked
:
out
,
lse
,
S_dmask
=
flash_attn_kvpacked_func
(
q
,
kv
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
)
else
:
out
,
lse
,
S_dmask
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
)
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
seqlen_q
,
seqlen_k
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
if
kvpacked
:
kv_rep
=
repeat
(
kv
,
"b s two h d -> b s two (h g) d"
,
g
=
nheads
//
nheads_k
)
k_rep
,
v_rep
=
kv_rep
.
unbind
(
dim
=
2
)
else
:
k_rep
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_rep
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
attn
=
normalize_flash_attn_S
(
attn_unnorm
,
q
,
k_rep
,
v_rep
,
None
,
None
,
attn_bias
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
None
,
None
,
causal
=
causal
,
window_size
=
window_size
).
item
()
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
else
:
dropout_mask
=
None
if
kvpacked
:
out_ref
,
attn_ref
=
attention_kvpacked_ref
(
q
,
kv
,
None
,
None
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
)
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
q
,
kv
,
None
,
None
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
reorder_ops
=
True
,
)
else
:
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
reorder_ops
=
True
,
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
if
dropout_p
>
0.0
:
print
(
f
"Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
do_o
=
(
g
.
float
()
*
out
.
float
()).
sum
(
-
1
)
if
((
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
))
and
softcap
==
0.0
:
if
kvpacked
:
(
dq
,
dkv
,
)
=
torch
.
autograd
.
grad
(
out
,
(
q
,
kv
),
g
)
dk
,
dv
=
dkv
.
unbind
(
2
)
(
dq_ref
,
dkv_ref
,
)
=
torch
.
autograd
.
grad
(
out_ref
,
(
q
,
kv
),
g
)
dk_ref
,
dv_ref
=
dkv_ref
.
unbind
(
2
)
(
dq_pt
,
dkv_pt
,
)
=
torch
.
autograd
.
grad
(
out_pt
,
(
q
,
kv
),
g
)
dk_pt
,
dv_pt
=
dkv_pt
.
unbind
(
2
)
else
:
(
dq
,
dk
,
dv
,
)
=
torch
.
autograd
.
grad
(
out
,
(
q
,
k
,
v
),
g
)
(
dq_ref
,
dk_ref
,
dv_ref
,
)
=
torch
.
autograd
.
grad
(
out_ref
,
(
q
,
k
,
v
),
g
)
(
dq_pt
,
dk_pt
,
dv_pt
,
)
=
torch
.
autograd
.
grad
(
out_pt
,
(
q
,
k
,
v
),
g
)
print
(
f
"dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ mean diff:
{
(
dq
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK mean diff:
{
(
dk
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV mean diff:
{
(
dv
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ Pytorch mean diff:
{
(
dq_pt
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK Pytorch mean diff:
{
(
dk_pt
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV Pytorch mean diff:
{
(
dv_pt
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
((
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
))
and
softcap
==
0.0
:
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"kvpacked"
,
[
True
,
False
])
# @pytest.mark.parametrize('kvpacked', [False])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize('mha_type', ["mqa"])
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
[
False
,
True
])
# @pytest.mark.parametrize("deterministic", [True])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
147
),
(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"softcap"
,
[
0.0
,
50.0
])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_varlen_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
,
softcap
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
pytest
.
skip
()
# Reference implementation OOM
if
softcap
>
0.0
and
dropout_p
>
0.0
:
pytest
.
skip
(
"Softcap and dropout not supported together"
)
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
4
nheads
=
6
if
softcap
==
0.0
else
4
# softcap reference impl takes more memory
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
2
)
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
softcap
>
0
:
# Ensure the values of qk are at least within softcap range.
q
=
q
*
softcap
if
kvpacked
:
kv
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
else
:
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
query_padding_mask
=
generate_random_padding_mask
(
seqlen_q
,
batch_size
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen_k
,
batch_size
,
device
,
mode
=
"random"
)
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
if
kvpacked
:
(
q_unpad
,
kv_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
kv
,
output_pad_fn
,
dq_pad_fn
,
dkv_pad_fn
,
)
=
generate_qkv
(
q
,
*
kv
.
unbind
(
dim
=
2
),
query_padding_mask
,
key_padding_mask
,
kvpacked
=
True
)
out_unpad
,
sm_lse
,
S_dmask
=
flash_attn_varlen_kvpacked_func
(
q_unpad
,
kv_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
)
else
:
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
k
,
v
,
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
)
out_unpad
,
sm_lse
,
S_dmask
=
flash_attn_varlen_func
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
)
out
=
output_pad_fn
(
out_unpad
)
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
if
kvpacked
:
kv_rep
=
repeat
(
kv
,
"b s two h d -> b s two (h g) d"
,
g
=
nheads
//
nheads_k
)
k_rep
,
v_rep
=
kv_rep
.
unbind
(
dim
=
2
)
else
:
k_rep
=
repeat
(
k
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_rep
=
repeat
(
v
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
attn
=
normalize_flash_attn_S
(
attn_unnorm
,
q
,
k_rep
,
v_rep
,
query_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
>
0.0
,
causal
=
causal
,
window_size
=
window_size
,
)
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
,
window_size
=
window_size
,
).
item
()
print
(
f
"Actual dropout fraction:
{
dropout_fraction
}
"
)
else
:
dropout_mask
=
None
if
kvpacked
:
out_ref
,
attn_ref
=
attention_kvpacked_ref
(
q
,
kv
,
query_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
)
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
q
,
kv
,
query_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
reorder_ops
=
True
,
)
else
:
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
attn_bias
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
reorder_ops
=
True
,
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
if
dropout_p
>
0.0
:
print
(
f
"Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
if
((
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
))
and
softcap
==
0.0
:
if
kvpacked
:
(
dq_unpad
,
dkv_unpad
,
)
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
kv_unpad
),
g
)
dk
,
dv
=
dkv_pad_fn
(
dkv_unpad
).
unbind
(
2
)
(
dq_ref
,
dkv_ref
,
)
=
torch
.
autograd
.
grad
(
out_ref
,
(
q
,
kv
),
g
)
dk_ref
,
dv_ref
=
dkv_ref
.
unbind
(
2
)
(
dq_pt
,
dkv_pt
,
)
=
torch
.
autograd
.
grad
(
out_pt
,
(
q
,
kv
),
g
)
dk_pt
,
dv_pt
=
dkv_pt
.
unbind
(
2
)
else
:
(
dq_unpad
,
dk_unpad
,
dv_unpad
,
)
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
dk
=
dk_pad_fn
(
dk_unpad
)
dv
=
dk_pad_fn
(
dv_unpad
)
(
dq_ref
,
dk_ref
,
dv_ref
,
)
=
torch
.
autograd
.
grad
(
out_ref
,
(
q
,
k
,
v
),
g
)
(
dq_pt
,
dk_pt
,
dv_pt
,
)
=
torch
.
autograd
.
grad
(
out_pt
,
(
q
,
k
,
v
),
g
)
dq
=
dq_pad_fn
(
dq_unpad
)
print
(
f
"dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ mean diff:
{
(
dq
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK mean diff:
{
(
dk
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV mean diff:
{
(
dv
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ Pytorch mean diff:
{
(
dq_pt
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK Pytorch mean diff:
{
(
dk_pt
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV Pytorch mean diff:
{
(
dv_pt
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
if
dropout_p
>
0.0
:
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.04
)
if
((
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
))
and
softcap
==
0.0
:
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
3
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
3
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
3
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64, 128])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
239
),
(
3
,
799
),
(
127
,
512
),
(
127
,
513
),
(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
1023
,
1024
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
local
,
dtype
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
pytest
.
skip
()
# Reference implementation OOM
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
causal
=
True
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
do_o
=
(
g
.
float
()
*
out
.
float
()).
sum
(
-
1
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
d
>
224
)
or
(
is_sm80
or
is_sm90
):
(
dq
,
dk
,
dv
,
)
=
torch
.
autograd
.
grad
(
out
,
(
q
,
k
,
v
),
g
)
(
dq_ref
,
dk_ref
,
dv_ref
,
)
=
torch
.
autograd
.
grad
(
out_ref
,
(
q
,
k
,
v
),
g
)
(
dq_pt
,
dk_pt
,
dv_pt
,
)
=
torch
.
autograd
.
grad
(
out_pt
,
(
q
,
k
,
v
),
g
)
print
(
f
"dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ mean diff:
{
(
dq
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK mean diff:
{
(
dk
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV mean diff:
{
(
dv
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ Pytorch mean diff:
{
(
dq_pt
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK Pytorch mean diff:
{
(
dk_pt
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV Pytorch mean diff:
{
(
dv_pt
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
if
(
d
<=
MAX_HEADDIM_SM8x
or
d
>
224
)
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
+
1e-5
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
1e-5
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
1e-5
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
239
),
(
3
,
799
),
(
127
,
512
),
(
127
,
513
),
(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
1023
,
1024
),
],
)
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
16
,
256
,
512
])
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def
test_flash_attn_varlen_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
local
,
paged_kv_block_size
,
dtype
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
pytest
.
skip
()
# Reference implementation OOM
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
causal
=
True
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
paged_kv_block_size
is
None
:
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
block_table
=
None
else
:
k
,
v
,
block_table
,
k_cache_paged
,
v_cache_paged
,
num_blocks
=
_generate_block_kvcache
(
seqlen_k
,
paged_kv_block_size
,
batch_size
,
nheads
,
d
,
device
,
dtype
)
query_padding_mask
=
generate_random_padding_mask
(
seqlen_q
,
batch_size
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen_k
,
batch_size
,
device
,
mode
=
"random"
)
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
k
,
v
,
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
)
out_unpad
=
flash_attn_varlen_func
(
q_unpad
,
k_unpad
if
paged_kv_block_size
is
None
else
k_cache_paged
,
v_unpad
if
paged_kv_block_size
is
None
else
v_cache_paged
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
0.0
,
causal
=
causal
,
window_size
=
window_size
,
block_table
=
block_table
,
)
out
=
output_pad_fn
(
out_unpad
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
None
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
do_o
=
(
g
.
float
()
*
out
.
float
()).
sum
(
-
1
)
test_backward
=
(
d
<=
MAX_HEADDIM_SM8x
or
d
>
224
or
is_sm80
or
is_sm90
)
and
block_table
is
None
if
test_backward
:
(
dq_unpad
,
dk_unpad
,
dv_unpad
,
)
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
dq
=
dq_pad_fn
(
dq_unpad
)
dk
=
dk_pad_fn
(
dk_unpad
)
dv
=
dk_pad_fn
(
dv_unpad
)
(
dq_ref
,
dk_ref
,
dv_ref
,
)
=
torch
.
autograd
.
grad
(
out_ref
,
(
q
,
k
,
v
),
g
)
(
dq_pt
,
dk_pt
,
dv_pt
,
)
=
torch
.
autograd
.
grad
(
out_pt
,
(
q
,
k
,
v
),
g
)
print
(
f
"dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ mean diff:
{
(
dq
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK mean diff:
{
(
dk
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV mean diff:
{
(
dv
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ Pytorch mean diff:
{
(
dq_pt
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK Pytorch mean diff:
{
(
dk_pt
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV Pytorch mean diff:
{
(
dv_pt
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
if
test_backward
:
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
+
1e-5
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
1e-5
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
1e-5
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
[
False
,
True
])
# @pytest.mark.parametrize("deterministic", [True])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [True])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [False])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
3
,
1024
),
(
1
,
339
),
(
64
,
800
),
(
3
,
799
),
(
64
,
2048
),
(
16
,
20000
),
(
16
,
100000
),
(
128
,
128
),
(
256
,
256
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_splitkv
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
local
,
alibi
,
deterministic
,
dtype
):
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
1
nheads
=
12
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen_q
,
seqlen_k
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
out
,
lse
,
_
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
attn_bias
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
attn_bias
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
do_o
=
(
g
.
float
()
*
out
.
float
()).
sum
(
-
1
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
d
>
224
)
or
(
is_sm80
or
is_sm90
):
(
dq
,
dk
,
dv
,
)
=
torch
.
autograd
.
grad
(
out
,
(
q
,
k
,
v
),
g
)
(
dq_ref
,
dk_ref
,
dv_ref
,
)
=
torch
.
autograd
.
grad
(
out_ref
,
(
q
,
k
,
v
),
g
)
(
dq_pt
,
dk_pt
,
dv_pt
,
)
=
torch
.
autograd
.
grad
(
out_pt
,
(
q
,
k
,
v
),
g
)
print
(
f
"dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ mean diff:
{
(
dq
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK mean diff:
{
(
dk
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV mean diff:
{
(
dv
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ Pytorch mean diff:
{
(
dq_pt
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK Pytorch mean diff:
{
(
dk_pt
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV Pytorch mean diff:
{
(
dv_pt
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
mult
=
2
if
not
alibi
else
8
if
(
d
<=
MAX_HEADDIM_SM8x
or
d
>
224
)
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
mult
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
mult
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
2e-4
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
mult
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
2e-4
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"num_splits"
,
[
1
,
0
])
# @pytest.mark.parametrize("num_splits", [1])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
,
True
])
# @pytest.mark.parametrize("new_kv", [False])
@
pytest
.
mark
.
parametrize
(
"alibi"
,
[
False
,
True
])
# @pytest.mark.parametrize("alibi", [False])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [False])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [False])
@
pytest
.
mark
.
parametrize
(
"seqlen_new_eq_seqlen_q"
,
[
True
,
False
])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@
pytest
.
mark
.
parametrize
(
"rotary_interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
16
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_batch_idx", [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
128
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
128
),
(
1
,
339
),
(
3
,
1024
),
(
64
,
800
),
(
64
,
256
),
(
3
,
799
),
(
64
,
2048
),
(
16
,
20000
),
(
1
,
128
*
1024
),
(
16
,
128
*
1024
),
(
128
,
128
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_kvcache
(
seqlen_q
,
seqlen_k
,
d
,
has_batch_idx
,
paged_kv_block_size
,
rotary_fraction
,
rotary_interleaved
,
seqlen_new_eq_seqlen_q
,
causal
,
local
,
alibi
,
new_kv
,
mha_type
,
num_splits
,
dtype
,
):
if
seqlen_q
>
seqlen_k
and
new_kv
:
pytest
.
skip
()
if
not
new_kv
and
rotary_fraction
>
0.0
:
pytest
.
skip
()
if
has_batch_idx
and
paged_kv_block_size
is
not
None
:
pytest
.
skip
()
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
batch_size_cache
=
batch_size
if
not
has_batch_idx
else
batch_size
*
2
nheads
=
6
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim
=
math
.
floor
(
int
(
rotary_fraction
*
d
)
/
16
)
*
16
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
)
seqlen_new
=
seqlen_q
if
seqlen_new_eq_seqlen_q
else
torch
.
randint
(
1
,
seqlen_q
+
1
,
(
1
,)).
item
()
if
new_kv
:
k
=
torch
.
randn
(
batch_size
,
seqlen_new
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
batch_size
,
seqlen_new
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
else
:
k
,
v
=
None
,
None
if
paged_kv_block_size
is
None
:
k_cache
=
torch
.
randn
(
batch_size_cache
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v_cache
=
torch
.
randn
(
batch_size_cache
,
seqlen_k
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
block_table
=
None
else
:
(
k_cache
,
v_cache
,
block_table
,
k_cache_paged
,
v_cache_paged
,
num_blocks
,
)
=
_generate_block_kvcache
(
seqlen_k
,
paged_kv_block_size
,
batch_size
,
nheads_k
,
d
,
device
,
dtype
)
cache_seqlens
=
torch
.
randint
(
0
if
new_kv
else
1
,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(
(
seqlen_k
-
(
seqlen_q
if
(
causal
or
local
)
and
rotary_dim
>
1
else
seqlen_new
)
+
1
)
if
new_kv
else
(
seqlen_k
+
1
)
),
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
,
)
arange
=
rearrange
(
torch
.
arange
(
seqlen_k
,
device
=
device
),
"s -> 1 s"
)
cache_seqlens_expanded
=
rearrange
(
cache_seqlens
,
"b -> b 1"
)
key_padding_mask
=
arange
<
cache_seqlens_expanded
+
(
seqlen_new
if
new_kv
else
0
)
if
has_batch_idx
:
cache_batch_idx
=
torch
.
randperm
(
batch_size_cache
,
dtype
=
torch
.
int32
,
device
=
device
)[
:
batch_size
]
else
:
cache_batch_idx
=
None
if
alibi
:
alibi_slopes
=
torch
.
rand
(
batch_size
,
nheads
,
device
=
device
,
dtype
=
torch
.
float32
)
*
0.3
attn_bias
=
attn_bias_from_alibi_slopes
(
alibi_slopes
,
seqlen_q
,
seqlen_k
,
None
,
key_padding_mask
,
causal
=
causal
)
else
:
alibi_slopes
,
attn_bias
=
None
,
None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if
rotary_dim
>
0
:
angle
=
(
torch
.
rand
(
seqlen_k
if
paged_kv_block_size
is
None
else
num_blocks
*
paged_kv_block_size
,
rotary_dim
//
2
,
device
=
device
,
)
*
2
*
math
.
pi
)
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
causal
or
local
:
q_ro
=
apply_rotary_emb
(
q
,
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
)
else
:
q_ro
=
rearrange
(
apply_rotary_emb
(
rearrange
(
q
,
"b s h d -> b 1 (s h) d"
),
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
,
),
"b 1 (s h) d -> b s h d"
,
s
=
seqlen_q
,
)
# q_ro = q
k_ro
=
apply_rotary_emb
(
k
,
cos
,
sin
,
seqlen_offsets
=
cache_seqlens
,
interleaved
=
rotary_interleaved
)
else
:
cos
,
sin
=
None
,
None
q_ro
,
k_ro
=
q
,
k
# k_cache[:, 64:] = -1
k_cache_ref
=
(
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
.
to
(
dtype
=
torch
.
long
)]
).
clone
()
v_cache_ref
=
(
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
.
to
(
dtype
=
torch
.
long
)]
).
clone
()
if
new_kv
:
update_mask
=
torch
.
logical_and
(
cache_seqlens_expanded
<=
arange
,
arange
<
cache_seqlens_expanded
+
seqlen_new
)
k_cache_ref
[
update_mask
]
=
rearrange
(
k_ro
,
"b s ... -> (b s) ..."
)
v_cache_ref
[
update_mask
]
=
rearrange
(
v
,
"b s ... -> (b s) ..."
)
k_cache_rep
=
repeat
(
k_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
v_cache_rep
=
repeat
(
v_cache_ref
,
"b s h d -> b s (h g) d"
,
g
=
nheads
//
nheads_k
)
out
=
flash_attn_with_kvcache
(
q
,
k_cache
if
paged_kv_block_size
is
None
else
k_cache_paged
,
v_cache
if
paged_kv_block_size
is
None
else
v_cache_paged
,
k
,
v
,
rotary_cos
=
cos
,
rotary_sin
=
sin
,
cache_seqlens
=
cache_seqlens
,
cache_batch_idx
=
cache_batch_idx
,
block_table
=
block_table
,
causal
=
causal
,
window_size
=
window_size
,
rotary_interleaved
=
rotary_interleaved
,
alibi_slopes
=
alibi_slopes
,
num_splits
=
num_splits
,
)
# out = flash_attn_with_kvcache(
# q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
# )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
out_ref
,
_
=
attention_ref
(
q_ro
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
attn_bias
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
)
out_pt
,
_
=
attention_ref
(
q_ro
,
k_cache_rep
,
v_cache_rep
,
None
,
key_padding_mask
,
attn_bias
,
0.0
,
None
,
causal
=
causal
,
window_size
=
window_size
,
upcast
=
False
,
reorder_ops
=
True
,
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if
new_kv
:
if
paged_kv_block_size
is
None
:
k_cache_select
=
(
k_cache
if
not
has_batch_idx
else
k_cache
[
cache_batch_idx
.
to
(
dtype
=
torch
.
long
)]
)
v_cache_select
=
(
v_cache
if
not
has_batch_idx
else
v_cache
[
cache_batch_idx
.
to
(
dtype
=
torch
.
long
)]
)
else
:
k_cache_select
=
rearrange
(
k_cache_paged
[
block_table
.
to
(
dtype
=
torch
.
long
).
flatten
()],
"(b nblocks) block_size ... -> b (nblocks block_size) ..."
,
b
=
batch_size
,
)[:,
:
seqlen_k
]
v_cache_select
=
rearrange
(
v_cache_paged
[
block_table
.
to
(
dtype
=
torch
.
long
).
flatten
()],
"(b nblocks) block_size ... -> b (nblocks block_size) ..."
,
b
=
batch_size
,
)[:,
:
seqlen_k
]
assert
torch
.
allclose
(
k_cache_select
,
k_cache_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
assert
torch
.
equal
(
v_cache_select
,
v_cache_ref
)
mult
=
3
if
not
alibi
else
5
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
mult
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
def
_generate_block_kvcache
(
seqlen_k
,
paged_kv_block_size
,
batch_size
,
nheads_k
,
d
,
device
,
dtype
):
num_blocks
=
math
.
ceil
(
seqlen_k
/
paged_kv_block_size
)
*
batch_size
*
3
k_cache_paged
=
torch
.
randn
(
num_blocks
,
paged_kv_block_size
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v_cache_paged
=
torch
.
randn
(
num_blocks
,
paged_kv_block_size
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
)
block_table
=
rearrange
(
torch
.
randperm
(
num_blocks
,
dtype
=
torch
.
int32
,
device
=
device
),
"(b nblocks) -> b nblocks"
,
b
=
batch_size
,
)
k_cache
=
rearrange
(
# pytorch 1.12 doesn't have indexing with int32
k_cache_paged
[
block_table
.
to
(
dtype
=
torch
.
long
).
flatten
()],
"(b nblocks) block_size ... -> b (nblocks block_size) ..."
,
b
=
batch_size
,
)[:,
:
seqlen_k
]
v_cache
=
rearrange
(
v_cache_paged
[
block_table
.
to
(
dtype
=
torch
.
long
).
flatten
()],
"(b nblocks) block_size ... -> b (nblocks block_size) ..."
,
b
=
batch_size
,
)[:,
:
seqlen_k
]
return
k_cache
,
v_cache
,
block_table
,
k_cache_paged
,
v_cache_paged
,
num_blocks
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [128])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
239
),
(
239
,
1
),
(
3
,
799
),
(
799
,
3
),
(
1024
,
128
),
(
97
,
97
),
(
128
,
128
),
(
200
,
200
),
(
256
,
256
),
(
257
,
257
),
(
384
,
384
),
(
512
,
512
),
(
768
,
768
),
(
1024
,
1024
),
],
)
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.0])
def
test_flash_attn_race_condition
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
dtype
):
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
60
# Sometimes we need large batch size for the race conditions to trigger
nheads
=
4
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
torch
.
random
.
manual_seed
(
42
)
out0
,
lse0
,
_
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
,
causal
=
causal
,
return_attn_probs
=
True
)
g
=
torch
.
randn_like
(
out0
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
(
dq0
,
dk0
,
dv0
,
)
=
torch
.
autograd
.
grad
(
out0
,
(
q
,
k
,
v
),
g
)
# Numerical error if we just do any arithmetic on dq
dq_atol
=
2
*
((
dq0
+
0.3
-
0.3
)
-
dq0
).
abs
().
max
().
item
()
for
i
in
range
(
250
):
torch
.
random
.
manual_seed
(
42
)
out
,
lse
,
_
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
,
causal
=
causal
,
return_attn_probs
=
True
)
assert
torch
.
equal
(
out
,
out0
)
assert
torch
.
equal
(
lse
,
lse0
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
(
dq
,
dk
,
dv
,
)
=
torch
.
autograd
.
grad
(
out
,
(
q
,
k
,
v
),
g
)
dq_equal
=
torch
.
allclose
(
dq
,
dq0
,
atol
=
dq_atol
)
if
not
dq_equal
:
print
(
f
"Iter
{
i
}
,
{
dq_atol
=
}
, dQ max diff:
{
(
dq
-
dq0
).
abs
().
max
().
item
()
}
"
)
assert
torch
.
equal
(
dv
,
dv0
)
assert
torch
.
equal
(
dk
,
dk0
)
assert
dq_equal
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
16
,
32
,
64
])
# @pytest.mark.parametrize('d', [16])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
2
,
5
,
17
,
128
])
# @pytest.mark.parametrize('seqlen', [2])
def
test_flash_attn_bwd_overflow
(
seqlen
,
d
,
causal
,
dtype
):
"""We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
in the case where seqlen % 128 != 0.
"""
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
nheads
=
5
q
=
torch
.
randn
([
batch_size
,
seqlen
,
nheads
,
d
],
dtype
=
dtype
,
device
=
"cuda"
)
*
5
k
,
v
=
[
torch
.
randn
([
batch_size
,
seqlen
,
nheads
,
d
],
dtype
=
dtype
,
device
=
"cuda"
)
*
3
for
_
in
range
(
2
)
]
q
.
requires_grad_
(
True
)
k
.
requires_grad_
(
True
)
v
.
requires_grad_
(
True
)
out
=
flash_attn_func
(
q
,
k
,
v
,
causal
=
causal
)
g
=
torch
.
randn_like
(
out
)
out
.
backward
(
g
)
q_pt
=
q
.
detach
().
clone
().
requires_grad_
(
True
)
k_pt
=
k
.
detach
().
clone
().
requires_grad_
(
True
)
v_pt
=
v
.
detach
().
clone
().
requires_grad_
(
True
)
out_pt
,
_
=
attention_ref
(
q_pt
,
k_pt
,
v_pt
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
out_pt
.
backward
(
g
)
q_ref
=
q
.
detach
().
clone
().
requires_grad_
(
True
)
k_ref
=
k
.
detach
().
clone
().
requires_grad_
(
True
)
v_ref
=
v
.
detach
().
clone
().
requires_grad_
(
True
)
out_ref
,
attn_ref
=
attention_ref
(
q_ref
,
k_ref
,
v_ref
,
causal
=
causal
)
out_ref
.
backward
(
g
)
print
(
f
"dQ max diff:
{
(
q
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
k
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
v
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
q_pt
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
k_pt
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
v_pt
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
}
"
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
assert
(
q
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
<=
5
*
(
q_pt
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
+
1e-3
assert
(
k
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
<=
5
*
(
k_pt
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
+
1e-3
assert
(
v
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
<=
5
*
(
v_pt
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
+
1e-3
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
64
,
128
])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
97
,
128
,
200
,
256
])
# @pytest.mark.parametrize('seqlen', [128])
def
test_flash_attn_bwd_transpose
(
seqlen
,
d
,
causal
,
dtype
):
"""We previously had a bug where we were using the wrong strides of dout, which shows up
when dout is not contiguous.
"""
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
5
nheads
=
2
q
,
k
,
v
=
[
torch
.
randn
([
batch_size
,
seqlen
,
nheads
,
d
],
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
for
_
in
range
(
3
)
]
out
=
rearrange
(
flash_attn_func
(
q
,
k
,
v
,
causal
=
causal
),
"b s ... -> s b ..."
)
# So g is not contiguous
g
=
torch
.
randn
(
seqlen
,
2
*
batch_size
,
nheads
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)[:,
::
2
]
out
.
backward
(
g
)
q_pt
=
q
.
detach
().
clone
().
requires_grad_
(
True
)
k_pt
=
k
.
detach
().
clone
().
requires_grad_
(
True
)
v_pt
=
v
.
detach
().
clone
().
requires_grad_
(
True
)
out_pt
,
attn_pt
=
attention_ref
(
q_pt
,
k_pt
,
v_pt
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
out_pt
=
rearrange
(
out_pt
,
"b s ... -> s b ..."
)
out_pt
.
backward
(
g
)
q_ref
=
q
.
detach
().
clone
().
requires_grad_
(
True
)
k_ref
=
k
.
detach
().
clone
().
requires_grad_
(
True
)
v_ref
=
v
.
detach
().
clone
().
requires_grad_
(
True
)
out_ref
,
attn_ref
=
attention_ref
(
q_ref
,
k_ref
,
v_ref
,
causal
=
causal
)
out_ref
=
rearrange
(
out_ref
,
"b s ... -> s b ..."
)
out_ref
.
backward
(
g
)
print
(
f
"dQ max diff:
{
(
q
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
k
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
v
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
q_pt
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
k_pt
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
v_pt
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
}
"
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
assert
(
q
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
<=
2
*
(
q_pt
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
assert
(
k
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
<=
2
*
(
k_pt
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
assert
(
v
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
<=
2
*
(
v_pt
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
16
,
32
,
64
])
# @pytest.mark.parametrize('d', [16])
def
test_flash_attn_bwd_varlen_overflow
(
d
,
causal
,
dtype
):
"""We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
in the case where seqlen % 128 != 0 or varlen.
"""
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
nheads
=
5
q_cuseqlen
=
torch
.
tensor
([
0
,
76
,
110
,
256
],
device
=
device
,
dtype
=
torch
.
int32
)
k_cuseqlen
=
torch
.
tensor
([
0
,
1
,
2
,
3
],
device
=
device
,
dtype
=
torch
.
int32
)
Mq
=
256
Mk
=
3
q
=
torch
.
randn
([
Mq
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
*
3
k
,
v
=
[
torch
.
randn
([
Mk
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
*
3
for
_
in
range
(
2
)]
q
.
requires_grad_
(
True
)
k
.
requires_grad_
(
True
)
v
.
requires_grad_
(
True
)
out
=
flash_attn_varlen_func
(
q
,
k
,
v
,
q_cuseqlen
,
k_cuseqlen
,
Mq
,
Mk
,
causal
=
causal
)
g
=
torch
.
randn_like
(
out
)
out
.
backward
(
g
)
assert
not
q
.
grad
.
isnan
().
any
()
assert
not
k
.
grad
.
isnan
().
any
()
assert
not
v
.
grad
.
isnan
().
any
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
239
),
(
3
,
799
),
(
127
,
512
),
(
127
,
513
),
(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
1023
,
1024
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_deterministic
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
local
,
dtype
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
pytest
.
skip
()
# Reference implementation OOM
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
4
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
,
window_size
=
window_size
,
deterministic
=
True
)
g
=
torch
.
randn_like
(
out
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
d
>
224
)
or
(
is_sm80
or
is_sm90
):
dq0
,
dk0
,
dv0
=
torch
.
autograd
.
grad
(
out
,
(
q
,
k
,
v
),
g
,
retain_graph
=
True
)
for
_
in
range
(
50
):
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
out
,
(
q
,
k
,
v
),
g
,
retain_graph
=
True
)
assert
torch
.
equal
(
dv
,
dv0
)
assert
torch
.
equal
(
dk
,
dk0
)
assert
torch
.
equal
(
dq
,
dq0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
# @pytest.mark.parametrize("local", [True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
239
),
(
3
,
799
),
(
127
,
512
),
(
127
,
513
),
(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
1023
,
1024
),
],
)
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def
test_flash_attn_varlen_deterministic
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
local
,
dtype
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
pytest
.
skip
()
# Reference implementation OOM
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
nheads
=
9
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
query_padding_mask
=
generate_random_padding_mask
(
seqlen_q
,
batch_size
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen_k
,
batch_size
,
device
,
mode
=
"random"
)
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
k
,
v
,
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
)
out
=
flash_attn_varlen_func
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
0.0
,
causal
=
causal
,
window_size
=
window_size
,
deterministic
=
True
,
)
g
=
torch
.
randn_like
(
out
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
d
>
224
)
or
(
is_sm80
or
is_sm90
):
dq0
,
dk0
,
dv0
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
,
retain_graph
=
True
)
for
_
in
range
(
50
):
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
,
retain_graph
=
True
)
assert
torch
.
equal
(
dv
,
dv0
)
assert
torch
.
equal
(
dk
,
dk0
)
assert
torch
.
equal
(
dq
,
dq0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [False])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
16
])
# @pytest.mark.parametrize("has_batch_idx", [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"nheads"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"b"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[(
170
,
170
)])
def
test_flash_attn_paged_kvcache_overflow
(
seqlen_q
,
seqlen_k
,
d
,
nheads
,
b
,
n
,
paged_kv_block_size
,
causal
,
dtype
,
):
device
=
"cuda"
num_blocks
=
1000
*
16
//
paged_kv_block_size
key_cache
=
torch
.
rand
([
num_blocks
,
paged_kv_block_size
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
value_cache
=
torch
.
rand
([
num_blocks
,
paged_kv_block_size
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
cache_seqlens
=
torch
.
zeros
(
b
,
dtype
=
torch
.
int32
,
device
=
device
)
for
_
in
range
(
n
):
query
=
torch
.
rand
([
b
,
seqlen_q
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
key
=
torch
.
rand
([
b
,
seqlen_k
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
rand
([
b
,
seqlen_k
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
size
=
(
b
,
(
seqlen_k
+
paged_kv_block_size
-
1
)
//
paged_kv_block_size
),
dtype
=
torch
.
int32
,
device
=
device
)
output
=
flash_attn_with_kvcache
(
query
,
key_cache
,
value_cache
,
k
=
key
,
v
=
value
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_tables
,
causal
=
causal
,
)
tests/test_rotary.py
deleted
100644 → 0
View file @
344c988d
import
math
import
random
import
pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
flash_attn.layers.rotary
import
apply_rotary_emb
,
apply_rotary_emb_torch
from
flash_attn.layers.rotary
import
apply_rotary_emb_qkv_
,
apply_rotary_emb_kv_
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
>=
(
8
,
0
)
def
generate_cos_sin
(
seqlen
,
rotary_dim
,
device
,
dtype
):
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
return
cos
,
sin
def
generate_seqlen_offsets
(
seqlen_offsets_type
,
batch_size
,
seqlen
,
device
):
if
seqlen_offsets_type
==
0
:
return
0
elif
seqlen_offsets_type
is
int
:
return
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
return
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
def
index_cos_sin
(
cos
,
sin
,
seqlen_offsets
,
seqlen
):
if
isinstance
(
seqlen_offsets
,
torch
.
Tensor
):
batch_size
=
seqlen_offsets
.
shape
[
0
]
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
cos
.
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
return
cos_pt
,
sin_pt
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [True])
@
pytest
.
mark
.
parametrize
(
"inplace"
,
[
False
,
True
])
# @pytest.mark.parametrize('inplace', [False])
def
test_rotary_emb_func
(
inplace
,
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
217
headdim
=
128
device
=
"cuda"
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
torch
.
manual_seed
(
42
)
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
x_pt
=
x
.
detach
().
clone
().
requires_grad_
()
cos
,
sin
=
generate_cos_sin
(
seqlen
,
rotary_dim
,
device
,
dtype
)
seqlen_offsets
=
generate_seqlen_offsets
(
seqlen_offsets_type
,
batch_size
,
seqlen
,
device
)
out
=
apply_rotary_emb
(
x
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
inplace
)
cos_pt
,
sin_pt
=
index_cos_sin
(
cos
,
sin
,
seqlen_offsets
,
seqlen
)
out_pt
=
apply_rotary_emb_torch
(
x_pt
.
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# If inplace=True, we might modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
x
.
grad
-
x_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
if
not
inplace
:
assert
torch
.
equal
(
x
,
x_pt
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
x_pt
.
grad
+
0.3
-
0.3
)
-
x_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
def
test_rotary_emb_qkv
(
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
512
headdim
=
128
device
=
"cuda"
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
torch
.
manual_seed
(
42
)
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
qkv_pt
=
qkv
.
detach
().
clone
().
requires_grad_
()
cos
,
sin
=
generate_cos_sin
(
seqlen
,
rotary_dim
,
device
,
dtype
)
seqlen_offsets
=
generate_seqlen_offsets
(
seqlen_offsets_type
,
batch_size
,
seqlen
,
device
)
out
=
apply_rotary_emb_qkv_
(
qkv
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
)
cos_pt
,
sin_pt
=
index_cos_sin
(
cos
,
sin
,
seqlen_offsets
,
seqlen
)
q_pt
=
apply_rotary_emb_torch
(
qkv_pt
[:,
:,
0
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
k_pt
=
apply_rotary_emb_torch
(
qkv_pt
[:,
:,
1
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
out_pt
=
torch
.
stack
([
q_pt
,
k_pt
,
qkv_pt
[:,
:,
2
]],
dim
=
2
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# Since inplace=True, we modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
qkv
.
grad
-
qkv_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
qkv_pt
.
grad
+
0.3
-
0.3
)
-
qkv_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
qkv
.
grad
,
qkv_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
def
test_rotary_emb_kv
(
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
781
headdim
=
64
device
=
"cuda"
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
torch
.
manual_seed
(
42
)
kv
=
torch
.
randn
(
batch_size
,
seqlen
,
2
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
kv_pt
=
kv
.
detach
().
clone
().
requires_grad_
()
cos
,
sin
=
generate_cos_sin
(
seqlen
,
rotary_dim
,
device
,
dtype
)
seqlen_offsets
=
generate_seqlen_offsets
(
seqlen_offsets_type
,
batch_size
,
seqlen
,
device
)
out
=
apply_rotary_emb_kv_
(
kv
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
)
cos_pt
,
sin_pt
=
index_cos_sin
(
cos
,
sin
,
seqlen_offsets
,
seqlen
)
k_pt
=
apply_rotary_emb_torch
(
kv_pt
[:,
:,
0
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
out_pt
=
torch
.
stack
([
k_pt
,
kv_pt
[:,
:,
1
]],
dim
=
2
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# Since inplace=True, we modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
kv
.
grad
-
kv_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
kv_pt
.
grad
+
0.3
-
0.3
)
-
kv_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
kv
.
grad
,
kv_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize("dtype", ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize("rotary_fraction", [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize("interleaved", [True])
@
pytest
.
mark
.
parametrize
(
"inplace"
,
[
False
,
True
])
# @pytest.mark.parametrize("inplace", [False])
def
test_rotary_emb_varlen_func
(
inplace
,
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
217
headdim
=
128
device
=
"cuda"
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
torch
.
manual_seed
(
42
)
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
)
x_pt
=
x
.
detach
().
clone
().
requires_grad_
()
lengths
=
torch
.
randint
(
max
(
1
,
seqlen
-
20
),
seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
padding_mask
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
<
lengths
x_unpad
,
indices
,
cu_seqlens
,
max_seqlen
=
unpad_input
(
x
,
padding_mask
)
x_unpad_clone
=
x_unpad
.
clone
()
x_unpad
=
x_unpad
.
requires_grad_
()
cos
,
sin
=
generate_cos_sin
(
seqlen
,
rotary_dim
,
device
,
dtype
)
seqlen_offsets
=
generate_seqlen_offsets
(
seqlen_offsets_type
,
batch_size
,
seqlen
,
device
)
out_unpad
=
apply_rotary_emb
(
x_unpad
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
inplace
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
out
=
pad_input
(
out_unpad
,
indices
,
batch_size
,
seqlen
)
cos_pt
,
sin_pt
=
index_cos_sin
(
cos
,
sin
,
seqlen_offsets
,
seqlen
)
out_pt
=
apply_rotary_emb_torch
(
x_pt
.
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
out_pt
=
out_pt
.
masked_fill
(
rearrange
(
~
padding_mask
,
"b s -> b s 1 1"
),
0.0
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# If inplace=True, we might modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
x_grad
=
pad_input
(
x_unpad
.
grad
,
indices
,
batch_size
,
seqlen
)
print
(
f
"Grad max diff:
{
(
x_grad
-
x_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
if
not
inplace
:
assert
torch
.
equal
(
x_unpad
,
x_unpad_clone
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
x_pt
.
grad
+
0.3
-
0.3
)
-
x_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
x_grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
def
test_compilation_count
():
batch_size
=
1
headdim
=
128
device
=
"cuda"
dtype
=
torch
.
float16
torch
.
manual_seed
(
42
)
from
triton.runtime.jit
import
JITFunction
from
flash_attn.ops.triton.rotary
import
rotary_kernel
compilation_count
=
0
def
count_compilations
(
*
args
,
**
kwargs
):
nonlocal
compilation_count
compilation_count
+=
1
old_cache_func
=
JITFunction
.
cache_hook
try
:
rotary_kernel
.
cache
.
clear
()
JITFunction
.
cache_hook
=
count_compilations
for
seqlen
in
(
128
,
256
):
for
nheads
in
(
4
,
32
):
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
)
x
.
requires_grad_
()
cos
,
sin
=
generate_cos_sin
(
seqlen
,
headdim
,
device
,
dtype
)
out
=
apply_rotary_emb
(
x
,
cos
,
sin
)
out
.
backward
(
torch
.
randn_like
(
out
))
# Only two kernels are expected to be compiled:
# * for the forward pass (conjugate=False)
# * for the backward pass (conjugate=True)
assert
compilation_count
==
2
finally
:
JITFunction
.
cache_hook
=
old_cache_func
tests/test_vllm_flash_attn.py
0 → 100644
View file @
013f0c4f
#
# This file is copied verbatim from vLLM:
# https://github.com/vllm-project/vllm/blob/main/tests/kernels/test_flash_attn.py
#
from
typing
import
List
,
Optional
,
Tuple
import
pytest
import
torch
import
flash_attn_wrapper
# noqa: F401
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
16
,
2
)]
HEAD_SIZES
=
[
128
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS
=
[
32768
,
2048
]
def
ref_paged_attn
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
query_lens
:
List
[
int
],
kv_lens
:
List
[
int
],
block_tables
:
torch
.
Tensor
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
,
soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
query_lens
)
block_tables
=
block_tables
.
cpu
().
numpy
()
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
outputs
:
List
[
torch
.
Tensor
]
=
[]
start_idx
=
0
for
i
in
range
(
num_seqs
):
query_len
=
query_lens
[
i
]
kv_len
=
kv_lens
[
i
]
q
=
query
[
start_idx
:
start_idx
+
query_len
]
q
*=
scale
num_kv_blocks
=
(
kv_len
+
block_size
-
1
)
//
block_size
block_indices
=
block_tables
[
i
,
:
num_kv_blocks
]
k
=
key_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
k
=
k
[:
kv_len
]
v
=
value_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
v
=
v
[:
kv_len
]
if
q
.
shape
[
1
]
!=
k
.
shape
[
1
]:
k
=
torch
.
repeat_interleave
(
k
,
q
.
shape
[
1
]
//
k
.
shape
[
1
],
dim
=
1
)
v
=
torch
.
repeat_interleave
(
v
,
q
.
shape
[
1
]
//
v
.
shape
[
1
],
dim
=
1
)
attn
=
torch
.
einsum
(
"qhd,khd->hqk"
,
q
,
k
).
float
()
empty_mask
=
torch
.
ones
(
query_len
,
kv_len
)
mask
=
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
query_len
+
1
).
bool
()
if
sliding_window
is
not
None
:
sliding_window_mask
=
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
(
query_len
+
sliding_window
)
+
1
).
bool
().
logical_not
()
mask
|=
sliding_window_mask
if
soft_cap
is
not
None
:
attn
=
soft_cap
*
torch
.
tanh
(
attn
/
soft_cap
)
attn
.
masked_fill_
(
mask
,
float
(
"-inf"
))
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
).
to
(
v
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
v
)
outputs
.
append
(
out
)
start_idx
+=
query_len
return
torch
.
cat
(
outputs
,
dim
=
0
)
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
value_cache
=
torch
.
randn_like
(
key_cache
)
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
output
=
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
(
decode_query
=
query
.
unsqueeze
(
1
),
key_cache
=
key_cache
,
value_cache
=
value_cache
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
).
squeeze
(
1
)
if
num_blocks
<=
2048
:
test_utils
=
[
"test_faketensor"
,
"test_schema"
]
else
:
test_utils
=
[
"test_faketensor"
]
torch
.
library
.
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
,
args
=
tuple
(),
kwargs
=
dict
(
decode_query
=
query
.
unsqueeze
(
1
),
key_cache
=
key_cache
,
value_cache
=
value_cache
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
,
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
[[(
1
,
1328
),
(
5
,
18
),
(
129
,
463
)]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
torch
.
inference_mode
()
def
test_varlen_with_paged_kv
(
seq_lens
:
List
[
Tuple
[
int
,
int
]],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_query_len
=
max
(
query_lens
)
max_kv_len
=
max
(
kv_lens
)
window_size
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
sum
(
query_lens
),
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
value_cache
=
torch
.
randn_like
(
key_cache
)
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
cu_kv_lens
=
torch
.
tensor
([
0
]
+
kv_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
)
if
num_blocks
<=
2048
:
test_utils
=
[
"test_faketensor"
,
"test_schema"
]
else
:
test_utils
=
[
"test_faketensor"
]
torch
.
library
.
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
args
=
tuple
(),
kwargs
=
dict
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
query_lens
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
sliding_window
=
sliding_window
,
soft_cap
=
soft_cap
,
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
vllm_flash_attn/__init__.py
View file @
013f0c4f
__version__
=
"2.6.2"
from
vllm_flash_attn.flash_attn_interface
import
(
# Use relative import to support build-from-source installation in vLLM
from
.flash_attn_interface
import
(
flash_attn_func
,
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
...
...
vllm_flash_attn/flash_attn_interface.py
View file @
013f0c4f
...
...
@@ -7,7 +7,8 @@ import torch.nn as nn
# isort: off
# We need to import the CUDA kernels after importing torch
import
vllm_flash_attn_2_cuda
as
flash_attn_cuda
# Use relative import to support build-from-source installation in vLLM
from
.
import
vllm_flash_attn_c
# noqa: F401
# isort: on
...
...
@@ -49,7 +50,7 @@ def _flash_attn_forward(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
return_softmax
,
*
,
out
=
None
):
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_c
uda
.
fwd
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
torch
.
ops
.
vllm_
flash_attn_c
.
fwd
(
q
,
k
,
v
,
...
...
@@ -87,7 +88,7 @@ def _flash_attn_varlen_forward(
out
=
None
):
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_c
uda
.
varlen_fwd
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
torch
.
ops
.
vllm_
flash_attn_c
.
varlen_fwd
(
q
,
k
,
v
,
...
...
@@ -140,7 +141,7 @@ def _flash_attn_backward(
dk
,
dv
,
softmax_d
,
)
=
flash_attn_c
uda
.
bwd
(
)
=
torch
.
ops
.
vllm_
flash_attn_c
.
bwd
(
dout
,
q
,
k
,
...
...
@@ -194,7 +195,7 @@ def _flash_attn_varlen_backward(
dk
,
dv
,
softmax_d
,
)
=
flash_attn_c
uda
.
varlen_bwd
(
)
=
torch
.
ops
.
vllm_
flash_attn_c
.
varlen_bwd
(
dout
,
q
,
k
,
...
...
@@ -1292,7 +1293,7 @@ def flash_attn_with_kvcache(
cache_seqlens
=
maybe_contiguous
(
cache_seqlens
)
cache_batch_idx
=
maybe_contiguous
(
cache_batch_idx
)
block_table
=
maybe_contiguous
(
block_table
)
out
,
softmax_lse
=
flash_attn_c
uda
.
fwd_kvcache
(
out
,
softmax_lse
=
torch
.
ops
.
vllm_
flash_attn_c
.
fwd_kvcache
(
q
,
k_cache
,
v_cache
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment