Unverified Commit 013f0c4f authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

CMake build, allowing parent build (#19)

parent 344c988d
Pipeline #2020 failed with stages
in 0 seconds
# Copyright (c) 2023, Tri Dao.
# To run the huggingface implementation of LLaMa (1), we first need to convert the weights:
# https://github.com/huggingface/transformers/pull/21955
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
# and repeat for 13B, 30B, 65B
import os
import time
from pathlib import Path
current_dir = Path(__file__).parent.absolute()
import shutil
import pytest
import torch
from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.models.llama import (
config_from_checkpoint,
inv_remap_state_dict_hf_llama,
llama_config_to_gpt2_config,
remap_state_dict_hf_llama,
remap_state_dict_meta_llama,
state_dicts_from_checkpoint,
)
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import LlamaConfig, LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers import AutoConfig
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
if checkpoint_format == "meta":
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
else:
pretrained_state_dict = state_dict_from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf"
)
pretrained_state_dict = remap_state_dict_hf_llama(pretrained_state_dict, config)
return pretrained_state_dict
@pytest.mark.parametrize("model_name", ["7B"])
def test_llama_state_dict(model_name):
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
# TinyLlama-1.1B is to test MQA
@pytest.mark.parametrize(
"model_name", ["meta-llama/Llama-2-7b-hf", "PY007/TinyLlama-1.1B-step-50K-105b"]
)
def test_inv_remap_state_dict_hf_llama(model_name):
config = llama_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
state_dict = state_dict_from_pretrained(model_name)
# inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key}
pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config)
state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config)
assert set(state_dict_recover.keys()) == set(state_dict.keys())
for key in state_dict_recover.keys():
torch.testing.assert_close(state_dict_recover[key], state_dict[key])
# TinyLlama-1.1B is to test MQA
@pytest.mark.parametrize(
"model_name",
[
"7B", # Llama 1
"13B", # Llama 1
"meta-llama/Llama-2-13b-hf",
"codellama/CodeLlama-7b-hf",
"codellama/CodeLlama-13b-hf",
"codellama/CodeLlama-34b-hf",
"PY007/TinyLlama-1.1B-step-50K-105b",
],
)
def test_llama_optimized(model_name):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
dtype = torch.float16
device = "cuda"
if "/" in model_name: # Download from HF
config = llama_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
else:
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
if "/" in model_name: # Download from HF
pretrained_state_dict = remap_state_dict_hf_llama(
state_dict_from_pretrained(model_name), config
)
else:
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format="meta"
)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict)
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = LlamaForCausalLM.from_pretrained(
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
device_map="auto",
)
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
torch_dtype=dtype,
device_map={"": device},
)
model_hf.eval()
with torch.no_grad():
out_hf = model_hf.model(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize(
"model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"]
)
def test_llama_parallel(model_name, world_size):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
from apex.transformer import parallel_state
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
dtype = torch.float16
if "/" in model_name: # Download from HF
config = llama_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
else:
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
if "/" in model_name: # Download from HF
pretrained_state_dict = remap_state_dict_hf_llama(
state_dict_from_pretrained(model_name), config
)
else:
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format="meta"
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
out, _ = all_gather_raw(out, process_group=process_group)
out = rearrange(out, "(b s) d -> b s d", b=batch_size)
logits = model(input_ids).logits
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
del model
if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = LlamaForCausalLM.from_pretrained(
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
device_map="auto",
)
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
torch_dtype=dtype,
device_map="auto",
)
model_hf.eval()
with torch.no_grad():
out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
logits_hf = model_hf(input_ids).logits.to(device=device)
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * (
logits_hf - logits_ref
).abs().max().item()
# @pytest.mark.parametrize('model_name', ["7B", "13B"])
@pytest.mark.parametrize("model_name", ["7B"])
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
def test_llama_generation(model_name, checkpoint_format):
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
dtype = torch.float16
device = "cuda"
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf")
eos_token_id = tokenizer.eos_token_id
torch.manual_seed(0)
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
)
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
del model_ref
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format
)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict)
model.eval()
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
del model
hf_error = (logits_hf - logits_ref).abs().max().item()
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}")
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
assert torch.equal(logits_cg, logits)
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize(
"model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"]
)
def test_llama_parallel_generation(model_name, world_size):
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
from apex.transformer import parallel_state
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
dtype = torch.float16
if "/" in model_name: # Download from HF
config = llama_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
else:
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
config.pad_vocab_size_multiple = 8 * world_size
config.sequence_parallel = False # Need to set this to False for generation
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
torch.manual_seed(0)
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
if "/" in model_name: # Download from HF
pretrained_state_dict = remap_state_dict_hf_llama(
state_dict_from_pretrained(model_name), config
)
else:
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format="meta"
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
print("Without CUDA graph")
out = model.generate(
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
cg=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
del model
parallel_state.destroy_model_parallel()
if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow
model_hf = LlamaForCausalLM.from_pretrained(
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
torch_dtype=dtype,
device_map="auto",
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
with torch.inference_mode():
out_hf = model_hf.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = LlamaForCausalLM.from_pretrained(
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
device_map="auto",
)
model_ref.eval()
with torch.inference_mode():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
hf_error = (logits_hf - logits_ref).abs().max().item()
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}")
assert torch.equal(logits_cg, logits)
@torch.no_grad()
@pytest.mark.parametrize("world_size", [2])
def test_llama_parallel_uneven_num_heads(world_size):
from apex.transformer import parallel_state
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
num_attention_heads = world_size + 1
model_name = f"teeny-{num_attention_heads}-heads"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
dtype = torch.float16
llama_config = LlamaConfig(
hidden_size=256
* num_attention_heads, # ParallelGatedMlp hidden_features must be divisible by 256
intermediate_size=256 * num_attention_heads * 4,
num_hidden_layers=4,
num_attention_heads=num_attention_heads,
initializer_range=0.5, # Set crazy init range so we don't have near zero weights implying a vacuous test.
)
config = llama_config_to_gpt2_config(llama_config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
# Create a shared test model.
if rank == 0:
LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")
torch.distributed.barrier()
# Run the standard forward pass test.
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format="hf"
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
# TODO: Avoid duplicate code. Modularize the comparison of two forward pass diffs.
out = model.transformer(input_ids)
out, _ = all_gather_raw(out, process_group=process_group)
out = rearrange(out, "(b s) d -> b s d", b=batch_size)
logits = model(input_ids).logits
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
if rank == 0:
model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", device_map={"": device}
)
model_ref = model_ref.to(device=device)
model_ref.eval()
out_ref = model_ref.model(input_ids).last_hidden_state
logits_ref = model_ref(input_ids).logits
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
logits_hf = model_hf(input_ids).logits.to(device=device)
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * (
logits_hf - logits_ref
).abs().max().item()
if os.path.exists(checkpoint_path / f"{model_name}-hf"):
shutil.rmtree(checkpoint_path / f"{model_name}-hf")
import re
import time
import pytest
import torch
from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
@pytest.mark.parametrize(
"model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_state_dict(model_name):
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_opt(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config)
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize(
"model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_optimized(model_name):
"""Check that our implementation of OPT (without all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, "prenorm", True)
config.pad_vocab_size_multiple = 8
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
if model_name != "facebook/opt-350m": # The OPT-350m projects the embeddings to dimension 512
out = model.transformer(input_ids)
out_hf = model_hf.model(input_ids).last_hidden_state
out_ref = model_ref.model(input_ids).last_hidden_state
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits
logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize(
"model_name",
[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-2.7b",
"facebook/opt-6.7b",
],
)
# @pytest.mark.parametrize('model_name', ["facebook/opt-125m"])
def test_opt_generation(model_name):
"""Check that our implementation of OPT generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
print(f"\nMODEL: {model_name}")
verbose = False
dtype = torch.float16
device = "cuda"
rtol, atol = 3e-3, 3e-1
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, "prenorm", True)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
torch.manual_seed(0)
# OPT tokenizer requires use_fast=False
# https://huggingface.co/docs/transformers/model_doc/opt
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
eos_token_id = tokenizer.eos_token_id
input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
device=device
)
max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences = []
scores = []
cur_input_ids = input_ids
with torch.inference_mode():
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
if eos_token_id is not None and (sequences[-1] == eos_token_id).all():
break
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
if verbose:
print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
if getattr(config, "use_flash_attn", False):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
if verbose:
print(out_cg.sequences)
print(tokenizer.batch_decode(out_cg.sequences.tolist()))
del model
model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
model_ref.eval()
print("HF fp32")
torch.cuda.synchronize()
start = time.time()
out_ref = model_ref.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_ref
print(tokenizer.batch_decode(out_ref.sequences.tolist()))
if verbose:
print(
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
assert torch.all(out.sequences == sequences)
assert torch.allclose(
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
import re
import pytest
import torch
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
from timm.models.vision_transformer import vit_base_patch16_224
@pytest.mark.parametrize("fused_mlp", [False, True])
# @pytest.mark.parametrize('fused_mlp', [False])
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [True])
def test_vit(optimized, fused_mlp):
"""Check that our implementation of ViT matches the timm's implementation:
the output of our forward pass in fp16 should be around the same as
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
kwargs = {}
if optimized:
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
kwargs["fused_mlp"] = fused_mlp
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype)
model.load_state_dict(model_ref.state_dict())
model.eval()
model_ref.eval()
model_timm.eval()
torch.manual_seed(0)
batch_size = 2
x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype)
out = model(x)
out_timm = model_timm(x)
out_ref = model_ref(x.float())
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}")
print(f"timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}")
rtol = 2 if not fused_mlp else 8
assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item()
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_block_parallel.py
import math
from functools import partial
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange
from flash_attn.modules.block import Block
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [True])
@pytest.mark.parametrize("dim", [1024])
def test_block_parallel(dim, sequence_parallel, world_size, dtype):
head_dim = 64
assert dim % head_dim == 0
num_heads = dim // head_dim
assert num_heads % world_size == 0
rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 2
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)
residual_pt = torch.randn(batch_size * seqlen, dim, device=device, requires_grad=True)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
if sequence_parallel:
x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
residual = (
tensor_parallel.scatter_to_sequence_parallel_region(residual_pt)
.detach()
.clone()
.requires_grad_()
)
else:
x = x_pt.detach().clone().requires_grad_()
residual = residual_pt.detach().clone().requires_grad_()
mixer_cls_pt = partial(
MHA,
num_heads=num_heads,
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
device=device,
dtype=dtype,
)
mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype)
norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype)
model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True)
with torch.no_grad():
nn.init.normal_(model_pt.norm1.weight)
nn.init.normal_(model_pt.norm1.bias)
nn.init.normal_(model_pt.norm2.weight)
nn.init.normal_(model_pt.norm2.bias)
mixer_cls = partial(
ParallelMHA,
num_heads=num_heads,
process_group=parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
mlp_cls = partial(
ParallelFusedMLP,
hidden_features=4 * dim,
process_group=parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
model = Block(
dim,
mixer_cls,
mlp_cls,
norm_cls,
fused_dropout_add_ln=True,
sequence_parallel=sequence_parallel,
mark_shared_params=True,
)
partition_dim = dim // world_size
partition_hidden_dim = 4 * dim // world_size
with torch.no_grad():
model.mixer.Wqkv.weight.copy_(
rearrange(
rearrange(model_pt.mixer.Wqkv.weight, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
)
)
model.mixer.Wqkv.bias.copy_(
rearrange(
rearrange(model_pt.mixer.Wqkv.bias, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
)
)
model.mixer.out_proj.weight.copy_(
model_pt.mixer.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
)
if rank == 0:
model.mixer.out_proj.bias.copy_(model_pt.mixer.out_proj.bias)
model.mlp.fc1.weight.copy_(
model_pt.mlp.fc1.weight[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
)
model.mlp.fc1.bias.copy_(
model_pt.mlp.fc1.bias[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]
)
model.mlp.fc2.weight.copy_(
model_pt.mlp.fc2.weight[
:, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
]
)
if rank == 0:
model.mlp.fc2.bias.copy_(model_pt.mlp.fc2.bias)
model.norm1.weight.copy_(model_pt.norm1.weight)
model.norm1.bias.copy_(model_pt.norm1.bias)
model.norm2.weight.copy_(model_pt.norm2.weight)
model.norm2.bias.copy_(model_pt.norm2.bias)
mixer_kwargs = {"seqlen": seqlen}
out, out_residual = model(x, residual, mixer_kwargs=mixer_kwargs)
out_pt, out_residual_pt = model_pt(
rearrange(x_pt, "(b s) d -> b s d", s=seqlen),
rearrange(residual_pt, "(b s) d -> b s d", s=seqlen),
)
out_pt, out_residual_pt = [rearrange(x, "b s d -> (b s) d") for x in [out_pt, out_residual_pt]]
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
assert torch.allclose(
out_residual,
out_residual_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_residual_pt,
rtol=rtol,
atol=atol,
)
(out_pt + 2 * out_residual_pt).backward(g)
(out + 2 * out_residual).backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group())
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else x_pt.grad,
rtol=rtol,
atol=atol / 10, # magnitude of x.grad is quite small
)
assert torch.allclose(
residual.grad,
residual_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else residual_pt.grad,
rtol=rtol,
atol=atol,
)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(
model.mixer.Wqkv.weight.grad,
rearrange(
rearrange(model_pt.mixer.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
),
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.mixer.Wqkv.bias.grad,
rearrange(
rearrange(model_pt.mixer.Wqkv.bias.grad, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
),
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.mixer.out_proj.weight.grad,
model_pt.mixer.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(
model.mixer.out_proj.bias.grad,
model_pt.mixer.out_proj.bias.grad,
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.mlp.fc1.weight.grad,
model_pt.mlp.fc1.weight.grad[
rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.mlp.fc1.bias.grad,
model_pt.mlp.fc1.bias.grad[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.mlp.fc2.weight.grad,
model_pt.mlp.fc2.weight.grad[
:, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim
],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(
model.mlp.fc2.bias.grad, model_pt.mlp.fc2.bias.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(
model.norm1.weight.grad, model_pt.norm1.weight.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm1.bias.grad, model_pt.norm1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.norm2.weight.grad, model_pt.norm2.weight.grad, rtol=rtol, atol=atol * 5
)
assert torch.allclose(model.norm2.bias.grad, model_pt.norm2.bias.grad, rtol=rtol, atol=atol * 5)
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.transformer import parallel_state
from einops import rearrange
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize("has_pos_emb", [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize("dim", [1024])
def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
vocab_size = 50264
seqlen = 2048
assert vocab_size % world_size == 0
assert dim % world_size == 0
rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
input_ids_pt = torch.randint(0, vocab_size, (batch_size, seqlen), device=device)
input_ids = input_ids_pt.detach().clone()
model_pt = GPT2Embeddings(
dim, vocab_size, seqlen if has_pos_emb else 0, device=device, dtype=dtype
)
model = ParallelGPT2Embeddings(
dim,
vocab_size,
seqlen if has_pos_emb else 0,
parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
partition_vocab_size = vocab_size // world_size
partition_dim = dim // world_size
with torch.no_grad():
model.word_embeddings.weight.copy_(
model_pt.word_embeddings.weight[
rank * partition_vocab_size : (rank + 1) * partition_vocab_size
]
)
if has_pos_emb:
model.position_embeddings.weight.copy_(
model_pt.position_embeddings.weight[
:, rank * partition_dim : (rank + 1) * partition_dim
]
)
out = model(input_ids, combine_batch_seqlen_dim=True)
out_pt = rearrange(model_pt(input_ids), "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
g = torch.randn_like(out_pt)
out_pt.backward(g)
out.backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel()
assert torch.allclose(
model.word_embeddings.weight.grad,
model_pt.word_embeddings.weight.grad[
rank * partition_vocab_size : (rank + 1) * partition_vocab_size
],
rtol=rtol,
atol=atol,
)
if has_pos_emb:
assert torch.allclose(
model.position_embeddings.weight.grad,
model_pt.position_embeddings.weight.grad[
:, rank * partition_dim : (rank + 1) * partition_dim
],
rtol=rtol,
atol=atol,
)
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mha_parallel.py
import math
import pytest
import torch
import torch.nn.functional as F
from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange
from flash_attn.modules.mha import MHA, ParallelMHA
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize("head_dim", [64, 128])
# @pytest.mark.parametrize('head_dim', [64])
@pytest.mark.parametrize("embed_dim", [1024, 4096])
# @pytest.mark.parametrize('embed_dim', [1024])
def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype):
assert embed_dim % head_dim == 0
num_heads = embed_dim // head_dim
assert num_heads % world_size == 0
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 2
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(
batch_size * seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True
)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
if sequence_parallel:
x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else:
x = x_pt.detach().clone().requires_grad_()
model_pt = MHA(
embed_dim,
num_heads,
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
device=device,
dtype=dtype,
)
partition_dim = embed_dim // world_size
model = ParallelMHA(
embed_dim,
num_heads,
parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.Wqkv.weight.copy_(
rearrange(
rearrange(model_pt.Wqkv.weight, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
)
)
model.Wqkv.bias.copy_(
rearrange(
rearrange(model_pt.Wqkv.bias, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
)
)
model.out_proj.weight.copy_(
model_pt.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
)
if rank == 0:
model.out_proj.bias.copy_(model_pt.out_proj.bias)
out = model(x, seqlen=seqlen)
out_pt = rearrange(model_pt(rearrange(x_pt, "(b s) d -> b s d", s=seqlen)), "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
out_pt.backward(g)
out.backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else x_pt.grad,
rtol=rtol,
atol=atol / 100, # magnitude of x.grad is quite small
)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(
model.Wqkv.weight.grad,
rearrange(
rearrange(model_pt.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o i -> (three o) i",
),
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.Wqkv.bias.grad,
rearrange(
rearrange(model_pt.Wqkv.bias.grad, "(three o) -> three o", three=3)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"three o -> (three o)",
),
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.out_proj.weight.grad,
model_pt.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(
model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5
)
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py
import pytest
import torch
import torch.nn.functional as F
from apex.transformer import parallel_state, tensor_parallel
from einops import rearrange
from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize("activation", [F.silu, F.sigmoid])
# @pytest.mark.parametrize('activation', [F.silu])
@pytest.mark.parametrize("dim", [1024, 4096])
# @pytest.mark.parametrize('dim', [1024])
def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 2
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
if sequence_parallel:
x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else:
x = x_pt.detach().clone().requires_grad_()
model_pt = GatedMlp(dim, activation=activation, device=device, dtype=dtype)
partition_dim = model_pt.fc1.weight.shape[0] // 2 // world_size
model = ParallelGatedMlp(
dim,
parallel_state.get_tensor_model_parallel_group(),
activation=activation,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.fc1.weight.copy_(
rearrange(
rearrange(model_pt.fc1.weight, "(two o) i -> two o i", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o i -> (two o) i",
)
)
model.fc1.bias.copy_(
rearrange(
rearrange(model_pt.fc1.bias, "(two o) -> two o", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o -> (two o)",
)
)
model.fc2.weight.copy_(
model_pt.fc2.weight[:, rank * partition_dim : (rank + 1) * partition_dim]
)
if rank == 0:
model.fc2.bias.copy_(model_pt.fc2.bias)
out = model(x)
out_pt = model_pt(x_pt)
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
out_pt.backward(g)
out.backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else x_pt.grad,
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.fc1.weight.grad,
rearrange(
rearrange(model_pt.fc1.weight.grad, "(two o) i -> two o i", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o i -> (two o) i",
),
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.fc1.bias.grad,
rearrange(
rearrange(model_pt.fc1.bias.grad, "(two o) -> two o", two=2)[
:, rank * partition_dim : (rank + 1) * partition_dim
],
"two o -> (two o)",
),
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.fc2.weight.grad,
model_pt.fc2.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],
rtol=rtol,
atol=atol,
)
if rank == 0:
assert torch.allclose(model.fc2.bias.grad, model_pt.fc2.bias.grad, rtol=rtol, atol=atol)
import math
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from flash_attn.ops.layer_norm import (
DropoutAddLayerNorm,
dropout_add_layer_norm,
dropout_add_layer_norm_parallel_residual,
dropout_add_layer_norm_subset,
)
from flash_attn.ops.rms_norm import (
DropoutAddRMSNorm,
dropout_add_rms_norm,
dropout_add_rms_norm_parallel_residual,
dropout_add_rms_norm_subset,
)
try:
from apex.normalization import FusedRMSNorm
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
except:
FusedRMSNorm, fused_rms_norm_affine = None, None
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("is_rms_norm", [False, True])
@pytest.mark.parametrize("has_colscale", [True, False])
# @pytest.mark.parametrize('has_colscale', [False])
@pytest.mark.parametrize("has_rowscale", [True, False])
# @pytest.mark.parametrize('has_rowscale', [True])
@pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize("dropout_p", [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_training(
hidden_size,
input_dtype,
residual_dtype,
weight_dtype,
dropout_p,
has_residual,
has_rowscale,
has_colscale,
is_rms_norm,
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
if is_rms_norm and FusedRMSNorm is None:
pytest.skip() # We need Apex's FusedRMSNorm to test
layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
colscale_pt = colscale.detach().clone().requires_grad_()
colscale_ref = colscale.detach().clone().float().requires_grad_()
else:
colscale = None
if has_residual:
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
else:
res = None
if has_rowscale:
rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
survival_rate = 0.87
rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
x0_scaled_pt = x0_pt * rearrange(rowscale, "... -> ... 1")
x0_scaled_ref = x0_ref * rearrange(rowscale, "... -> ... 1")
else:
rowscale = None
x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref
if has_colscale:
x0_scaled_pt = x0_scaled_pt * colscale_pt
x0_scaled_ref = x0_scaled_ref * colscale_ref
model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
if not is_rms_norm:
torch.nn.init.normal_(model_pt.bias)
model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
model = our_layer_norm_cls(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model_ref.weight.copy_(model_pt.weight)
if not is_rms_norm:
model.bias.copy_(model_pt.bias)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = our_layer_norm_func(
x0,
res,
model.weight,
model.bias,
model.p,
model.eps,
rowscale=rowscale,
layerscale=colscale,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
assert out.dtype == input_dtype
print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
if has_residual:
residual_pt = (
(x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
out_ref = model_ref(residual_ref)
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size
out_pt.backward(g)
out.backward(g)
out_ref.backward(g)
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (
res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 3 * (
model_pt.weight.grad - model_ref.weight.grad
).abs().max() + 3e-5
if not is_rms_norm:
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
model_pt.bias.grad - model_ref.bias.grad
).abs().max() + 3e-5
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
colscale_pt.grad - colscale_ref.grad
).abs().max() + 2e-4
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
@pytest.mark.parametrize("hidden_size", [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
dropout_p = 0.37
# set seed
torch.random.manual_seed(0)
batch_size = 32
seqlen = 512
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
model_pt.eval()
model.eval()
model_ref.eval()
out = model(x0, res)
residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)
residual_ref = x0_ref + res_ref
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
out_ref = model_ref(residual_ref)
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
@pytest.mark.parametrize("is_rms_norm", [False, True])
@pytest.mark.parametrize("has_colscale", [True, False])
@pytest.mark.parametrize("has_rowscale", [True, False])
@pytest.mark.parametrize("has_residual", [True, False])
@pytest.mark.parametrize("dropout_p", [0.37, 0.0])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_prenorm_training(
hidden_size,
input_dtype,
residual_dtype,
weight_dtype,
dropout_p,
has_residual,
has_rowscale,
has_colscale,
is_rms_norm,
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
if is_rms_norm and FusedRMSNorm is None:
pytest.skip() # We need Apex's FusedRMSNorm to test
layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
colscale_pt = colscale.detach().clone().requires_grad_()
colscale_ref = colscale.detach().clone().float().requires_grad_()
else:
colscale = None
if has_residual:
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
else:
res = None
if has_rowscale:
rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
survival_rate = 0.87
rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
x0_scaled_pt = x0_pt * rearrange(rowscale, "... -> ... 1")
x0_scaled_ref = x0_ref * rearrange(rowscale, "... -> ... 1")
else:
rowscale = None
x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref
if has_colscale:
x0_scaled_pt = x0_scaled_pt * colscale_pt
x0_scaled_ref = x0_scaled_ref * colscale_ref
model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
if not is_rms_norm:
torch.nn.init.normal_(model_pt.bias)
model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
model = our_layer_norm_cls(
hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype
)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model_ref.weight.copy_(model_pt.weight)
if not is_rms_norm:
model.bias.copy_(model_pt.bias)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = our_layer_norm_func(
x0,
res,
model.weight,
model.bias,
model.p,
model.eps,
rowscale=rowscale,
layerscale=colscale,
prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
if has_residual:
residual_pt = (
(x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
out_ref = model_ref(residual_ref)
assert out.dtype == input_dtype
assert residual.dtype == residual_dtype
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (
residual_pt - residual_ref
).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size
(out_pt * F.sigmoid(residual_pt)).backward(g)
(out * F.sigmoid(residual)).backward(g)
(out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (
res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (
model_pt.weight.grad - model_ref.weight.grad
).abs().max() + 2e-4
if not is_rms_norm:
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
model_pt.bias.grad - model_ref.bias.grad
).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
colscale_pt.grad - colscale_ref.grad
).abs().max() + 2e-4
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
@pytest.mark.parametrize("hidden_size", [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
dropout_p = 0.37
# set seed
torch.random.manual_seed(0)
batch_size = 32
seqlen = 512
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model = DropoutAddLayerNorm(
hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype
)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
model_pt.eval()
model.eval()
model_ref.eval()
out, residual = model(x0, res)
residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)
residual_ref = x0_ref + res_ref
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
out_ref = model_ref(residual_ref)
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (
residual_pt - residual_ref
).abs().max() + 1e-4
@pytest.mark.parametrize("has_colscale", [True, False])
@pytest.mark.parametrize("has_residual", [True, False])
@pytest.mark.parametrize("dropout_p", [0.37, 0.0])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_training(
hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
drop_path_rate = 0.4
drop_path_scale = 1 / (1 - drop_path_rate)
def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
numrows = (mask_batch).sum().item() * seqlen
mask_batch = mask_batch.to(device=device, non_blocking=True)
mask_batch_seqlen = repeat(mask_batch, "b -> (b s)", s=seqlen)
subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_(
~mask_batch_seqlen, 0
)
return mask_batch, numrows, rearrange(subset, "(b s) -> b s", b=batch_size)
x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(
batch_size, seqlen, drop_path_rate, device
)
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(
batch_size, seqlen, drop_path_rate, device
)
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
colscale_pt = colscale.detach().clone().requires_grad_()
colscale_ref = colscale.detach().clone().float().requires_grad_()
else:
colscale = None
if has_residual:
res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
else:
res = None
if has_colscale:
x0_scaled_pt = x0_pt * colscale_pt
x0_scaled_ref = x0_ref * colscale_ref
else:
x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(
hidden_size, prenorm=False, p=dropout_p, device=device, dtype=weight_dtype
)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = dropout_add_layer_norm_subset(
x0,
res,
model.weight,
model.bias,
model.p,
model.eps,
layerscale=colscale,
x0_subset=x0_subset,
out_subset=out_subset,
rowscale_const=drop_path_scale,
out_numrows=out_numrows,
prenorm=False,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
x0_scaled_pt = (
x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
* drop_path_scale
)
x0_scaled_ref = (
x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
* drop_path_scale
)
dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
dmask_expanded[x0_mask_batch] = dmask
if has_residual:
residual_pt = (
(x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
out_ref = model_ref(residual_ref)[out_mask_batch]
assert out.dtype == input_dtype
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size
out_pt.backward(g)
out.backward(g)
out_ref.backward(g)
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[
x0_mask_batch
].abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (
res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (
model_pt.weight.grad - model_ref.weight.grad
).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
model_pt.bias.grad - model_ref.bias.grad
).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
colscale_pt.grad - colscale_ref.grad
).abs().max() + 2e-4
@pytest.mark.parametrize("has_colscale", [True, False])
@pytest.mark.parametrize("has_residual", [True, False])
@pytest.mark.parametrize("dropout_p", [0.37, 0.0])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_subset_prenorm_training(
hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
drop_path_rate = 0.4
drop_path_scale = 1 / (1 - drop_path_rate)
def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):
# Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync
mask_batch = torch.rand(batch_size) < 1 - drop_path_rate
numrows = (mask_batch).sum().item() * seqlen
mask_batch = mask_batch.to(device=device, non_blocking=True)
mask_batch_seqlen = repeat(mask_batch, "b -> (b s)", s=seqlen)
subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_(
~mask_batch_seqlen, 0
)
return mask_batch, numrows, rearrange(subset, "(b s) -> b s", b=batch_size)
x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(
batch_size, seqlen, drop_path_rate, device
)
out_mask_batch, out_numrows, out_subset = generate_droppath_masks(
batch_size, seqlen, drop_path_rate, device
)
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_colscale:
colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
colscale_pt = colscale.detach().clone().requires_grad_()
colscale_ref = colscale.detach().clone().float().requires_grad_()
else:
colscale = None
if has_residual:
res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
else:
res = None
if has_colscale:
x0_scaled_pt = x0_pt * colscale_pt
x0_scaled_ref = x0_ref * colscale_ref
else:
x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(
hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype
)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = dropout_add_layer_norm_subset(
x0,
res,
model.weight,
model.bias,
model.p,
model.eps,
layerscale=colscale,
x0_subset=x0_subset,
out_subset=out_subset,
rowscale_const=drop_path_scale,
out_numrows=out_numrows,
prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}")
x0_scaled_pt = (
x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
* drop_path_scale
)
x0_scaled_ref = (
x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0)
* drop_path_scale
)
dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
dmask_expanded[x0_mask_batch] = dmask
if has_residual:
residual_pt = (
(x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]
out_ref = model_ref(residual_ref)[out_mask_batch]
assert out.dtype == input_dtype
assert residual.dtype == residual_dtype
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (
residual_pt - residual_ref
).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size
(out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward(
g
)
(out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g)
(
out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype))
+ residual_ref.mean(0, keepdim=True)
).backward(g)
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[
x0_mask_batch
].abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (
res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (
model_pt.weight.grad - model_ref.weight.grad
).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (
model_pt.bias.grad - model_ref.bias.grad
).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (
colscale_pt.grad - colscale_ref.grad
).abs().max() + 2e-4
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize('is_rms_norm', [False])
@pytest.mark.parametrize("tied_norm", [False, True])
# @pytest.mark.parametrize('tied_norm', [False])
@pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize("has_x1", [True, False])
# @pytest.mark.parametrize('has_x1', [True])
@pytest.mark.parametrize("dropout_p", [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float16])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_parallel_residual_training(
hidden_size,
input_dtype,
residual_dtype,
weight_dtype,
dropout_p,
has_x1,
has_residual,
tied_norm,
is_rms_norm,
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
if is_rms_norm and fused_rms_norm_affine is None:
pytest.skip() # We need Apex's FusedRMSNorm to test
our_layer_norm_func = (
dropout_add_layer_norm_parallel_residual
if not is_rms_norm
else dropout_add_rms_norm_parallel_residual
)
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_x1:
x1_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
else:
x1 = None
if has_residual:
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
else:
res = None
weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias0 = (
torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm
else None
)
weight0_pt = weight0.detach().clone().requires_grad_()
weight0_ref = weight0.detach().clone().float().requires_grad_()
bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
if not tied_norm:
weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias1 = (
torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm
else None
)
weight1_pt = weight1.detach().clone().requires_grad_()
weight1_ref = weight1.detach().clone().float().requires_grad_()
bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None
else:
weight1, bias1 = None, None
epsilon = 1e-5
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out0, out1, dmask0, dmask1 = our_layer_norm_func(
x0,
x1,
res,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
assert out0.dtype == input_dtype
if not tied_norm:
assert out1.dtype == input_dtype
print(f"Actual dropout fraction: {1 - dmask0.float().mean().item()}")
if has_residual:
if has_x1:
residual_pt = (
(x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+ res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (
(x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)
) + res_ref
else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to(
dtype=residual_dtype
)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
else:
if has_x1:
residual_pt = (
(x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
).to(dtype=residual_dtype)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + (
x1_ref * dmask1.float()
) / (1 - dropout_p)
else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
if not is_rms_norm:
out0_pt = F.layer_norm(
residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon
).to(dtype=input_dtype)
out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
if not tied_norm:
out1_pt = F.layer_norm(
residual_pt.to(dtype=weight_dtype),
(hidden_size,),
weight1_pt,
bias1_pt,
eps=epsilon,
).to(dtype=input_dtype)
out1_ref = F.layer_norm(
residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon
)
else:
out0_pt = fused_rms_norm_affine(
residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon
).to(dtype=input_dtype)
out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
if not tied_norm:
out1_pt = fused_rms_norm_affine(
residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon
).to(dtype=input_dtype)
out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
if not tied_norm:
assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4
g0 = torch.randn_like(out0) / batch_size
if tied_norm:
out0.backward(g0)
out0_pt.backward(g0)
out0_ref.backward(g0)
else:
g1 = torch.randn_like(out1) / batch_size
(out0 * g0 + out1 * g1).sum().backward()
(out0_pt * g0 + out1_pt * g1).sum().backward()
(out0_ref * g0 + out1_ref * g1).sum().backward()
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_x1:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (
x1_pt.grad - x1_ref.grad
).abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (
res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (
weight0_pt.grad - weight0_ref.grad
).abs().max() + 3e-5
if not is_rms_norm:
assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (
bias0_pt.grad - bias0_ref.grad
).abs().max() + 3e-5
if not tied_norm:
assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (
weight1_pt.grad - weight1_ref.grad
).abs().max() + 3e-5
if not is_rms_norm:
assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (
bias1_pt.grad - bias1_ref.grad
).abs().max() + 3e-5
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize('is_rms_norm', [False])
@pytest.mark.parametrize("tied_norm", [False, True])
# @pytest.mark.parametrize('tied_norm', [False])
@pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize("has_x1", [True, False])
# @pytest.mark.parametrize('has_x1', [True])
@pytest.mark.parametrize("dropout_p", [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float16])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize(
"hidden_size",
[192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],
)
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_parallel_residual_prenorm_training(
hidden_size,
input_dtype,
residual_dtype,
weight_dtype,
dropout_p,
has_x1,
has_residual,
tied_norm,
is_rms_norm,
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
if is_rms_norm and fused_rms_norm_affine is None:
pytest.skip() # We need Apex's FusedRMSNorm to test
our_layer_norm_func = (
dropout_add_layer_norm_parallel_residual
if not is_rms_norm
else dropout_add_rms_norm_parallel_residual
)
device = "cuda"
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_x1:
x1_pt = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
else:
x1 = None
if has_residual:
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
else:
res = None
weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias0 = (
torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm
else None
)
weight0_pt = weight0.detach().clone().requires_grad_()
weight0_ref = weight0.detach().clone().float().requires_grad_()
bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
if not tied_norm:
weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias1 = (
torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm
else None
)
weight1_pt = weight1.detach().clone().requires_grad_()
weight1_ref = weight1.detach().clone().float().requires_grad_()
bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None
else:
weight1, bias1 = None, None
epsilon = 1e-5
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out0, out1, residual, dmask0, dmask1 = our_layer_norm_func(
x0,
x1,
res,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True,
)
assert out0.dtype == input_dtype
if not tied_norm:
assert out1.dtype == input_dtype
print(f"Actual dropout fraction: {1 - dmask0.float().mean().item()}")
if has_residual:
if has_x1:
residual_pt = (
(x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+ res_pt.float()
).to(dtype=residual_dtype)
residual_ref = (
(x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)
) + res_ref
else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to(
dtype=residual_dtype
)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
else:
if has_x1:
residual_pt = (
(x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
).to(dtype=residual_dtype)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + (
x1_ref * dmask1.float()
) / (1 - dropout_p)
else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(
dtype=residual_dtype
)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
if not is_rms_norm:
out0_pt = F.layer_norm(
residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon
).to(dtype=input_dtype)
out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
if not tied_norm:
out1_pt = F.layer_norm(
residual_pt.to(dtype=weight_dtype),
(hidden_size,),
weight1_pt,
bias1_pt,
eps=epsilon,
).to(dtype=input_dtype)
out1_ref = F.layer_norm(
residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon
)
else:
out0_pt = fused_rms_norm_affine(
residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon
).to(dtype=input_dtype)
out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
if not tied_norm:
out1_pt = fused_rms_norm_affine(
residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon
).to(dtype=input_dtype)
out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
if not tied_norm:
assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (
residual_pt - residual_ref
).abs().max() + 1e-4
g0 = torch.randn_like(out0) / batch_size
if tied_norm:
(out0 * F.sigmoid(residual)).backward(g0)
(out0_pt * F.sigmoid(residual_pt)).backward(g0)
(out0_ref * F.sigmoid(residual_ref)).backward(g0)
else:
g1 = torch.randn_like(out1) / batch_size
(out0 * F.sigmoid(residual) * g0 + out1 * g1).sum().backward()
(out0_pt * F.sigmoid(residual_pt) * g0 + out1_pt * g1).sum().backward()
(out0_ref * F.sigmoid(residual_ref) * g0 + out1_ref * g1).sum().backward()
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_x1:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (
x1_pt.grad - x1_ref.grad
).abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (
res_pt.grad - res_ref.grad
).abs().max() + 1e-4
assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (
weight0_pt.grad - weight0_ref.grad
).abs().max() + 3e-5
if not is_rms_norm:
assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (
bias0_pt.grad - bias0_ref.grad
).abs().max() + 3e-5
if not tied_norm:
assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (
weight1_pt.grad - weight1_ref.grad
).abs().max() + 3e-5
if not is_rms_norm:
assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (
bias1_pt.grad - bias1_ref.grad
).abs().max() + 3e-5
def test_dropout_layer_norm_randomness():
hidden_size = 256
dtype = torch.float32
dropout_p = 0.1
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0 = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=dtype, requires_grad=True
)
res = torch.randn_like(x0, dtype=dtype, requires_grad=True)
model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=dtype)
torch.random.manual_seed(42)
_, dmask0 = dropout_add_layer_norm(
x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True
)
# Subsequent call should have a different dropout mask
_, dmask1 = dropout_add_layer_norm(
x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True
)
torch.random.manual_seed(42)
# Resetting the seed, should get the same dropout mask
_, dmask2 = dropout_add_layer_norm(
x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True
)
assert not torch.equal(dmask0, dmask1)
assert torch.equal(dmask0, dmask2)
import math
from functools import partial
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.ops.fused_dense import FusedDense, FusedMLP
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_residual", [False, True])
@pytest.mark.parametrize("has_bias", [True, False])
@pytest.mark.parametrize("out_features", [1024, 4096])
@pytest.mark.parametrize("in_features", [1024, 4096])
def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, dtype):
device = "cuda"
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(
batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True
)
x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
model = FusedDense(
in_features,
out_features,
bias=has_bias,
return_residual=return_residual,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
if has_bias:
model.bias.copy_(model_pt.bias)
out_pt = model_pt(x_pt)
if not return_residual:
out = model(x)
else:
out, x_copy = model(x)
x_copy = (
x_copy[..., :out_features]
if out_features < in_features
else F.pad(x_copy, (0, out_features - in_features))
)
x_pt_copy = (
x_pt[..., :out_features]
if out_features < in_features
else F.pad(x_pt, (0, out_features - in_features))
)
# Just add some random function of the residual
out_pt = out_pt + F.gelu(x_pt_copy)
out = out + F.gelu(x_copy)
# with torch.no_grad():
# out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(out) / 32
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
if has_bias:
assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("heuristic", ["auto", -1])
# @pytest.mark.parametrize('heuristic', ['auto'])
@pytest.mark.parametrize("checkpoint_lvl", [0, 1, 2])
# @pytest.mark.parametrize('checkpoint_lvl', [1])
@pytest.mark.parametrize("return_residual", [False, True])
# @pytest.mark.parametrize('return_residual', [False])
@pytest.mark.parametrize("has_bias2", [True, False])
@pytest.mark.parametrize("has_bias1", [True, False])
# @pytest.mark.parametrize('has_bias2', [True])
# @pytest.mark.parametrize('has_bias1', [True])
@pytest.mark.parametrize("activation", ["gelu_approx", "relu"])
# @pytest.mark.parametrize('activation', ['relu'])
@pytest.mark.parametrize("out_features", [1024, 4096])
@pytest.mark.parametrize("in_features", [1024, 4096])
# @pytest.mark.parametrize('out_features', [4096])
# @pytest.mark.parametrize('in_features', [1024])
def test_fused_mlp(
in_features,
out_features,
activation,
has_bias1,
has_bias2,
return_residual,
checkpoint_lvl,
heuristic,
dtype,
):
device = "cuda"
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(
batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True
)
x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(
in_features, out_features, bias=has_bias1, device=device, dtype=dtype
)
model_pt_fc2 = torch.nn.Linear(
out_features, in_features, bias=has_bias2, device=device, dtype=dtype
)
model = FusedMLP(
in_features,
out_features,
in_features,
activation=activation,
bias1=has_bias1,
bias2=has_bias2,
return_residual=return_residual,
checkpoint_lvl=checkpoint_lvl,
heuristic=heuristic,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.fc1.weight.copy_(model_pt_fc1.weight)
if has_bias1:
model.fc1.bias.copy_(model_pt_fc1.bias)
model.fc2.weight.copy_(model_pt_fc2.weight)
if has_bias2:
model.fc2.bias.copy_(model_pt_fc2.bias)
activation_fn = (
partial(F.gelu, approximate="tanh")
if activation == "gelu_approx"
else partial(F.relu, inplace=True)
)
out_pt = model_pt_fc2(activation_fn(model_pt_fc1(x_pt)))
if not return_residual:
out = model(x)
else:
out, x_copy = model(x)
# Just add some random function of the residual
out_pt = out_pt + F.gelu(x_pt)
out = out + F.gelu(x_copy)
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(out) / 32
out_pt.backward(g)
out.backward(g)
# The error for relu is higher still
if activation == "relu":
atol = 1e-1 if dtype == torch.bfloat16 else 5e-2
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(
model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10
)
if has_bias1:
assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10
)
if has_bias2:
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/ops/test_fused_dense_parallel.py
import math
import pytest
import torch
import torch.nn.functional as F
from apex.transformer import parallel_state, tensor_parallel
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, FusedMLP, ParallelFusedMLP
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize("has_bias", [True, False])
# @pytest.mark.parametrize('has_bias', [False])
@pytest.mark.parametrize("out_features", [1024])
@pytest.mark.parametrize("in_features", [4096])
def test_fused_linear_bias(
in_features, out_features, has_bias, sequence_parallel, world_size, dtype
):
assert out_features % world_size == 0
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 2
seqlen = 512
assert batch_size * seqlen % world_size == 0
x_pt = torch.randn(
batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True
)
if sequence_parallel:
x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else:
x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
partition_out_features = out_features // world_size
model = ColumnParallelLinear(
in_features,
out_features,
parallel_state.get_tensor_model_parallel_group(),
bias=has_bias,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.weight.copy_(
model_pt.weight[rank * partition_out_features : (rank + 1) * partition_out_features]
)
if has_bias:
model.bias.copy_(
model_pt.bias[rank * partition_out_features : (rank + 1) * partition_out_features]
)
out = model(x)
out_pt = model_pt(x_pt)
assert torch.allclose(
out,
out_pt[:, rank * partition_out_features : (rank + 1) * partition_out_features],
rtol=rtol,
atol=atol,
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(out_pt) / 32
out_pt.backward(g)
out.backward(g[:, rank * partition_out_features : (rank + 1) * partition_out_features])
parallel_state.destroy_model_parallel()
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else x_pt.grad,
rtol=rtol,
atol=atol,
)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(
model.weight.grad,
model_pt.weight.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
rtol=rtol,
atol=atol * 10,
)
if has_bias:
assert torch.allclose(
model.bias.grad,
model_pt.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
rtol=rtol,
atol=atol * 5,
)
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize("has_bias2", [True, False])
# @pytest.mark.parametrize('has_bias2', [True])
@pytest.mark.parametrize("out_features", [4096])
@pytest.mark.parametrize("in_features", [1024])
def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, world_size, dtype):
assert out_features % world_size == 0
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 2
seqlen = 512
assert batch_size * seqlen % world_size == 0
x_pt = torch.randn(
batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True
)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
if sequence_parallel:
x = (
tensor_parallel.scatter_to_sequence_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
else:
x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model_pt_fc2 = torch.nn.Linear(
out_features, in_features, bias=has_bias2, device=device, dtype=dtype
)
partition_out_features = out_features // world_size
partition_in_features = in_features // world_size
model = ParallelFusedMLP(
in_features,
out_features,
in_features,
process_group=parallel_state.get_tensor_model_parallel_group(),
bias2=has_bias2 and rank == 0,
sequence_parallel=sequence_parallel,
device=device,
dtype=dtype,
)
with torch.no_grad():
model.fc1.weight.copy_(
model_pt_fc1.weight[rank * partition_out_features : (rank + 1) * partition_out_features]
)
model.fc1.bias.copy_(
model_pt_fc1.bias[rank * partition_out_features : (rank + 1) * partition_out_features]
)
model.fc2.weight.copy_(
model_pt_fc2.weight[
:, rank * partition_out_features : (rank + 1) * partition_out_features
]
)
if has_bias2 and rank == 0:
model.fc2.bias.copy_(model_pt_fc2.bias)
out = model(x)
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate="tanh"))
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else out_pt,
rtol=rtol,
atol=atol,
)
out_pt.backward(g)
out.backward(
g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g
)
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad,
x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]
if sequence_parallel
else x_pt.grad,
rtol=rtol,
atol=atol,
)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(
model.fc1.weight.grad,
model_pt_fc1.weight.grad[
rank * partition_out_features : (rank + 1) * partition_out_features
],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.fc1.bias.grad,
model_pt_fc1.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.fc2.weight.grad,
model_pt_fc2.weight.grad[
:, rank * partition_out_features : (rank + 1) * partition_out_features
],
rtol=rtol,
atol=atol * 10,
)
if has_bias2 and rank == 0:
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
# Copyright (c) 2024, Tri Dao.
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from flash_attn.ops.triton.layer_norm import (
layer_norm_fn,
layer_norm_ref,
rms_norm_ref,
layer_norm_linear_fn,
)
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("has_weight1", [False, True])
# @pytest.mark.parametrize("has_weight1", [True])
@pytest.mark.parametrize("has_x1", [False, True])
# @pytest.mark.parametrize("has_x1", [False])
@pytest.mark.parametrize("has_rowscale", [False, True])
# @pytest.mark.parametrize("has_rowscale", [False])
@pytest.mark.parametrize("dropout_p", [0.0, 0.27])
# @pytest.mark.parametrize("dropout_p", [0.0])
@pytest.mark.parametrize("prenorm", [True, False])
# @pytest.mark.parametrize("prenorm", [False])
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize("has_residual", [False])
@pytest.mark.parametrize(
"weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize("weight_dtype", [torch.float32])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.float16, torch.float16)])
@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 4096])
# @pytest.mark.parametrize("hidden_size", [256])
def test_layer_norm(
hidden_size,
input_dtype,
residual_dtype,
weight_dtype,
has_residual,
is_rms_norm,
prenorm,
dropout_p,
has_rowscale,
has_x1,
has_weight1,
):
if has_rowscale and has_x1:
pytest.skip("Not supported")
device = "cuda"
if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
atol = 5e-2
elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
atol = 1e-2
else:
atol = 1e-4
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
allclose = (
# Sometimes x0_pt.grad is NaN
lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
<= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol
or (
# Sometimes x_pt and x_ref are the same (e.g. bfloat16) so we want to perturb is a bit
# by multiply and divide by 0.3
(x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() == 0.0
and (x - x_ref).abs().max()
<= 2 * (x_pt[~x_pt.isnan()] * 0.3 / 0.3 - x_ref[~x_pt.isnan()]).abs().max() + atol
)
)
x0 = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0_pt = x0.detach().clone().requires_grad_()
x0_ref = x0.detach().clone().requires_grad_()
if has_residual:
res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res_pt = res.detach().clone().requires_grad_()
res_ref = res.detach().clone().requires_grad_()
else:
res, res_pt, res_ref = None, None, None
weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm:
bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
else:
bias = None
weight_pt = weight.detach().clone().requires_grad_()
weight_ref = weight.detach().clone().requires_grad_()
bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
if has_x1:
x1 = torch.randn_like(x0, dtype=input_dtype, requires_grad=True)
x1_pt = x1.detach().clone().requires_grad_()
x1_ref = x1.detach().clone().requires_grad_()
else:
x1, x1_pt, x1_ref = None, None, None
if has_weight1:
weight1 = torch.randn(
hidden_size, device=device, dtype=weight_dtype, requires_grad=True
)
weight1_pt = weight1.detach().clone().requires_grad_()
weight1_ref = weight1.detach().clone().requires_grad_()
if not is_rms_norm:
bias1 = torch.randn(
hidden_size, device=device, dtype=weight_dtype, requires_grad=True
)
else:
bias1 = None
bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
bias1_ref = bias1.detach().clone().requires_grad_() if bias1 is not None else None
else:
weight1, weight1_pt, weight1_ref = None, None, None
bias1, bias1_pt, bias1_ref = None, None, None
rowscale = (
torch.randn(batch_size, seqlen, dtype=input_dtype, device=device)
if has_rowscale
else None
)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, *rest = layer_norm_fn(
x0,
weight,
bias,
residual=res,
x1=x1,
weight1=weight1,
bias1=bias1,
eps=1e-6,
dropout_p=dropout_p,
rowscale=rowscale,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=is_rms_norm,
return_dropout_mask=True,
)
dropout_mask = rest[-2] if dropout_p > 0.0 else None
dropout_mask1 = rest[-1] if dropout_p > 0.0 and x1 is not None else None
out_pt = layer_norm_ref_fn(
x0_pt,
weight_pt,
bias_pt,
residual=res_pt,
x1=x1_pt,
weight1=weight1_pt,
bias1=bias1_pt,
eps=1e-6,
dropout_p=dropout_p,
rowscale=rowscale,
prenorm=prenorm,
dropout_mask=dropout_mask,
dropout_mask1=dropout_mask1,
)
out_ref = layer_norm_ref_fn(
x0_ref,
weight_ref,
bias_ref,
residual=res_ref,
x1=x1_ref,
weight1=weight1_ref,
bias1=bias1_ref,
eps=1e-6,
dropout_p=dropout_p,
rowscale=rowscale,
prenorm=prenorm,
dropout_mask=dropout_mask,
dropout_mask1=dropout_mask1,
upcast=True,
)
if not has_weight1:
if prenorm:
residual = rest[0]
out_pt, residual_pt = out_pt
out_ref, residual_ref = out_ref
out1, out1_pt, out1_ref = None, None, None
else:
out1 = rest.pop(0)
if prenorm:
residual = rest[0]
out_pt, out1_pt, residual_pt = out_pt
out_ref, out1_ref, residual_ref = out_ref
else:
out_pt, out1_pt = out_pt
out_ref, out1_ref = out_ref
assert out.dtype == input_dtype
if prenorm:
assert residual.dtype == residual_dtype
assert allclose(residual, residual_pt, residual_ref)
assert allclose(out, out_pt, out_ref)
if out1 is not None:
assert out1.dtype == input_dtype
assert allclose(out1, out1_pt, out1_ref)
if dropout_mask is not None:
dropout_fraction = 1.0 - dropout_mask.float().mean()
assert abs(dropout_fraction - dropout_p) < 0.01
if dropout_mask1 is not None:
dropout_fraction = 1.0 - dropout_mask1.float().mean()
assert abs(dropout_fraction - dropout_p) < 0.01
assert not torch.equal(dropout_mask, dropout_mask1)
g = torch.randn_like(out) / batch_size
if has_weight1:
out = out * F.gelu(out1)
out_pt = out_pt * F.gelu(out1_pt)
out_ref = out_ref * F.gelu(out1_ref)
if not prenorm:
out.backward(g)
out_pt.backward(g)
out_ref.backward(g)
else:
(out * F.sigmoid(residual)).backward(g)
(out_pt * F.sigmoid(residual_pt)).backward(g)
(out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
if has_residual:
assert allclose(res.grad, res_pt.grad, res_ref.grad)
if has_x1:
assert allclose(x1.grad, x1_pt.grad, x1_ref.grad)
assert allclose(weight.grad, weight_pt.grad, weight_ref.grad)
if bias is not None:
assert allclose(bias.grad, bias_pt.grad, bias_ref.grad)
if has_weight1:
assert allclose(weight1.grad, weight1_pt.grad, weight1_ref.grad)
if bias1 is not None:
assert allclose(bias1.grad, bias1_pt.grad, bias1_ref.grad)
@pytest.mark.parametrize("prenorm", [True, False])
# @pytest.mark.parametrize("prenorm", [True])
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize("has_residual", [False])
@pytest.mark.parametrize("weight_dtype", [torch.float32])
@pytest.mark.parametrize(
"input_dtype,residual_dtype",
[(torch.float16, torch.float16), (torch.float16, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000])
# @pytest.mark.parametrize("hidden_size", [256])
def test_layer_norm_linear(
hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm, prenorm
):
device = "cuda"
if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
atol = 5e-2
elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
atol = 1e-2
else:
atol = 1e-4
# set seed
torch.random.manual_seed(0)
batch_size = 4
seqlen = 512
# batch_size = 1
# seqlen = 1
layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
allclose = (
lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
<= 2 * (x_pt - x_ref).abs().max() + atol
)
x0 = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
)
x0_pt = x0.detach().clone().requires_grad_()
x0_ref = x0.detach().clone().requires_grad_()
if has_residual:
res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res_pt = res.detach().clone().requires_grad_()
res_ref = res.detach().clone().requires_grad_()
else:
res, res_pt, res_ref = None, None, None
norm_weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm:
norm_bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
else:
norm_bias = None
norm_weight_pt = norm_weight.detach().clone().requires_grad_()
norm_weight_ref = norm_weight.detach().clone().requires_grad_()
norm_bias_pt = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
norm_bias_ref = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
linear_weight = torch.empty(
2 * hidden_size, hidden_size, device=device, dtype=weight_dtype, requires_grad=True
)
torch.nn.init.xavier_uniform_(linear_weight)
if not is_rms_norm:
linear_bias = torch.randn(
2 * hidden_size, device=device, dtype=weight_dtype, requires_grad=True
)
else:
linear_bias = None
linear_weight_pt = linear_weight.detach().clone().requires_grad_()
linear_weight_ref = linear_weight.detach().clone().requires_grad_()
linear_bias_pt = (
linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
)
linear_bias_ref = (
linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
with torch.autocast(device_type="cuda", dtype=input_dtype):
out, *rest = layer_norm_linear_fn(
x0,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=res,
eps=1e-6,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=is_rms_norm,
)
out_pt, *rest_pt = layer_norm_ref_fn(
x0_pt, norm_weight_pt, norm_bias_pt, residual=res_pt, eps=1e-6, prenorm=prenorm
)
with torch.autocast(device_type="cuda", dtype=input_dtype):
out_pt = F.linear(out_pt, linear_weight_pt, linear_bias_pt)
out_ref, *rest_ref = layer_norm_ref_fn(
x0_ref,
norm_weight_ref,
norm_bias_ref,
residual=res_ref,
eps=1e-6,
prenorm=prenorm,
upcast=True,
)
out_ref = F.linear(out_ref.to(linear_weight_ref.dtype), linear_weight_ref, linear_bias_ref)
if prenorm:
residual = rest[0]
residual_pt = rest_pt[0]
residual_ref = rest_ref[0]
assert out.dtype == input_dtype
if prenorm:
assert residual.dtype == residual_dtype
assert allclose(residual, residual_pt, residual_ref)
assert allclose(out, out_pt, out_ref)
g = torch.randn_like(out) / batch_size
out.backward(g)
out_pt.backward(g)
out_ref.backward(g)
assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
if has_residual:
assert allclose(res.grad, res_pt.grad, res_ref.grad)
assert allclose(norm_weight.grad, norm_weight_pt.grad, norm_weight_ref.grad)
if norm_bias is not None:
assert allclose(norm_bias.grad, norm_bias_pt.grad, norm_bias_ref.grad)
assert allclose(linear_weight.grad, linear_weight_pt.grad, linear_weight_ref.grad)
if linear_bias is not None:
assert allclose(linear_bias.grad, linear_bias_pt.grad, linear_bias_ref.grad)
[tool.black]
line-length = 100
target-version = ['py38']
\ No newline at end of file
import math
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from flash_attn import (
flash_attn_func,
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
)
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size_n
from flash_attn.layers.rotary import apply_rotary_emb
MAX_HEADDIM_SM8x = 192
is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
def attn_bias_from_alibi_slopes(
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
):
batch, nheads = slopes.shape
device = slopes.device
slopes = rearrange(slopes, "b h -> b h 1 1")
if causal:
return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
else:
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
relative_pos = torch.abs(row_idx + sk - sq - col_idx)
return -slopes * relative_pos.to(dtype=slopes.dtype)
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == "random":
lengths = torch.randint(
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
)
elif mode == "third":
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
padding_mask = (
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
)
return padding_mask
def generate_qkv(
q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
output_pad_fn = lambda output_unpad: pad_input(
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange(
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
max_seqlen_q,
qkv.detach().requires_grad_(),
output_pad_fn,
dqkv_pad_fn,
)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
else:
dkv_pad_fn = lambda dkv_unpad: rearrange(
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
kv.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
)
else:
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
else:
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
)
def construct_local_mask(
seqlen_q,
seqlen_k,
window_size=(-1, -1), # -1 means infinite window size
query_padding_mask=None,
key_padding_mask=None,
device=None,
):
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
col_idx < row_idx + sk - sq - window_size[0],
)
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if softcap > 0:
scores /= softcap
scores = scores.tanh()
scores *= softcap
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias
attention = torch.softmax(scores, dim=-1).to(v.dtype)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
def attention_kvpacked_ref(
q,
kv,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
):
return attention_ref(
q,
kv[:, :, 0],
kv[:, :, 1],
query_padding_mask,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
upcast=upcast,
causal=causal,
window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops,
)
def attention_qkvpacked_ref(
qkv,
key_padding_mask=None,
attn_bias=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
):
return attention_ref(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
key_padding_mask,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
upcast=upcast,
causal=causal,
window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops,
)
def generate_sparsity_mask(seqlen, sparsity=0.3):
repeats = seqlen // 16 // 2
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
nrow, ncol = seqlen // 16, seqlen // 256
mask = torch.rand(nrow, ncol, device="cuda") < sparsity
return mask
def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
blockmask: (seqlen / 16, seqlen / 256)
attn_mask: (batch_size, seqlen)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen, seqlen)
Output:
output: (batch_size, seqlen, nheads, head_dim)
attention: softmax after dropout
"""
q, k, v = qkv.float().unbind(dim=2)
d = qkv.shape[-1]
seqlen = qkv.shape[1]
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)")
blockmask = blockmask[:seqlen, :seqlen]
scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf"))
attention = torch.softmax(scores, dim=-1)
attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0)
attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0)
attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p)
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0)
return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
def convert_flash_attn_S_to_softmax(
S,
seqlen_q,
seqlen_k,
query_padding_mask,
key_padding_mask,
head_dim,
is_dropout,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""FlashAttention stores the S matrix in a different way.
Arguments:
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
query_padding_mask: (batch_size, seqlen_q_rounded)
key_padding_mask: (batch_size, seqlen_k_rounded)
"""
if causal:
window_size = (window_size[0], 0)
seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
S_converted = S
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
S.device,
)
local_mask = F.pad(
local_mask,
(0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
value=True,
)
S_converted = S_converted.masked_fill(local_mask, 0.0)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
seqlen_q_og = (
query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
)
if query_padding_mask is not None:
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
if key_padding_mask is not None:
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
return S_converted[:, :, :seqlen_q, :seqlen_k]
def normalize_flash_attn_S(
attn_unnorm,
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
is_dropout=False,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k, v: (batch_size, seqlen_k, nheads, head_dim)
key_padding_mask: (batch_size, seqlen_q)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
Output:
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
"""
if causal:
window_size = (window_size[0], 0)
q, k, v = q.float(), k.float(), v.float()
_, seqlen_q, _, head_dim = q.shape
seqlen_k = k.shape[1]
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k)
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias.to(dtype=scores.dtype)
block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal)
scores_block = scores.split(block_size_n, dim=-1)
lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
lse = torch.logsumexp(lse_block, dim=-1)
# lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
# so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
lse[lse == float("-inf")] = float("inf")
scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
attn_norm = torch.cat(
[
a * rearrange(torch.exp(m - lse), "b h s -> b h s 1")
for a, m in zip(attn_unnorm_block, cummax_block)
],
dim=-1,
)
if query_padding_mask is not None:
attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
return attn_norm.to(dtype=attn_unnorm.dtype)
def get_dropout_fraction(
dropout_mask,
query_padding_mask=None,
key_padding_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
"""
if causal:
window_size = (window_size[0], 0)
batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
dropped = ~dropout_mask
valid = torch.ones_like(dropout_mask)
if query_padding_mask is not None:
dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
valid.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
if key_padding_mask is not None:
dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
valid.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
dropout_mask.device,
)
dropped.masked_fill_(local_mask, False)
valid.masked_fill_(local_mask, False)
dropped_total = dropped.sum()
return dropped.sum() / valid.sum()
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [False])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [False])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize("d", [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize("seqlen", [512])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
)
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal)
else:
alibi_slopes, attn_bias = None, None
out, lse, S_dmask = flash_attn_qkvpacked_func(
qkv,
dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask,
seqlen,
seqlen,
None,
None,
d,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(
attn_unnorm,
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
None,
None,
attn_bias,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
)
dropout_fraction = get_dropout_fraction(
dropout_mask, None, None, causal=causal, window_size=window_size
).item()
print(f"Actual dropout fraction: {dropout_fraction}")
else:
dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref(
qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size
)
out_pt, attn_pt = attention_qkvpacked_ref(
qkv,
None,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
# v = qkv[:, :, 2].float()
# qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
# if causal:
# causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)
# qk.masked_fill_(causal_mask, float('-inf'))
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# p_tmp = torch.softmax(qk / math.sqrt(d), -1)
# p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values
# qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values
# qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values
# qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values
# o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:])
# o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:])
# o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:])
# o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :])
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if dropout_p > 0.0:
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out)
# do_o = (g.float() * out.float()).sum(-1)
# dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
# dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
(dqkv,) = torch.autograd.grad(out, qkv, g)
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_qkvpacked(
seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype
):
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 5
nheads = 6
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
)
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(
alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal
)
else:
alibi_slopes, attn_bias = None, None
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
*qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
)
out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens,
max_seqlen,
dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
out = output_pad_fn(out_unpad)
if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask,
seqlen,
seqlen,
key_padding_mask,
key_padding_mask,
d,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(
attn_unnorm,
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
key_padding_mask,
key_padding_mask,
attn_bias,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
)
dropout_fraction = get_dropout_fraction(
dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size
).item()
print(f"Actual dropout fraction: {dropout_fraction}")
else:
dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref(
qkv,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
)
out_pt, attn_pt = attention_qkvpacked_ref(
qkv,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if dropout_p > 0.0:
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
(dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad)
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
@pytest.mark.parametrize("kvpacked", [True, False])
# @pytest.mark.parametrize("kvpacked", [False])
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.17])
@pytest.mark.parametrize("softcap", [0.0, 50.0])
def test_flash_attn_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if softcap > 0.0 and dropout_p > 0.0:
pytest.skip("Softcap and dropout not supported together")
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if softcap > 0:
# Ensure the values of qk are at least within softcap range.
q = q * softcap
if kvpacked:
kv = torch.randn(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
else:
k = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
else:
alibi_slopes, attn_bias = None, None
if kvpacked:
out, lse, S_dmask = flash_attn_kvpacked_func(
q,
kv,
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
else:
out, lse, S_dmask = flash_attn_func(
q,
k,
v,
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask,
seqlen_q,
seqlen_k,
None,
None,
d,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
if kvpacked:
kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
k_rep, v_rep = kv_rep.unbind(dim=2)
else:
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
attn = normalize_flash_attn_S(
attn_unnorm,
q,
k_rep,
v_rep,
None,
None,
attn_bias,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
)
dropout_fraction = get_dropout_fraction(
dropout_mask, None, None, causal=causal, window_size=window_size
).item()
print(f"Actual dropout fraction: {dropout_fraction}")
else:
dropout_mask = None
if kvpacked:
out_ref, attn_ref = attention_kvpacked_ref(
q,
kv,
None,
None,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
)
out_pt, attn_pt = attention_kvpacked_ref(
q,
kv,
None,
None,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
upcast=False,
reorder_ops=True,
)
else:
out_ref, attn_ref = attention_ref(
q,
k,
v,
None,
None,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
)
out_pt, attn_pt = attention_ref(
q,
k,
v,
None,
None,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if dropout_p > 0.0:
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1)
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
if kvpacked:
(
dq,
dkv,
) = torch.autograd.grad(out, (q, kv), g)
dk, dv = dkv.unbind(2)
(
dq_ref,
dkv_ref,
) = torch.autograd.grad(out_ref, (q, kv), g)
dk_ref, dv_ref = dkv_ref.unbind(2)
(
dq_pt,
dkv_pt,
) = torch.autograd.grad(out_pt, (q, kv), g)
dk_pt, dv_pt = dkv_pt.unbind(2)
else:
(
dq,
dk,
dv,
) = torch.autograd.grad(out, (q, k, v), g)
(
dq_ref,
dk_ref,
dv_ref,
) = torch.autograd.grad(out_ref, (q, k, v), g)
(
dq_pt,
dk_pt,
dv_pt,
) = torch.autograd.grad(out_pt, (q, k, v), g)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.parametrize("kvpacked", [True, False])
# @pytest.mark.parametrize('kvpacked', [False])
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mqa"])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 147),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize("softcap", [0.0, 50.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if softcap > 0.0 and dropout_p > 0.0:
pytest.skip("Softcap and dropout not supported together")
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if softcap > 0:
# Ensure the values of qk are at least within softcap range.
q = q * softcap
if kvpacked:
kv = torch.randn(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
else:
k = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(
alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal
)
else:
alibi_slopes, attn_bias = None, None
if kvpacked:
(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)
out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
else:
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
out = output_pad_fn(out_unpad)
if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask,
seqlen_q,
seqlen_k,
query_padding_mask,
key_padding_mask,
d,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
if kvpacked:
kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
k_rep, v_rep = kv_rep.unbind(dim=2)
else:
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
attn = normalize_flash_attn_S(
attn_unnorm,
q,
k_rep,
v_rep,
query_padding_mask,
key_padding_mask,
attn_bias,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
)
dropout_fraction = get_dropout_fraction(
dropout_mask,
query_padding_mask,
key_padding_mask,
causal=causal,
window_size=window_size,
).item()
print(f"Actual dropout fraction: {dropout_fraction}")
else:
dropout_mask = None
if kvpacked:
out_ref, attn_ref = attention_kvpacked_ref(
q,
kv,
query_padding_mask,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
)
out_pt, attn_pt = attention_kvpacked_ref(
q,
kv,
query_padding_mask,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
upcast=False,
reorder_ops=True,
)
else:
out_ref, attn_ref = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
)
out_pt, attn_pt = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if dropout_p > 0.0:
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}")
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out)
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
if kvpacked:
(
dq_unpad,
dkv_unpad,
) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
(
dq_ref,
dkv_ref,
) = torch.autograd.grad(out_ref, (q, kv), g)
dk_ref, dv_ref = dkv_ref.unbind(2)
(
dq_pt,
dkv_pt,
) = torch.autograd.grad(out_pt, (q, kv), g)
dk_pt, dv_pt = dkv_pt.unbind(2)
else:
(
dq_unpad,
dk_unpad,
dv_unpad,
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
dk = dk_pad_fn(dk_unpad)
dv = dk_pad_fn(dv_unpad)
(
dq_ref,
dk_ref,
dv_ref,
) = torch.autograd.grad(out_ref, (q, k, v), g)
(
dq_pt,
dk_pt,
dv_pt,
) = torch.autograd.grad(out_pt, (q, k, v), g)
dq = dq_pad_fn(dq_unpad)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64, 128])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 239),
(3, 799),
(127, 512),
(127, 513),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(1023, 1024),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda"
causal = True
# set seed
torch.random.manual_seed(0)
batch_size = 8
nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size)
out_ref, attn_ref = attention_ref(
q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size
)
out_pt, attn_pt = attention_ref(
q,
k,
v,
None,
None,
None,
0.0,
None,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1)
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
(
dq,
dk,
dv,
) = torch.autograd.grad(out, (q, k, v), g)
(
dq_ref,
dk_ref,
dv_ref,
) = torch.autograd.grad(out_ref, (q, k, v), g)
(
dq_pt,
dk_pt,
dv_pt,
) = torch.autograd.grad(out_pt, (q, k, v), g)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 239),
(3, 799),
(127, 512),
(127, 513),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(1023, 1024),
],
)
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
@pytest.mark.parametrize("paged_kv_block_size", [None, 16, 256, 512])
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def test_flash_attn_varlen_causal(
seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype
):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda"
causal = True
# set seed
torch.random.manual_seed(0)
batch_size = 8
nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if paged_kv_block_size is None:
k = torch.randn(
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
)
block_table = None
else:
k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache(
seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype
)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
out_unpad = flash_attn_varlen_func(
q_unpad,
k_unpad if paged_kv_block_size is None else k_cache_paged,
v_unpad if paged_kv_block_size is None else v_cache_paged,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
causal=causal,
window_size=window_size,
block_table=block_table,
)
out = output_pad_fn(out_unpad)
out_ref, attn_ref = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
None,
0.0,
None,
causal=causal,
window_size=window_size,
)
out_pt, attn_pt = attention_ref(
q,
k,
v,
query_padding_mask,
key_padding_mask,
None,
0.0,
None,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1)
test_backward = (d <= MAX_HEADDIM_SM8x or d > 224 or is_sm80 or is_sm90) and block_table is None
if test_backward:
(
dq_unpad,
dk_unpad,
dv_unpad,
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
dq = dq_pad_fn(dq_unpad)
dk = dk_pad_fn(dk_unpad)
dv = dk_pad_fn(dv_unpad)
(
dq_ref,
dk_ref,
dv_ref,
) = torch.autograd.grad(out_ref, (q, k, v), g)
(
dq_pt,
dk_pt,
dv_pt,
) = torch.autograd.grad(out_pt, (q, k, v), g)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
if test_backward:
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(3, 1024),
(1, 339),
(64, 800),
(3, 799),
(64, 2048),
(16, 20000),
(16, 100000),
(128, 128),
(256, 256),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_splitkv(
seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype
):
if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 1
nheads = 12
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
else:
alibi_slopes, attn_bias = None, None
out, lse, _ = flash_attn_func(
q,
k,
v,
0.0,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
out_ref, attn_ref = attention_ref(
q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size
)
out_pt, attn_pt = attention_ref(
q,
k,
v,
None,
None,
attn_bias,
0.0,
None,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1)
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
(
dq,
dk,
dv,
) = torch.autograd.grad(out, (q, k, v), g)
(
dq_ref,
dk_ref,
dv_ref,
) = torch.autograd.grad(out_ref, (q, k, v), g)
(
dq_pt,
dk_pt,
dv_pt,
) = torch.autograd.grad(out_pt, (q, k, v), g)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
mult = 2 if not alibi else 8
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4
assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4
assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("num_splits", [1, 0])
# @pytest.mark.parametrize("num_splits", [1])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [False])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [False])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@pytest.mark.parametrize("rotary_interleaved", [False, True])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
@pytest.mark.parametrize("paged_kv_block_size", [None, 16, 256, 512])
@pytest.mark.parametrize("has_batch_idx", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 1024),
(16, 128 * 1024),
(128, 128),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_kvcache(
seqlen_q,
seqlen_k,
d,
has_batch_idx,
paged_kv_block_size,
rotary_fraction,
rotary_interleaved,
seqlen_new_eq_seqlen_q,
causal,
local,
alibi,
new_kv,
mha_type,
num_splits,
dtype,
):
if seqlen_q > seqlen_k and new_kv:
pytest.skip()
if not new_kv and rotary_fraction > 0.0:
pytest.skip()
if has_batch_idx and paged_kv_block_size is not None:
pytest.skip()
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 2
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
nheads = 6
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
if new_kv:
k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
else:
k, v = None, None
if paged_kv_block_size is None:
k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
block_table = None
else:
(
k_cache,
v_cache,
block_table,
k_cache_paged,
v_cache_paged,
num_blocks,
) = _generate_block_kvcache(
seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
)
cache_seqlens = torch.randint(
0 if new_kv else 1,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(
(seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
if new_kv
else (seqlen_k + 1)
),
(batch_size,),
dtype=torch.int32,
device=device,
)
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
if has_batch_idx:
cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
:batch_size
]
else:
cache_batch_idx = None
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(
alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal
)
else:
alibi_slopes, attn_bias = None, None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if rotary_dim > 0:
angle = (
torch.rand(
seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size,
rotary_dim // 2,
device=device,
)
* 2
* math.pi
)
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if causal or local:
q_ro = apply_rotary_emb(
q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
)
else:
q_ro = rearrange(
apply_rotary_emb(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
seqlen_offsets=cache_seqlens,
interleaved=rotary_interleaved,
),
"b 1 (s h) d -> b s h d",
s=seqlen_q,
)
# q_ro = q
k_ro = apply_rotary_emb(
k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
)
else:
cos, sin = None, None
q_ro, k_ro = q, k
# k_cache[:, 64:] = -1
k_cache_ref = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
).clone()
v_cache_ref = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
).clone()
if new_kv:
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
)
k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
out = flash_attn_with_kvcache(
q,
k_cache if paged_kv_block_size is None else k_cache_paged,
v_cache if paged_kv_block_size is None else v_cache_paged,
k,
v,
rotary_cos=cos,
rotary_sin=sin,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_batch_idx,
block_table=block_table,
causal=causal,
window_size=window_size,
rotary_interleaved=rotary_interleaved,
alibi_slopes=alibi_slopes,
num_splits=num_splits,
)
# out = flash_attn_with_kvcache(
# q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
# )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
out_ref, _ = attention_ref(
q_ro,
k_cache_rep,
v_cache_rep,
None,
key_padding_mask,
attn_bias,
0.0,
None,
causal=causal,
window_size=window_size,
)
out_pt, _ = attention_ref(
q_ro,
k_cache_rep,
v_cache_rep,
None,
key_padding_mask,
attn_bias,
0.0,
None,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if new_kv:
if paged_kv_block_size is None:
k_cache_select = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
)
v_cache_select = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
)
else:
k_cache_select = rearrange(
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache_select = rearrange(
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
assert torch.equal(v_cache_select, v_cache_ref)
mult = 3 if not alibi else 5
assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):
num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
k_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
v_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
block_table = rearrange(
torch.randperm(num_blocks, dtype=torch.int32, device=device),
"(b nblocks) -> b nblocks",
b=batch_size,
)
k_cache = rearrange(
# pytorch 1.12 doesn't have indexing with int32
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache = rearrange(
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 239),
(239, 1),
(3, 799),
(799, 3),
(1024, 128),
(97, 97),
(128, 128),
(200, 200),
(256, 256),
(257, 257),
(384, 384),
(512, 512),
(768, 768),
(1024, 1024),
],
)
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype):
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger
nheads = 4
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
torch.random.manual_seed(42)
out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
g = torch.randn_like(out0)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
(
dq0,
dk0,
dv0,
) = torch.autograd.grad(out0, (q, k, v), g)
# Numerical error if we just do any arithmetic on dq
dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item()
for i in range(250):
torch.random.manual_seed(42)
out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
assert torch.equal(out, out0)
assert torch.equal(lse, lse0)
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
(
dq,
dk,
dv,
) = torch.autograd.grad(out, (q, k, v), g)
dq_equal = torch.allclose(dq, dq0, atol=dq_atol)
if not dq_equal:
print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}")
assert torch.equal(dv, dv0)
assert torch.equal(dk, dk0)
assert dq_equal
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize("d", [16, 32, 64])
# @pytest.mark.parametrize('d', [16])
@pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128])
# @pytest.mark.parametrize('seqlen', [2])
def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
"""We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
in the case where seqlen % 128 != 0.
"""
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 2
nheads = 5
q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5
k, v = [
torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3
for _ in range(2)
]
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)
out = flash_attn_func(q, k, v, causal=causal)
g = torch.randn_like(out)
out.backward(g)
q_pt = q.detach().clone().requires_grad_(True)
k_pt = k.detach().clone().requires_grad_(True)
v_pt = v.detach().clone().requires_grad_(True)
out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
out_pt.backward(g)
q_ref = q.detach().clone().requires_grad_(True)
k_ref = k.detach().clone().requires_grad_(True)
v_ref = v.detach().clone().requires_grad_(True)
out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
out_ref.backward(g)
print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (
q_pt.grad - q_ref.grad
).abs().max().item() + 1e-3
assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (
k_pt.grad - k_ref.grad
).abs().max().item() + 1e-3
assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (
v_pt.grad - v_ref.grad
).abs().max().item() + 1e-3
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize("d", [64, 128])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256])
# @pytest.mark.parametrize('seqlen', [128])
def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
"""We previously had a bug where we were using the wrong strides of dout, which shows up
when dout is not contiguous.
"""
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 5
nheads = 2
q, k, v = [
torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True)
for _ in range(3)
]
out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...")
# So g is not contiguous
g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2]
out.backward(g)
q_pt = q.detach().clone().requires_grad_(True)
k_pt = k.detach().clone().requires_grad_(True)
v_pt = v.detach().clone().requires_grad_(True)
out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
out_pt = rearrange(out_pt, "b s ... -> s b ...")
out_pt.backward(g)
q_ref = q.detach().clone().requires_grad_(True)
k_ref = k.detach().clone().requires_grad_(True)
v_ref = v.detach().clone().requires_grad_(True)
out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
out_ref = rearrange(out_ref, "b s ... -> s b ...")
out_ref.backward(g)
print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (
q_pt.grad - q_ref.grad
).abs().max().item()
assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (
k_pt.grad - k_ref.grad
).abs().max().item()
assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (
v_pt.grad - v_ref.grad
).abs().max().item()
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize("d", [16, 32, 64])
# @pytest.mark.parametrize('d', [16])
def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
"""We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
in the case where seqlen % 128 != 0 or varlen.
"""
device = "cuda"
# set seed
torch.random.manual_seed(0)
nheads = 5
q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32)
k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32)
Mq = 256
Mk = 3
q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3
k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)]
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)
out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal)
g = torch.randn_like(out)
out.backward(g)
assert not q.grad.isnan().any()
assert not k.grad.isnan().any()
assert not v.grad.isnan().any()
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 239),
(3, 799),
(127, 512),
(127, 513),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(1023, 1024),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)
g = torch.randn_like(out)
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
assert torch.equal(dv, dv0)
assert torch.equal(dk, dk0)
assert torch.equal(dq, dq0)
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 239),
(3, 799),
(127, 512),
(127, 513),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(1023, 1024),
],
)
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 2
nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
out = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
causal=causal,
window_size=window_size,
deterministic=True,
)
g = torch.randn_like(out)
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
assert torch.equal(dv, dv0)
assert torch.equal(dk, dk0)
assert torch.equal(dq, dq0)
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("paged_kv_block_size", [16])
# @pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize("nheads", [32])
@pytest.mark.parametrize("b", [4])
@pytest.mark.parametrize("n", [10])
@pytest.mark.parametrize("seqlen_q,seqlen_k", [(170, 170)])
def test_flash_attn_paged_kvcache_overflow(
seqlen_q,
seqlen_k,
d,
nheads,
b,
n,
paged_kv_block_size,
causal,
dtype,
):
device = "cuda"
num_blocks = 1000*16//paged_kv_block_size
key_cache = torch.rand([num_blocks, paged_kv_block_size, nheads, d], dtype=dtype, device=device)
value_cache = torch.rand([num_blocks, paged_kv_block_size, nheads, d], dtype=dtype, device=device)
cache_seqlens = torch.zeros(b, dtype=torch.int32, device=device)
for _ in range(n):
query = torch.rand([b, seqlen_q, nheads, d], dtype=dtype, device=device)
key = torch.rand([b, seqlen_k, nheads, d], dtype=dtype, device=device)
value = torch.rand([b, seqlen_k, nheads, d], dtype=dtype, device=device)
block_tables = torch.randint(0, num_blocks, size=(b, (seqlen_k + paged_kv_block_size - 1) // paged_kv_block_size), dtype=torch.int32, device=device)
output = flash_attn_with_kvcache(
query,
key_cache,
value_cache,
k=key,
v=value,
cache_seqlens=cache_seqlens,
block_table=block_tables,
causal=causal,
)
import math
import random
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch
from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_
from flash_attn.bert_padding import pad_input, unpad_input
is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
def generate_cos_sin(seqlen, rotary_dim, device, dtype):
assert rotary_dim % 2 == 0
angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
return cos, sin
def generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device):
if seqlen_offsets_type == 0:
return 0
elif seqlen_offsets_type is int:
return torch.randint(0, seqlen + 1, (1,)).item()
elif seqlen_offsets_type is torch.Tensor:
return torch.randint(0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device)
def index_cos_sin(cos, sin, seqlen_offsets, seqlen):
if isinstance(seqlen_offsets, torch.Tensor):
batch_size = seqlen_offsets.shape[0]
arange = rearrange(torch.arange(seqlen, device=cos.device), "s -> 1 s")
idx = rearrange(seqlen_offsets, "b -> b 1") + arange
cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
else:
cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
return cos_pt, sin_pt
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [True])
@pytest.mark.parametrize("inplace", [False, True])
# @pytest.mark.parametrize('inplace', [False])
def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 217
headdim = 128
device = "cuda"
rotary_dim = int(rotary_fraction * headdim)
torch.manual_seed(42)
x = torch.randn(
batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
x_pt = x.detach().clone().requires_grad_()
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
out = apply_rotary_emb(
x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
)
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
out_pt = apply_rotary_emb_torch(
x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
print(f"Grad max diff: {(x.grad - x_pt.grad).abs().max().item()}")
if not inplace:
assert torch.equal(x, x_pt)
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 512
headdim = 128
device = "cuda"
rotary_dim = int(rotary_fraction * headdim)
torch.manual_seed(42)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
qkv_pt = qkv.detach().clone().requires_grad_()
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
out = apply_rotary_emb_qkv_(
qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
)
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
q_pt = apply_rotary_emb_torch(
qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
k_pt = apply_rotary_emb_torch(
qkv_pt[:, :, 1].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
out_pt = torch.stack([q_pt, k_pt, qkv_pt[:, :, 2]], dim=2)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
print(f"Grad max diff: {(qkv.grad - qkv_pt.grad).abs().max().item()}")
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((qkv_pt.grad + 0.3 - 0.3) - qkv_pt.grad).abs().max().item()
assert torch.allclose(qkv.grad, qkv_pt.grad, rtol=rtol, atol=2 * atol)
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 781
headdim = 64
device = "cuda"
rotary_dim = int(rotary_fraction * headdim)
torch.manual_seed(42)
kv = torch.randn(
batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
kv_pt = kv.detach().clone().requires_grad_()
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
out = apply_rotary_emb_kv_(kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved)
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
k_pt = apply_rotary_emb_torch(
kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
out_pt = torch.stack([k_pt, kv_pt[:, :, 1]], dim=2)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
print(f"Grad max diff: {(kv.grad - kv_pt.grad).abs().max().item()}")
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item()
assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol)
@pytest.mark.parametrize(
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize("dtype", ([torch.float16]))
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize("rotary_fraction", [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize("interleaved", [True])
@pytest.mark.parametrize("inplace", [False, True])
# @pytest.mark.parametrize("inplace", [False])
def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
seqlen = 217
headdim = 128
device = "cuda"
rotary_dim = int(rotary_fraction * headdim)
torch.manual_seed(42)
x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
x_pt = x.detach().clone().requires_grad_()
lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device)
padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths
x_unpad, indices, cu_seqlens, max_seqlen = unpad_input(x, padding_mask)
x_unpad_clone = x_unpad.clone()
x_unpad = x_unpad.requires_grad_()
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
out_unpad = apply_rotary_emb(
x_unpad,
cos,
sin,
seqlen_offsets=seqlen_offsets,
interleaved=interleaved,
inplace=inplace,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
out = pad_input(out_unpad, indices, batch_size, seqlen)
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
out_pt = apply_rotary_emb_torch(
x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
out_pt = out_pt.masked_fill(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
g = torch.randn_like(out)
g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
out.backward(g)
out_pt.backward(g_pt)
x_grad = pad_input(x_unpad.grad, indices, batch_size, seqlen)
print(f"Grad max diff: {(x_grad - x_pt.grad).abs().max().item()}")
if not inplace:
assert torch.equal(x_unpad, x_unpad_clone)
# Numerical error if we just do any arithmetic
atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol)
def test_compilation_count():
batch_size = 1
headdim = 128
device = "cuda"
dtype = torch.float16
torch.manual_seed(42)
from triton.runtime.jit import JITFunction
from flash_attn.ops.triton.rotary import rotary_kernel
compilation_count = 0
def count_compilations(*args, **kwargs):
nonlocal compilation_count
compilation_count += 1
old_cache_func = JITFunction.cache_hook
try:
rotary_kernel.cache.clear()
JITFunction.cache_hook = count_compilations
for seqlen in (128, 256):
for nheads in (4, 32):
x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
x.requires_grad_()
cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
out = apply_rotary_emb(x, cos, sin)
out.backward(torch.randn_like(out))
# Only two kernels are expected to be compiled:
# * for the forward pass (conjugate=False)
# * for the backward pass (conjugate=True)
assert compilation_count == 2
finally:
JITFunction.cache_hook = old_cache_func
#
# This file is copied verbatim from vLLM:
# https://github.com/vllm-project/vllm/blob/main/tests/kernels/test_flash_attn.py
#
from typing import List, Optional, Tuple
import pytest
import torch
import flash_attn_wrapper # noqa: F401
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: List[int],
kv_lens: List[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs: List[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len]
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len]
if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)
outputs.append(out)
start_idx += query_len
return torch.cat(outputs, dim=0)
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
kv_lens: List[int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
output = torch.ops.vllm.flash_attn_with_kvcache(
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
).squeeze(1)
if num_blocks <= 2048:
test_utils = ["test_faketensor", "test_schema"]
else:
test_utils = ["test_faketensor"]
torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache,
args=tuple(),
kwargs=dict(
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=test_utils)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@torch.inference_mode()
def test_varlen_with_paged_kv(
seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window,
sliding_window) if sliding_window is not None else
(-1, -1))
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
)
if num_blocks <= 2048:
test_utils = ["test_faketensor", "test_schema"]
else:
test_utils = ["test_faketensor"]
torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func,
args=tuple(),
kwargs=dict(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=test_utils)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
__version__ = "2.6.2"
from vllm_flash_attn.flash_attn_interface import (
# Use relative import to support build-from-source installation in vLLM
from .flash_attn_interface import (
flash_attn_func,
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
......
......@@ -7,7 +7,8 @@ import torch.nn as nn
# isort: off
# We need to import the CUDA kernels after importing torch
import vllm_flash_attn_2_cuda as flash_attn_cuda
# Use relative import to support build-from-source installation in vLLM
from . import vllm_flash_attn_c # noqa: F401
# isort: on
......@@ -49,7 +50,7 @@ def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, *, out=None
):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = torch.ops.vllm_flash_attn_c.fwd(
q,
k,
v,
......@@ -87,7 +88,7 @@ def _flash_attn_varlen_forward(
out=None
):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = torch.ops.vllm_flash_attn_c.varlen_fwd(
q,
k,
v,
......@@ -140,7 +141,7 @@ def _flash_attn_backward(
dk,
dv,
softmax_d,
) = flash_attn_cuda.bwd(
) = torch.ops.vllm_flash_attn_c.bwd(
dout,
q,
k,
......@@ -194,7 +195,7 @@ def _flash_attn_varlen_backward(
dk,
dv,
softmax_d,
) = flash_attn_cuda.varlen_bwd(
) = torch.ops.vllm_flash_attn_c.varlen_bwd(
dout,
q,
k,
......@@ -1292,7 +1293,7 @@ def flash_attn_with_kvcache(
cache_seqlens = maybe_contiguous(cache_seqlens)
cache_batch_idx = maybe_contiguous(cache_batch_idx)
block_table = maybe_contiguous(block_table)
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
out, softmax_lse = torch.ops.vllm_flash_attn_c.fwd_kvcache(
q,
k_cache,
v_cache,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment