Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
623 additions
and
986 deletions
+623
-986
examples/bitnet-1.58b/vllm_workspace/utils.py
examples/bitnet-1.58b/vllm_workspace/utils.py
+6
-17
examples/blocksparse_attention/block_sparse_attn_triton.py
examples/blocksparse_attention/block_sparse_attn_triton.py
+24
-47
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
...ocksparse_attention/example_tilelang_block_sparse_attn.py
+37
-44
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
...rse_attention/example_tilelang_sparse_gqa_decode_paged.py
+101
-139
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
...ntion/example_tilelang_sparse_gqa_decode_varlen_indice.py
+96
-144
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
...tention/example_tilelang_sparse_gqa_decode_varlen_mask.py
+93
-139
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
...tention/example_triton_sparse_gqa_decode_varlen_indice.py
+53
-100
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
...attention/example_triton_sparse_gqa_decode_varlen_mask.py
+52
-95
examples/blocksparse_attention/heuristic.py
examples/blocksparse_attention/heuristic.py
+1
-2
examples/blocksparse_attention/test_example_blocksparse_attention.py
...ocksparse_attention/test_example_blocksparse_attention.py
+4
-16
examples/blocksparse_gemm/example_blocksparse_gemm.py
examples/blocksparse_gemm/example_blocksparse_gemm.py
+30
-39
examples/cast/example_group_per_split_token_cast_to_fp8.py
examples/cast/example_group_per_split_token_cast_to_fp8.py
+25
-26
examples/cast/example_per_token_cast_to_fp8.py
examples/cast/example_per_token_cast_to_fp8.py
+11
-16
examples/cast/example_triton_cast_to_fp8.py
examples/cast/example_triton_cast_to_fp8.py
+1
-3
examples/cast/test_example_cast.py
examples/cast/test_example_cast.py
+1
-2
examples/compile_flags/usecase.py
examples/compile_flags/usecase.py
+4
-6
examples/conftest.py
examples/conftest.py
+2
-5
examples/convolution/example_convolution.py
examples/convolution/example_convolution.py
+23
-41
examples/convolution/example_convolution_autotune.py
examples/convolution/example_convolution_autotune.py
+46
-89
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
+13
-16
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
examples/bitnet-1.58b/vllm_workspace/utils.py
View file @
29051439
...
...
@@ -3,8 +3,7 @@ from typing import Dict, List, Tuple
TokensText
=
Tuple
[
List
[
int
],
str
]
def
check_outputs_equal
(
outputs_0_lst
:
List
[
TokensText
],
outputs_1_lst
:
List
[
TokensText
],
name_0
:
str
,
name_1
:
str
):
def
check_outputs_equal
(
outputs_0_lst
:
List
[
TokensText
],
outputs_1_lst
:
List
[
TokensText
],
name_0
:
str
,
name_1
:
str
):
"""
Compare the two sequences generated by different models,
which should be equal.
...
...
@@ -15,19 +14,14 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[Tok
output_ids_0
,
output_str_0
=
outputs_0
output_ids_1
,
output_str_1
=
outputs_1
assert
output_str_0
==
output_str_1
,
(
f
"Test
{
prompt_idx
}
:"
f
"
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
)
assert
output_ids_0
==
output_ids_1
,
(
f
"Test
{
prompt_idx
}
:"
f
"
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
)
assert
output_str_0
==
output_str_1
,
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
assert
output_ids_0
==
output_ids_1
,
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
TokensTextLogprobs
=
Tuple
[
List
[
int
],
str
,
List
[
Dict
[
int
,
float
]]]
def
check_logprobs_close
(
outputs_0_lst
:
List
[
TokensTextLogprobs
],
outputs_1_lst
:
List
[
TokensTextLogprobs
],
name_0
:
str
,
name_1
:
str
):
def
check_logprobs_close
(
outputs_0_lst
:
List
[
TokensTextLogprobs
],
outputs_1_lst
:
List
[
TokensTextLogprobs
],
name_0
:
str
,
name_1
:
str
):
"""
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
...
...
@@ -41,16 +35,11 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
# Loop through generated tokens.
for
idx
,
(
output_id_0
,
output_id_1
)
in
enumerate
(
zip
(
output_ids_0
,
output_ids_1
)):
# If generated tokens don't match, then
if
output_id_0
!=
output_id_1
:
# Each predicted token must be in top N logprobs of the other
assert
output_id_0
in
logprobs_1
[
idx
],
(
f
"Test
{
prompt_idx
}
:"
f
"
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
)
assert
output_id_1
in
logprobs_0
[
idx
],
(
f
"Test
{
prompt_idx
}
:"
f
"
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
)
assert
output_id_0
in
logprobs_1
[
idx
],
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
assert
output_id_1
in
logprobs_0
[
idx
],
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
# Break out since sequences will now diverge.
break
examples/blocksparse_attention/block_sparse_attn_triton.py
View file @
29051439
...
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
mask_val
=
tl
.
load
(
block_mask_ptr
+
k_block_col_idx
*
stride_bmask_n
)
# print
...
...
@@ -73,8 +69,7 @@ def _fwd_kernel_inner(
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if
LAST_K_BLOCK
:
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
'-inf'
))
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"-inf"
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
-=
m_ij
[:,
None
]
...
...
@@ -154,7 +149,7 @@ def _fwd_kernel(
v_ptrs
=
V
+
off_v
mask_ptrs
=
block_mask_ptr
+
start_m
*
stride_bmm
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
'
inf
'
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"
inf
"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
...
...
@@ -192,24 +187,12 @@ def _fwd_kernel(
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
Out
.
dtype
.
element_ty
)
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
N_CTX
)
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
k
.
shape
[
2
]
==
v
.
shape
[
2
]
o
=
out
if
out
is
not
None
else
torch
.
empty_like
(
q
).
contiguous
()
...
...
@@ -254,7 +237,6 @@ def _forward(ctx,
class
_sparse_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
):
# shape constraints
...
...
@@ -278,9 +260,9 @@ def test_topk_sparse_attention():
torch
.
manual_seed
(
0
)
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
...
...
@@ -288,9 +270,7 @@ def test_topk_sparse_attention():
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
print
(
"downsample_len"
,
downsample_len
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
print
(
"x_ds.shape"
,
x_ds
.
shape
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
...
@@ -302,22 +282,21 @@ def test_topk_sparse_attention():
# Compute reference
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
# print("ref_output", ref_output)
# print("triton_output", triton_output)
# Verify accuracy
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
"Triton output doesn't match reference"
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference"
print
(
"Pass topk sparse attention test with qlen == klen"
)
...
...
@@ -329,9 +308,9 @@ def test_topk_sparse_attention_qlt_kl():
torch
.
manual_seed
(
0
)
# Create inputs.
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
# softmax scale
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
...
...
@@ -339,8 +318,7 @@ def test_topk_sparse_attention_qlt_kl():
print
(
"downsample_factor"
,
downsample_factor
)
downsample_len
=
math
.
ceil
(
K_LEN
/
downsample_factor
)
# number of blocks along one dimension
print
(
"downsample_len"
,
downsample_len
)
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# Force the first column to be high so that the first block is always selected.
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
...
@@ -351,26 +329,25 @@ def test_topk_sparse_attention_qlt_kl():
past_len
=
K_LEN
-
Q_LEN
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
)).
bool
()
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
)).
bool
()
full_mask_full
=
full_mask_full
[...,
:
K_LEN
,
:
K_LEN
]
effective_mask
=
full_mask_full
[...,
past_len
:
K_LEN
,
:]
# shape: (B, H, Q_LEN, K_LEN)
i_global
=
torch
.
arange
(
past_len
,
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
1
)
# shape: (Q_LEN, 1)
j_global
=
torch
.
arange
(
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
0
)
# shape: (1, K_LEN)
causal_mask
=
(
j_global
<=
i_global
)
# shape: (Q_LEN, K_LEN)
causal_mask
=
j_global
<=
i_global
# shape: (Q_LEN, K_LEN)
final_mask
=
effective_mask
&
causal_mask
# shape: (B, H, Q_LEN, K_LEN)
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
# Verify accuracy.
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
"Triton output doesn't match reference when qlen < klen"
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference when qlen < klen"
print
(
"Pass topk sparse attention test with qlen < klen"
)
...
...
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
View file @
29051439
...
...
@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
...
@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
@
tilelang
.
jit
(
out_idx
=
[
4
],
pass_configs
=
{
out_idx
=
[
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
blocksparse_flashattn
(
batch
,
heads
,
seq_len
,
dim
,
downsample_len
,
is_causal
):
block_M
=
64
block_N
=
64
num_stages
=
1
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
block_mask_shape
=
[
batch
,
heads
,
downsample_len
,
downsample_len
]
...
...
@@ -47,7 +46,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype
=
"bool"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
shape
,
dtype
),
...
...
@@ -59,11 +57,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
...
@@ -78,18 +75,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -113,22 +110,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
blocksparse_flashattn
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
...
@@ -143,7 +139,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -152,20 +148,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask
[
vj
]
=
BlockSparseMask
[
bz
,
by
,
bx
,
vj
]
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
k
]
!=
0
:
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
blocksparse_flashattn
...
...
@@ -180,18 +175,16 @@ def test_topk_sparse_attention():
torch
.
manual_seed
(
0
)
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
...
@@ -202,15 +195,15 @@ def test_topk_sparse_attention():
# Compute reference
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
print
(
"ref_output"
,
ref_output
)
print
(
"tilelang_output"
,
tilelang_output
)
...
...
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
View file @
29051439
...
...
@@ -13,17 +13,20 @@ from heuristic import num_splits_heuristic
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
heads_kv
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
def
kernel_func
(
block_N
,
block_H
,
page_block_size
,
num_split
,
num_stages
,
threads
,
num_pages
,
max_num_blocks_per_seq
,
max_selected_blocks
):
},
)
def
kernel_func
(
block_N
,
block_H
,
page_block_size
,
num_split
,
num_stages
,
threads
,
num_pages
,
max_num_blocks_per_seq
,
max_selected_blocks
):
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
num_pages
,
page_block_size
,
heads_kv
,
dim
]
shape_v
=
[
num_pages
,
page_block_size
,
heads_kv
,
dim_v
]
...
...
@@ -37,17 +40,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
block_table
:
T
.
Tensor
(
shape_block_table
,
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
block_table
:
T
.
Tensor
(
shape_block_table
,
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
...
...
@@ -67,7 +69,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -75,7 +77,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
num_blocks
=
max_selected_blocks
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
loop_range
=
(
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
has_valid_block
=
False
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
...
...
@@ -85,29 +87,20 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
block_table_idx
=
T
.
floordiv
(
logical_block_idx
,
block_ratio
)
block_tile_idx
=
T
.
floormod
(
logical_block_idx
,
block_ratio
)
physical_block_idx
=
block_table
[
bid
,
block_table_idx
]
T
.
copy
(
K
[
physical_block_idx
,
block_tile_idx
*
block_N
:(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
K
[
physical_block_idx
,
block_tile_idx
*
block_N
:
(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
k
==
0
:
# assume block_indices is sorted in reverse order, otherwise, remove this if condition
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
logical_block_idx
*
block_N
+
j
>=
cache_seqlens
[
bid
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
)
logical_block_idx
*
block_N
+
j
>=
cache_seqlens
[
bid
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
...
...
@@ -116,10 +109,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
physical_block_idx
,
block_tile_idx
*
block_N
:(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
copy
(
V
[
physical_block_idx
,
block_tile_idx
*
block_N
:
(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_valid_block
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
...
...
@@ -138,9 +128,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
...
...
@@ -151,17 +141,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
max_split
=
T
.
alloc_local
([
1
],
"int32"
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
lse_max_local
[
0
]
=
-
T
.
infinity
(
accum_dtype
)
for
k
in
T
.
serial
(
num_split
):
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
if
(
lse_local_split
[
0
]
!=
0
)
:
if
lse_local_split
[
0
]
!=
0
:
max_split
[
0
]
=
k
lse_max_local
[
0
]
=
T
.
max
(
lse_max_local
[
0
],
glse
[
bz
,
by
,
k
])
...
...
@@ -183,18 +174,17 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
block_table
:
T
.
Tensor
(
shape_block_table
,
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
block_table
:
T
.
Tensor
(
shape_block_table
,
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
flash_attn_split
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
block_table
,
glse
,
Output_partial
)
flash_attn_split
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
block_table
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
return
main
...
...
@@ -203,7 +193,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
page_block_size
,
block_N
,
num_pages
):
super
(
SparseFlashAttn
,
self
).
__init__
()
self
.
batch
=
batch
...
...
@@ -249,18 +238,11 @@ class SparseFlashAttn(torch.nn.Module):
num_sm
=
self
.
num_sm
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output
=
self
.
kernel
(
query
,
...
...
@@ -275,14 +257,13 @@ class SparseFlashAttn(torch.nn.Module):
return
output
def
ref_program_torch_paged
(
query
,
key_cache
,
value_cache
,
block_indices
,
cache_seqlens
,
block_table
,
page_block_size
,
block_size
):
def
ref_program_torch_paged
(
query
,
key_cache
,
value_cache
,
block_indices
,
cache_seqlens
,
block_table
,
page_block_size
,
block_size
):
"""
Paged version of sparse attention reference implementation.
Args:
query: [batch, heads, dim]
key_cache: [num_pages, page_block_size, heads_kv, dim]
key_cache: [num_pages, page_block_size, heads_kv, dim]
value_cache: [num_pages, page_block_size, heads_kv, dim]
block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices
cache_seqlens: [batch] - actual sequence lengths
...
...
@@ -298,12 +279,8 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
# Reconstruct the full key and value tensors from paged cache
max_cache_seqlen
=
max
(
cache_seqlens
).
item
()
key_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim
),
dtype
=
key_cache
.
dtype
,
device
=
key_cache
.
device
)
value_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim_v
),
dtype
=
value_cache
.
dtype
,
device
=
value_cache
.
device
)
key_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim
),
dtype
=
key_cache
.
dtype
,
device
=
key_cache
.
device
)
value_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim_v
),
dtype
=
value_cache
.
dtype
,
device
=
value_cache
.
device
)
# Reconstruct full tensors from paged cache using block_table
for
b
in
range
(
batch
):
...
...
@@ -319,20 +296,14 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
actual_block_size
=
end_token
-
start_token
# Copy from paged cache to full tensors
key_full
[
b
,
:,
start_token
:
end_token
,
:]
=
key_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
value_full
[
b
,
:,
start_token
:
end_token
,
:]
=
value_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
key_full
[
b
,
:,
start_token
:
end_token
,
:]
=
key_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
value_full
[
b
,
:,
start_token
:
end_token
,
:]
=
value_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
# Reshape query for grouped attention
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
# Compute attention scores
scores
=
einsum
(
query
,
key_full
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
einsum
(
query
,
key_full
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
# Create sparse mask based on block_indices
sparse_mask
=
torch
.
zeros_like
(
scores
)
...
...
@@ -348,24 +319,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
sparse_mask
[
b
,
:,
h
,
start_pos
:
end_pos
]
=
1
# Apply sparse mask
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
# Apply causal mask based on actual sequence lengths
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
scores
.
device
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"
-inf
"
))
# Compute attention weights
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# Apply attention to values
out
=
einsum
(
attention
,
value_full
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
einsum
(
attention
,
value_full
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
# Reshape output back to original format
out
=
rearrange
(
out
,
'
b g h d -> b (h g) d
'
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
"
b g h d -> b (h g) d
"
)
# [batch_size, heads, dim]
return
out
...
...
@@ -373,17 +343,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
def
ref_program_fa
(
query
,
kcache
,
vcache
,
cache_seqlens
,
block_table
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
kcache
,
vcache
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_table
)
output
=
flash_attn_with_kvcache
(
query
,
kcache
,
vcache
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_table
)
output
=
output
.
squeeze
(
1
)
return
output
def
main
(
args
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
)
sparse_ratio
=
args
.
sparse_ratio
block_N
=
args
.
block_N
page_block_size
=
args
.
page_block_size
...
...
@@ -395,35 +371,30 @@ def main(args):
dtype
=
torch
.
float16
# Generate random inputs
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'cuda'
)
cache_seqlens
=
torch
.
randint
(
max_cache_seqlen
//
2
,
max_cache_seqlen
+
1
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'cuda'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
cache_seqlens
=
torch
.
randint
(
max_cache_seqlen
//
2
,
max_cache_seqlen
+
1
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
print
(
"cache_seqlens: "
,
cache_seqlens
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
# Create paged KV cache
K_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'cuda'
)
V_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'cuda'
)
K_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
V_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"cuda"
)
# Create block table and block indices for dense case (all blocks selected)
max_num_blocks_per_seq
=
int
(
math
.
ceil
(
max_cache_seqlen
/
page_block_size
))
print
(
"max_num_blocks_per_seq: "
,
max_num_blocks_per_seq
)
block_table
=
torch
.
zeros
((
batch
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
,
device
=
'cuda'
)
block_indices
=
torch
.
zeros
((
batch
,
heads_kv
,
max_selected_blocks
),
dtype
=
torch
.
int32
,
device
=
'cuda'
)
block_table
=
torch
.
zeros
((
batch
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
block_indices
=
torch
.
zeros
((
batch
,
heads_kv
,
max_selected_blocks
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# Fill block table and block indices and cache
# Create a pool of available physical blocks
total_blocks_needed
=
sum
(
int
(
math
.
ceil
(
cache_seqlens
[
seq_idx
].
item
()
/
page_block_size
))
for
seq_idx
in
range
(
batch
))
total_blocks_needed
=
sum
(
int
(
math
.
ceil
(
cache_seqlens
[
seq_idx
].
item
()
/
page_block_size
))
for
seq_idx
in
range
(
batch
))
available_blocks
=
list
(
range
(
total_blocks_needed
))
import
random
random
.
seed
(
42
)
# For reproducibility
random
.
shuffle
(
available_blocks
)
...
...
@@ -458,10 +429,8 @@ def main(args):
actual_block_size
=
end_token
-
start_token
# Copy K and V data to the paged cache
K_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
K
[
seq_idx
,
start_token
:
end_token
,
:,
:]
V_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
V
[
seq_idx
,
start_token
:
end_token
,
:,
:]
K_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
K
[
seq_idx
,
start_token
:
end_token
,
:,
:]
V_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
V
[
seq_idx
,
start_token
:
end_token
,
:,
:]
# Fill block_indices for sparse attention
# For dense case (verification), we select all blocks in reverse order
...
...
@@ -496,10 +465,9 @@ def main(args):
remaining_blocks
=
[
b
for
b
in
all_blocks
if
b
not
in
selected_blocks
]
if
remaining_blocks
:
import
random
random
.
seed
(
42
)
# For reproducibility
additional_blocks
=
random
.
sample
(
remaining_blocks
,
min
(
num_selected
-
recent_blocks
,
len
(
remaining_blocks
)))
additional_blocks
=
random
.
sample
(
remaining_blocks
,
min
(
num_selected
-
recent_blocks
,
len
(
remaining_blocks
)))
selected_blocks
.
extend
(
additional_blocks
)
# Sort selected blocks in reverse order (most recent first)
...
...
@@ -512,25 +480,20 @@ def main(args):
block_indices
[
seq_idx
,
head_idx
,
i
]
=
-
1
# Initialize sparse attention module
sparse_attn
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
page_block_size
,
block_N
,
num_blocks
)
output_sparse
=
sparse_attn
.
forward
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
)
sparse_attn
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
page_block_size
,
block_N
,
num_blocks
)
output_sparse
=
sparse_attn
.
forward
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
)
import
flash_attn
# noqa: F401
output_ref_torch
=
ref_program_torch_paged
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
,
page_block_size
,
block_N
)
output_ref_torch
=
ref_program_torch_paged
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
,
page_block_size
,
block_N
)
output_ref_fa
=
ref_program_fa
(
Q
,
K_cache
,
V_cache
,
cache_seqlens
,
block_table
)
# Check correctness
if
sparse_ratio
==
0.0
:
max_diff
=
torch
.
max
(
torch
.
abs
(
output_sparse
-
output_ref_fa
)).
item
()
mean_diff
=
torch
.
mean
(
torch
.
abs
(
output_sparse
-
output_ref_fa
)).
item
()
assert
torch
.
allclose
(
output_ref_fa
,
output_ref_torch
,
atol
=
1e-2
),
"Reference outputs do not match!"
assert
torch
.
allclose
(
output_ref_fa
,
output_ref_torch
,
atol
=
1e-2
),
"Reference outputs do not match!"
else
:
max_diff
=
torch
.
max
(
torch
.
abs
(
output_sparse
-
output_ref_torch
)).
item
()
mean_diff
=
torch
.
mean
(
torch
.
abs
(
output_sparse
-
output_ref_torch
)).
item
()
...
...
@@ -574,16 +537,15 @@ def main(args):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.0
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
'--block_N'
,
type
=
int
,
default
=
64
,
help
=
'block_N'
)
parser
.
add_argument
(
'--page_block_size'
,
type
=
int
,
default
=
256
,
help
=
'block size of pages'
)
parser
.
add_argument
(
'--num_pages'
,
type
=
int
,
default
=
1024
,
help
=
'total number of pages'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.0
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_N"
,
type
=
int
,
default
=
64
,
help
=
"block_N"
)
parser
.
add_argument
(
"--page_block_size"
,
type
=
int
,
default
=
256
,
help
=
"block size of pages"
)
parser
.
add_argument
(
"--num_pages"
,
type
=
int
,
default
=
1024
,
help
=
"total number of pages"
)
args
=
parser
.
parse_args
()
main
(
args
)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
View file @
29051439
...
...
@@ -10,17 +10,18 @@ from heuristic import num_splits_heuristic
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
heads_kv
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
}
)
def
kernel_func
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
,
max_cache_seqlen
,
max_selected_blocks
):
}
,
)
def
kernel_func
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
,
max_cache_seqlen
,
max_selected_blocks
):
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim
]
shape_v
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
]
...
...
@@ -31,17 +32,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
...
...
@@ -62,7 +62,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
...
@@ -70,7 +70,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
num_blocks
=
max_selected_blocks
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
loop_range
=
(
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
has_valid_block
=
False
...
...
@@ -78,26 +78,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
i_s
=
block_indices
[
bid
,
cur_kv_head
,
start
+
k
]
if
i_s
>=
0
:
has_valid_block
=
True
T
.
copy
(
K
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
K
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
k
==
0
:
# assume block_indices is sorted in reverse order, otherwise, remove this if condition
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
i_s
*
block_N
+
j
>=
cache_seqlens
[
bid
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
i_s
*
block_N
+
j
>=
cache_seqlens
[
bid
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
...
...
@@ -106,7 +98,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
copy
(
V
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_valid_block
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
...
...
@@ -125,9 +117,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
...
...
@@ -138,17 +130,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
max_split
=
T
.
alloc_local
([
1
],
"int32"
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
lse_max_local
[
0
]
=
-
T
.
infinity
(
accum_dtype
)
for
k
in
T
.
serial
(
num_split
):
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
if
(
lse_local_split
[
0
]
!=
0
)
:
if
lse_local_split
[
0
]
!=
0
:
max_split
[
0
]
=
k
lse_max_local
[
0
]
=
T
.
max
(
lse_max_local
[
0
],
glse
[
bz
,
by
,
k
])
...
...
@@ -170,15 +163,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
# flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial)
flash_attn_split
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
glse
,
Output_partial
)
...
...
@@ -190,7 +183,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
):
super
(
SparseFlashAttn
,
self
).
__init__
()
self
.
batch
=
batch
...
...
@@ -209,7 +201,8 @@ class SparseFlashAttn(torch.nn.Module):
num_stages
=
2
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
))
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
),
)
props
=
torch
.
cuda
.
get_device_properties
(
torch
.
device
(
"cuda:0"
))
self
.
num_sm
=
props
.
multi_processor_count
...
...
@@ -232,25 +225,17 @@ class SparseFlashAttn(torch.nn.Module):
num_sm
=
self
.
num_sm
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output
=
self
.
kernel
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
glse
,
output_partial
)
return
output
def
sparse_gqa_decode_varlen_indice
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
block_size
):
def
sparse_gqa_decode_varlen_indice
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
block_size
):
"""
Args:
query: [batch, heads, dim]
...
...
@@ -272,31 +257,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
block_H
=
64
actual_num_blocks
=
torch
.
sum
(
block_indices
!=
-
1
,
dim
=-
1
).
to
(
torch
.
int32
)
actual_num_blocks
=
actual_num_blocks
[:,
0
]
#[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
actual_num_blocks
=
actual_num_blocks
[
:,
0
]
# [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
# get num_split
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
#(kv_seqlen + block_size - 1 ) // block_size
num_n_blocks
=
max_selected_blocks
#
(kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
132
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
kernel
=
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
)(
block_N
=
block_size
,
block_H
=
block_H
,
...
...
@@ -304,29 +282,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
num_stages
=
2
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
))
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
),
)
output
=
kernel
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
glse
,
Output_partial
)
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values based on block_indices
...
...
@@ -335,28 +308,26 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
for
idx
in
valid_indices
:
if
idx
>=
0
:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
...
...
@@ -368,23 +339,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
print
(
name
+
" all_close={}"
.
format
(
all_close
))
if
not
all_close
:
diff
=
(
expect
-
actual
).
abs
()
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
max_indices
=
torch
.
nonzero
(
diff
==
diff
.
max
().
item
())
first_index
=
tuple
(
max_indices
[
0
].
tolist
())
print
(
f
"Index:
{
first_index
}
, expect:
{
expect
[
first_index
]
}
, actual:
{
actual
[
first_index
]
}
"
)
def
main
(
batch
=
8
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
def
main
(
batch
=
8
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
block_size
=
block_size
...
...
@@ -392,10 +353,10 @@ def main(batch=8,
print
(
"max_selected_blocks: "
,
max_selected_blocks
)
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# # Ensure at least one element equals cache_seqlen
# random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
...
...
@@ -406,10 +367,7 @@ def main(batch=8,
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_indices with -1 (for padding blocks)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size)
# block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda')
...
...
@@ -418,10 +376,9 @@ def main(batch=8,
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
if
max_valid_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
# valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks]
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
# Sort indices within each batch-group for consistency
block_indices
,
_
=
block_indices
.
sort
(
dim
=-
1
,
descending
=
True
)
...
...
@@ -434,8 +391,7 @@ def main(batch=8,
print
(
"max_num_blocks: "
,
max_num_blocks
)
# parity reference
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
sparse_kernel
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
)
out
=
sparse_kernel
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
)
...
...
@@ -445,13 +401,11 @@ def main(batch=8,
## latency reference
for
_
in
range
(
10
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
_
in
range
(
100
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
print
(
"dense time: "
,
(
time
.
time
()
-
start
)
/
100
*
1000
)
...
...
@@ -469,15 +423,13 @@ def main(batch=8,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
View file @
29051439
...
...
@@ -12,15 +12,17 @@ from heuristic import num_splits_heuristic
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
accum_dtype
=
"float"
kv_group_num
=
heads
//
heads_kv
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
kernel_func
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
,
max_cache_seqlen
,
num_blocks
):
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim
]
...
...
@@ -32,16 +34,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_mask
:
T
.
Tensor
(
shape_mask
,
"bool"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_mask
:
T
.
Tensor
(
shape_mask
,
"bool"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
...
...
@@ -62,39 +63,31 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
loop_range
=
(
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
has_valid_block
=
False
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
bid
,
hid
,
start
+
k
]:
has_valid_block
=
True
T
.
copy
(
K
[
bid
,
(
start
+
k
)
*
block_N
:(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
K
[
bid
,
(
start
+
k
)
*
block_N
:
(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
start
+
k
)
*
block_N
+
j
>=
cache_seqlens
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
)
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
start
+
k
)
*
block_N
+
j
>=
cache_seqlens
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
...
...
@@ -103,9 +96,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
(
start
+
k
)
*
block_N
:(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
copy
(
V
[
bid
,
(
start
+
k
)
*
block_N
:
(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_valid_block
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
...
...
@@ -123,9 +114,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
...
...
@@ -135,10 +126,11 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
...
...
@@ -161,14 +153,14 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_mask
:
T
.
Tensor
(
shape_mask
,
"bool"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_mask
:
T
.
Tensor
(
shape_mask
,
"bool"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
flash_attn_split
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
...
...
@@ -179,7 +171,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
):
super
(
SparseFlashAttn
,
self
).
__init__
()
self
.
batch
=
batch
...
...
@@ -198,7 +189,8 @@ class SparseFlashAttn(torch.nn.Module):
num_stages
=
2
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
num_blocks
=
T
.
dynamic
(
"num_blocks"
))
num_blocks
=
T
.
dynamic
(
"num_blocks"
),
)
props
=
torch
.
cuda
.
get_device_properties
(
torch
.
device
(
"cuda:0"
))
self
.
num_sm
=
props
.
multi_processor_count
...
...
@@ -217,24 +209,16 @@ class SparseFlashAttn(torch.nn.Module):
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
# num_sm = 132
num_sm
=
self
.
num_sm
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_split: ", num_split)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output
=
self
.
kernel
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
return
output
...
...
@@ -259,26 +243,21 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
block_H
=
64
actual_num_blocks
=
torch
.
sum
(
block_mask
,
dim
=-
1
).
to
(
torch
.
int32
)
actual_num_blocks
=
actual_num_blocks
[:,
0
]
#[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
actual_num_blocks
=
actual_num_blocks
[
:,
0
]
# [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
max_selected_blocks
=
actual_num_blocks
.
max
().
item
()
# get num_split
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
#(kv_seqlen + block_size - 1 ) // block_size
num_n_blocks
=
max_selected_blocks
#
(kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
132
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
kernel
=
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
)(
block_N
=
block_size
,
...
...
@@ -287,11 +266,10 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
num_stages
=
2
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
num_blocks
=
T
.
dynamic
(
"num_blocks"
))
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
num_blocks
=
T
.
dynamic
(
"num_blocks"
),
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
# print(kernel.get_kernel_source())
output
=
kernel
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
...
...
@@ -299,24 +277,18 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values
...
...
@@ -324,29 +296,27 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
for
h
in
range
(
heads_kv
):
for
idx
in
range
(
num_blocks
):
if
block_mask
[
b
,
h
,
idx
]:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
...
...
@@ -360,23 +330,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
# print(expect[3, 28])
# print(actual[3, 28])
diff
=
(
expect
-
actual
).
abs
()
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
max_indices
=
torch
.
nonzero
(
diff
==
diff
.
max
().
item
())
first_index
=
tuple
(
max_indices
[
0
].
tolist
())
print
(
f
"Index:
{
first_index
}
, expect:
{
expect
[
first_index
]
}
, actual:
{
actual
[
first_index
]
}
"
)
def
main
(
batch
=
8
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
def
main
(
batch
=
8
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
block_size
=
block_size
...
...
@@ -384,14 +344,13 @@ def main(batch=8,
print
(
"max_selected_blocks: "
,
max_selected_blocks
)
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
'cuda'
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
print
(
"cache_seqlens: "
,
cache_seqlens
)
...
...
@@ -403,7 +362,7 @@ def main(batch=8,
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_mask with false (for padding blocks)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
'
cuda
'
)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
"
cuda
"
)
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
...
...
@@ -411,13 +370,12 @@ def main(batch=8,
valid_num_block
=
valid_num_blocks
[
b
].
item
()
# Valid blocks for this batch
if
valid_num_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
'
cuda
'
)[:
valid_num_block
]
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
"
cuda
"
)[:
valid_num_block
]
block_mask
[
b
,
h
,
perm
]
=
True
# print("block_mask: ", block_mask)
# parity reference
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
model
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
)
out
=
model
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
)
...
...
@@ -427,13 +385,11 @@ def main(batch=8,
## latency reference
for
_
in
range
(
10
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
_
in
range
(
100
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
print
(
"dense time: "
,
(
time
.
time
()
-
start
)
/
100
*
1000
)
...
...
@@ -452,15 +408,13 @@ def main(batch=8,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
View file @
29051439
...
...
@@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_H'
,
'BLOCK_N'
,
'BLOCK_D'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_H"
,
"BLOCK_N"
,
"BLOCK_D"
],
)
@
triton
.
jit
def
_split_kernel
(
...
...
@@ -79,16 +75,11 @@ def _split_kernel(
loop_range
=
blocks_per_split
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
for
i
in
range
(
loop_range
):
block_idx
=
tl
.
load
(
mask_ptr
+
(
start
+
i
)
*
stride_mask_s
)
...
...
@@ -119,23 +110,18 @@ def _split_kernel(
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
o_partial_ptr
+=
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
o_partial_ptr
+=
(
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_D'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_D"
],
)
@
triton
.
jit
def
_merge_kernel
(
...
...
@@ -163,18 +149,15 @@ def _merge_kernel(
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_max
=
tl
.
max
(
lse
)
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_partial
=
tl
.
load
(
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
)
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
,
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
...
...
@@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton(
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
64
# num_sm = self.num_sm
num_splits
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...
...
@@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton(
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
dim_v
=
value
.
shape
[
-
1
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values based on block_indices
...
...
@@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
for
idx
in
valid_indices
:
if
idx
>=
0
:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
return
output
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
block_size
=
block_size
...
...
@@ -369,34 +331,29 @@ def main(batch=64,
dtype
=
torch
.
float16
block_H
=
64
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
'cuda'
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
print
(
"cache_seqlens: "
,
cache_seqlens
)
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_indices with -1 (for padding blocks)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
if
max_valid_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
# Sort indices within each batch-group for consistency
block_indices
,
_
=
block_indices
.
sort
(
dim
=-
1
,
descending
=
True
)
...
...
@@ -408,8 +365,7 @@ def main(batch=64,
max_num_blocks
=
torch
.
max
(
max_valid_num_blocks
).
item
()
print
(
"max_num_blocks: "
,
max_num_blocks
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
triton_out
=
block_sparse_flash_decode_gqa_indice_triton
(
Q
,
...
...
@@ -423,8 +379,7 @@ def main(batch=64,
)
print
(
"max difference: "
,
torch
.
max
(
torch
.
abs
(
ref
-
triton_out
)))
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
print
(
"Passed the ref test!"
)
# Measure performance
...
...
@@ -466,15 +421,13 @@ def main(batch=64,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
64
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
64
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
View file @
29051439
...
...
@@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_H'
,
'BLOCK_N'
,
'BLOCK_D'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_H"
,
"BLOCK_N"
,
"BLOCK_D"
],
)
@
triton
.
jit
def
_split_kernel
(
...
...
@@ -77,16 +73,11 @@ def _split_kernel(
loop_range
=
blocks_per_split
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
for
block_idx
in
range
(
loop_range
):
start_n
=
(
start
+
block_idx
)
*
BLOCK_N
...
...
@@ -117,23 +108,18 @@ def _split_kernel(
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
o_partial_ptr
+=
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
o_partial_ptr
+=
(
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_D'
],
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_D"
],
)
@
triton
.
jit
def
_merge_kernel
(
...
...
@@ -161,18 +147,15 @@ def _merge_kernel(
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_max
=
tl
.
max
(
lse
)
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_partial
=
tl
.
load
(
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
)
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
,
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
...
...
@@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton(
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
64
# num_sm = self.num_sm
num_splits
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...
...
@@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton(
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values
...
...
@@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
for
h
in
range
(
heads_kv
):
for
idx
in
range
(
num_blocks
):
if
block_mask
[
b
,
h
,
idx
]:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
return
output
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
block_size
=
block_size
sparse_ratio
=
sparse_ratio
...
...
@@ -363,14 +325,13 @@ def main(batch=64,
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
'cuda'
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
num_blocks
=
(
max_cache_seqlen
+
block_size
-
1
)
//
block_size
...
...
@@ -379,7 +340,7 @@ def main(batch=64,
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_mask with false (for padding blocks)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
'
cuda
'
)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
"
cuda
"
)
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
...
...
@@ -387,11 +348,10 @@ def main(batch=64,
valid_num_block
=
valid_num_blocks
[
b
].
item
()
# Valid blocks for this batch
if
valid_num_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
'
cuda
'
)[:
valid_num_block
]
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
"
cuda
"
)[:
valid_num_block
]
block_mask
[
b
,
h
,
perm
]
=
True
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
triton_out
=
block_sparse_flash_decode_gqa_mask_triton
(
Q
,
...
...
@@ -404,8 +364,7 @@ def main(batch=64,
)
# print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
print
(
"Passed the ref test!"
)
# Measure performance
...
...
@@ -448,15 +407,13 @@ def main(batch=64,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
64
,
help
=
'batch size'
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
64
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/heuristic.py
View file @
29051439
import
math
def
num_splits_heuristic
(
total_mblocks
,
num_SMs
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
,
max_splits
):
def
num_splits_heuristic
(
total_mblocks
,
num_SMs
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
,
max_splits
):
"""
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
...
...
examples/blocksparse_attention/test_example_blocksparse_attention.py
View file @
29051439
...
...
@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def
test_example_triton_sparse_gqa_decode_varlen_indice
():
example_triton_sparse_gqa_decode_varlen_indice
.
main
(
batch
=
8
,
heads
=
8
,
heads_kv
=
4
,
max_cache_seqlen
=
2048
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
batch
=
8
,
heads
=
8
,
heads_kv
=
4
,
max_cache_seqlen
=
2048
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
def
test_example_triton_sparse_gqa_decode_varlen_mask
():
example_triton_sparse_gqa_decode_varlen_mask
.
main
(
batch
=
16
,
heads
=
16
,
heads_kv
=
8
,
max_cache_seqlen
=
1024
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
batch
=
16
,
heads
=
16
,
heads_kv
=
8
,
max_cache_seqlen
=
1024
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
if
__name__
==
"__main__"
:
...
...
examples/blocksparse_gemm/example_blocksparse_gemm.py
View file @
29051439
...
...
@@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--sparsity"
,
type
=
float
,
default
=
0.5
,
help
=
"Sparsity ratio (0-1)"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune"
)
args
,
_
=
parser
.
parse_known_args
()
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
...
...
@@ -41,17 +40,19 @@ def get_configs():
thread_num
=
[
128
,
256
]
enable_rasterization
=
[
True
,
False
]
_configs
=
list
(
itertools
.
product
(
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasterization
))
_configs
=
list
(
itertools
.
product
(
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasterization
))
return
[{
"block_M"
:
c
[
0
],
"block_N"
:
c
[
1
],
"block_K"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
}
for
c
in
_configs
]
return
[
{
"block_M"
:
c
[
0
],
"block_N"
:
c
[
1
],
"block_K"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
}
for
c
in
_configs
]
def
ref_program
(
A
,
B
,
BlockMask
,
block_M
,
block_N
,
block_K
):
...
...
@@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
for
k
in
range
(
K
//
block_K
):
if
BlockMask
[
i
,
j
,
k
]:
accu
+=
(
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
k
*
block_K
:(
k
+
1
)
*
block_K
,
j
*
block_N
:(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
))
ref_c
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
accu
+=
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
k
*
block_K
:
(
k
+
1
)
*
block_K
,
j
*
block_N
:
(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
)
ref_c
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
return
ref_c
...
...
@@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]):
return
input_tensors
@
tilelang
.
autotune
(
configs
=
get_configs
(),)
@
tilelang
.
autotune
(
configs
=
get_configs
(),
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
blocksparse_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
blocksparse_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
@
T
.
prim_func
def
block_sparse_matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
@@ -134,7 +126,6 @@ def blocksparse_matmul(M,
def
main
():
# Initialize input matrices A and B on the GPU with half precision
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
b
=
torch
.
randn
(
K
,
N
).
cuda
().
half
()
...
...
@@ -147,8 +138,7 @@ def main():
best_config
=
kernel
.
config
best_latency
=
kernel
.
latency
block_M
,
block_N
,
block_K
=
best_config
[
"block_M"
],
best_config
[
"block_N"
],
best_config
[
"block_K"
]
block_M
,
block_N
,
block_K
=
best_config
[
"block_M"
],
best_config
[
"block_N"
],
best_config
[
"block_K"
]
print
(
f
"Best Config:
{
best_config
}
"
)
print
(
f
"Sparsity Ratio:
{
sparsity
}
"
)
...
...
@@ -163,7 +153,8 @@ def main():
block_K
=
DEFAULT_BLOCK_K
,
num_stages
=
DEFAULT_NUM_STAGES
,
thread_num
=
DEFAULT_THREAD_NUM
,
enable_rasteration
=
DEFAULT_ENABLE_RASTERIZATION
)
enable_rasteration
=
DEFAULT_ENABLE_RASTERIZATION
,
)
block_M
,
block_N
,
block_K
=
DEFAULT_BLOCK_M
,
DEFAULT_BLOCK_N
,
DEFAULT_BLOCK_K
print
(
f
"Using default kernel with block size (
{
block_M
}
,
{
block_N
}
,
{
block_K
}
)"
)
# Create block mask with desired sparsity
...
...
examples/cast/example_group_per_split_token_cast_to_fp8.py
View file @
29051439
...
...
@@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
fp8_max
=
448.0
@
T
.
prim_func
def
group_per_split_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
batch_sizes
:
T
.
Tensor
(
(
BG
,),
"int32"
),
X_fp8
:
T
.
Tensor
((
BG
,
M_max
,
N
),
"float8_e4m3"
),
X_amax
:
T
.
Tensor
(
(
BG
,
M_max
,
T
.
ceildiv
(
N
,
group_size
)),
accum_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
M_max
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
BG
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
def
group_per_split_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
batch_sizes
:
T
.
Tensor
((
BG
,),
"int32"
),
X_fp8
:
T
.
Tensor
((
BG
,
M_max
,
N
),
"float8_e4m3"
),
X_amax
:
T
.
Tensor
((
BG
,
M_max
,
T
.
ceildiv
(
N
,
group_size
)),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
M_max
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
BG
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
row
=
bx
row_g_id
=
by
bg
=
bz
...
...
@@ -31,36 +33,32 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
"float8_e4m3"
)
row_offset
=
T
.
alloc_fragment
((
1
,),
"int32"
)
T
.
annotate_layout
({
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
})
T
.
annotate_layout
(
{
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
}
)
row_offset
[
0
]
=
0
for
i
in
T
.
serial
(
bg
):
row_offset
[
0
]
+=
batch_sizes
[
i
]
T
.
copy
(
X
[
row_offset
[
0
]
+
row
*
blk_m
:
row_offset
[
0
]
+
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
],
y_local
)
X
[
row_offset
[
0
]
+
row
*
blk_m
:
row_offset
[
0
]
+
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
],
y_local
,
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
y_s_local
[
i
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_amax_local
[
i
]
/
fp8_max
,
0
)
y_s_local
[
i
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_amax_local
[
i
]
/
fp8_max
,
0
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
y_q_local
[
i
,
j
]
=
T
.
clamp
(
y_local
[
i
,
j
]
/
y_s_local
[
i
],
fp8_min
,
fp8_max
)
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
y_q_local_fp8
[
i
,
j
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_q_local
[
i
,
j
],
0
)
y_q_local_fp8
[
i
,
j
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_q_local
[
i
,
j
],
0
)
for
i
in
T
.
Parallel
(
blk_m
):
X_amax
[
bg
,
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
bg
,
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
])
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
bg
,
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
])
return
group_per_split_token_cast
...
...
@@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
return
x
.
squeeze
(
0
)
if
remove_dim
else
x
# Normal layout requires transposing
aligned_x
=
torch
.
transpose
(
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
aligned_x
=
torch
.
transpose
(
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
aligned_x
[:,
:
m
,
:]
=
x
aligned_x
=
aligned_x
[:,
:
m
,
:]
return
aligned_x
.
squeeze
(
0
)
if
remove_dim
else
aligned_x
...
...
@@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
x_fp8
=
x_fp8
.
view
(
m
,
-
1
)[:,
:
n
].
contiguous
()
return
x_fp8
,
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
ref_program
(
x
:
torch
.
Tensor
,
batch_sizes
:
torch
.
Tensor
)
->
\
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
ref_program
(
x
:
torch
.
Tensor
,
batch_sizes
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# assert x.shape[0] == batch_sizes.sum()
M_max
=
ceil_div
(
batch_sizes
.
max
(),
128
)
*
128
split_x
=
torch
.
split
(
x
,
batch_sizes
.
tolist
(),
dim
=
0
)
padded_x
=
[
torch
.
nn
.
functional
.
pad
(
t
,
(
0
,
0
,
0
,
M_max
-
t
.
shape
[
0
]))
for
t
in
split_x
]
num_groups
,
m
,
n
=
batch_sizes
.
shape
[
0
],
M_max
,
x
.
shape
[
1
]
x_fp8
=
(
torch
.
empty
((
num_groups
,
m
,
n
),
device
=
'cuda'
,
dtype
=
torch
.
float8_e4m3fn
),
torch
.
empty
((
num_groups
,
m
,
n
//
128
),
device
=
'cuda'
,
dtype
=
torch
.
float
))
x_fp8
=
(
torch
.
empty
((
num_groups
,
m
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
),
torch
.
empty
((
num_groups
,
m
,
n
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float
),
)
for
i
in
range
(
num_groups
):
x_fp8
[
0
][
i
],
x_fp8
[
1
][
i
]
=
ref_per_token_cast_to_fp8
(
padded_x
[
i
])
x_fp8
=
(
x_fp8
[
0
],
get_col_major_tma_aligned_tensor
(
x_fp8
[
1
]))
...
...
examples/cast/example_per_token_cast_to_fp8.py
View file @
29051439
...
...
@@ -13,8 +13,9 @@ def per_token_cast_to_fp8(M, N, blk_m):
fp8_max
=
448.0
@
T
.
prim_func
def
per_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
X_fp8
:
T
.
Tensor
((
M
,
N
),
"float8_e4m3"
),
X_amax
:
T
.
Tensor
((
M
,
T
.
ceildiv
(
N
,
group_size
)),
dtype
)):
def
per_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
X_fp8
:
T
.
Tensor
((
M
,
N
),
"float8_e4m3"
),
X_amax
:
T
.
Tensor
((
M
,
T
.
ceildiv
(
N
,
group_size
)),
dtype
)
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
threads
=
128
)
as
(
bx
,
by
):
row
=
bx
row_g_id
=
by
...
...
@@ -24,16 +25,13 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_q_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
dtype
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
"float8_e4m3"
)
T
.
annotate_layout
({
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
})
T
.
annotate_layout
(
{
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
}
)
T
.
copy
(
X
[
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
],
y_local
)
T
.
copy
(
X
[
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
],
y_local
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
...
...
@@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
for
i
in
T
.
Parallel
(
blk_m
):
X_amax
[
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
])
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
])
return
per_token_cast
...
...
@@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8):
from
example_triton_cast_to_fp8
import
per_token_group_quant_fp8
def
run_triton
():
x_fp8_triton_
,
x_amax_triton_
=
per_token_group_quant_fp8
(
x
,
128
,
1e-4
,
dtype
=
torch
.
float8_e4m3fn
,
column_major_scales
=
False
)
x_fp8_triton_
,
x_amax_triton_
=
per_token_group_quant_fp8
(
x
,
128
,
1e-4
,
dtype
=
torch
.
float8_e4m3fn
,
column_major_scales
=
False
)
return
x_fp8_triton_
,
x_amax_triton_
x_fp8_triton
,
x_amax_triton
=
run_triton
()
...
...
examples/cast/example_triton_cast_to_fp8.py
View file @
29051439
...
...
@@ -128,9 +128,7 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
(
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible "
f
"by `group_size`
{
group_size
}
"
)
assert
x
.
shape
[
-
1
]
%
group_size
==
0
,
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible by `group_size`
{
group_size
}
"
assert
x
.
stride
(
-
1
)
==
1
,
"`x` groups must be contiguous"
finfo
=
torch
.
finfo
(
dtype
)
...
...
examples/cast/test_example_cast.py
View file @
29051439
...
...
@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8
def
test_example_group_per_split_token_cast_to_fp8
():
example_group_per_split_token_cast_to_fp8
.
main
(
M
=
1024
,
N
=
1024
,
BG
=
2
,
blk_m
=
4
,
batch_sizes
=
[
128
,
896
])
example_group_per_split_token_cast_to_fp8
.
main
(
M
=
1024
,
N
=
1024
,
BG
=
2
,
blk_m
=
4
,
batch_sizes
=
[
128
,
896
])
def
test_example_per_token_cast_to_fp8
():
...
...
examples/compile_flags/usecase.py
View file @
29051439
...
...
@@ -4,12 +4,11 @@ import tilelang.language as T
# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
...
@@ -36,8 +35,7 @@ block_K = 32
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
)
jit_kernel
=
tilelang
.
compile
(
func
,
out_idx
=
[
2
],
target
=
"cuda"
,
compile_flags
=
"-O3 --use_fast_math --expt-relaxed-constexpr"
)
jit_kernel
=
tilelang
.
compile
(
func
,
out_idx
=
[
2
],
target
=
"cuda"
,
compile_flags
=
"-O3 --use_fast_math --expt-relaxed-constexpr"
)
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"])
...
...
examples/conftest.py
View file @
29051439
...
...
@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
"warnings"
,
"error"
,
}
if
(
sum
(
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
):
if
sum
(
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
:
terminalreporter
.
write_sep
(
"!"
,
(
f
"Error: No tests were collected. "
f
"
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
(
f
"Error: No tests were collected.
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
)
pytest
.
exit
(
"No tests were collected."
,
returncode
=
5
)
examples/convolution/example_convolution.py
View file @
29051439
...
...
@@ -14,7 +14,6 @@ def check_hopper():
def
ref_program
(
stride
,
padding
,
dilation
):
def
main
(
A
,
B
):
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
...
...
@@ -26,22 +25,7 @@ def ref_program(stride, padding, dilation):
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
...
...
@@ -51,13 +35,11 @@ def convolution(N,
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -66,11 +48,13 @@ def convolution(N,
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
T
.
annotate_layout
({
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
})
T
.
annotate_layout
(
{
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
}
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
...
@@ -82,10 +66,8 @@ def convolution(N,
m
=
by
*
block_M
+
i
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
in_bound
=
((
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
))
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
in_bound
=
(
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
)
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
...
...
@@ -97,15 +79,15 @@ def convolution(N,
def
main
(
argv
=
None
):
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--n
'
,
type
=
int
,
default
=
128
,
help
=
'n'
)
parser
.
add_argument
(
'
--c
'
,
type
=
int
,
default
=
128
,
help
=
'c'
)
parser
.
add_argument
(
'
--h
'
,
type
=
int
,
default
=
64
,
help
=
'h'
)
parser
.
add_argument
(
'
--w
'
,
type
=
int
,
default
=
64
,
help
=
'w'
)
parser
.
add_argument
(
'
--f
'
,
type
=
int
,
default
=
128
,
help
=
'f'
)
parser
.
add_argument
(
'
--k
'
,
type
=
int
,
default
=
3
,
help
=
'k'
)
parser
.
add_argument
(
'
--s
'
,
type
=
int
,
default
=
1
,
help
=
's'
)
parser
.
add_argument
(
'
--d
'
,
type
=
int
,
default
=
1
,
help
=
'd'
)
parser
.
add_argument
(
'
--p
'
,
type
=
int
,
default
=
1
,
help
=
'p'
)
parser
.
add_argument
(
"
--n
"
,
type
=
int
,
default
=
128
,
help
=
"n"
)
parser
.
add_argument
(
"
--c
"
,
type
=
int
,
default
=
128
,
help
=
"c"
)
parser
.
add_argument
(
"
--h
"
,
type
=
int
,
default
=
64
,
help
=
"h"
)
parser
.
add_argument
(
"
--w
"
,
type
=
int
,
default
=
64
,
help
=
"w"
)
parser
.
add_argument
(
"
--f
"
,
type
=
int
,
default
=
128
,
help
=
"f"
)
parser
.
add_argument
(
"
--k
"
,
type
=
int
,
default
=
3
,
help
=
"k"
)
parser
.
add_argument
(
"
--s
"
,
type
=
int
,
default
=
1
,
help
=
"s"
)
parser
.
add_argument
(
"
--d
"
,
type
=
int
,
default
=
1
,
help
=
"d"
)
parser
.
add_argument
(
"
--p
"
,
type
=
int
,
default
=
1
,
help
=
"p"
)
args
=
parser
.
parse_args
(
argv
)
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
...
...
examples/convolution/example_convolution_autotune.py
View file @
29051439
...
...
@@ -14,7 +14,6 @@ def check_hopper():
def
ref_program
(
stride
,
padding
,
dilation
):
def
main
(
A
,
B
):
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
...
...
@@ -40,7 +39,8 @@ def get_configs():
num_stages
,
thread_num
,
enable_rasterization
,
))
)
)
configs
=
[
{
...
...
@@ -50,7 +50,8 @@ def get_configs():
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
return
configs
...
...
@@ -64,53 +65,18 @@ def get_heuristic_config() -> dict:
sm_version
=
sm_major
*
10
+
sm_minor
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
if
sm_version
in
{
80
}:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
elif
sm_version
in
{
90
}:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
else
:
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
@
tilelang
.
autotune
(
configs
=
get_configs
())
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
...
...
@@ -120,13 +86,11 @@ def convolution(N,
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -136,9 +100,11 @@ def convolution(N,
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
if
is_hopper
:
T
.
annotate_layout
({
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
})
T
.
annotate_layout
(
{
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
}
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
...
@@ -150,10 +116,8 @@ def convolution(N,
m
=
by
*
block_M
+
i
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
in_bound
=
((
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
))
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
in_bound
=
(
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
)
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
...
...
@@ -166,17 +130,19 @@ def convolution(N,
return
main
def
main
(
n
:
int
=
128
,
c
:
int
=
128
,
h
:
int
=
64
,
w
:
int
=
64
,
f
:
int
=
128
,
k
:
int
=
3
,
s
:
int
=
1
,
d
:
int
=
1
,
p
:
int
=
1
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
True
):
def
main
(
n
:
int
=
128
,
c
:
int
=
128
,
h
:
int
=
64
,
w
:
int
=
64
,
f
:
int
=
128
,
k
:
int
=
3
,
s
:
int
=
1
,
d
:
int
=
1
,
p
:
int
=
1
,
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
True
,
):
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
n
,
c
,
h
,
w
,
f
,
k
,
s
,
d
,
p
ref_prog
=
ref_program
(
S
,
P
,
D
)
...
...
@@ -196,25 +162,16 @@ def main(n: int = 128,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned MatMul Benchmark"
)
parser
.
add_argument
(
'--n'
,
type
=
int
,
default
=
128
,
help
=
'n'
)
parser
.
add_argument
(
'--c'
,
type
=
int
,
default
=
128
,
help
=
'c'
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
64
,
help
=
'h'
)
parser
.
add_argument
(
'--w'
,
type
=
int
,
default
=
64
,
help
=
'w'
)
parser
.
add_argument
(
'--f'
,
type
=
int
,
default
=
128
,
help
=
'f'
)
parser
.
add_argument
(
'--k'
,
type
=
int
,
default
=
3
,
help
=
'k'
)
parser
.
add_argument
(
'--s'
,
type
=
int
,
default
=
1
,
help
=
's'
)
parser
.
add_argument
(
'--d'
,
type
=
int
,
default
=
1
,
help
=
'd'
)
parser
.
add_argument
(
'--p'
,
type
=
int
,
default
=
1
,
help
=
'p'
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
True
,
help
=
"Whether to enable BitBLAS roller for search space"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
128
,
help
=
"n"
)
parser
.
add_argument
(
"--c"
,
type
=
int
,
default
=
128
,
help
=
"c"
)
parser
.
add_argument
(
"--h"
,
type
=
int
,
default
=
64
,
help
=
"h"
)
parser
.
add_argument
(
"--w"
,
type
=
int
,
default
=
64
,
help
=
"w"
)
parser
.
add_argument
(
"--f"
,
type
=
int
,
default
=
128
,
help
=
"f"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
3
,
help
=
"k"
)
parser
.
add_argument
(
"--s"
,
type
=
int
,
default
=
1
,
help
=
"s"
)
parser
.
add_argument
(
"--d"
,
type
=
int
,
default
=
1
,
help
=
"d"
)
parser
.
add_argument
(
"--p"
,
type
=
int
,
default
=
1
,
help
=
"p"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
True
,
help
=
"Whether to enable BitBLAS roller for search space"
)
args
=
parser
.
parse_args
()
main
(
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
,
args
.
use_autotune
,
args
.
with_roller
)
main
(
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
,
args
.
use_autotune
,
args
.
with_roller
)
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
View file @
29051439
...
...
@@ -41,14 +41,13 @@ def tl_gemm(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
scales_a
:
T
.
Tensor
(
Scales_A_shape
,
"float32"
),
scales_b
:
T
.
Tensor
(
Scales_B_shape
,
"float32"
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
scales_a
:
T
.
Tensor
(
Scales_A_shape
,
"float32"
),
scales_b
:
T
.
Tensor
(
Scales_B_shape
,
"float32"
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
)
...
...
@@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
ceildiv
(
m
,
128
)
*
128
,
ceildiv
(
n
,
128
)
*
128
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
=
torch
.
zeros
(
ceildiv
(
m
,
128
)
*
128
,
ceildiv
(
n
,
128
)
*
128
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
def
ref_deepgemm_fp8
(
A_fp8
,
B_fp8
,
A_scale
,
B_scale
,
out_dtype
):
...
...
@@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
c_acc
.
zero_
()
for
k
in
range
(
ceildiv
(
K
,
128
)):
c
=
torch
.
_scaled_mm
(
A_fp8
[
i
*
128
:
(
i
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
],
B_fp8
[
j
*
128
:
(
j
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
].
T
,
A_fp8
[
i
*
128
:
(
i
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
],
B_fp8
[
j
*
128
:
(
j
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
].
T
,
scale_a
=
A_scales
[
i
,
k
].
view
(
128
,
1
).
contiguous
(),
scale_b
=
B_scales
[
j
,
k
].
view
(
1
,
128
).
contiguous
(),
out_dtype
=
torch
.
bfloat16
)
out_dtype
=
torch
.
bfloat16
,
)
c_acc
+=
c
.
to
(
torch
.
float32
)
C
[
i
*
128
:
(
i
+
1
)
*
128
,
j
*
128
:
(
j
+
1
)
*
128
]
=
c_acc
.
to
(
out_dtype
)
C
[
i
*
128
:
(
i
+
1
)
*
128
,
j
*
128
:
(
j
+
1
)
*
128
]
=
c_acc
.
to
(
out_dtype
)
return
C
...
...
Prev
1
2
3
4
5
6
7
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment