Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
623 additions
and
986 deletions
+623
-986
examples/bitnet-1.58b/vllm_workspace/utils.py
examples/bitnet-1.58b/vllm_workspace/utils.py
+6
-17
examples/blocksparse_attention/block_sparse_attn_triton.py
examples/blocksparse_attention/block_sparse_attn_triton.py
+24
-47
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
...ocksparse_attention/example_tilelang_block_sparse_attn.py
+37
-44
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
...rse_attention/example_tilelang_sparse_gqa_decode_paged.py
+101
-139
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
...ntion/example_tilelang_sparse_gqa_decode_varlen_indice.py
+96
-144
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
...tention/example_tilelang_sparse_gqa_decode_varlen_mask.py
+93
-139
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
...tention/example_triton_sparse_gqa_decode_varlen_indice.py
+53
-100
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
...attention/example_triton_sparse_gqa_decode_varlen_mask.py
+52
-95
examples/blocksparse_attention/heuristic.py
examples/blocksparse_attention/heuristic.py
+1
-2
examples/blocksparse_attention/test_example_blocksparse_attention.py
...ocksparse_attention/test_example_blocksparse_attention.py
+4
-16
examples/blocksparse_gemm/example_blocksparse_gemm.py
examples/blocksparse_gemm/example_blocksparse_gemm.py
+30
-39
examples/cast/example_group_per_split_token_cast_to_fp8.py
examples/cast/example_group_per_split_token_cast_to_fp8.py
+25
-26
examples/cast/example_per_token_cast_to_fp8.py
examples/cast/example_per_token_cast_to_fp8.py
+11
-16
examples/cast/example_triton_cast_to_fp8.py
examples/cast/example_triton_cast_to_fp8.py
+1
-3
examples/cast/test_example_cast.py
examples/cast/test_example_cast.py
+1
-2
examples/compile_flags/usecase.py
examples/compile_flags/usecase.py
+4
-6
examples/conftest.py
examples/conftest.py
+2
-5
examples/convolution/example_convolution.py
examples/convolution/example_convolution.py
+23
-41
examples/convolution/example_convolution_autotune.py
examples/convolution/example_convolution_autotune.py
+46
-89
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
+13
-16
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
examples/bitnet-1.58b/vllm_workspace/utils.py
View file @
29051439
...
...
@@ -3,8 +3,7 @@ from typing import Dict, List, Tuple
TokensText
=
Tuple
[
List
[
int
],
str
]
def
check_outputs_equal
(
outputs_0_lst
:
List
[
TokensText
],
outputs_1_lst
:
List
[
TokensText
],
name_0
:
str
,
name_1
:
str
):
def
check_outputs_equal
(
outputs_0_lst
:
List
[
TokensText
],
outputs_1_lst
:
List
[
TokensText
],
name_0
:
str
,
name_1
:
str
):
"""
Compare the two sequences generated by different models,
which should be equal.
...
...
@@ -15,19 +14,14 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[Tok
output_ids_0
,
output_str_0
=
outputs_0
output_ids_1
,
output_str_1
=
outputs_1
assert
output_str_0
==
output_str_1
,
(
f
"Test
{
prompt_idx
}
:"
f
"
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
)
assert
output_ids_0
==
output_ids_1
,
(
f
"Test
{
prompt_idx
}
:"
f
"
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
)
assert
output_str_0
==
output_str_1
,
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
assert
output_ids_0
==
output_ids_1
,
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
TokensTextLogprobs
=
Tuple
[
List
[
int
],
str
,
List
[
Dict
[
int
,
float
]]]
def
check_logprobs_close
(
outputs_0_lst
:
List
[
TokensTextLogprobs
],
outputs_1_lst
:
List
[
TokensTextLogprobs
],
name_0
:
str
,
name_1
:
str
):
def
check_logprobs_close
(
outputs_0_lst
:
List
[
TokensTextLogprobs
],
outputs_1_lst
:
List
[
TokensTextLogprobs
],
name_0
:
str
,
name_1
:
str
):
"""
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
...
...
@@ -41,16 +35,11 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
# Loop through generated tokens.
for
idx
,
(
output_id_0
,
output_id_1
)
in
enumerate
(
zip
(
output_ids_0
,
output_ids_1
)):
# If generated tokens don't match, then
if
output_id_0
!=
output_id_1
:
# Each predicted token must be in top N logprobs of the other
assert
output_id_0
in
logprobs_1
[
idx
],
(
f
"Test
{
prompt_idx
}
:"
f
"
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
)
assert
output_id_1
in
logprobs_0
[
idx
],
(
f
"Test
{
prompt_idx
}
:"
f
"
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
)
assert
output_id_0
in
logprobs_1
[
idx
],
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
assert
output_id_1
in
logprobs_0
[
idx
],
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
# Break out since sequences will now diverge.
break
examples/blocksparse_attention/block_sparse_attn_triton.py
View file @
29051439
...
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
mask_val
=
tl
.
load
(
block_mask_ptr
+
k_block_col_idx
*
stride_bmask_n
)
# print
...
...
@@ -73,8 +69,7 @@ def _fwd_kernel_inner(
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if
LAST_K_BLOCK
:
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
'-inf'
))
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"-inf"
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
-=
m_ij
[:,
None
]
...
...
@@ -154,7 +149,7 @@ def _fwd_kernel(
v_ptrs
=
V
+
off_v
mask_ptrs
=
block_mask_ptr
+
start_m
*
stride_bmm
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
'
inf
'
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"
inf
"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
...
...
@@ -192,24 +187,12 @@ def _fwd_kernel(
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
Out
.
dtype
.
element_ty
)
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
N_CTX
)
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
k
.
shape
[
2
]
==
v
.
shape
[
2
]
o
=
out
if
out
is
not
None
else
torch
.
empty_like
(
q
).
contiguous
()
...
...
@@ -254,7 +237,6 @@ def _forward(ctx,
class
_sparse_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
):
# shape constraints
...
...
@@ -278,9 +260,9 @@ def test_topk_sparse_attention():
torch
.
manual_seed
(
0
)
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
...
...
@@ -288,9 +270,7 @@ def test_topk_sparse_attention():
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
print
(
"downsample_len"
,
downsample_len
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
print
(
"x_ds.shape"
,
x_ds
.
shape
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
...
@@ -302,22 +282,21 @@ def test_topk_sparse_attention():
# Compute reference
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
# print("ref_output", ref_output)
# print("triton_output", triton_output)
# Verify accuracy
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
"Triton output doesn't match reference"
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference"
print
(
"Pass topk sparse attention test with qlen == klen"
)
...
...
@@ -329,9 +308,9 @@ def test_topk_sparse_attention_qlt_kl():
torch
.
manual_seed
(
0
)
# Create inputs.
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
# softmax scale
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
...
...
@@ -339,8 +318,7 @@ def test_topk_sparse_attention_qlt_kl():
print
(
"downsample_factor"
,
downsample_factor
)
downsample_len
=
math
.
ceil
(
K_LEN
/
downsample_factor
)
# number of blocks along one dimension
print
(
"downsample_len"
,
downsample_len
)
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# Force the first column to be high so that the first block is always selected.
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
...
@@ -351,26 +329,25 @@ def test_topk_sparse_attention_qlt_kl():
past_len
=
K_LEN
-
Q_LEN
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
)).
bool
()
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
)).
bool
()
full_mask_full
=
full_mask_full
[...,
:
K_LEN
,
:
K_LEN
]
effective_mask
=
full_mask_full
[...,
past_len
:
K_LEN
,
:]
# shape: (B, H, Q_LEN, K_LEN)
i_global
=
torch
.
arange
(
past_len
,
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
1
)
# shape: (Q_LEN, 1)
j_global
=
torch
.
arange
(
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
0
)
# shape: (1, K_LEN)
causal_mask
=
(
j_global
<=
i_global
)
# shape: (Q_LEN, K_LEN)
causal_mask
=
j_global
<=
i_global
# shape: (Q_LEN, K_LEN)
final_mask
=
effective_mask
&
causal_mask
# shape: (B, H, Q_LEN, K_LEN)
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
# Verify accuracy.
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
"Triton output doesn't match reference when qlen < klen"
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference when qlen < klen"
print
(
"Pass topk sparse attention test with qlen < klen"
)
...
...
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
View file @
29051439
...
...
@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
@
tilelang
.
jit
(
out_idx
=
[
4
],
pass_configs
=
{
out_idx
=
[
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
blocksparse_flashattn
(
batch
,
heads
,
seq_len
,
dim
,
downsample_len
,
is_causal
):
block_M
=
64
block_N
=
64
num_stages
=
1
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
block_mask_shape
=
[
batch
,
heads
,
downsample_len
,
downsample_len
]
...
...
@@ -47,7 +46,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype
=
"bool"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
shape
,
dtype
),
...
...
@@ -59,11 +57,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -78,18 +75,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -113,22 +110,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
blocksparse_flashattn
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
...
@@ -143,7 +139,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -152,20 +148,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask
[
vj
]
=
BlockSparseMask
[
bz
,
by
,
bx
,
vj
]
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
k
]
!=
0
:
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
blocksparse_flashattn
...
...
@@ -180,18 +175,16 @@ def test_topk_sparse_attention():
torch
.
manual_seed
(
0
)
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
...
@@ -202,15 +195,15 @@ def test_topk_sparse_attention():
# Compute reference
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
print
(
"ref_output"
,
ref_output
)
print
(
"tilelang_output"
,
tilelang_output
)
...
...
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
View file @
29051439
...
...
@@ -13,17 +13,20 @@ from heuristic import num_splits_heuristic
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
heads_kv
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
kernel_func
(
block_N
,
block_H
,
page_block_size
,
num_split
,
num_stages
,
threads
,
num_pages
,
max_num_blocks_per_seq
,
max_selected_blocks
):
},
)
def
kernel_func
(
block_N
,
block_H
,
page_block_size
,
num_split
,
num_stages
,
threads
,
num_pages
,
max_num_blocks_per_seq
,
max_selected_blocks
):
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
num_pages
,
page_block_size
,
heads_kv
,
dim
]
shape_v
=
[
num_pages
,
page_block_size
,
heads_kv
,
dim_v
]
...
...
@@ -37,17 +40,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
block_table
:
T
.
Tensor
(
shape_block_table
,
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
block_table
:
T
.
Tensor
(
shape_block_table
,
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
...
...
@@ -67,7 +69,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -75,7 +77,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
num_blocks
=
max_selected_blocks
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
loop_range
=
(
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
has_valid_block
=
False
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
...
...
@@ -85,29 +87,20 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
block_table_idx
=
T
.
floordiv
(
logical_block_idx
,
block_ratio
)
block_tile_idx
=
T
.
floormod
(
logical_block_idx
,
block_ratio
)
physical_block_idx
=
block_table
[
bid
,
block_table_idx
]
T
.
copy
(
K
[
physical_block_idx
,
block_tile_idx
*
block_N
:(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
K
[
physical_block_idx
,
block_tile_idx
*
block_N
:
(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
k
==
0
:
# assume block_indices is sorted in reverse order, otherwise, remove this if condition
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
logical_block_idx
*
block_N
+
j
>=
cache_seqlens
[
bid
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
)
logical_block_idx
*
block_N
+
j
>=
cache_seqlens
[
bid
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
...
...
@@ -116,10 +109,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
physical_block_idx
,
block_tile_idx
*
block_N
:(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
copy
(
V
[
physical_block_idx
,
block_tile_idx
*
block_N
:
(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_valid_block
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
...
...
@@ -138,9 +128,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
...
...
@@ -151,17 +141,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
max_split
=
T
.
alloc_local
([
1
],
"int32"
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
lse_max_local
[
0
]
=
-
T
.
infinity
(
accum_dtype
)
for
k
in
T
.
serial
(
num_split
):
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
if
(
lse_local_split
[
0
]
!=
0
)
:
if
lse_local_split
[
0
]
!=
0
:
max_split
[
0
]
=
k
lse_max_local
[
0
]
=
T
.
max
(
lse_max_local
[
0
],
glse
[
bz
,
by
,
k
])
...
...
@@ -183,18 +174,17 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
block_table
:
T
.
Tensor
(
shape_block_table
,
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
block_table
:
T
.
Tensor
(
shape_block_table
,
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
flash_attn_split
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
block_table
,
glse
,
Output_partial
)
flash_attn_split
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
block_table
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
return
main
...
...
@@ -203,7 +193,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
page_block_size
,
block_N
,
num_pages
):
super
(
SparseFlashAttn
,
self
).
__init__
()
self
.
batch
=
batch
...
...
@@ -249,18 +238,11 @@ class SparseFlashAttn(torch.nn.Module):
num_sm
=
self
.
num_sm
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output
=
self
.
kernel
(
query
,
...
...
@@ -275,14 +257,13 @@ class SparseFlashAttn(torch.nn.Module):
return
output
def
ref_program_torch_paged
(
query
,
key_cache
,
value_cache
,
block_indices
,
cache_seqlens
,
block_table
,
page_block_size
,
block_size
):
def
ref_program_torch_paged
(
query
,
key_cache
,
value_cache
,
block_indices
,
cache_seqlens
,
block_table
,
page_block_size
,
block_size
):
"""
Paged version of sparse attention reference implementation.
Args:
query: [batch, heads, dim]
key_cache: [num_pages, page_block_size, heads_kv, dim]
key_cache: [num_pages, page_block_size, heads_kv, dim]
value_cache: [num_pages, page_block_size, heads_kv, dim]
block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices
cache_seqlens: [batch] - actual sequence lengths
...
...
@@ -298,12 +279,8 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
# Reconstruct the full key and value tensors from paged cache
max_cache_seqlen
=
max
(
cache_seqlens
).
item
()
key_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim
),
dtype
=
key_cache
.
dtype
,
device
=
key_cache
.
device
)
value_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim_v
),
dtype
=
value_cache
.
dtype
,
device
=
value_cache
.
device
)
key_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim
),
dtype
=
key_cache
.
dtype
,
device
=
key_cache
.
device
)
value_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim_v
),
dtype
=
value_cache
.
dtype
,
device
=
value_cache
.
device
)
# Reconstruct full tensors from paged cache using block_table
for
b
in
range
(
batch
):
...
...
@@ -319,20 +296,14 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
actual_block_size
=
end_token
-
start_token
# Copy from paged cache to full tensors
key_full
[
b
,
:,
start_token
:
end_token
,
:]
=
key_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
value_full
[
b
,
:,
start_token
:
end_token
,
:]
=
value_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
key_full
[
b
,
:,
start_token
:
end_token
,
:]
=
key_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
value_full
[
b
,
:,
start_token
:
end_token
,
:]
=
value_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
# Reshape query for grouped attention
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
# Compute attention scores
scores
=
einsum
(
query
,
key_full
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
einsum
(
query
,
key_full
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
# Create sparse mask based on block_indices
sparse_mask
=
torch
.
zeros_like
(
scores
)
...
...
@@ -348,24 +319,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
sparse_mask
[
b
,
:,
h
,
start_pos
:
end_pos
]
=
1
# Apply sparse mask
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
# Apply causal mask based on actual sequence lengths
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
scores
.
device
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"
-inf
"
))
# Compute attention weights
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# Apply attention to values
out
=
einsum
(
attention
,
value_full
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
einsum
(
attention
,
value_full
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
# Reshape output back to original format
out
=
rearrange
(
out
,
'
b g h d -> b (h g) d
'
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
"
b g h d -> b (h g) d
"
)
# [batch_size, heads, dim]
return
out
...
...
@@ -373,17 +343,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
def
ref_program_fa
(
query
,
kcache
,
vcache
,
cache_seqlens
,
block_table
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
kcache
,
vcache
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_table
)
output
=
flash_attn_with_kvcache
(
query
,
kcache
,
vcache
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_table
)
output
=
output
.
squeeze
(
1
)
return
output
def
main
(
args
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
)
sparse_ratio
=
args
.
sparse_ratio
block_N
=
args
.
block_N
page_block_size
=
args
.
page_block_size
...
...
@@ -395,35 +371,30 @@ def main(args):
dtype
=
torch
.
float16
# Generate random inputs
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'cuda'
)
cache_seqlens
=
torch
.
randint
(
max_cache_seqlen
//
2
,
max_cache_seqlen
+
1
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'cuda'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
cache_seqlens
=
torch
.
randint
(
max_cache_seqlen
//
2
,
max_cache_seqlen
+
1
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
print
(
"cache_seqlens: "
,
cache_seqlens
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
# Create paged KV cache
K_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'cuda'
)
V_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'cuda'
)
K_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
V_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"cuda"
)
# Create block table and block indices for dense case (all blocks selected)
max_num_blocks_per_seq
=
int
(
math
.
ceil
(
max_cache_seqlen
/
page_block_size
))
print
(
"max_num_blocks_per_seq: "
,
max_num_blocks_per_seq
)
block_table
=
torch
.
zeros
((
batch
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
,
device
=
'cuda'
)
block_indices
=
torch
.
zeros
((
batch
,
heads_kv
,
max_selected_blocks
),
dtype
=
torch
.
int32
,
device
=
'cuda'
)
block_table
=
torch
.
zeros
((
batch
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
block_indices
=
torch
.
zeros
((
batch
,
heads_kv
,
max_selected_blocks
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# Fill block table and block indices and cache
# Create a pool of available physical blocks
total_blocks_needed
=
sum
(
int
(
math
.
ceil
(
cache_seqlens
[
seq_idx
].
item
()
/
page_block_size
))
for
seq_idx
in
range
(
batch
))
total_blocks_needed
=
sum
(
int
(
math
.
ceil
(
cache_seqlens
[
seq_idx
].
item
()
/
page_block_size
))
for
seq_idx
in
range
(
batch
))
available_blocks
=
list
(
range
(
total_blocks_needed
))
import
random
random
.
seed
(
42
)
# For reproducibility
random
.
shuffle
(
available_blocks
)
...
...
@@ -458,10 +429,8 @@ def main(args):
actual_block_size
=
end_token
-
start_token
# Copy K and V data to the paged cache
K_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
K
[
seq_idx
,
start_token
:
end_token
,
:,
:]
V_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
V
[
seq_idx
,
start_token
:
end_token
,
:,
:]
K_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
K
[
seq_idx
,
start_token
:
end_token
,
:,
:]
V_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
V
[
seq_idx
,
start_token
:
end_token
,
:,
:]
# Fill block_indices for sparse attention
# For dense case (verification), we select all blocks in reverse order
...
...
@@ -496,10 +465,9 @@ def main(args):
remaining_blocks
=
[
b
for
b
in
all_blocks
if
b
not
in
selected_blocks
]
if
remaining_blocks
:
import
random
random
.
seed
(
42
)
# For reproducibility
additional_blocks
=
random
.
sample
(
remaining_blocks
,
min
(
num_selected
-
recent_blocks
,
len
(
remaining_blocks
)))
additional_blocks
=
random
.
sample
(
remaining_blocks
,
min
(
num_selected
-
recent_blocks
,
len
(
remaining_blocks
)))
selected_blocks
.
extend
(
additional_blocks
)
# Sort selected blocks in reverse order (most recent first)
...
...
@@ -512,25 +480,20 @@ def main(args):
block_indices
[
seq_idx
,
head_idx
,
i
]
=
-
1
# Initialize sparse attention module
sparse_attn
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
page_block_size
,
block_N
,
num_blocks
)
output_sparse
=
sparse_attn
.
forward
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
)
sparse_attn
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
page_block_size
,
block_N
,
num_blocks
)
output_sparse
=
sparse_attn
.
forward
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
)
import
flash_attn
# noqa: F401
output_ref_torch
=
ref_program_torch_paged
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
,
page_block_size
,
block_N
)
output_ref_torch
=
ref_program_torch_paged
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
,
page_block_size
,
block_N
)
output_ref_fa
=
ref_program_fa
(
Q
,
K_cache
,
V_cache
,
cache_seqlens
,
block_table
)
# Check correctness
if
sparse_ratio
==
0.0
:
max_diff
=
torch
.
max
(
torch
.
abs
(
output_sparse
-
output_ref_fa
)).
item
()
mean_diff
=
torch
.
mean
(
torch
.
abs
(
output_sparse
-
output_ref_fa
)).
item
()
assert
torch
.
allclose
(
output_ref_fa
,
output_ref_torch
,
atol
=
1e-2
),
"Reference outputs do not match!"
assert
torch
.
allclose
(
output_ref_fa
,
output_ref_torch
,
atol
=
1e-2
),
"Reference outputs do not match!"
else
:
max_diff
=
torch
.
max
(
torch
.
abs
(
output_sparse
-
output_ref_torch
)).
item
()
mean_diff
=
torch
.
mean
(
torch
.
abs
(
output_sparse
-
output_ref_torch
)).
item
()
...
...
@@ -574,16 +537,15 @@ def main(args):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.0
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
'--block_N'
,
type
=
int
,
default
=
64
,
help
=
'block_N'
)
parser
.
add_argument
(
'--page_block_size'
,
type
=
int
,
default
=
256
,
help
=
'block size of pages'
)
parser
.
add_argument
(
'--num_pages'
,
type
=
int
,
default
=
1024
,
help
=
'total number of pages'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.0
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_N"
,
type
=
int
,
default
=
64
,
help
=
"block_N"
)
parser
.
add_argument
(
"--page_block_size"
,
type
=
int
,
default
=
256
,
help
=
"block size of pages"
)
parser
.
add_argument
(
"--num_pages"
,
type
=
int
,
default
=
1024
,
help
=
"total number of pages"
)
args
=
parser
.
parse_args
()
main
(
args
)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
View file @
29051439
...
...
@@ -10,17 +10,18 @@ from heuristic import num_splits_heuristic
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
heads_kv
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
}
)
def
kernel_func
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
,
max_cache_seqlen
,
max_selected_blocks
):
}
,
)
def
kernel_func
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
,
max_cache_seqlen
,
max_selected_blocks
):
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim
]
shape_v
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
]
...
...
@@ -31,17 +32,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
...
...
@@ -62,7 +62,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -70,7 +70,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
num_blocks
=
max_selected_blocks
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
loop_range
=
(
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
has_valid_block
=
False
...
...
@@ -78,26 +78,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
i_s
=
block_indices
[
bid
,
cur_kv_head
,
start
+
k
]
if
i_s
>=
0
:
has_valid_block
=
True
T
.
copy
(
K
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
K
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
k
==
0
:
# assume block_indices is sorted in reverse order, otherwise, remove this if condition
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
i_s
*
block_N
+
j
>=
cache_seqlens
[
bid
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
i_s
*
block_N
+
j
>=
cache_seqlens
[
bid
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
...
...
@@ -106,7 +98,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
copy
(
V
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_valid_block
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
...
...
@@ -125,9 +117,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
...
...
@@ -138,17 +130,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
max_split
=
T
.
alloc_local
([
1
],
"int32"
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
lse_max_local
[
0
]
=
-
T
.
infinity
(
accum_dtype
)
for
k
in
T
.
serial
(
num_split
):
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
if
(
lse_local_split
[
0
]
!=
0
)
:
if
lse_local_split
[
0
]
!=
0
:
max_split
[
0
]
=
k
lse_max_local
[
0
]
=
T
.
max
(
lse_max_local
[
0
],
glse
[
bz
,
by
,
k
])
...
...
@@ -170,15 +163,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
# flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial)
flash_attn_split
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
glse
,
Output_partial
)
...
...
@@ -190,7 +183,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
):
super
(
SparseFlashAttn
,
self
).
__init__
()
self
.
batch
=
batch
...
...
@@ -209,7 +201,8 @@ class SparseFlashAttn(torch.nn.Module):
num_stages
=
2
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
))
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
),
)
props
=
torch
.
cuda
.
get_device_properties
(
torch
.
device
(
"cuda:0"
))
self
.
num_sm
=
props
.
multi_processor_count
...
...
@@ -232,25 +225,17 @@ class SparseFlashAttn(torch.nn.Module):
num_sm
=
self
.
num_sm
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output
=
self
.
kernel
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
glse
,
output_partial
)
return
output
def
sparse_gqa_decode_varlen_indice
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
block_size
):
def
sparse_gqa_decode_varlen_indice
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
block_size
):
"""
Args:
query: [batch, heads, dim]
...
...
@@ -272,31 +257,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
block_H
=
64
actual_num_blocks
=
torch
.
sum
(
block_indices
!=
-
1
,
dim
=-
1
).
to
(
torch
.
int32
)
actual_num_blocks
=
actual_num_blocks
[:,
0
]
#[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
actual_num_blocks
=
actual_num_blocks
[
:,
0
]
# [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
# get num_split
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
#(kv_seqlen + block_size - 1 ) // block_size
num_n_blocks
=
max_selected_blocks
#
(kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
132
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
kernel
=
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
)(
block_N
=
block_size
,
block_H
=
block_H
,
...
...
@@ -304,29 +282,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
num_stages
=
2
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
))
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
),
)
output
=
kernel
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
glse
,
Output_partial
)
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values based on block_indices
...
...
@@ -335,28 +308,26 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
for
idx
in
valid_indices
:
if
idx
>=
0
:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
...
...
@@ -368,23 +339,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
print
(
name
+
" all_close={}"
.
format
(
all_close
))
if
not
all_close
:
diff
=
(
expect
-
actual
).
abs
()
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
max_indices
=
torch
.
nonzero
(
diff
==
diff
.
max
().
item
())
first_index
=
tuple
(
max_indices
[
0
].
tolist
())
print
(
f
"Index:
{
first_index
}
, expect:
{
expect
[
first_index
]
}
, actual:
{
actual
[
first_index
]
}
"
)
def
main
(
batch
=
8
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
def
main
(
batch
=
8
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
block_size
=
block_size
...
...
@@ -392,10 +353,10 @@ def main(batch=8,
print
(
"max_selected_blocks: "
,
max_selected_blocks
)
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# # Ensure at least one element equals cache_seqlen
# random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
...
...
@@ -406,10 +367,7 @@ def main(batch=8,
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_indices with -1 (for padding blocks)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size)
# block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda')
...
...
@@ -418,10 +376,9 @@ def main(batch=8,
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
if
max_valid_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
# valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks]
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
# Sort indices within each batch-group for consistency
block_indices
,
_
=
block_indices
.
sort
(
dim
=-
1
,
descending
=
True
)
...
...
@@ -434,8 +391,7 @@ def main(batch=8,
print
(
"max_num_blocks: "
,
max_num_blocks
)
# parity reference
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
sparse_kernel
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
)
out
=
sparse_kernel
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
)
...
...
@@ -445,13 +401,11 @@ def main(batch=8,
## latency reference
for
_
in
range
(
10
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
_
in
range
(
100
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
print
(
"dense time: "
,
(
time
.
time
()
-
start
)
/
100
*
1000
)
...
...
@@ -469,15 +423,13 @@ def main(batch=8,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
View file @
29051439
...
...
@@ -12,15 +12,17 @@ from heuristic import num_splits_heuristic
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
heads_kv
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
kernel_func
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
,
max_cache_seqlen
,
num_blocks
):
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim
]
...
...
@@ -32,16 +34,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_mask
:
T
.
Tensor
(
shape_mask
,
"bool"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_mask
:
T
.
Tensor
(
shape_mask
,
"bool"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
...
...
@@ -62,39 +63,31 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
loop_range
=
(
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
has_valid_block
=
False
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
bid
,
hid
,
start
+
k
]:
has_valid_block
=
True
T
.
copy
(
K
[
bid
,
(
start
+
k
)
*
block_N
:(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
K
[
bid
,
(
start
+
k
)
*
block_N
:
(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
start
+
k
)
*
block_N
+
j
>=
cache_seqlens
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
)
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
start
+
k
)
*
block_N
+
j
>=
cache_seqlens
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
...
...
@@ -103,9 +96,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
(
start
+
k
)
*
block_N
:(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
copy
(
V
[
bid
,
(
start
+
k
)
*
block_N
:
(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_valid_block
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
...
...
@@ -123,9 +114,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
...
...
@@ -135,10 +126,11 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
...
...
@@ -161,14 +153,14 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_mask
:
T
.
Tensor
(
shape_mask
,
"bool"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_mask
:
T
.
Tensor
(
shape_mask
,
"bool"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
flash_attn_split
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
...
...
@@ -179,7 +171,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
):
super
(
SparseFlashAttn
,
self
).
__init__
()
self
.
batch
=
batch
...
...
@@ -198,7 +189,8 @@ class SparseFlashAttn(torch.nn.Module):
num_stages
=
2
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
num_blocks
=
T
.
dynamic
(
"num_blocks"
))
num_blocks
=
T
.
dynamic
(
"num_blocks"
),
)
props
=
torch
.
cuda
.
get_device_properties
(
torch
.
device
(
"cuda:0"
))
self
.
num_sm
=
props
.
multi_processor_count
...
...
@@ -217,24 +209,16 @@ class SparseFlashAttn(torch.nn.Module):
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
# num_sm = 132
num_sm
=
self
.
num_sm
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_split: ", num_split)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output
=
self
.
kernel
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
return
output
...
...
@@ -259,26 +243,21 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
block_H
=
64
actual_num_blocks
=
torch
.
sum
(
block_mask
,
dim
=-
1
).
to
(
torch
.
int32
)
actual_num_blocks
=
actual_num_blocks
[:,
0
]
#[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
actual_num_blocks
=
actual_num_blocks
[
:,
0
]
# [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
max_selected_blocks
=
actual_num_blocks
.
max
().
item
()
# get num_split
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
#(kv_seqlen + block_size - 1 ) // block_size
num_n_blocks
=
max_selected_blocks
#
(kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
132
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
kernel
=
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
)(
block_N
=
block_size
,
...
...
@@ -287,11 +266,10 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
num_stages
=
2
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
num_blocks
=
T
.
dynamic
(
"num_blocks"
))
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
num_blocks
=
T
.
dynamic
(
"num_blocks"
),
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
# print(kernel.get_kernel_source())
output
=
kernel
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
...
...
@@ -299,24 +277,18 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values
...
...
@@ -324,29 +296,27 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
for
h
in
range
(
heads_kv
):
for
idx
in
range
(
num_blocks
):
if
block_mask
[
b
,
h
,
idx
]:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
...
...
@@ -360,23 +330,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
# print(expect[3, 28])
# print(actual[3, 28])
diff
=
(
expect
-
actual
).
abs
()
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
max_indices
=
torch
.
nonzero
(
diff
==
diff
.
max
().
item
())
first_index
=
tuple
(
max_indices
[
0
].
tolist
())
print
(
f
"Index:
{
first_index
}
, expect:
{
expect
[
first_index
]
}
, actual:
{
actual
[
first_index
]
}
"
)
def
main
(
batch
=
8
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
def
main
(
batch
=
8
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
block_size
=
block_size
...
...
@@ -384,14 +344,13 @@ def main(batch=8,
print
(
"max_selected_blocks: "
,
max_selected_blocks
)
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
'cuda'
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
print
(
"cache_seqlens: "
,
cache_seqlens
)
...
...
@@ -403,7 +362,7 @@ def main(batch=8,
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_mask with false (for padding blocks)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
'
cuda
'
)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
"
cuda
"
)
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
...
...
@@ -411,13 +370,12 @@ def main(batch=8,
valid_num_block
=
valid_num_blocks
[
b
].
item
()
# Valid blocks for this batch
if
valid_num_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
'
cuda
'
)[:
valid_num_block
]
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
"
cuda
"
)[:
valid_num_block
]
block_mask
[
b
,
h
,
perm
]
=
True
# print("block_mask: ", block_mask)
# parity reference
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
model
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
)
out
=
model
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
)
...
...
@@ -427,13 +385,11 @@ def main(batch=8,
## latency reference
for
_
in
range
(
10
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
_
in
range
(
100
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
print
(
"dense time: "
,
(
time
.
time
()
-
start
)
/
100
*
1000
)
...
...
@@ -452,15 +408,13 @@ def main(batch=8,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
View file @
29051439
...
...
@@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_H'
,
'BLOCK_N'
,
'BLOCK_D'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_H"
,
"BLOCK_N"
,
"BLOCK_D"
],
)
@
triton
.
jit
def
_split_kernel
(
...
...
@@ -79,16 +75,11 @@ def _split_kernel(
loop_range
=
blocks_per_split
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
for
i
in
range
(
loop_range
):
block_idx
=
tl
.
load
(
mask_ptr
+
(
start
+
i
)
*
stride_mask_s
)
...
...
@@ -119,23 +110,18 @@ def _split_kernel(
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
o_partial_ptr
+=
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
o_partial_ptr
+=
(
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_D'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_D"
],
)
@
triton
.
jit
def
_merge_kernel
(
...
...
@@ -163,18 +149,15 @@ def _merge_kernel(
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_max
=
tl
.
max
(
lse
)
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_partial
=
tl
.
load
(
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
)
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
,
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
...
...
@@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton(
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
64
# num_sm = self.num_sm
num_splits
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...
...
@@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton(
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
dim_v
=
value
.
shape
[
-
1
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values based on block_indices
...
...
@@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
for
idx
in
valid_indices
:
if
idx
>=
0
:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
return
output
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
block_size
=
block_size
...
...
@@ -369,34 +331,29 @@ def main(batch=64,
dtype
=
torch
.
float16
block_H
=
64
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
'cuda'
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
print
(
"cache_seqlens: "
,
cache_seqlens
)
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_indices with -1 (for padding blocks)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
if
max_valid_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
# Sort indices within each batch-group for consistency
block_indices
,
_
=
block_indices
.
sort
(
dim
=-
1
,
descending
=
True
)
...
...
@@ -408,8 +365,7 @@ def main(batch=64,
max_num_blocks
=
torch
.
max
(
max_valid_num_blocks
).
item
()
print
(
"max_num_blocks: "
,
max_num_blocks
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
triton_out
=
block_sparse_flash_decode_gqa_indice_triton
(
Q
,
...
...
@@ -423,8 +379,7 @@ def main(batch=64,
)
print
(
"max difference: "
,
torch
.
max
(
torch
.
abs
(
ref
-
triton_out
)))
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
print
(
"Passed the ref test!"
)
# Measure performance
...
...
@@ -466,15 +421,13 @@ def main(batch=64,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
64
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
64
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
View file @
29051439
...
...
@@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_H'
,
'BLOCK_N'
,
'BLOCK_D'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_H"
,
"BLOCK_N"
,
"BLOCK_D"
],
)
@
triton
.
jit
def
_split_kernel
(
...
...
@@ -77,16 +73,11 @@ def _split_kernel(
loop_range
=
blocks_per_split
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
for
block_idx
in
range
(
loop_range
):
start_n
=
(
start
+
block_idx
)
*
BLOCK_N
...
...
@@ -117,23 +108,18 @@ def _split_kernel(
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
o_partial_ptr
+=
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
o_partial_ptr
+=
(
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_D'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_D"
],
)
@
triton
.
jit
def
_merge_kernel
(
...
...
@@ -161,18 +147,15 @@ def _merge_kernel(
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_max
=
tl
.
max
(
lse
)
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_partial
=
tl
.
load
(
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
)
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
,
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
...
...
@@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton(
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
64
# num_sm = self.num_sm
num_splits
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...
...
@@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton(
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values
...
...
@@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
for
h
in
range
(
heads_kv
):
for
idx
in
range
(
num_blocks
):
if
block_mask
[
b
,
h
,
idx
]:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
return
output
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
block_size
=
block_size
sparse_ratio
=
sparse_ratio
...
...
@@ -363,14 +325,13 @@ def main(batch=64,
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
'cuda'
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
num_blocks
=
(
max_cache_seqlen
+
block_size
-
1
)
//
block_size
...
...
@@ -379,7 +340,7 @@ def main(batch=64,
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_mask with false (for padding blocks)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
'
cuda
'
)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
"
cuda
"
)
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
...
...
@@ -387,11 +348,10 @@ def main(batch=64,
valid_num_block
=
valid_num_blocks
[
b
].
item
()
# Valid blocks for this batch
if
valid_num_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
'
cuda
'
)[:
valid_num_block
]
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
"
cuda
"
)[:
valid_num_block
]
block_mask
[
b
,
h
,
perm
]
=
True
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
triton_out
=
block_sparse_flash_decode_gqa_mask_triton
(
Q
,
...
...
@@ -404,8 +364,7 @@ def main(batch=64,
)
# print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
print
(
"Passed the ref test!"
)
# Measure performance
...
...
@@ -448,15 +407,13 @@ def main(batch=64,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
64
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
64
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/heuristic.py
View file @
29051439
import
math
def
num_splits_heuristic
(
total_mblocks
,
num_SMs
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
,
max_splits
):
def
num_splits_heuristic
(
total_mblocks
,
num_SMs
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
,
max_splits
):
"""
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
...
...
examples/blocksparse_attention/test_example_blocksparse_attention.py
View file @
29051439
...
...
@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def
test_example_triton_sparse_gqa_decode_varlen_indice
():
example_triton_sparse_gqa_decode_varlen_indice
.
main
(
batch
=
8
,
heads
=
8
,
heads_kv
=
4
,
max_cache_seqlen
=
2048
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
batch
=
8
,
heads
=
8
,
heads_kv
=
4
,
max_cache_seqlen
=
2048
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
def
test_example_triton_sparse_gqa_decode_varlen_mask
():
example_triton_sparse_gqa_decode_varlen_mask
.
main
(
batch
=
16
,
heads
=
16
,
heads_kv
=
8
,
max_cache_seqlen
=
1024
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
batch
=
16
,
heads
=
16
,
heads_kv
=
8
,
max_cache_seqlen
=
1024
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
if
__name__
==
"__main__"
:
...
...
examples/blocksparse_gemm/example_blocksparse_gemm.py
View file @
29051439
...
...
@@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--sparsity"
,
type
=
float
,
default
=
0.5
,
help
=
"Sparsity ratio (0-1)"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune"
)
args
,
_
=
parser
.
parse_known_args
()
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
...
...
@@ -41,17 +40,19 @@ def get_configs():
thread_num
=
[
128
,
256
]
enable_rasterization
=
[
True
,
False
]
_configs
=
list
(
itertools
.
product
(
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasterization
))
_configs
=
list
(
itertools
.
product
(
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasterization
))
return
[{
"block_M"
:
c
[
0
],
"block_N"
:
c
[
1
],
"block_K"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
}
for
c
in
_configs
]
return
[
{
"block_M"
:
c
[
0
],
"block_N"
:
c
[
1
],
"block_K"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
}
for
c
in
_configs
]
def
ref_program
(
A
,
B
,
BlockMask
,
block_M
,
block_N
,
block_K
):
...
...
@@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
for
k
in
range
(
K
//
block_K
):
if
BlockMask
[
i
,
j
,
k
]:
accu
+=
(
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
k
*
block_K
:(
k
+
1
)
*
block_K
,
j
*
block_N
:(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
))
ref_c
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
accu
+=
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
k
*
block_K
:
(
k
+
1
)
*
block_K
,
j
*
block_N
:
(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
)
ref_c
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
return
ref_c
...
...
@@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]):
return
input_tensors
@
tilelang
.
autotune
(
configs
=
get_configs
(),)
@
tilelang
.
autotune
(
configs
=
get_configs
(),
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
blocksparse_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
blocksparse_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
@
T
.
prim_func
def
block_sparse_matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
@@ -134,7 +126,6 @@ def blocksparse_matmul(M,
def
main
():
# Initialize input matrices A and B on the GPU with half precision
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
b
=
torch
.
randn
(
K
,
N
).
cuda
().
half
()
...
...
@@ -147,8 +138,7 @@ def main():
best_config
=
kernel
.
config
best_latency
=
kernel
.
latency
block_M
,
block_N
,
block_K
=
best_config
[
"block_M"
],
best_config
[
"block_N"
],
best_config
[
"block_K"
]
block_M
,
block_N
,
block_K
=
best_config
[
"block_M"
],
best_config
[
"block_N"
],
best_config
[
"block_K"
]
print
(
f
"Best Config:
{
best_config
}
"
)
print
(
f
"Sparsity Ratio:
{
sparsity
}
"
)
...
...
@@ -163,7 +153,8 @@ def main():
block_K
=
DEFAULT_BLOCK_K
,
num_stages
=
DEFAULT_NUM_STAGES
,
thread_num
=
DEFAULT_THREAD_NUM
,
enable_rasteration
=
DEFAULT_ENABLE_RASTERIZATION
)
enable_rasteration
=
DEFAULT_ENABLE_RASTERIZATION
,
)
block_M
,
block_N
,
block_K
=
DEFAULT_BLOCK_M
,
DEFAULT_BLOCK_N
,
DEFAULT_BLOCK_K
print
(
f
"Using default kernel with block size (
{
block_M
}
,
{
block_N
}
,
{
block_K
}
)"
)
# Create block mask with desired sparsity
...
...
examples/cast/example_group_per_split_token_cast_to_fp8.py
View file @
29051439
...
...
@@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
fp8_max
=
448.0
@
T
.
prim_func
def
group_per_split_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
batch_sizes
:
T
.
Tensor
(
(
BG
,),
"int32"
),
X_fp8
:
T
.
Tensor
((
BG
,
M_max
,
N
),
"float8_e4m3"
),
X_amax
:
T
.
Tensor
(
(
BG
,
M_max
,
T
.
ceildiv
(
N
,
group_size
)),
accum_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M_max
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
BG
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
def
group_per_split_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
batch_sizes
:
T
.
Tensor
((
BG
,),
"int32"
),
X_fp8
:
T
.
Tensor
((
BG
,
M_max
,
N
),
"float8_e4m3"
),
X_amax
:
T
.
Tensor
((
BG
,
M_max
,
T
.
ceildiv
(
N
,
group_size
)),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
M_max
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
BG
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
row
=
bx
row_g_id
=
by
bg
=
bz
...
...
@@ -31,36 +33,32 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
"float8_e4m3"
)
row_offset
=
T
.
alloc_fragment
((
1
,),
"int32"
)
T
.
annotate_layout
({
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
})
T
.
annotate_layout
(
{
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
}
)
row_offset
[
0
]
=
0
for
i
in
T
.
serial
(
bg
):
row_offset
[
0
]
+=
batch_sizes
[
i
]
T
.
copy
(
X
[
row_offset
[
0
]
+
row
*
blk_m
:
row_offset
[
0
]
+
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
],
y_local
)
X
[
row_offset
[
0
]
+
row
*
blk_m
:
row_offset
[
0
]
+
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
],
y_local
,
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
y_s_local
[
i
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_amax_local
[
i
]
/
fp8_max
,
0
)
y_s_local
[
i
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_amax_local
[
i
]
/
fp8_max
,
0
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
y_q_local
[
i
,
j
]
=
T
.
clamp
(
y_local
[
i
,
j
]
/
y_s_local
[
i
],
fp8_min
,
fp8_max
)
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
y_q_local_fp8
[
i
,
j
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_q_local
[
i
,
j
],
0
)
y_q_local_fp8
[
i
,
j
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_q_local
[
i
,
j
],
0
)
for
i
in
T
.
Parallel
(
blk_m
):
X_amax
[
bg
,
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
bg
,
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
])
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
bg
,
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
])
return
group_per_split_token_cast
...
...
@@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
return
x
.
squeeze
(
0
)
if
remove_dim
else
x
# Normal layout requires transposing
aligned_x
=
torch
.
transpose
(
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
aligned_x
=
torch
.
transpose
(
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
aligned_x
[:,
:
m
,
:]
=
x
aligned_x
=
aligned_x
[:,
:
m
,
:]
return
aligned_x
.
squeeze
(
0
)
if
remove_dim
else
aligned_x
...
...
@@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
x_fp8
=
x_fp8
.
view
(
m
,
-
1
)[:,
:
n
].
contiguous
()
return
x_fp8
,
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
ref_program
(
x
:
torch
.
Tensor
,
batch_sizes
:
torch
.
Tensor
)
->
\
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
ref_program
(
x
:
torch
.
Tensor
,
batch_sizes
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# assert x.shape[0] == batch_sizes.sum()
M_max
=
ceil_div
(
batch_sizes
.
max
(),
128
)
*
128
split_x
=
torch
.
split
(
x
,
batch_sizes
.
tolist
(),
dim
=
0
)
padded_x
=
[
torch
.
nn
.
functional
.
pad
(
t
,
(
0
,
0
,
0
,
M_max
-
t
.
shape
[
0
]))
for
t
in
split_x
]
num_groups
,
m
,
n
=
batch_sizes
.
shape
[
0
],
M_max
,
x
.
shape
[
1
]
x_fp8
=
(
torch
.
empty
((
num_groups
,
m
,
n
),
device
=
'cuda'
,
dtype
=
torch
.
float8_e4m3fn
),
torch
.
empty
((
num_groups
,
m
,
n
//
128
),
device
=
'cuda'
,
dtype
=
torch
.
float
))
x_fp8
=
(
torch
.
empty
((
num_groups
,
m
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
),
torch
.
empty
((
num_groups
,
m
,
n
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float
),
)
for
i
in
range
(
num_groups
):
x_fp8
[
0
][
i
],
x_fp8
[
1
][
i
]
=
ref_per_token_cast_to_fp8
(
padded_x
[
i
])
x_fp8
=
(
x_fp8
[
0
],
get_col_major_tma_aligned_tensor
(
x_fp8
[
1
]))
...
...
examples/cast/example_per_token_cast_to_fp8.py
View file @
29051439
...
...
@@ -13,8 +13,9 @@ def per_token_cast_to_fp8(M, N, blk_m):
fp8_max
=
448.0
@
T
.
prim_func
def
per_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
X_fp8
:
T
.
Tensor
((
M
,
N
),
"float8_e4m3"
),
X_amax
:
T
.
Tensor
((
M
,
T
.
ceildiv
(
N
,
group_size
)),
dtype
)):
def
per_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
X_fp8
:
T
.
Tensor
((
M
,
N
),
"float8_e4m3"
),
X_amax
:
T
.
Tensor
((
M
,
T
.
ceildiv
(
N
,
group_size
)),
dtype
)
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
threads
=
128
)
as
(
bx
,
by
):
row
=
bx
row_g_id
=
by
...
...
@@ -24,16 +25,13 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_q_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
dtype
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
"float8_e4m3"
)
T
.
annotate_layout
({
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
})
T
.
annotate_layout
(
{
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
}
)
T
.
copy
(
X
[
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
],
y_local
)
T
.
copy
(
X
[
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
],
y_local
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
...
...
@@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
for
i
in
T
.
Parallel
(
blk_m
):
X_amax
[
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
])
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
])
return
per_token_cast
...
...
@@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8):
from
example_triton_cast_to_fp8
import
per_token_group_quant_fp8
def
run_triton
():
x_fp8_triton_
,
x_amax_triton_
=
per_token_group_quant_fp8
(
x
,
128
,
1e-4
,
dtype
=
torch
.
float8_e4m3fn
,
column_major_scales
=
False
)
x_fp8_triton_
,
x_amax_triton_
=
per_token_group_quant_fp8
(
x
,
128
,
1e-4
,
dtype
=
torch
.
float8_e4m3fn
,
column_major_scales
=
False
)
return
x_fp8_triton_
,
x_amax_triton_
x_fp8_triton
,
x_amax_triton
=
run_triton
()
...
...
examples/cast/example_triton_cast_to_fp8.py
View file @
29051439
...
...
@@ -128,9 +128,7 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
(
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible "
f
"by `group_size`
{
group_size
}
"
)
assert
x
.
shape
[
-
1
]
%
group_size
==
0
,
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible by `group_size`
{
group_size
}
"
assert
x
.
stride
(
-
1
)
==
1
,
"`x` groups must be contiguous"
finfo
=
torch
.
finfo
(
dtype
)
...
...
examples/cast/test_example_cast.py
View file @
29051439
...
...
@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8
def
test_example_group_per_split_token_cast_to_fp8
():
example_group_per_split_token_cast_to_fp8
.
main
(
M
=
1024
,
N
=
1024
,
BG
=
2
,
blk_m
=
4
,
batch_sizes
=
[
128
,
896
])
example_group_per_split_token_cast_to_fp8
.
main
(
M
=
1024
,
N
=
1024
,
BG
=
2
,
blk_m
=
4
,
batch_sizes
=
[
128
,
896
])
def
test_example_per_token_cast_to_fp8
():
...
...
examples/compile_flags/usecase.py
View file @
29051439
...
...
@@ -4,12 +4,11 @@ import tilelang.language as T
# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
@@ -36,8 +35,7 @@ block_K = 32
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
)
jit_kernel
=
tilelang
.
compile
(
func
,
out_idx
=
[
2
],
target
=
"cuda"
,
compile_flags
=
"-O3 --use_fast_math --expt-relaxed-constexpr"
)
jit_kernel
=
tilelang
.
compile
(
func
,
out_idx
=
[
2
],
target
=
"cuda"
,
compile_flags
=
"-O3 --use_fast_math --expt-relaxed-constexpr"
)
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"])
...
...
examples/conftest.py
View file @
29051439
...
...
@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
"warnings"
,
"error"
,
}
if
(
sum
(
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
):
if
sum
(
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
:
terminalreporter
.
write_sep
(
"!"
,
(
f
"Error: No tests were collected. "
f
"
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
(
f
"Error: No tests were collected.
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
)
pytest
.
exit
(
"No tests were collected."
,
returncode
=
5
)
examples/convolution/example_convolution.py
View file @
29051439
...
...
@@ -14,7 +14,6 @@ def check_hopper():
def
ref_program
(
stride
,
padding
,
dilation
):
def
main
(
A
,
B
):
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
...
...
@@ -26,22 +25,7 @@ def ref_program(stride, padding, dilation):
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
...
...
@@ -51,13 +35,11 @@ def convolution(N,
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -66,11 +48,13 @@ def convolution(N,
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
T
.
annotate_layout
({
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
})
T
.
annotate_layout
(
{
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
}
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
...
@@ -82,10 +66,8 @@ def convolution(N,
m
=
by
*
block_M
+
i
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
in_bound
=
((
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
))
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
in_bound
=
(
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
)
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
...
...
@@ -97,15 +79,15 @@ def convolution(N,
def
main
(
argv
=
None
):
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--n
'
,
type
=
int
,
default
=
128
,
help
=
'n'
)
parser
.
add_argument
(
'
--c
'
,
type
=
int
,
default
=
128
,
help
=
'c'
)
parser
.
add_argument
(
'
--h
'
,
type
=
int
,
default
=
64
,
help
=
'h'
)
parser
.
add_argument
(
'
--w
'
,
type
=
int
,
default
=
64
,
help
=
'w'
)
parser
.
add_argument
(
'
--f
'
,
type
=
int
,
default
=
128
,
help
=
'f'
)
parser
.
add_argument
(
'
--k
'
,
type
=
int
,
default
=
3
,
help
=
'k'
)
parser
.
add_argument
(
'
--s
'
,
type
=
int
,
default
=
1
,
help
=
's'
)
parser
.
add_argument
(
'
--d
'
,
type
=
int
,
default
=
1
,
help
=
'd'
)
parser
.
add_argument
(
'
--p
'
,
type
=
int
,
default
=
1
,
help
=
'p'
)
parser
.
add_argument
(
"
--n
"
,
type
=
int
,
default
=
128
,
help
=
"n"
)
parser
.
add_argument
(
"
--c
"
,
type
=
int
,
default
=
128
,
help
=
"c"
)
parser
.
add_argument
(
"
--h
"
,
type
=
int
,
default
=
64
,
help
=
"h"
)
parser
.
add_argument
(
"
--w
"
,
type
=
int
,
default
=
64
,
help
=
"w"
)
parser
.
add_argument
(
"
--f
"
,
type
=
int
,
default
=
128
,
help
=
"f"
)
parser
.
add_argument
(
"
--k
"
,
type
=
int
,
default
=
3
,
help
=
"k"
)
parser
.
add_argument
(
"
--s
"
,
type
=
int
,
default
=
1
,
help
=
"s"
)
parser
.
add_argument
(
"
--d
"
,
type
=
int
,
default
=
1
,
help
=
"d"
)
parser
.
add_argument
(
"
--p
"
,
type
=
int
,
default
=
1
,
help
=
"p"
)
args
=
parser
.
parse_args
(
argv
)
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
...
...
examples/convolution/example_convolution_autotune.py
View file @
29051439
...
...
@@ -14,7 +14,6 @@ def check_hopper():
def
ref_program
(
stride
,
padding
,
dilation
):
def
main
(
A
,
B
):
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
...
...
@@ -40,7 +39,8 @@ def get_configs():
num_stages
,
thread_num
,
enable_rasterization
,
))
)
)
configs
=
[
{
...
...
@@ -50,7 +50,8 @@ def get_configs():
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
return
configs
...
...
@@ -64,53 +65,18 @@ def get_heuristic_config() -> dict:
sm_version
=
sm_major
*
10
+
sm_minor
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
if
sm_version
in
{
80
}:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
elif
sm_version
in
{
90
}:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
else
:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
@
tilelang
.
autotune
(
configs
=
get_configs
())
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
...
...
@@ -120,13 +86,11 @@ def convolution(N,
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -136,9 +100,11 @@ def convolution(N,
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
if
is_hopper
:
T
.
annotate_layout
({
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
})
T
.
annotate_layout
(
{
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
}
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
...
@@ -150,10 +116,8 @@ def convolution(N,
m
=
by
*
block_M
+
i
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
in_bound
=
((
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
))
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
in_bound
=
(
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
)
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
...
...
@@ -166,17 +130,19 @@ def convolution(N,
return
main
def
main
(
n
:
int
=
128
,
c
:
int
=
128
,
h
:
int
=
64
,
w
:
int
=
64
,
f
:
int
=
128
,
k
:
int
=
3
,
s
:
int
=
1
,
d
:
int
=
1
,
p
:
int
=
1
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
True
):
def
main
(
n
:
int
=
128
,
c
:
int
=
128
,
h
:
int
=
64
,
w
:
int
=
64
,
f
:
int
=
128
,
k
:
int
=
3
,
s
:
int
=
1
,
d
:
int
=
1
,
p
:
int
=
1
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
True
,
):
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
n
,
c
,
h
,
w
,
f
,
k
,
s
,
d
,
p
ref_prog
=
ref_program
(
S
,
P
,
D
)
...
...
@@ -196,25 +162,16 @@ def main(n: int = 128,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned MatMul Benchmark"
)
parser
.
add_argument
(
'--n'
,
type
=
int
,
default
=
128
,
help
=
'n'
)
parser
.
add_argument
(
'--c'
,
type
=
int
,
default
=
128
,
help
=
'c'
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
64
,
help
=
'h'
)
parser
.
add_argument
(
'--w'
,
type
=
int
,
default
=
64
,
help
=
'w'
)
parser
.
add_argument
(
'--f'
,
type
=
int
,
default
=
128
,
help
=
'f'
)
parser
.
add_argument
(
'--k'
,
type
=
int
,
default
=
3
,
help
=
'k'
)
parser
.
add_argument
(
'--s'
,
type
=
int
,
default
=
1
,
help
=
's'
)
parser
.
add_argument
(
'--d'
,
type
=
int
,
default
=
1
,
help
=
'd'
)
parser
.
add_argument
(
'--p'
,
type
=
int
,
default
=
1
,
help
=
'p'
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
True
,
help
=
"Whether to enable BitBLAS roller for search space"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
128
,
help
=
"n"
)
parser
.
add_argument
(
"--c"
,
type
=
int
,
default
=
128
,
help
=
"c"
)
parser
.
add_argument
(
"--h"
,
type
=
int
,
default
=
64
,
help
=
"h"
)
parser
.
add_argument
(
"--w"
,
type
=
int
,
default
=
64
,
help
=
"w"
)
parser
.
add_argument
(
"--f"
,
type
=
int
,
default
=
128
,
help
=
"f"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
3
,
help
=
"k"
)
parser
.
add_argument
(
"--s"
,
type
=
int
,
default
=
1
,
help
=
"s"
)
parser
.
add_argument
(
"--d"
,
type
=
int
,
default
=
1
,
help
=
"d"
)
parser
.
add_argument
(
"--p"
,
type
=
int
,
default
=
1
,
help
=
"p"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
True
,
help
=
"Whether to enable BitBLAS roller for search space"
)
args
=
parser
.
parse_args
()
main
(
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
,
args
.
use_autotune
,
args
.
with_roller
)
main
(
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
,
args
.
use_autotune
,
args
.
with_roller
)
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
View file @
29051439
...
...
@@ -41,14 +41,13 @@ def tl_gemm(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
scales_a
:
T
.
Tensor
(
Scales_A_shape
,
"float32"
),
scales_b
:
T
.
Tensor
(
Scales_B_shape
,
"float32"
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
scales_a
:
T
.
Tensor
(
Scales_A_shape
,
"float32"
),
scales_b
:
T
.
Tensor
(
Scales_B_shape
,
"float32"
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
)
...
...
@@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
ceildiv
(
m
,
128
)
*
128
,
ceildiv
(
n
,
128
)
*
128
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
=
torch
.
zeros
(
ceildiv
(
m
,
128
)
*
128
,
ceildiv
(
n
,
128
)
*
128
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
def
ref_deepgemm_fp8
(
A_fp8
,
B_fp8
,
A_scale
,
B_scale
,
out_dtype
):
...
...
@@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
c_acc
.
zero_
()
for
k
in
range
(
ceildiv
(
K
,
128
)):
c
=
torch
.
_scaled_mm
(
A_fp8
[
i
*
128
:
(
i
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
],
B_fp8
[
j
*
128
:
(
j
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
].
T
,
A_fp8
[
i
*
128
:
(
i
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
],
B_fp8
[
j
*
128
:
(
j
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
].
T
,
scale_a
=
A_scales
[
i
,
k
].
view
(
128
,
1
).
contiguous
(),
scale_b
=
B_scales
[
j
,
k
].
view
(
1
,
128
).
contiguous
(),
out_dtype
=
torch
.
bfloat16
)
out_dtype
=
torch
.
bfloat16
,
)
c_acc
+=
c
.
to
(
torch
.
float32
)
C
[
i
*
128
:
(
i
+
1
)
*
128
,
j
*
128
:
(
j
+
1
)
*
128
]
=
c_acc
.
to
(
out_dtype
)
C
[
i
*
128
:
(
i
+
1
)
*
128
,
j
*
128
:
(
j
+
1
)
*
128
]
=
c_acc
.
to
(
out_dtype
)
return
C
...
...
Prev
1
2
3
4
5
6
7
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment