Commit 4f83cf8f authored by Junxian's avatar Junxian
Browse files

[release] v0.0.1

parents
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py
import torch
from block_sparse_attn import (
block_sparse_attn_func,
flash_attn_varlen_func,
)
from utils import (
time_fwd,
flops,
efficiency,
write_to_excel,
)
def generate_base_sparsity_mask(max_seqlen_q, max_seqlen_k, round_base, m_block_dim, n_block_dim, sparsity, causal=False, device="cuda"):
def round_to_multiple(x, base):
return ((x + base - 1) // base) * base
nrow, ncol = round_to_multiple(max_seqlen_q, round_base) // m_block_dim, round_to_multiple(max_seqlen_k, round_base) // n_block_dim
base_mask = torch.zeros(1, nrow, ncol, device=device, dtype=torch.bool)
total_block_num = 0
density = 1.0 - sparsity
if not density == 0.0 and not density == 1.0:
for i in range(nrow): # do in reverse order
idx = nrow - i - 1
if causal:
available_col_num = max(0, ncol - i)
total_block_num += available_col_num
num_one = max(1, int(density * available_col_num))
base_mask[0][idx, torch.randperm(available_col_num)[:num_one]] = True
else:
available_col_num = ncol
total_block_num += available_col_num
num_one = max(1, int(density * available_col_num))
base_mask[0][idx, torch.randperm(available_col_num)[:num_one]] = True
elif density == 1.0:
base_mask[0] = torch.ones_like(base_mask[0])
total_block_num = nrow * ncol
else:
total_block_num = nrow * ncol
calculated_block_num = base_mask.sum().item()
real_sparsity = 1.0 - calculated_block_num / total_block_num
return base_mask, real_sparsity
block_size = 128
def get_sparsity_list(sampling_steps, seqlen, causal):
blockmask_element_num = (seqlen // block_size) ** 2 // (2 if causal else 1)
stride = max(blockmask_element_num // sampling_steps, 1)
actual_steps = (blockmask_element_num + stride - 1) // stride
sparsity_list = []
for i in range(actual_steps):
sparse_rate = (1 + i * stride) / blockmask_element_num
if sparse_rate > 0.95 or sparse_rate < 0.0:
continue
sparsity_list.append(sparse_rate)
return sparsity_list
def profile_blocksparse_fwd():
repeats = 15
block_sparse_repeats = 3
device = 'cuda:0'
dtype = torch.float16
causal = True
batch_size = 8
sparsity_sampling_steps = 20
seqlen_vals = [1024,2048,4096,8192,16384,32768,65536]
headdim = 128
dim = 4096
dropout_p = 0.0
method = ("Block_Sparse_Flash2")
time_f = {}
speed_f = {}
excel_label = ["batch_size", "seqlen", "actual_sparsity", "speed", "latency", "speedup", "base_speed", "base_latency"]
excel_data = []
excel_dir_path = "./excel/blocksparse/"
excel_file_name = f"hdim{headdim}_nheads{dim // headdim}_bts{batch_size}_fwd"
if causal:
excel_file_name += "_causal"
all_results = {}
for seqlen in seqlen_vals:
results = {}
nheads = dim // headdim
shape = (batch_size * seqlen, nheads, headdim)
q = torch.randn(shape, device=device, dtype=dtype)
k = torch.randn(shape, device=device, dtype=dtype)
v = torch.randn(shape, device=device, dtype=dtype)
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=device)
base_f = time_fwd(flash_attn_varlen_func, q, k, v, cu_seqlens, cu_seqlens, seqlen, seqlen, dropout_p, None, causal, repeats=repeats, verbose=False)
base_speed = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), base_f)
results["base"] = [[base_f], [base_speed]]
sparsity_list = get_sparsity_list(sparsity_sampling_steps, seqlen, causal)
print(f"sparsity_list: {sparsity_list}")
for sparsity in sparsity_list:
sum_sparsity, sum_speed, sum_latency = 0, 0, 0
for _ in range(block_sparse_repeats):
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=device)
head_mask_type = torch.tensor([1] * nheads, device=device, dtype=torch.int32)
base_blockmask, real_sparsity = generate_base_sparsity_mask(seqlen, seqlen, block_size, block_size, block_size, sparsity, causal = causal, device=device)
base_blockmask = base_blockmask.unsqueeze(0).repeat(batch_size, nheads, 1, 1)
config = (causal, headdim, nheads, batch_size, seqlen, sparsity, real_sparsity)
f = time_fwd(block_sparse_attn_func, q, k, v, cu_seqlens, cu_seqlens, head_mask_type, None, base_blockmask, seqlen, seqlen, dropout_p, is_causal=causal, exact_streaming=False, repeats=repeats, verbose=False)
time_f[config, method] = f
print(f"### causal={causal}, headdim={headdim}, nheads = {nheads}, batch_size={batch_size}, seqlen={seqlen}, real_sparsity={real_sparsity} ###")
speed_f[config, method] = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), time_f[config, method])
print(
f"{method}"
f"fwd: {speed_f[config, method]:.2f} TFLOPs/s, {(time_f[config, method]*1000):.2f} ms, "
f"fwd base: {base_speed:.2f} TFLOPs/s, {base_f*1000:.2f} ms"
)
sum_sparsity += real_sparsity
sum_speed += speed_f[config, method]
sum_latency += time_f[config, method]
avg_sparsity = sum_sparsity / block_sparse_repeats
avg_speed = sum_speed / block_sparse_repeats
avg_latency = sum_latency / block_sparse_repeats
if avg_sparsity not in results:
results[avg_sparsity] = [[],[]]
results[avg_sparsity][0].append(avg_latency)
results[avg_sparsity][1].append(avg_speed)
excel_data.append([batch_size, seqlen, avg_sparsity, avg_speed, avg_latency, avg_speed / base_speed, base_speed, base_f])
for key in results.keys():
avg_latency = sum(results[key][0]) / len(results[key][0])
avg_speed = sum(results[key][1]) / len(results[key][1])
results[key] = [avg_latency, avg_speed]
all_results[seqlen] = results
import json
with open(f"all_results_{excel_file_name}.json", "w") as f:
json.dump(all_results, f)
write_to_excel(excel_label, excel_data, excel_dir_path, excel_file_name)
profile_blocksparse_fwd()
\ No newline at end of file
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py
import openpyxl
from block_sparse_attn.utils.benchmark import benchmark_forward
import math
import torch
from block_sparse_attn import (
token_streaming_attn_func,
flash_attn_varlen_func,
)
from utils import (
time_fwd,
flops,
efficiency,
write_to_excel,
)
def profile_exact_streaming_fwd():
repeats = 20
block_sparse_repeats = 10
device = 'cuda:0'
dtype = torch.float16
causal = True
batch_size = 8
sink_local_num = [64,256]
seqlen_vals = [4096,8192,16384,32768,65536]
headdim_vals = [128]
dim = 4096
dropout_p = 0.0
methods = (["Flash2"])
time_f = {}
speed_f = {}
for headdim in headdim_vals:
excel_label = ["batch_size", "seqlen", "speed", "latency", "speedup", "base_speed", "base_latency"]
excel_data = []
excel_dir_path = "./excel/streaming/"
excel_file_name = f"hdim{headdim}_nheads{dim // headdim}_bts{batch_size}_sink{sink_local_num[0]}_local{sink_local_num[1]}_fwd"
for seqlen in seqlen_vals:
nheads = dim // headdim
shape = (batch_size * seqlen, nheads, headdim)
q = torch.randn(shape, device=device, dtype=dtype)
k = torch.randn(shape, device=device, dtype=dtype)
v = torch.randn(shape, device=device, dtype=dtype)
cu_seqlens = torch.arange(
0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=device)
base_f = time_fwd(flash_attn_varlen_func, q, k, v, cu_seqlens, cu_seqlens, seqlen, seqlen, dropout_p, None, causal, repeats=repeats, verbose=False)
base_speed = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), base_f)
head_mask_type = torch.tensor([-1] * (nheads//2) + [0] * (nheads - nheads//2), device=device, dtype=torch.int32)
streaming_info = torch.tensor([sink_local_num[0], sink_local_num[1]] * nheads, device=device, dtype=torch.int32)
config = (causal, headdim, nheads, batch_size, seqlen, sink_local_num[0], sink_local_num[1])
sum_speed, sum_latency = 0,0
for _ in range(block_sparse_repeats):
f = time_fwd(
token_streaming_attn_func, q, k, v, cu_seqlens, cu_seqlens, head_mask_type, streaming_info, seqlen, seqlen, repeats=repeats, verbose=False
)
time_f[config, "Flash2"] = f
print(f"### causal={causal}, headdim={headdim}, nheads = {nheads}, batch_size={batch_size}, seqlen={seqlen}, sink={sink_local_num[0]}, local={sink_local_num[1]} ###")
for method in methods:
speed_f[config, method] = efficiency(
flops(batch_size, seqlen, headdim,
nheads, causal, mode="fwd"),
time_f[config, method]
)
print(f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {(time_f[config, method]*1000):.2f} ms, ")
sum_speed += speed_f[config, "Flash2"]
sum_latency += time_f[config, "Flash2"]
excel_data.append([batch_size, seqlen, sum_speed / block_sparse_repeats, sum_latency / block_sparse_repeats, (sum_speed / block_sparse_repeats) / base_speed, base_speed, base_f])
write_to_excel(excel_label, excel_data, excel_dir_path, excel_file_name)
profile_exact_streaming_fwd()
\ No newline at end of file
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py
import openpyxl
from block_sparse_attn.utils.benchmark import benchmark_forward
import math
import torch
import os
def benchmark_fwd(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
return benchmark_forward(
fn,
*inputs,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
)
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
assert mode in ["fwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
def efficiency(flop, time):
return (flop / time / 10**12) if not math.isnan(time) else 0.0
def time_fwd(func, *args, **kwargs):
time_f = benchmark_fwd(func, *args, **kwargs)
return time_f[1].mean
def write_to_excel(label, data, dir_path, file_name):
workbook = openpyxl.Workbook()
sheet = workbook.active
sheet.append(label)
os.makedirs(dir_path, exist_ok=True)
for row in data:
sheet.append(row)
workbook.save(dir_path + file_name + ".xlsx")
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/tests/test_flash_attn.py
import pytest
import torch
from einops import repeat
from block_sparse_attn import (
block_sparse_attn_func,
)
from utils import (
generate_random_padding_mask,
generate_base_sparsity_mask,
generate_qkv,
generate_streaming_mask,
prepare_mixed_exact_mask,
prepare_mixed_mask,
convert_flash_attn_S_to_softmax,
normalize_flash_attn_S,
get_dropout_fraction,
attention_blocksparse_ref
)
MAX_HEADDIM_SM8x = 192
block_size = 128
is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("d", [32, 64, 128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
@pytest.mark.parametrize(
"causal, exact_streaming, sink_num, local_num",
[
# (True, True, 1, 3),
# (True, True, 64, 256),
(True, False, 1, 3),
(False, False, 1, 3),
]
)
@pytest.mark.parametrize("p_dropout", [0.17, 0.0])
@pytest.mark.parametrize("sparsity", [0, 0.1, 0.3, 0.7, 1.0])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("nheads", [16, 32])
def test_flash_attn_varlen_block_output(
seqlen_q, seqlen_k, d, p_dropout, causal, exact_streaming, sink_num, local_num, mha_type, dtype, sparsity, batch_size, nheads
):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
device = "cuda:0"
# set seed
torch.random.manual_seed(42)
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 8)
assert nheads % nheads_k == 0
window_size = (-1, -1)
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
alibi_slopes, attn_bias = None, None
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
num_streaming_heads = nheads // 3
num_blocksparse_heads = nheads // 3
num_dense_heads = nheads - num_streaming_heads - num_blocksparse_heads
sparsity_list = [sparsity] * num_blocksparse_heads
head_mask_type = torch.tensor([0] * num_dense_heads + [1] * num_blocksparse_heads + [-1] * num_streaming_heads, device=device, dtype=torch.int32)
base_blockmask = generate_base_sparsity_mask(max_seqlen_q, max_seqlen_k, block_size, block_size, block_size, batch_size, num_blocksparse_heads, sparsity_list, causal = causal, device=device)
streaming_info = torch.tensor([sink_num, local_num] * nheads, device=device, dtype=torch.int32)
streaming_mask = generate_streaming_mask(max_seqlen_q, max_seqlen_k, batch_size, nheads, cu_seqlens_q, cu_seqlens_k, block_size, block_size, block_size, streaming_info, causal=causal, device=device)
if exact_streaming:
assert causal
print(f"exact_streaming: {exact_streaming}")
if exact_streaming:
mixed_mask = prepare_mixed_exact_mask(base_blockmask, streaming_info, head_mask_type, batch_size, nheads, block_size, block_size, block_size, max_seqlen_q, max_seqlen_k, q.shape[1], k.shape[1], query_padding_mask, key_padding_mask, device=device)
else:
mixed_mask = prepare_mixed_mask(base_blockmask, streaming_mask, head_mask_type, batch_size, nheads, block_size, block_size, block_size, max_seqlen_q, max_seqlen_k, q.shape[1], k.shape[1], device=device)
out_unpad, sm_lse, S_dmask = block_sparse_attn_func(
q_unpad, k_unpad, v_unpad,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q, max_seqlen_k,
p_dropout,
deterministic=True,
softmax_scale=None,
is_causal=causal,
exact_streaming=exact_streaming,
return_attn_probs=True,
)
out = output_pad_fn(out_unpad)
if p_dropout > 0.0:
assert S_dmask is not None
S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask,
seqlen_q,
seqlen_k,
query_padding_mask,
key_padding_mask,
d,
p_dropout > 0.0,
causal=causal,
window_size=window_size,
)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
attn = normalize_flash_attn_S(
attn_unnorm,
q,
k_rep,
v_rep,
query_padding_mask,
key_padding_mask,
attn_bias,
p_dropout > 0.0,
causal=causal,
window_size=window_size,
)
dropout_fraction = get_dropout_fraction(
dropout_mask,
mixed_mask,
block_size, block_size,
query_padding_mask,
key_padding_mask,
causal=causal,
window_size=window_size,
).item()
print(f"Actual dropout fraction: {dropout_fraction}")
else:
dropout_mask = None
out_ref, attn_ref = attention_blocksparse_ref(
q,
k,
v,
mixed_mask,
block_size, block_size,
query_padding_mask,
key_padding_mask,
p_dropout,
dropout_mask,
causal=causal,
window_size=window_size,
)
out_pt, attn_pt = attention_blocksparse_ref(
q,
k,
v,
mixed_mask,
block_size, block_size,
query_padding_mask,
key_padding_mask,
p_dropout,
dropout_mask,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
g = torch.randn_like(out)
# g = torch.zeros_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
(
dq_unpad,
dk_unpad,
dv_unpad,
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
dk = dk_pad_fn(dk_unpad)
dv = dk_pad_fn(dv_unpad)
(
dq_ref,
dk_ref,
dv_ref,
) = torch.autograd.grad(out_ref, (q, k, v), g)
(
dq_pt,
dk_pt,
dv_pt,
) = torch.autograd.grad(out_pt, (q, k, v), g)
dq = dq_pad_fn(dq_unpad)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
\ No newline at end of file
This diff is collapsed.
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py
import openpyxl
from block_sparse_attn.utils.benchmark import benchmark_forward
import math
import torch
from block_sparse_attn import (
block_streaming_attn_func,
flash_attn_varlen_func,
)
from utils import (
time_fwd_bwd,
flops,
efficiency,
write_to_excel,
)
def profile_block_streaming_fwd_bwd():
repeats = 10
block_sparse_repeats = 5
device = 'cuda:0'
dtype = torch.float16
causal = True
batch_size = 1
sink_local_block_num = [1,3]
seqlen_vals = [1024, 2048, 4096, 8192, 16384, 20480, 24576, 28672, 32768, 65536, 131072]
headdim_vals = [128]
dim = 4096
p_dropout = 0.0
methods = (["Flash2"])
time_f = {}
time_b = {}
time_f_b = {}
speed_f = {}
speed_b = {}
speed_f_b = {}
for headdim in headdim_vals:
excel_label = ["batch_size", "seqlen", "speed", "latency", "speedup", "base_speed", "base_latency"]
excel_data = []
excel_dir_path = "./excel/block_streaming/"
excel_file_name = f"hdim{headdim}_nheads{dim // headdim}_bts{batch_size}_sink_block{sink_local_block_num[0]}_local_block{sink_local_block_num[1]}_fwd_bwd"
for seqlen in seqlen_vals:
nheads = dim // headdim
shape = (batch_size * seqlen, nheads, headdim)
q = torch.randn(shape, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(shape, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(shape, device=device, dtype=dtype, requires_grad=True)
cu_seqlens = torch.arange(
0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=device)
base_f, base_b = time_fwd_bwd(flash_attn_varlen_func, q, k, v, cu_seqlens, cu_seqlens, seqlen, seqlen, p_dropout, None, causal, repeats=repeats, verbose=False)
base_speed = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"), base_f + base_b)
head_mask_type = torch.tensor([-1] * (nheads//2) + [0] * (nheads - nheads//2), device=device, dtype=torch.int32)
streaming_info = torch.tensor([sink_local_block_num[0], sink_local_block_num[1]] * nheads, device=device, dtype=torch.int32)
config = (causal, headdim, nheads, batch_size, seqlen, sink_local_block_num[0], sink_local_block_num[1])
sum_speed, sum_latency = 0,0
for _ in range(block_sparse_repeats):
f, b = time_fwd_bwd(
block_streaming_attn_func, q, k, v, cu_seqlens, cu_seqlens, head_mask_type, streaming_info, seqlen, seqlen, p_dropout, False, None, causal, repeats=repeats, verbose=False
)
time_f[config, "Flash2"] = f
time_b[config, "Flash2"] = b
print(f"### causal={causal}, headdim={headdim}, nheads = {nheads}, batch_size={batch_size}, seqlen={seqlen}, sink={sink_local_block_num[0]}, local={sink_local_block_num[1]} ###")
for method in methods:
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
speed_f[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
time_f[config, method]
)
speed_b[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
time_b[config, method]
)
speed_f_b[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
time_f_b[config, method]
)
print(
f"{method}"
f"fwd: {speed_f[config, method]:.2f} TFLOPs/s, {(time_f[config, method]*1000):.2f} ms, "
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, {(time_b[config, method]*1000):.2f} ms, "
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s, {(time_f_b[config, method]*1000):.2f} ms, "
f"fwd + bwd base: {base_speed:.2f} TFLOPs/s, {(base_f + base_b)*1000:.2f} ms"
)
sum_speed += speed_f_b[config, "Flash2"]
sum_latency += time_f_b[config, "Flash2"]
excel_data.append([batch_size, seqlen, sum_speed / block_sparse_repeats, sum_latency / block_sparse_repeats, (sum_speed / block_sparse_repeats) / base_speed, base_speed, base_f])
write_to_excel(excel_label, excel_data, excel_dir_path, excel_file_name)
profile_block_streaming_fwd_bwd()
\ No newline at end of file
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py
import torch
from block_sparse_attn import (
block_sparse_attn_func,
flash_attn_varlen_func,
)
from utils import (
time_fwd_bwd,
flops,
efficiency,
write_to_excel,
)
def generate_base_sparsity_mask(max_seqlen_q, max_seqlen_k, round_base, m_block_dim, n_block_dim, sparsity, causal=False, device="cuda"):
def round_to_multiple(x, base):
return ((x + base - 1) // base) * base
nrow, ncol = round_to_multiple(max_seqlen_q, round_base) // m_block_dim, round_to_multiple(max_seqlen_k, round_base) // n_block_dim
base_mask = torch.zeros(1, nrow, ncol, device=device, dtype=torch.bool)
total_block_num = 0
density = 1.0 - sparsity
if not density == 0.0 and not density == 1.0:
for i in range(nrow): # do in reverse order
idx = nrow - i - 1
if causal:
available_col_num = max(0, ncol - i)
total_block_num += available_col_num
num_one = max(1, int(density * available_col_num))
base_mask[0][idx, torch.randperm(available_col_num)[:num_one]] = True
else:
available_col_num = ncol
total_block_num += available_col_num
num_one = max(1, int(density * available_col_num))
base_mask[0][idx, torch.randperm(available_col_num)[:num_one]] = True
elif density == 1.0:
base_mask[0] = torch.ones_like(base_mask[0])
total_block_num = nrow * ncol
else:
total_block_num = nrow * ncol
calculated_block_num = base_mask.sum().item()
real_sparsity = 1.0 - calculated_block_num / total_block_num
return base_mask, real_sparsity
block_size = 128
def get_sparsity_list(sampling_steps, seqlen, causal):
blockmask_element_num = (seqlen // block_size) ** 2 // (2 if causal else 1)
stride = max(blockmask_element_num // sampling_steps, 1)
actual_steps = (blockmask_element_num + stride - 1) // stride
sparsity_list = []
for i in range(actual_steps):
sparse_rate = (1 + i * stride) / blockmask_element_num
if sparse_rate > 0.95 or sparse_rate < 0.0:
continue
sparsity_list.append(sparse_rate)
return sparsity_list
def profile_blocksparse_fwd_bwd():
repeats = 10
block_sparse_repeats = 5
device = 'cuda:0'
dtype = torch.float16
causal = True
batch_size = 1
sparsity_sampling_steps = 20
seqlen_vals = [8192,16384,32768]
headdim = 128
dim = 4096
dropout_p = 0.0
method = ("Block_Sparse_Attn")
time_f = {}
time_b = {}
time_f_b = {}
speed_f = {}
speed_b = {}
speed_f_b = {}
excel_label = ["batch_size", "seqlen", "actual_sparsity", "speed", "latency", "speedup", "base_speed", "base_latency"]
excel_data = []
excel_dir_path = "./excel/blocksparse/"
excel_file_name = f"hdim{headdim}_nheads{dim // headdim}_bts{batch_size}_fwd_bwd"
if causal:
excel_file_name += "_causal"
all_results = {}
for seqlen in seqlen_vals:
results = {}
nheads = dim // headdim
shape = (batch_size * seqlen, nheads, headdim)
q = torch.randn(shape, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(shape, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(shape, device=device, dtype=dtype, requires_grad=True)
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=device)
base_f, base_b = time_fwd_bwd(flash_attn_varlen_func, q, k, v, cu_seqlens, cu_seqlens, seqlen, seqlen, dropout_p, None, causal, repeats=repeats, verbose=False)
base_speed = efficiency(flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"), base_f + base_b)
results["base"] = [[base_f + base_b], [base_speed]]
sparsity_list = get_sparsity_list(sparsity_sampling_steps, seqlen, causal)
print(f"sparsity_list: {sparsity_list}")
for sparsity in sparsity_list:
sum_sparsity, sum_speed, sum_latency = 0, 0, 0
for _ in range(block_sparse_repeats):
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=device)
head_mask_type = torch.tensor([1] * nheads, device=device, dtype=torch.int32)
base_blockmask, real_sparsity = generate_base_sparsity_mask(seqlen, seqlen, block_size, block_size, block_size, sparsity, causal = causal, device=device)
base_blockmask = base_blockmask.unsqueeze(0).repeat(batch_size, nheads, 1, 1)
config = (causal, headdim, nheads, batch_size, seqlen, sparsity, real_sparsity)
f, b = time_fwd_bwd(block_sparse_attn_func, q, k, v, cu_seqlens, cu_seqlens, head_mask_type, None, base_blockmask, seqlen, seqlen, dropout_p, is_causal=causal, exact_streaming=False, repeats=repeats, verbose=False)
time_f[config, method] = f
time_b[config, method] = b
print(f"### causal={causal}, headdim={headdim}, nheads = {nheads}, batch_size={batch_size}, seqlen={seqlen}, real_sparsity={real_sparsity} ###")
time_f_b[config, method] = time_f[config, method] + time_b[config, method]
speed_f[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
time_f[config, method]
)
speed_b[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"),
time_b[config, method]
)
speed_f_b[config, method] = efficiency(
flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"),
time_f_b[config, method]
)
print(
f"{method}"
f"fwd: {speed_f[config, method]:.2f} TFLOPs/s, {(time_f[config, method]*1000):.2f} ms, "
f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, {(time_b[config, method]*1000):.2f} ms, "
f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s, {(time_f_b[config, method]*1000):.2f} ms, "
f"fwd + bwd base: {base_speed:.2f} TFLOPs/s, {(base_f + base_b)*1000:.2f} ms"
)
sum_sparsity += real_sparsity
sum_speed += speed_f_b[config, method]
sum_latency += time_f_b[config, method]
avg_sparsity = sum_sparsity / block_sparse_repeats
avg_speed = sum_speed / block_sparse_repeats
avg_latency = sum_latency / block_sparse_repeats
if avg_sparsity not in results:
results[avg_sparsity] = [[],[]]
results[avg_sparsity][0].append(avg_latency)
results[avg_sparsity][1].append(avg_speed)
excel_data.append([batch_size, seqlen, avg_sparsity, avg_speed, avg_latency, avg_speed / base_speed, base_speed, base_f + base_b])
for key in results.keys():
avg_latency = sum(results[key][0]) / len(results[key][0])
avg_speed = sum(results[key][1]) / len(results[key][1])
results[key] = [avg_latency, avg_speed]
all_results[seqlen] = results
import json
with open(f"all_results_{excel_file_name}.json", "w") as f:
json.dump(all_results, f)
write_to_excel(excel_label, excel_data, excel_dir_path, excel_file_name)
profile_blocksparse_fwd_bwd()
\ No newline at end of file
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py
import openpyxl
from block_sparse_attn.utils.benchmark import benchmark_forward, benchmark_backward
import math
import torch
import os
def benchmark_fwd_bwd(
fn,
*inputs,
grad=None,
repeats=10,
desc="",
verbose=True,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
return (
benchmark_forward(
fn,
*inputs,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
benchmark_backward(
fn,
*inputs,
grad=grad,
repeats=repeats,
desc=desc,
verbose=verbose,
amp=amp,
amp_dtype=amp_dtype,
**kwinputs,
),
)
def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
assert mode in ["fwd", "bwd", "fwd_bwd"]
f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
def efficiency(flop, time):
return (flop / time / 10**12) if not math.isnan(time) else 0.0
def time_fwd_bwd(func, *args, **kwargs):
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
return time_f[1].mean, time_b[1].mean
def write_to_excel(label, data, dir_path, file_name):
workbook = openpyxl.Workbook()
sheet = workbook.active
sheet.append(label)
os.makedirs(dir_path, exist_ok=True)
for row in data:
sheet.append(row)
workbook.save(dir_path + file_name + ".xlsx")
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
// Copyright (c) 2023, Tri Dao.
// Adapted by Junxian Guo from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_block_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
run_mha_bwd_block_hdim128<cutlass::bfloat16_t>(params, stream, configure);
}
// Copyright (c) 2023, Tri Dao.
// Adapted by Junxian Guo from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_block_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
run_mha_bwd_block_hdim128<cutlass::half_t>(params, stream, configure);
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment