Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
623 additions
and
986 deletions
+623
-986
examples/bitnet-1.58b/vllm_workspace/utils.py
examples/bitnet-1.58b/vllm_workspace/utils.py
+6
-17
examples/blocksparse_attention/block_sparse_attn_triton.py
examples/blocksparse_attention/block_sparse_attn_triton.py
+24
-47
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
...ocksparse_attention/example_tilelang_block_sparse_attn.py
+37
-44
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
...rse_attention/example_tilelang_sparse_gqa_decode_paged.py
+101
-139
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
...ntion/example_tilelang_sparse_gqa_decode_varlen_indice.py
+96
-144
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
...tention/example_tilelang_sparse_gqa_decode_varlen_mask.py
+93
-139
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
...tention/example_triton_sparse_gqa_decode_varlen_indice.py
+53
-100
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
...attention/example_triton_sparse_gqa_decode_varlen_mask.py
+52
-95
examples/blocksparse_attention/heuristic.py
examples/blocksparse_attention/heuristic.py
+1
-2
examples/blocksparse_attention/test_example_blocksparse_attention.py
...ocksparse_attention/test_example_blocksparse_attention.py
+4
-16
examples/blocksparse_gemm/example_blocksparse_gemm.py
examples/blocksparse_gemm/example_blocksparse_gemm.py
+30
-39
examples/cast/example_group_per_split_token_cast_to_fp8.py
examples/cast/example_group_per_split_token_cast_to_fp8.py
+25
-26
examples/cast/example_per_token_cast_to_fp8.py
examples/cast/example_per_token_cast_to_fp8.py
+11
-16
examples/cast/example_triton_cast_to_fp8.py
examples/cast/example_triton_cast_to_fp8.py
+1
-3
examples/cast/test_example_cast.py
examples/cast/test_example_cast.py
+1
-2
examples/compile_flags/usecase.py
examples/compile_flags/usecase.py
+4
-6
examples/conftest.py
examples/conftest.py
+2
-5
examples/convolution/example_convolution.py
examples/convolution/example_convolution.py
+23
-41
examples/convolution/example_convolution_autotune.py
examples/convolution/example_convolution_autotune.py
+46
-89
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
+13
-16
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
examples/bitnet-1.58b/vllm_workspace/utils.py
View file @
29051439
...
@@ -3,8 +3,7 @@ from typing import Dict, List, Tuple
...
@@ -3,8 +3,7 @@ from typing import Dict, List, Tuple
TokensText
=
Tuple
[
List
[
int
],
str
]
TokensText
=
Tuple
[
List
[
int
],
str
]
def
check_outputs_equal
(
outputs_0_lst
:
List
[
TokensText
],
outputs_1_lst
:
List
[
TokensText
],
def
check_outputs_equal
(
outputs_0_lst
:
List
[
TokensText
],
outputs_1_lst
:
List
[
TokensText
],
name_0
:
str
,
name_1
:
str
):
name_0
:
str
,
name_1
:
str
):
"""
"""
Compare the two sequences generated by different models,
Compare the two sequences generated by different models,
which should be equal.
which should be equal.
...
@@ -15,19 +14,14 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[Tok
...
@@ -15,19 +14,14 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[Tok
output_ids_0
,
output_str_0
=
outputs_0
output_ids_0
,
output_str_0
=
outputs_0
output_ids_1
,
output_str_1
=
outputs_1
output_ids_1
,
output_str_1
=
outputs_1
assert
output_str_0
==
output_str_1
,
(
f
"Test
{
prompt_idx
}
:"
assert
output_str_0
==
output_str_1
,
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
f
"
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
"
assert
output_ids_0
==
output_ids_1
,
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
)
assert
output_ids_0
==
output_ids_1
,
(
f
"Test
{
prompt_idx
}
:"
f
"
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
)
TokensTextLogprobs
=
Tuple
[
List
[
int
],
str
,
List
[
Dict
[
int
,
float
]]]
TokensTextLogprobs
=
Tuple
[
List
[
int
],
str
,
List
[
Dict
[
int
,
float
]]]
def
check_logprobs_close
(
outputs_0_lst
:
List
[
TokensTextLogprobs
],
def
check_logprobs_close
(
outputs_0_lst
:
List
[
TokensTextLogprobs
],
outputs_1_lst
:
List
[
TokensTextLogprobs
],
name_0
:
str
,
name_1
:
str
):
outputs_1_lst
:
List
[
TokensTextLogprobs
],
name_0
:
str
,
name_1
:
str
):
"""
"""
Compare the logprobs of two sequences generated by different models,
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
which should be similar but not necessarily equal.
...
@@ -41,16 +35,11 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
...
@@ -41,16 +35,11 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
# Loop through generated tokens.
# Loop through generated tokens.
for
idx
,
(
output_id_0
,
output_id_1
)
in
enumerate
(
zip
(
output_ids_0
,
output_ids_1
)):
for
idx
,
(
output_id_0
,
output_id_1
)
in
enumerate
(
zip
(
output_ids_0
,
output_ids_1
)):
# If generated tokens don't match, then
# If generated tokens don't match, then
if
output_id_0
!=
output_id_1
:
if
output_id_0
!=
output_id_1
:
# Each predicted token must be in top N logprobs of the other
# Each predicted token must be in top N logprobs of the other
assert
output_id_0
in
logprobs_1
[
idx
],
(
f
"Test
{
prompt_idx
}
:"
assert
output_id_0
in
logprobs_1
[
idx
],
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
f
"
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
"
assert
output_id_1
in
logprobs_0
[
idx
],
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
)
assert
output_id_1
in
logprobs_0
[
idx
],
(
f
"Test
{
prompt_idx
}
:"
f
"
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
"
f
"
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
)
# Break out since sequences will now diverge.
# Break out since sequences will now diverge.
break
break
examples/blocksparse_attention/block_sparse_attn_triton.py
View file @
29051439
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
...
@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
...
@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
):
mask_val
=
tl
.
load
(
block_mask_ptr
+
k_block_col_idx
*
stride_bmask_n
)
mask_val
=
tl
.
load
(
block_mask_ptr
+
k_block_col_idx
*
stride_bmask_n
)
# print
# print
...
@@ -73,8 +69,7 @@ def _fwd_kernel_inner(
...
@@ -73,8 +69,7 @@ def _fwd_kernel_inner(
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if
LAST_K_BLOCK
:
if
LAST_K_BLOCK
:
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"-inf"
))
float
(
'-inf'
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
-=
m_ij
[:,
None
]
qk
-=
m_ij
[:,
None
]
...
@@ -154,7 +149,7 @@ def _fwd_kernel(
...
@@ -154,7 +149,7 @@ def _fwd_kernel(
v_ptrs
=
V
+
off_v
v_ptrs
=
V
+
off_v
mask_ptrs
=
block_mask_ptr
+
start_m
*
stride_bmm
mask_ptrs
=
block_mask_ptr
+
start_m
*
stride_bmm
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
'
inf
'
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"
inf
"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
...
@@ -192,24 +187,12 @@ def _fwd_kernel(
...
@@ -192,24 +187,12 @@ def _fwd_kernel(
acc
=
acc
*
l_recip
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
Out
.
dtype
.
element_ty
)
acc
=
acc
.
to
(
Out
.
dtype
.
element_ty
)
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
None
,
:]
*
stride_od
out_ptrs
=
Out
+
off_o
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
N_CTX
)
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
N_CTX
)
def
_forward
(
ctx
,
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
k
.
shape
[
2
]
==
v
.
shape
[
2
]
assert
k
.
shape
[
2
]
==
v
.
shape
[
2
]
o
=
out
if
out
is
not
None
else
torch
.
empty_like
(
q
).
contiguous
()
o
=
out
if
out
is
not
None
else
torch
.
empty_like
(
q
).
contiguous
()
...
@@ -254,7 +237,6 @@ def _forward(ctx,
...
@@ -254,7 +237,6 @@ def _forward(ctx,
class
_sparse_attention
(
torch
.
autograd
.
Function
):
class
_sparse_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
):
def
forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
):
# shape constraints
# shape constraints
...
@@ -278,9 +260,9 @@ def test_topk_sparse_attention():
...
@@ -278,9 +260,9 @@ def test_topk_sparse_attention():
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
# Create inputs
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
# Create sparse mask (downsampled to block level)
...
@@ -288,9 +270,7 @@ def test_topk_sparse_attention():
...
@@ -288,9 +270,7 @@ def test_topk_sparse_attention():
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
print
(
"downsample_len"
,
downsample_len
)
print
(
"downsample_len"
,
downsample_len
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
x_ds
[:,
:,
:,
0
]
=
100
print
(
"x_ds.shape"
,
x_ds
.
shape
)
print
(
"x_ds.shape"
,
x_ds
.
shape
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
@@ -302,22 +282,21 @@ def test_topk_sparse_attention():
...
@@ -302,22 +282,21 @@ def test_topk_sparse_attention():
# Compute reference
# Compute reference
# Expand block mask to full attention matrix
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
# print("ref_output", ref_output)
# print("ref_output", ref_output)
# print("triton_output", triton_output)
# print("triton_output", triton_output)
# Verify accuracy
# Verify accuracy
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference"
"Triton output doesn't match reference"
print
(
"Pass topk sparse attention test with qlen == klen"
)
print
(
"Pass topk sparse attention test with qlen == klen"
)
...
@@ -329,9 +308,9 @@ def test_topk_sparse_attention_qlt_kl():
...
@@ -329,9 +308,9 @@ def test_topk_sparse_attention_qlt_kl():
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
# Create inputs.
# Create inputs.
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
bfloat16
)
# softmax scale
# softmax scale
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
...
@@ -339,8 +318,7 @@ def test_topk_sparse_attention_qlt_kl():
...
@@ -339,8 +318,7 @@ def test_topk_sparse_attention_qlt_kl():
print
(
"downsample_factor"
,
downsample_factor
)
print
(
"downsample_factor"
,
downsample_factor
)
downsample_len
=
math
.
ceil
(
K_LEN
/
downsample_factor
)
# number of blocks along one dimension
downsample_len
=
math
.
ceil
(
K_LEN
/
downsample_factor
)
# number of blocks along one dimension
print
(
"downsample_len"
,
downsample_len
)
print
(
"downsample_len"
,
downsample_len
)
x_ds
=
torch
.
randn
(
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
# Force the first column to be high so that the first block is always selected.
# Force the first column to be high so that the first block is always selected.
x_ds
[:,
:,
:,
0
]
=
100
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
@@ -351,26 +329,25 @@ def test_topk_sparse_attention_qlt_kl():
...
@@ -351,26 +329,25 @@ def test_topk_sparse_attention_qlt_kl():
past_len
=
K_LEN
-
Q_LEN
past_len
=
K_LEN
-
Q_LEN
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
)).
bool
()
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
)).
bool
()
full_mask_full
=
full_mask_full
[...,
:
K_LEN
,
:
K_LEN
]
full_mask_full
=
full_mask_full
[...,
:
K_LEN
,
:
K_LEN
]
effective_mask
=
full_mask_full
[...,
past_len
:
K_LEN
,
:]
# shape: (B, H, Q_LEN, K_LEN)
effective_mask
=
full_mask_full
[...,
past_len
:
K_LEN
,
:]
# shape: (B, H, Q_LEN, K_LEN)
i_global
=
torch
.
arange
(
past_len
,
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
1
)
# shape: (Q_LEN, 1)
i_global
=
torch
.
arange
(
past_len
,
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
1
)
# shape: (Q_LEN, 1)
j_global
=
torch
.
arange
(
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
0
)
# shape: (1, K_LEN)
j_global
=
torch
.
arange
(
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
0
)
# shape: (1, K_LEN)
causal_mask
=
(
j_global
<=
i_global
)
# shape: (Q_LEN, K_LEN)
causal_mask
=
j_global
<=
i_global
# shape: (Q_LEN, K_LEN)
final_mask
=
effective_mask
&
causal_mask
# shape: (B, H, Q_LEN, K_LEN)
final_mask
=
effective_mask
&
causal_mask
# shape: (B, H, Q_LEN, K_LEN)
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
# Verify accuracy.
# Verify accuracy.
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference when qlen < klen"
"Triton output doesn't match reference when qlen < klen"
print
(
"Pass topk sparse attention test with qlen < klen"
)
print
(
"Pass topk sparse attention test with qlen < klen"
)
...
...
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
View file @
29051439
...
@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
...
@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
dense_mask
[:,
:,
-
2
:,
:]
=
True
...
@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
...
@@ -30,15 +27,17 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
4
],
pass_configs
=
{
out_idx
=
[
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
blocksparse_flashattn
(
batch
,
heads
,
seq_len
,
dim
,
downsample_len
,
is_causal
):
def
blocksparse_flashattn
(
batch
,
heads
,
seq_len
,
dim
,
downsample_len
,
is_causal
):
block_M
=
64
block_M
=
64
block_N
=
64
block_N
=
64
num_stages
=
1
num_stages
=
1
threads
=
128
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
block_mask_shape
=
[
batch
,
heads
,
downsample_len
,
downsample_len
]
block_mask_shape
=
[
batch
,
heads
,
downsample_len
,
downsample_len
]
...
@@ -47,7 +46,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -47,7 +46,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype
=
"bool"
block_mask_dtype
=
"bool"
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
@
T
.
macro
def
MMA0
(
def
MMA0
(
K
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
...
@@ -59,11 +57,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -59,11 +57,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
else
:
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
...
@@ -78,18 +75,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -78,18 +75,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
bz
:
T
.
int32
,
):
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
@
T
.
macro
def
Softmax
(
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -113,22 +110,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -113,22 +110,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@
T
.
macro
@
T
.
macro
def
Rescale
(
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
@
T
.
prim_func
def
blocksparse_flashattn
(
def
blocksparse_flashattn
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
...
@@ -143,7 +139,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -143,7 +139,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -152,20 +148,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
...
@@ -152,20 +148,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask
[
vj
]
=
BlockSparseMask
[
bz
,
by
,
bx
,
vj
]
block_mask
[
vj
]
=
BlockSparseMask
[
bz
,
by
,
bx
,
vj
]
loop_range
=
(
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
(
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
(
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
k
]
!=
0
:
if
block_mask
[
k
]
!=
0
:
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
blocksparse_flashattn
return
blocksparse_flashattn
...
@@ -180,18 +175,16 @@ def test_topk_sparse_attention():
...
@@ -180,18 +175,16 @@ def test_topk_sparse_attention():
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
# Create inputs
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
device
=
'cuda'
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
...
@@ -202,15 +195,15 @@ def test_topk_sparse_attention():
...
@@ -202,15 +195,15 @@ def test_topk_sparse_attention():
# Compute reference
# Compute reference
# Expand block mask to full attention matrix
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
'
cuda
'
))
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"
cuda
"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
# PyTorch reference implementation
attn
=
torch
.
einsum
(
'
bhsd,bhtd->bhst
'
,
q
,
k
)
*
sm_scale
attn
=
torch
.
einsum
(
"
bhsd,bhtd->bhst
"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
'
-inf
'
))
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"
-inf
"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
'
bhst,bhtd->bhsd
'
,
attn
,
v
)
ref_output
=
torch
.
einsum
(
"
bhst,bhtd->bhsd
"
,
attn
,
v
)
print
(
"ref_output"
,
ref_output
)
print
(
"ref_output"
,
ref_output
)
print
(
"tilelang_output"
,
tilelang_output
)
print
(
"tilelang_output"
,
tilelang_output
)
...
...
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
View file @
29051439
...
@@ -13,17 +13,20 @@ from heuristic import num_splits_heuristic
...
@@ -13,17 +13,20 @@ from heuristic import num_splits_heuristic
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
dtype
=
"float16"
accum_dtype
=
"float"
accum_dtype
=
"float"
kv_group_num
=
heads
//
heads_kv
kv_group_num
=
heads
//
heads_kv
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
def
kernel_func
(
block_N
,
block_H
,
page_block_size
,
num_split
,
num_stages
,
threads
,
num_pages
,
)
max_num_blocks_per_seq
,
max_selected_blocks
):
def
kernel_func
(
block_N
,
block_H
,
page_block_size
,
num_split
,
num_stages
,
threads
,
num_pages
,
max_num_blocks_per_seq
,
max_selected_blocks
):
shape_q
=
[
batch
,
heads
,
dim
]
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
num_pages
,
page_block_size
,
heads_kv
,
dim
]
shape_k
=
[
num_pages
,
page_block_size
,
heads_kv
,
dim
]
shape_v
=
[
num_pages
,
page_block_size
,
heads_kv
,
dim_v
]
shape_v
=
[
num_pages
,
page_block_size
,
heads_kv
,
dim_v
]
...
@@ -37,17 +40,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -37,17 +40,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
@
T
.
macro
def
flash_attn_split
(
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
block_table
:
T
.
Tensor
(
shape_block_table
,
"int32"
),
block_table
:
T
.
Tensor
(
shape_block_table
,
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
...
@@ -67,7 +69,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -67,7 +69,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid
=
bz
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -75,7 +77,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -75,7 +77,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
num_blocks
=
max_selected_blocks
num_blocks
=
max_selected_blocks
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
loop_range
=
(
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
has_valid_block
=
False
has_valid_block
=
False
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
...
@@ -85,29 +87,20 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -85,29 +87,20 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
block_table_idx
=
T
.
floordiv
(
logical_block_idx
,
block_ratio
)
block_table_idx
=
T
.
floordiv
(
logical_block_idx
,
block_ratio
)
block_tile_idx
=
T
.
floormod
(
logical_block_idx
,
block_ratio
)
block_tile_idx
=
T
.
floormod
(
logical_block_idx
,
block_ratio
)
physical_block_idx
=
block_table
[
bid
,
block_table_idx
]
physical_block_idx
=
block_table
[
bid
,
block_table_idx
]
T
.
copy
(
T
.
copy
(
K
[
physical_block_idx
,
block_tile_idx
*
block_N
:
(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
K
[
physical_block_idx
,
block_tile_idx
*
block_N
:(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
k
==
0
:
# assume block_indices is sorted in reverse order, otherwise, remove this if condition
if
k
==
0
:
# assume block_indices is sorted in reverse order, otherwise, remove this if condition
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
logical_block_idx
*
block_N
+
j
>=
cache_seqlens
[
bid
],
logical_block_idx
*
block_N
+
j
>=
cache_seqlens
[
bid
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
)
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
...
@@ -116,10 +109,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -116,10 +109,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
T
.
copy
(
V
[
physical_block_idx
,
block_tile_idx
*
block_N
:
(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
V
[
physical_block_idx
,
block_tile_idx
*
block_N
:(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_valid_block
:
if
has_valid_block
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
...
@@ -138,9 +128,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -138,9 +128,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
@
T
.
macro
def
combine
(
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
...
@@ -151,17 +141,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -151,17 +141,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
max_split
=
T
.
alloc_local
([
1
],
"int32"
)
max_split
=
T
.
alloc_local
([
1
],
"int32"
)
T
.
annotate_layout
({
T
.
annotate_layout
(
lse_logsum_local
:
{
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
T
.
clear
(
o_accum_local
)
lse_max_local
[
0
]
=
-
T
.
infinity
(
accum_dtype
)
lse_max_local
[
0
]
=
-
T
.
infinity
(
accum_dtype
)
for
k
in
T
.
serial
(
num_split
):
for
k
in
T
.
serial
(
num_split
):
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
if
(
lse_local_split
[
0
]
!=
0
)
:
if
lse_local_split
[
0
]
!=
0
:
max_split
[
0
]
=
k
max_split
[
0
]
=
k
lse_max_local
[
0
]
=
T
.
max
(
lse_max_local
[
0
],
glse
[
bz
,
by
,
k
])
lse_max_local
[
0
]
=
T
.
max
(
lse_max_local
[
0
],
glse
[
bz
,
by
,
k
])
...
@@ -183,18 +174,17 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -183,18 +174,17 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
block_table
:
T
.
Tensor
(
shape_block_table
,
"int32"
),
block_table
:
T
.
Tensor
(
shape_block_table
,
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
):
flash_attn_split
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
block_table
,
glse
,
flash_attn_split
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
block_table
,
glse
,
Output_partial
)
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
combine
(
glse
,
Output_partial
,
Output
)
return
main
return
main
...
@@ -203,7 +193,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -203,7 +193,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
page_block_size
,
block_N
,
num_pages
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
page_block_size
,
block_N
,
num_pages
):
super
(
SparseFlashAttn
,
self
).
__init__
()
super
(
SparseFlashAttn
,
self
).
__init__
()
self
.
batch
=
batch
self
.
batch
=
batch
...
@@ -249,18 +238,11 @@ class SparseFlashAttn(torch.nn.Module):
...
@@ -249,18 +238,11 @@ class SparseFlashAttn(torch.nn.Module):
num_sm
=
self
.
num_sm
num_sm
=
self
.
num_sm
num_split
=
num_splits_heuristic
(
num_split
=
num_splits_heuristic
(
total_mblocks
,
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
num_sm
,
)
num_n_blocks
,
num_m_blocks
,
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
size_one_kv_head
,
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
output
=
self
.
kernel
(
output
=
self
.
kernel
(
query
,
query
,
...
@@ -275,14 +257,13 @@ class SparseFlashAttn(torch.nn.Module):
...
@@ -275,14 +257,13 @@ class SparseFlashAttn(torch.nn.Module):
return
output
return
output
def
ref_program_torch_paged
(
query
,
key_cache
,
value_cache
,
block_indices
,
cache_seqlens
,
def
ref_program_torch_paged
(
query
,
key_cache
,
value_cache
,
block_indices
,
cache_seqlens
,
block_table
,
page_block_size
,
block_size
):
block_table
,
page_block_size
,
block_size
):
"""
"""
Paged version of sparse attention reference implementation.
Paged version of sparse attention reference implementation.
Args:
Args:
query: [batch, heads, dim]
query: [batch, heads, dim]
key_cache: [num_pages, page_block_size, heads_kv, dim]
key_cache: [num_pages, page_block_size, heads_kv, dim]
value_cache: [num_pages, page_block_size, heads_kv, dim]
value_cache: [num_pages, page_block_size, heads_kv, dim]
block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices
block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices
cache_seqlens: [batch] - actual sequence lengths
cache_seqlens: [batch] - actual sequence lengths
...
@@ -298,12 +279,8 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
...
@@ -298,12 +279,8 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
# Reconstruct the full key and value tensors from paged cache
# Reconstruct the full key and value tensors from paged cache
max_cache_seqlen
=
max
(
cache_seqlens
).
item
()
max_cache_seqlen
=
max
(
cache_seqlens
).
item
()
key_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim
),
key_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim
),
dtype
=
key_cache
.
dtype
,
device
=
key_cache
.
device
)
dtype
=
key_cache
.
dtype
,
value_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim_v
),
dtype
=
value_cache
.
dtype
,
device
=
value_cache
.
device
)
device
=
key_cache
.
device
)
value_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim_v
),
dtype
=
value_cache
.
dtype
,
device
=
value_cache
.
device
)
# Reconstruct full tensors from paged cache using block_table
# Reconstruct full tensors from paged cache using block_table
for
b
in
range
(
batch
):
for
b
in
range
(
batch
):
...
@@ -319,20 +296,14 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
...
@@ -319,20 +296,14 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
actual_block_size
=
end_token
-
start_token
actual_block_size
=
end_token
-
start_token
# Copy from paged cache to full tensors
# Copy from paged cache to full tensors
key_full
[
b
,
:,
start_token
:
end_token
,
:]
=
key_cache
[
key_full
[
b
,
:,
start_token
:
end_token
,
:]
=
key_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
value_full
[
b
,
:,
start_token
:
end_token
,
:]
=
value_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
value_full
[
b
,
:,
start_token
:
end_token
,
:]
=
value_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
# Reshape query for grouped attention
# Reshape query for grouped attention
query
=
rearrange
(
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
# Compute attention scores
# Compute attention scores
scores
=
einsum
(
scores
=
einsum
(
query
,
key_full
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
query
,
key_full
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
# Create sparse mask based on block_indices
# Create sparse mask based on block_indices
sparse_mask
=
torch
.
zeros_like
(
scores
)
sparse_mask
=
torch
.
zeros_like
(
scores
)
...
@@ -348,24 +319,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
...
@@ -348,24 +319,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
sparse_mask
[
b
,
:,
h
,
start_pos
:
end_pos
]
=
1
sparse_mask
[
b
,
:,
h
,
start_pos
:
end_pos
]
=
1
# Apply sparse mask
# Apply sparse mask
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
# Apply causal mask based on actual sequence lengths
# Apply causal mask based on actual sequence lengths
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
scores
.
device
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
scores
.
device
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"
-inf
"
))
# Compute attention weights
# Compute attention weights
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# Apply attention to values
# Apply attention to values
out
=
einsum
(
attention
,
value_full
,
out
=
einsum
(
attention
,
value_full
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
# Reshape output back to original format
# Reshape output back to original format
out
=
rearrange
(
out
,
'
b g h d -> b (h g) d
'
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
"
b g h d -> b (h g) d
"
)
# [batch_size, heads, dim]
return
out
return
out
...
@@ -373,17 +343,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
...
@@ -373,17 +343,23 @@ def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_
def
ref_program_fa
(
query
,
kcache
,
vcache
,
cache_seqlens
,
block_table
):
def
ref_program_fa
(
query
,
kcache
,
vcache
,
cache_seqlens
,
block_table
):
# latency reference
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
output
=
flash_attn_with_kvcache
(
query
,
kcache
,
vcache
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_table
)
query
,
kcache
,
vcache
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_table
)
output
=
output
.
squeeze
(
1
)
output
=
output
.
squeeze
(
1
)
return
output
return
output
def
main
(
args
):
def
main
(
args
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
(
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
)
sparse_ratio
=
args
.
sparse_ratio
sparse_ratio
=
args
.
sparse_ratio
block_N
=
args
.
block_N
block_N
=
args
.
block_N
page_block_size
=
args
.
page_block_size
page_block_size
=
args
.
page_block_size
...
@@ -395,35 +371,30 @@ def main(args):
...
@@ -395,35 +371,30 @@ def main(args):
dtype
=
torch
.
float16
dtype
=
torch
.
float16
# Generate random inputs
# Generate random inputs
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'cuda'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
cache_seqlens
=
torch
.
randint
(
cache_seqlens
=
torch
.
randint
(
max_cache_seqlen
//
2
,
max_cache_seqlen
+
1
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
max_cache_seqlen
//
2
,
max_cache_seqlen
+
1
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'cuda'
)
print
(
"cache_seqlens: "
,
cache_seqlens
)
print
(
"cache_seqlens: "
,
cache_seqlens
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
# Create paged KV cache
# Create paged KV cache
K_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'cuda'
)
K_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
V_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim_v
),
V_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"cuda"
)
dtype
=
dtype
,
device
=
'cuda'
)
# Create block table and block indices for dense case (all blocks selected)
# Create block table and block indices for dense case (all blocks selected)
max_num_blocks_per_seq
=
int
(
math
.
ceil
(
max_cache_seqlen
/
page_block_size
))
max_num_blocks_per_seq
=
int
(
math
.
ceil
(
max_cache_seqlen
/
page_block_size
))
print
(
"max_num_blocks_per_seq: "
,
max_num_blocks_per_seq
)
print
(
"max_num_blocks_per_seq: "
,
max_num_blocks_per_seq
)
block_table
=
torch
.
zeros
((
batch
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
,
device
=
'cuda'
)
block_table
=
torch
.
zeros
((
batch
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
block_indices
=
torch
.
zeros
((
batch
,
heads_kv
,
max_selected_blocks
),
block_indices
=
torch
.
zeros
((
batch
,
heads_kv
,
max_selected_blocks
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
dtype
=
torch
.
int32
,
device
=
'cuda'
)
# Fill block table and block indices and cache
# Fill block table and block indices and cache
# Create a pool of available physical blocks
# Create a pool of available physical blocks
total_blocks_needed
=
sum
(
total_blocks_needed
=
sum
(
int
(
math
.
ceil
(
cache_seqlens
[
seq_idx
].
item
()
/
page_block_size
))
for
seq_idx
in
range
(
batch
))
int
(
math
.
ceil
(
cache_seqlens
[
seq_idx
].
item
()
/
page_block_size
))
for
seq_idx
in
range
(
batch
))
available_blocks
=
list
(
range
(
total_blocks_needed
))
available_blocks
=
list
(
range
(
total_blocks_needed
))
import
random
import
random
random
.
seed
(
42
)
# For reproducibility
random
.
seed
(
42
)
# For reproducibility
random
.
shuffle
(
available_blocks
)
random
.
shuffle
(
available_blocks
)
...
@@ -458,10 +429,8 @@ def main(args):
...
@@ -458,10 +429,8 @@ def main(args):
actual_block_size
=
end_token
-
start_token
actual_block_size
=
end_token
-
start_token
# Copy K and V data to the paged cache
# Copy K and V data to the paged cache
K_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
K
[
seq_idx
,
K_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
K
[
seq_idx
,
start_token
:
end_token
,
:,
:]
start_token
:
end_token
,
:,
:]
V_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
V
[
seq_idx
,
start_token
:
end_token
,
:,
:]
V_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
V
[
seq_idx
,
start_token
:
end_token
,
:,
:]
# Fill block_indices for sparse attention
# Fill block_indices for sparse attention
# For dense case (verification), we select all blocks in reverse order
# For dense case (verification), we select all blocks in reverse order
...
@@ -496,10 +465,9 @@ def main(args):
...
@@ -496,10 +465,9 @@ def main(args):
remaining_blocks
=
[
b
for
b
in
all_blocks
if
b
not
in
selected_blocks
]
remaining_blocks
=
[
b
for
b
in
all_blocks
if
b
not
in
selected_blocks
]
if
remaining_blocks
:
if
remaining_blocks
:
import
random
import
random
random
.
seed
(
42
)
# For reproducibility
random
.
seed
(
42
)
# For reproducibility
additional_blocks
=
random
.
sample
(
additional_blocks
=
random
.
sample
(
remaining_blocks
,
min
(
num_selected
-
recent_blocks
,
len
(
remaining_blocks
)))
remaining_blocks
,
min
(
num_selected
-
recent_blocks
,
len
(
remaining_blocks
)))
selected_blocks
.
extend
(
additional_blocks
)
selected_blocks
.
extend
(
additional_blocks
)
# Sort selected blocks in reverse order (most recent first)
# Sort selected blocks in reverse order (most recent first)
...
@@ -512,25 +480,20 @@ def main(args):
...
@@ -512,25 +480,20 @@ def main(args):
block_indices
[
seq_idx
,
head_idx
,
i
]
=
-
1
block_indices
[
seq_idx
,
head_idx
,
i
]
=
-
1
# Initialize sparse attention module
# Initialize sparse attention module
sparse_attn
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
page_block_size
,
block_N
,
sparse_attn
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
page_block_size
,
block_N
,
num_blocks
)
num_blocks
)
output_sparse
=
sparse_attn
.
forward
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
)
output_sparse
=
sparse_attn
.
forward
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
)
import
flash_attn
# noqa: F401
import
flash_attn
# noqa: F401
output_ref_torch
=
ref_program_torch_paged
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
output_ref_torch
=
ref_program_torch_paged
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
,
page_block_size
,
block_N
)
block_table
,
page_block_size
,
block_N
)
output_ref_fa
=
ref_program_fa
(
Q
,
K_cache
,
V_cache
,
cache_seqlens
,
block_table
)
output_ref_fa
=
ref_program_fa
(
Q
,
K_cache
,
V_cache
,
cache_seqlens
,
block_table
)
# Check correctness
# Check correctness
if
sparse_ratio
==
0.0
:
if
sparse_ratio
==
0.0
:
max_diff
=
torch
.
max
(
torch
.
abs
(
output_sparse
-
output_ref_fa
)).
item
()
max_diff
=
torch
.
max
(
torch
.
abs
(
output_sparse
-
output_ref_fa
)).
item
()
mean_diff
=
torch
.
mean
(
torch
.
abs
(
output_sparse
-
output_ref_fa
)).
item
()
mean_diff
=
torch
.
mean
(
torch
.
abs
(
output_sparse
-
output_ref_fa
)).
item
()
assert
torch
.
allclose
(
assert
torch
.
allclose
(
output_ref_fa
,
output_ref_torch
,
atol
=
1e-2
),
"Reference outputs do not match!"
output_ref_fa
,
output_ref_torch
,
atol
=
1e-2
),
"Reference outputs do not match!"
else
:
else
:
max_diff
=
torch
.
max
(
torch
.
abs
(
output_sparse
-
output_ref_torch
)).
item
()
max_diff
=
torch
.
max
(
torch
.
abs
(
output_sparse
-
output_ref_torch
)).
item
()
mean_diff
=
torch
.
mean
(
torch
.
abs
(
output_sparse
-
output_ref_torch
)).
item
()
mean_diff
=
torch
.
mean
(
torch
.
abs
(
output_sparse
-
output_ref_torch
)).
item
()
...
@@ -574,16 +537,15 @@ def main(args):
...
@@ -574,16 +537,15 @@ def main(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.0
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.0
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
"--block_N"
,
type
=
int
,
default
=
64
,
help
=
"block_N"
)
parser
.
add_argument
(
'--block_N'
,
type
=
int
,
default
=
64
,
help
=
'block_N'
)
parser
.
add_argument
(
"--page_block_size"
,
type
=
int
,
default
=
256
,
help
=
"block size of pages"
)
parser
.
add_argument
(
'--page_block_size'
,
type
=
int
,
default
=
256
,
help
=
'block size of pages'
)
parser
.
add_argument
(
"--num_pages"
,
type
=
int
,
default
=
1024
,
help
=
"total number of pages"
)
parser
.
add_argument
(
'--num_pages'
,
type
=
int
,
default
=
1024
,
help
=
'total number of pages'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
View file @
29051439
...
@@ -10,17 +10,18 @@ from heuristic import num_splits_heuristic
...
@@ -10,17 +10,18 @@ from heuristic import num_splits_heuristic
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
dtype
=
"float16"
accum_dtype
=
"float"
accum_dtype
=
"float"
kv_group_num
=
heads
//
heads_kv
kv_group_num
=
heads
//
heads_kv
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
}
)
}
,
def
kernel_func
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
,
max_cache_seqlen
,
)
max_selected_blocks
):
def
kernel_func
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
,
max_cache_seqlen
,
max_selected_blocks
):
shape_q
=
[
batch
,
heads
,
dim
]
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim
]
shape_k
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim
]
shape_v
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
]
shape_v
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
]
...
@@ -31,17 +32,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -31,17 +32,16 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
@
T
.
macro
def
flash_attn_split
(
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
# actual_num_blocks: T.Tensor([batch], "int32"),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
...
@@ -62,7 +62,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -62,7 +62,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid
=
bz
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
...
@@ -70,7 +70,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -70,7 +70,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
num_blocks
=
max_selected_blocks
num_blocks
=
max_selected_blocks
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
loop_range
=
(
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
has_valid_block
=
False
has_valid_block
=
False
...
@@ -78,26 +78,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -78,26 +78,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
i_s
=
block_indices
[
bid
,
cur_kv_head
,
start
+
k
]
i_s
=
block_indices
[
bid
,
cur_kv_head
,
start
+
k
]
if
i_s
>=
0
:
if
i_s
>=
0
:
has_valid_block
=
True
has_valid_block
=
True
T
.
copy
(
K
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
copy
(
K
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
k
==
0
:
# assume block_indices is sorted in reverse order, otherwise, remove this if condition
if
k
==
0
:
# assume block_indices is sorted in reverse order, otherwise, remove this if condition
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
i_s
*
block_N
+
j
>=
cache_seqlens
[
bid
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
j
]
=
T
.
if_then_else
(
i_s
*
block_N
+
j
>=
cache_seqlens
[
bid
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
...
@@ -106,7 +98,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -106,7 +98,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
copy
(
V
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_valid_block
:
if
has_valid_block
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
...
@@ -125,9 +117,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -125,9 +117,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
@
T
.
macro
def
combine
(
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
...
@@ -138,17 +130,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -138,17 +130,18 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
max_split
=
T
.
alloc_local
([
1
],
"int32"
)
max_split
=
T
.
alloc_local
([
1
],
"int32"
)
T
.
annotate_layout
({
T
.
annotate_layout
(
lse_logsum_local
:
{
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
T
.
clear
(
o_accum_local
)
lse_max_local
[
0
]
=
-
T
.
infinity
(
accum_dtype
)
lse_max_local
[
0
]
=
-
T
.
infinity
(
accum_dtype
)
for
k
in
T
.
serial
(
num_split
):
for
k
in
T
.
serial
(
num_split
):
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
if
(
lse_local_split
[
0
]
!=
0
)
:
if
lse_local_split
[
0
]
!=
0
:
max_split
[
0
]
=
k
max_split
[
0
]
=
k
lse_max_local
[
0
]
=
T
.
max
(
lse_max_local
[
0
],
glse
[
bz
,
by
,
k
])
lse_max_local
[
0
]
=
T
.
max
(
lse_max_local
[
0
],
glse
[
bz
,
by
,
k
])
...
@@ -170,15 +163,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -170,15 +163,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
block_indices
:
T
.
Tensor
(
shape_indices
,
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
# actual_num_blocks: T.Tensor([batch], "int32"),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
):
# flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial)
# flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial)
flash_attn_split
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
glse
,
Output_partial
)
flash_attn_split
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
glse
,
Output_partial
)
...
@@ -190,7 +183,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -190,7 +183,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
):
super
(
SparseFlashAttn
,
self
).
__init__
()
super
(
SparseFlashAttn
,
self
).
__init__
()
self
.
batch
=
batch
self
.
batch
=
batch
...
@@ -209,7 +201,8 @@ class SparseFlashAttn(torch.nn.Module):
...
@@ -209,7 +201,8 @@ class SparseFlashAttn(torch.nn.Module):
num_stages
=
2
,
num_stages
=
2
,
threads
=
128
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
))
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
),
)
props
=
torch
.
cuda
.
get_device_properties
(
torch
.
device
(
"cuda:0"
))
props
=
torch
.
cuda
.
get_device_properties
(
torch
.
device
(
"cuda:0"
))
self
.
num_sm
=
props
.
multi_processor_count
self
.
num_sm
=
props
.
multi_processor_count
...
@@ -232,25 +225,17 @@ class SparseFlashAttn(torch.nn.Module):
...
@@ -232,25 +225,17 @@ class SparseFlashAttn(torch.nn.Module):
num_sm
=
self
.
num_sm
num_sm
=
self
.
num_sm
num_split
=
num_splits_heuristic
(
num_split
=
num_splits_heuristic
(
total_mblocks
,
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
num_sm
,
)
num_n_blocks
,
num_m_blocks
,
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
size_one_kv_head
,
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
output
=
self
.
kernel
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
glse
,
output_partial
)
output
=
self
.
kernel
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
glse
,
output_partial
)
return
output
return
output
def
sparse_gqa_decode_varlen_indice
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
def
sparse_gqa_decode_varlen_indice
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
block_size
):
max_cache_seqlen
,
block_size
):
"""
"""
Args:
Args:
query: [batch, heads, dim]
query: [batch, heads, dim]
...
@@ -272,31 +257,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
...
@@ -272,31 +257,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
block_H
=
64
block_H
=
64
actual_num_blocks
=
torch
.
sum
(
block_indices
!=
-
1
,
dim
=-
1
).
to
(
torch
.
int32
)
actual_num_blocks
=
torch
.
sum
(
block_indices
!=
-
1
,
dim
=-
1
).
to
(
torch
.
int32
)
actual_num_blocks
=
actual_num_blocks
[:,
actual_num_blocks
=
actual_num_blocks
[
0
]
#[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
:,
0
]
# [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
# get num_split
# get num_split
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
#(kv_seqlen + block_size - 1 ) // block_size
num_n_blocks
=
max_selected_blocks
#
(kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
132
num_sm
=
132
num_split
=
num_splits_heuristic
(
num_split
=
num_splits_heuristic
(
total_mblocks
,
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
num_sm
,
)
num_n_blocks
,
num_m_blocks
,
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
size_one_kv_head
,
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
kernel
=
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
)(
kernel
=
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
)(
block_N
=
block_size
,
block_N
=
block_size
,
block_H
=
block_H
,
block_H
=
block_H
,
...
@@ -304,29 +282,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
...
@@ -304,29 +282,24 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
num_stages
=
2
,
num_stages
=
2
,
threads
=
128
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
))
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
),
)
output
=
kernel
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
glse
,
Output_partial
)
output
=
kernel
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
glse
,
Output_partial
)
return
output
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
block_size
):
batch
,
heads
,
dim
=
query
.
shape
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values based on block_indices
# Assign mask values based on block_indices
...
@@ -335,28 +308,26 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
...
@@ -335,28 +308,26 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
for
idx
in
valid_indices
:
for
idx
in
valid_indices
:
if
idx
>=
0
:
if
idx
>=
0
:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
return
out
return
out
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
block_size
):
# latency reference
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
output
=
output
.
squeeze
(
1
)
...
@@ -368,23 +339,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
...
@@ -368,23 +339,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
print
(
name
+
" all_close={}"
.
format
(
all_close
))
print
(
name
+
" all_close={}"
.
format
(
all_close
))
if
not
all_close
:
if
not
all_close
:
diff
=
(
expect
-
actual
).
abs
()
diff
=
(
expect
-
actual
).
abs
()
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
max_indices
=
torch
.
nonzero
(
diff
==
diff
.
max
().
item
())
max_indices
=
torch
.
nonzero
(
diff
==
diff
.
max
().
item
())
first_index
=
tuple
(
max_indices
[
0
].
tolist
())
first_index
=
tuple
(
max_indices
[
0
].
tolist
())
print
(
f
"Index:
{
first_index
}
, expect:
{
expect
[
first_index
]
}
, actual:
{
actual
[
first_index
]
}
"
)
print
(
f
"Index:
{
first_index
}
, expect:
{
expect
[
first_index
]
}
, actual:
{
actual
[
first_index
]
}
"
)
def
main
(
batch
=
8
,
def
main
(
batch
=
8
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
sparse_ratio
=
sparse_ratio
block_size
=
block_size
block_size
=
block_size
...
@@ -392,10 +353,10 @@ def main(batch=8,
...
@@ -392,10 +353,10 @@ def main(batch=8,
print
(
"max_selected_blocks: "
,
max_selected_blocks
)
print
(
"max_selected_blocks: "
,
max_selected_blocks
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# # Ensure at least one element equals cache_seqlen
# # Ensure at least one element equals cache_seqlen
# random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
# random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
...
@@ -406,10 +367,7 @@ def main(batch=8,
...
@@ -406,10 +367,7 @@ def main(batch=8,
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_indices with -1 (for padding blocks)
# Initialize block_indices with -1 (for padding blocks)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
-
1
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
# max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size)
# max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size)
# block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda')
# block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda')
...
@@ -418,10 +376,9 @@ def main(batch=8,
...
@@ -418,10 +376,9 @@ def main(batch=8,
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
if
max_valid_block
>
0
:
# Ensure there's at least one valid block
if
max_valid_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
for
h
in
range
(
heads_kv
):
valid_indices
=
torch
.
randperm
(
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
max_valid_block
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
# valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks]
# valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks]
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
# Sort indices within each batch-group for consistency
# Sort indices within each batch-group for consistency
block_indices
,
_
=
block_indices
.
sort
(
dim
=-
1
,
descending
=
True
)
block_indices
,
_
=
block_indices
.
sort
(
dim
=-
1
,
descending
=
True
)
...
@@ -434,8 +391,7 @@ def main(batch=8,
...
@@ -434,8 +391,7 @@ def main(batch=8,
print
(
"max_num_blocks: "
,
max_num_blocks
)
print
(
"max_num_blocks: "
,
max_num_blocks
)
# parity reference
# parity reference
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
block_size
)
sparse_kernel
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
)
sparse_kernel
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
)
out
=
sparse_kernel
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
)
out
=
sparse_kernel
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
)
...
@@ -445,13 +401,11 @@ def main(batch=8,
...
@@ -445,13 +401,11 @@ def main(batch=8,
## latency reference
## latency reference
for
_
in
range
(
10
):
for
_
in
range
(
10
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
max_num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
start
=
time
.
time
()
for
_
in
range
(
100
):
for
_
in
range
(
100
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
max_num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
"dense time: "
,
(
time
.
time
()
-
start
)
/
100
*
1000
)
print
(
"dense time: "
,
(
time
.
time
()
-
start
)
/
100
*
1000
)
...
@@ -469,15 +423,13 @@ def main(batch=8,
...
@@ -469,15 +423,13 @@ def main(batch=8,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
View file @
29051439
...
@@ -12,15 +12,17 @@ from heuristic import num_splits_heuristic
...
@@ -12,15 +12,17 @@ from heuristic import num_splits_heuristic
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
dtype
=
"float16"
dtype
=
"float16"
accum_dtype
=
"float"
accum_dtype
=
"float"
kv_group_num
=
heads
//
heads_kv
kv_group_num
=
heads
//
heads_kv
@
tilelang
.
jit
(
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
})
},
)
def
kernel_func
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
,
max_cache_seqlen
,
num_blocks
):
def
kernel_func
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
,
max_cache_seqlen
,
num_blocks
):
shape_q
=
[
batch
,
heads
,
dim
]
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim
]
shape_k
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim
]
...
@@ -32,16 +34,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -32,16 +34,15 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
@
T
.
macro
def
flash_attn_split
(
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_mask
:
T
.
Tensor
(
shape_mask
,
"bool"
),
block_mask
:
T
.
Tensor
(
shape_mask
,
"bool"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
...
@@ -62,39 +63,31 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -62,39 +63,31 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
sid
=
bz
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
loop_range
=
(
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
has_valid_block
=
False
has_valid_block
=
False
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
bid
,
hid
,
start
+
k
]:
if
block_mask
[
bid
,
hid
,
start
+
k
]:
has_valid_block
=
True
has_valid_block
=
True
T
.
copy
(
T
.
copy
(
K
[
bid
,
(
start
+
k
)
*
block_N
:
(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
K
[
bid
,
(
start
+
k
)
*
block_N
:(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
clear
(
acc_s
)
T
.
gemm
(
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
start
+
k
)
*
block_N
+
j
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
>=
cache_seqlens
[
bx
],
(
start
+
k
)
*
block_N
+
j
>=
cache_seqlens
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
)
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
...
@@ -103,9 +96,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -103,9 +96,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T
.
copy
(
acc_s
,
acc_s_cast
)
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
T
.
copy
(
V
[
bid
,
(
start
+
k
)
*
block_N
:
(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
V
[
bid
,
(
start
+
k
)
*
block_N
:(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_valid_block
:
if
has_valid_block
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
...
@@ -123,9 +114,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -123,9 +114,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
macro
@
T
.
macro
def
combine
(
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
...
@@ -135,10 +126,11 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -135,10 +126,11 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
lse_logsum_local
:
{
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
})
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
T
.
clear
(
o_accum_local
)
...
@@ -161,14 +153,14 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -161,14 +153,14 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_mask
:
T
.
Tensor
(
shape_mask
,
"bool"
),
block_mask
:
T
.
Tensor
(
shape_mask
,
"bool"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
cache_seqlens
:
T
.
Tensor
([
batch
],
"int32"
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
):
flash_attn_split
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
flash_attn_split
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
combine
(
glse
,
Output_partial
,
Output
)
...
@@ -179,7 +171,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
...
@@ -179,7 +171,6 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
):
super
(
SparseFlashAttn
,
self
).
__init__
()
super
(
SparseFlashAttn
,
self
).
__init__
()
self
.
batch
=
batch
self
.
batch
=
batch
...
@@ -198,7 +189,8 @@ class SparseFlashAttn(torch.nn.Module):
...
@@ -198,7 +189,8 @@ class SparseFlashAttn(torch.nn.Module):
num_stages
=
2
,
num_stages
=
2
,
threads
=
128
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
num_blocks
=
T
.
dynamic
(
"num_blocks"
))
num_blocks
=
T
.
dynamic
(
"num_blocks"
),
)
props
=
torch
.
cuda
.
get_device_properties
(
torch
.
device
(
"cuda:0"
))
props
=
torch
.
cuda
.
get_device_properties
(
torch
.
device
(
"cuda:0"
))
self
.
num_sm
=
props
.
multi_processor_count
self
.
num_sm
=
props
.
multi_processor_count
...
@@ -217,24 +209,16 @@ class SparseFlashAttn(torch.nn.Module):
...
@@ -217,24 +209,16 @@ class SparseFlashAttn(torch.nn.Module):
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
# num_sm = 132
# num_sm = 132
num_sm
=
self
.
num_sm
num_sm
=
self
.
num_sm
num_split
=
num_splits_heuristic
(
num_split
=
num_splits_heuristic
(
total_mblocks
,
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
num_sm
,
)
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_split: ", num_split)
# print("num_split: ", num_split)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
dtype
=
torch
.
float32
,
device
=
'cuda'
)
output
=
self
.
kernel
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
output
=
self
.
kernel
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
return
output
return
output
...
@@ -259,26 +243,21 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
...
@@ -259,26 +243,21 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
block_H
=
64
block_H
=
64
actual_num_blocks
=
torch
.
sum
(
block_mask
,
dim
=-
1
).
to
(
torch
.
int32
)
actual_num_blocks
=
torch
.
sum
(
block_mask
,
dim
=-
1
).
to
(
torch
.
int32
)
actual_num_blocks
=
actual_num_blocks
[:,
actual_num_blocks
=
actual_num_blocks
[
0
]
#[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
:,
0
]
# [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
max_selected_blocks
=
actual_num_blocks
.
max
().
item
()
max_selected_blocks
=
actual_num_blocks
.
max
().
item
()
# get num_split
# get num_split
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
#(kv_seqlen + block_size - 1 ) // block_size
num_n_blocks
=
max_selected_blocks
#
(kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
132
num_sm
=
132
num_split
=
num_splits_heuristic
(
num_split
=
num_splits_heuristic
(
total_mblocks
,
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
num_sm
,
)
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
kernel
=
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
)(
kernel
=
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
)(
block_N
=
block_size
,
block_N
=
block_size
,
...
@@ -287,11 +266,10 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
...
@@ -287,11 +266,10 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
num_stages
=
2
,
num_stages
=
2
,
threads
=
128
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
num_blocks
=
T
.
dynamic
(
"num_blocks"
))
num_blocks
=
T
.
dynamic
(
"num_blocks"
),
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
dtype
=
torch
.
float32
,
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
device
=
'cuda'
)
# print(kernel.get_kernel_source())
# print(kernel.get_kernel_source())
output
=
kernel
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
output
=
kernel
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
...
@@ -299,24 +277,18 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
...
@@ -299,24 +277,18 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
return
output
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
block_size
):
batch
,
heads
,
dim
=
query
.
shape
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values
# Assign mask values
...
@@ -324,29 +296,27 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
...
@@ -324,29 +296,27 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
for
h
in
range
(
heads_kv
):
for
h
in
range
(
heads_kv
):
for
idx
in
range
(
num_blocks
):
for
idx
in
range
(
num_blocks
):
if
block_mask
[
b
,
h
,
idx
]:
if
block_mask
[
b
,
h
,
idx
]:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
return
out
return
out
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
block_size
):
# latency reference
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
output
=
output
.
squeeze
(
1
)
...
@@ -360,23 +330,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
...
@@ -360,23 +330,13 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
# print(expect[3, 28])
# print(expect[3, 28])
# print(actual[3, 28])
# print(actual[3, 28])
diff
=
(
expect
-
actual
).
abs
()
diff
=
(
expect
-
actual
).
abs
()
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
max_indices
=
torch
.
nonzero
(
diff
==
diff
.
max
().
item
())
max_indices
=
torch
.
nonzero
(
diff
==
diff
.
max
().
item
())
first_index
=
tuple
(
max_indices
[
0
].
tolist
())
first_index
=
tuple
(
max_indices
[
0
].
tolist
())
print
(
f
"Index:
{
first_index
}
, expect:
{
expect
[
first_index
]
}
, actual:
{
actual
[
first_index
]
}
"
)
print
(
f
"Index:
{
first_index
}
, expect:
{
expect
[
first_index
]
}
, actual:
{
actual
[
first_index
]
}
"
)
def
main
(
batch
=
8
,
def
main
(
batch
=
8
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
sparse_ratio
=
sparse_ratio
block_size
=
block_size
block_size
=
block_size
...
@@ -384,14 +344,13 @@ def main(batch=8,
...
@@ -384,14 +344,13 @@ def main(batch=8,
print
(
"max_selected_blocks: "
,
max_selected_blocks
)
print
(
"max_selected_blocks: "
,
max_selected_blocks
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# Ensure at least one element equals cache_seqlen
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
'cuda'
).
item
()
# Select a random index
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
print
(
"cache_seqlens: "
,
cache_seqlens
)
print
(
"cache_seqlens: "
,
cache_seqlens
)
...
@@ -403,7 +362,7 @@ def main(batch=8,
...
@@ -403,7 +362,7 @@ def main(batch=8,
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_mask with false (for padding blocks)
# Initialize block_mask with false (for padding blocks)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
'
cuda
'
)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
"
cuda
"
)
# Assign valid indices while ensuring no duplicates within each batch-group
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
for
b
in
range
(
batch
):
...
@@ -411,13 +370,12 @@ def main(batch=8,
...
@@ -411,13 +370,12 @@ def main(batch=8,
valid_num_block
=
valid_num_blocks
[
b
].
item
()
# Valid blocks for this batch
valid_num_block
=
valid_num_blocks
[
b
].
item
()
# Valid blocks for this batch
if
valid_num_block
>
0
:
# Ensure there's at least one valid block
if
valid_num_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
for
h
in
range
(
heads_kv
):
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
'
cuda
'
)[:
valid_num_block
]
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
"
cuda
"
)[:
valid_num_block
]
block_mask
[
b
,
h
,
perm
]
=
True
block_mask
[
b
,
h
,
perm
]
=
True
# print("block_mask: ", block_mask)
# print("block_mask: ", block_mask)
# parity reference
# parity reference
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
block_size
)
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
model
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
)
model
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
)
out
=
model
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
)
out
=
model
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
)
...
@@ -427,13 +385,11 @@ def main(batch=8,
...
@@ -427,13 +385,11 @@ def main(batch=8,
## latency reference
## latency reference
for
_
in
range
(
10
):
for
_
in
range
(
10
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
block_size
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
start
=
time
.
time
()
for
_
in
range
(
100
):
for
_
in
range
(
100
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
block_size
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
"dense time: "
,
(
time
.
time
()
-
start
)
/
100
*
1000
)
print
(
"dense time: "
,
(
time
.
time
()
-
start
)
/
100
*
1000
)
...
@@ -452,15 +408,13 @@ def main(batch=8,
...
@@ -452,15 +408,13 @@ def main(batch=8,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
8
,
help
=
'batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
View file @
29051439
...
@@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic
...
@@ -12,12 +12,8 @@ from heuristic import num_splits_heuristic
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
key
=
[
"BLOCK_H"
,
"BLOCK_N"
,
"BLOCK_D"
],
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_H'
,
'BLOCK_N'
,
'BLOCK_D'
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_split_kernel
(
def
_split_kernel
(
...
@@ -79,16 +75,11 @@ def _split_kernel(
...
@@ -79,16 +75,11 @@ def _split_kernel(
loop_range
=
blocks_per_split
loop_range
=
blocks_per_split
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
q
=
tl
.
load
(
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
for
i
in
range
(
loop_range
):
for
i
in
range
(
loop_range
):
block_idx
=
tl
.
load
(
mask_ptr
+
(
start
+
i
)
*
stride_mask_s
)
block_idx
=
tl
.
load
(
mask_ptr
+
(
start
+
i
)
*
stride_mask_s
)
...
@@ -119,23 +110,18 @@ def _split_kernel(
...
@@ -119,23 +110,18 @@ def _split_kernel(
acc
=
acc
*
l_recip
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
o_partial_ptr
+=
batch_idx
*
stride_o_b
+
(
o_partial_ptr
+=
(
head_idx_q
+
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
key
=
[
"BLOCK_D"
],
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_D'
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_merge_kernel
(
def
_merge_kernel
(
...
@@ -163,18 +149,15 @@ def _merge_kernel(
...
@@ -163,18 +149,15 @@ def _merge_kernel(
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse
=
tl
.
load
(
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_max
=
tl
.
max
(
lse
)
lse_max
=
tl
.
max
(
lse
)
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_partial
=
tl
.
load
(
o_partial
=
tl
.
load
(
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
,
mask
=
offs_splits
[:,
None
]
<
num_splits
)
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
...
@@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton(
...
@@ -209,19 +192,13 @@ def block_sparse_flash_decode_gqa_indice_triton(
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
64
num_sm
=
64
# num_sm = self.num_sm
# num_sm = self.num_sm
num_splits
=
num_splits_heuristic
(
num_splits
=
num_splits_heuristic
(
total_mblocks
,
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
num_sm
,
)
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...
@@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton(
...
@@ -295,24 +272,18 @@ def block_sparse_flash_decode_gqa_indice_triton(
return
output
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
block_size
):
batch
,
heads
,
dim
=
query
.
shape
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
heads_kv
=
key
.
shape
[
2
]
dim_v
=
value
.
shape
[
-
1
]
dim_v
=
value
.
shape
[
-
1
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values based on block_indices
# Assign mask values based on block_indices
...
@@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
...
@@ -321,42 +292,33 @@ def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
for
idx
in
valid_indices
:
for
idx
in
valid_indices
:
if
idx
>=
0
:
if
idx
>=
0
:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
return
out
return
out
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
# latency reference
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
output
=
output
.
squeeze
(
1
)
return
output
return
output
def
main
(
batch
=
64
,
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
sparse_ratio
=
sparse_ratio
block_size
=
block_size
block_size
=
block_size
...
@@ -369,34 +331,29 @@ def main(batch=64,
...
@@ -369,34 +331,29 @@ def main(batch=64,
dtype
=
torch
.
float16
dtype
=
torch
.
float16
block_H
=
64
block_H
=
64
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
'cuda'
).
item
()
# Select a random index
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
print
(
"cache_seqlens: "
,
cache_seqlens
)
print
(
"cache_seqlens: "
,
cache_seqlens
)
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_indices with -1 (for padding blocks)
# Initialize block_indices with -1 (for padding blocks)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
-
1
,
dtype
=
torch
.
int32
,
device
=
'cuda'
)
# Assign valid indices while ensuring no duplicates within each batch-group
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
for
b
in
range
(
batch
):
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
if
max_valid_block
>
0
:
# Ensure there's at least one valid block
if
max_valid_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
for
h
in
range
(
heads_kv
):
valid_indices
=
torch
.
randperm
(
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
max_valid_block
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
# Sort indices within each batch-group for consistency
# Sort indices within each batch-group for consistency
block_indices
,
_
=
block_indices
.
sort
(
dim
=-
1
,
descending
=
True
)
block_indices
,
_
=
block_indices
.
sort
(
dim
=-
1
,
descending
=
True
)
...
@@ -408,8 +365,7 @@ def main(batch=64,
...
@@ -408,8 +365,7 @@ def main(batch=64,
max_num_blocks
=
torch
.
max
(
max_valid_num_blocks
).
item
()
max_num_blocks
=
torch
.
max
(
max_valid_num_blocks
).
item
()
print
(
"max_num_blocks: "
,
max_num_blocks
)
print
(
"max_num_blocks: "
,
max_num_blocks
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
block_size
)
triton_out
=
block_sparse_flash_decode_gqa_indice_triton
(
triton_out
=
block_sparse_flash_decode_gqa_indice_triton
(
Q
,
Q
,
...
@@ -423,8 +379,7 @@ def main(batch=64,
...
@@ -423,8 +379,7 @@ def main(batch=64,
)
)
print
(
"max difference: "
,
torch
.
max
(
torch
.
abs
(
ref
-
triton_out
)))
print
(
"max difference: "
,
torch
.
max
(
torch
.
abs
(
ref
-
triton_out
)))
assert
torch
.
allclose
(
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
print
(
"Passed the ref test!"
)
print
(
"Passed the ref test!"
)
# Measure performance
# Measure performance
...
@@ -466,15 +421,13 @@ def main(batch=64,
...
@@ -466,15 +421,13 @@ def main(batch=64,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
64
,
help
=
'batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
64
,
help
=
"batch size"
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
View file @
29051439
...
@@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic
...
@@ -11,12 +11,8 @@ from heuristic import num_splits_heuristic
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
key
=
[
"BLOCK_H"
,
"BLOCK_N"
,
"BLOCK_D"
],
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_H'
,
'BLOCK_N'
,
'BLOCK_D'
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_split_kernel
(
def
_split_kernel
(
...
@@ -77,16 +73,11 @@ def _split_kernel(
...
@@ -77,16 +73,11 @@ def _split_kernel(
loop_range
=
blocks_per_split
loop_range
=
blocks_per_split
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
q
=
tl
.
load
(
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
for
block_idx
in
range
(
loop_range
):
for
block_idx
in
range
(
loop_range
):
start_n
=
(
start
+
block_idx
)
*
BLOCK_N
start_n
=
(
start
+
block_idx
)
*
BLOCK_N
...
@@ -117,23 +108,18 @@ def _split_kernel(
...
@@ -117,23 +108,18 @@ def _split_kernel(
acc
=
acc
*
l_recip
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
o_partial_ptr
+=
batch_idx
*
stride_o_b
+
(
o_partial_ptr
+=
(
head_idx_q
+
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
key
=
[
"BLOCK_D"
],
for
num_warps
in
[
1
,
2
,
4
]
\
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]
],
key
=
[
'BLOCK_D'
],
)
)
@
triton
.
jit
@
triton
.
jit
def
_merge_kernel
(
def
_merge_kernel
(
...
@@ -161,18 +147,15 @@ def _merge_kernel(
...
@@ -161,18 +147,15 @@ def _merge_kernel(
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse
=
tl
.
load
(
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_max
=
tl
.
max
(
lse
)
lse_max
=
tl
.
max
(
lse
)
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_partial
=
tl
.
load
(
o_partial
=
tl
.
load
(
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
,
mask
=
offs_splits
[:,
None
]
<
num_splits
)
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
...
@@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton(
...
@@ -207,19 +190,13 @@ def block_sparse_flash_decode_gqa_mask_triton(
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
dim
+
dim_v
)
*
2
#kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
64
num_sm
=
64
# num_sm = self.num_sm
# num_sm = self.num_sm
num_splits
=
num_splits_heuristic
(
num_splits
=
num_splits_heuristic
(
total_mblocks
,
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
num_sm
,
)
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
...
@@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton(
...
@@ -292,24 +269,18 @@ def block_sparse_flash_decode_gqa_mask_triton(
return
output
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
block_size
):
batch
,
heads
,
dim
=
query
.
shape
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
key
=
rearrange
(
key
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
'
b n h d -> b h n d
'
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"
b n h d -> b h n d
"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
query
,
'b (h g) d -> b g h d'
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
query
,
key
,
'b g h d, b h s d -> b g h s'
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values
# Assign mask values
...
@@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
...
@@ -317,43 +288,34 @@ def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_se
for
h
in
range
(
heads_kv
):
for
h
in
range
(
heads_kv
):
for
idx
in
range
(
num_blocks
):
for
idx
in
range
(
num_blocks
):
if
block_mask
[
b
,
h
,
idx
]:
if
block_mask
[
b
,
h
,
idx
]:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
'
-inf
'
))
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"
-inf
"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
'
cuda
'
).
unsqueeze
(
0
)
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"
cuda
"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
'-inf'
))
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
'b g h s, b h s d -> b g h d'
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
out
=
rearrange
(
out
,
'b g h d -> b (h g) d'
)
# [batch_size, heads, dim]
return
out
return
out
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
# latency reference
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
#fa2
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
output
=
output
.
squeeze
(
1
)
return
output
return
output
def
main
(
batch
=
64
,
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
block_size
=
block_size
block_size
=
block_size
sparse_ratio
=
sparse_ratio
sparse_ratio
=
sparse_ratio
...
@@ -363,14 +325,13 @@ def main(batch=64,
...
@@ -363,14 +325,13 @@ def main(batch=64,
dtype
=
torch
.
float16
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
'
cuda
'
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"
cuda
"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
'
cuda
'
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"
cuda
"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
'
cuda
'
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"
cuda
"
)
# Ensure at least one element equals cache_seqlen
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
'cuda'
).
item
()
# Select a random index
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
num_blocks
=
(
max_cache_seqlen
+
block_size
-
1
)
//
block_size
num_blocks
=
(
max_cache_seqlen
+
block_size
-
1
)
//
block_size
...
@@ -379,7 +340,7 @@ def main(batch=64,
...
@@ -379,7 +340,7 @@ def main(batch=64,
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_mask with false (for padding blocks)
# Initialize block_mask with false (for padding blocks)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
'
cuda
'
)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
"
cuda
"
)
# Assign valid indices while ensuring no duplicates within each batch-group
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
for
b
in
range
(
batch
):
...
@@ -387,11 +348,10 @@ def main(batch=64,
...
@@ -387,11 +348,10 @@ def main(batch=64,
valid_num_block
=
valid_num_blocks
[
b
].
item
()
# Valid blocks for this batch
valid_num_block
=
valid_num_blocks
[
b
].
item
()
# Valid blocks for this batch
if
valid_num_block
>
0
:
# Ensure there's at least one valid block
if
valid_num_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
for
h
in
range
(
heads_kv
):
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
'
cuda
'
)[:
valid_num_block
]
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
"
cuda
"
)[:
valid_num_block
]
block_mask
[
b
,
h
,
perm
]
=
True
block_mask
[
b
,
h
,
perm
]
=
True
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
block_size
)
triton_out
=
block_sparse_flash_decode_gqa_mask_triton
(
triton_out
=
block_sparse_flash_decode_gqa_mask_triton
(
Q
,
Q
,
...
@@ -404,8 +364,7 @@ def main(batch=64,
...
@@ -404,8 +364,7 @@ def main(batch=64,
)
)
# print("max difference: ", torch.max(torch.abs(ref - triton_out)))
# print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert
torch
.
allclose
(
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
print
(
"Passed the ref test!"
)
print
(
"Passed the ref test!"
)
# Measure performance
# Measure performance
...
@@ -448,15 +407,13 @@ def main(batch=64,
...
@@ -448,15 +407,13 @@ def main(batch=64,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
64
,
help
=
'batch size'
)
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
64
,
help
=
"batch size"
)
parser
.
add_argument
(
'--heads'
,
type
=
int
,
default
=
32
,
help
=
'heads'
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
'--heads_kv'
,
type
=
int
,
default
=
8
,
help
=
'heads_kv'
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
'--max_cache_seqlen'
,
type
=
int
,
default
=
8192
,
help
=
'kvcache sequence length'
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
128
,
help
=
'dim'
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
'--dim_v'
,
type
=
int
,
default
=
128
,
help
=
'dim_v'
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
'--sparse_ratio'
,
type
=
float
,
default
=
0.8
,
help
=
'sparse ratio'
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
parser
.
add_argument
(
'--block_size'
,
type
=
int
,
default
=
32
,
help
=
'block_size'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/heuristic.py
View file @
29051439
import
math
import
math
def
num_splits_heuristic
(
total_mblocks
,
num_SMs
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
def
num_splits_heuristic
(
total_mblocks
,
num_SMs
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
,
max_splits
):
is_causal_or_local
,
max_splits
):
"""
"""
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
...
...
examples/blocksparse_attention/test_example_blocksparse_attention.py
View file @
29051439
...
@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
...
@@ -25,26 +25,14 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask():
def
test_example_triton_sparse_gqa_decode_varlen_indice
():
def
test_example_triton_sparse_gqa_decode_varlen_indice
():
example_triton_sparse_gqa_decode_varlen_indice
.
main
(
example_triton_sparse_gqa_decode_varlen_indice
.
main
(
batch
=
8
,
batch
=
8
,
heads
=
8
,
heads_kv
=
4
,
max_cache_seqlen
=
2048
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
heads
=
8
,
)
heads_kv
=
4
,
max_cache_seqlen
=
2048
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
def
test_example_triton_sparse_gqa_decode_varlen_mask
():
def
test_example_triton_sparse_gqa_decode_varlen_mask
():
example_triton_sparse_gqa_decode_varlen_mask
.
main
(
example_triton_sparse_gqa_decode_varlen_mask
.
main
(
batch
=
16
,
batch
=
16
,
heads
=
16
,
heads_kv
=
8
,
max_cache_seqlen
=
1024
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
heads
=
16
,
)
heads_kv
=
8
,
max_cache_seqlen
=
1024
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/blocksparse_gemm/example_blocksparse_gemm.py
View file @
29051439
...
@@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
...
@@ -19,8 +19,7 @@ parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--sparsity"
,
type
=
float
,
default
=
0.5
,
help
=
"Sparsity ratio (0-1)"
)
parser
.
add_argument
(
"--sparsity"
,
type
=
float
,
default
=
0.5
,
help
=
"Sparsity ratio (0-1)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune"
)
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune"
)
args
,
_
=
parser
.
parse_known_args
()
args
,
_
=
parser
.
parse_known_args
()
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
...
@@ -41,17 +40,19 @@ def get_configs():
...
@@ -41,17 +40,19 @@ def get_configs():
thread_num
=
[
128
,
256
]
thread_num
=
[
128
,
256
]
enable_rasterization
=
[
True
,
False
]
enable_rasterization
=
[
True
,
False
]
_configs
=
list
(
_configs
=
list
(
itertools
.
product
(
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasterization
))
itertools
.
product
(
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasterization
))
return
[{
return
[
"block_M"
:
c
[
0
],
{
"block_N"
:
c
[
1
],
"block_M"
:
c
[
0
],
"block_K"
:
c
[
2
],
"block_N"
:
c
[
1
],
"num_stages"
:
c
[
3
],
"block_K"
:
c
[
2
],
"thread_num"
:
c
[
4
],
"num_stages"
:
c
[
3
],
"enable_rasteration"
:
c
[
5
],
"thread_num"
:
c
[
4
],
}
for
c
in
_configs
]
"enable_rasteration"
:
c
[
5
],
}
for
c
in
_configs
]
def
ref_program
(
A
,
B
,
BlockMask
,
block_M
,
block_N
,
block_K
):
def
ref_program
(
A
,
B
,
BlockMask
,
block_M
,
block_N
,
block_K
):
...
@@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
...
@@ -61,12 +62,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
for
k
in
range
(
K
//
block_K
):
for
k
in
range
(
K
//
block_K
):
if
BlockMask
[
i
,
j
,
k
]:
if
BlockMask
[
i
,
j
,
k
]:
accu
+=
(
accu
+=
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
A
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
k
*
block_K
:(
k
+
1
)
*
block_K
].
to
(
k
*
block_K
:
(
k
+
1
)
*
block_K
,
j
*
block_N
:
(
j
+
1
)
*
block_N
torch
.
float32
)
@
B
[
k
*
block_K
:(
k
+
1
)
*
block_K
,
].
to
(
torch
.
float32
)
j
*
block_N
:(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
))
ref_c
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
ref_c
[
i
*
block_M
:(
i
+
1
)
*
block_M
,
j
*
block_N
:(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
return
ref_c
return
ref_c
...
@@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]):
...
@@ -89,28 +88,21 @@ def supply_program(params: List[KernelParam]):
return
input_tensors
return
input_tensors
@
tilelang
.
autotune
(
configs
=
get_configs
(),)
@
tilelang
.
autotune
(
configs
=
get_configs
(),
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
blocksparse_matmul
(
M
,
def
blocksparse_matmul
(
N
,
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
K
,
):
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
@
T
.
prim_func
@
T
.
prim_func
def
block_sparse_matmul
(
def
block_sparse_matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -134,7 +126,6 @@ def blocksparse_matmul(M,
...
@@ -134,7 +126,6 @@ def blocksparse_matmul(M,
def
main
():
def
main
():
# Initialize input matrices A and B on the GPU with half precision
# Initialize input matrices A and B on the GPU with half precision
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
b
=
torch
.
randn
(
K
,
N
).
cuda
().
half
()
b
=
torch
.
randn
(
K
,
N
).
cuda
().
half
()
...
@@ -147,8 +138,7 @@ def main():
...
@@ -147,8 +138,7 @@ def main():
best_config
=
kernel
.
config
best_config
=
kernel
.
config
best_latency
=
kernel
.
latency
best_latency
=
kernel
.
latency
block_M
,
block_N
,
block_K
=
best_config
[
"block_M"
],
best_config
[
"block_N"
],
best_config
[
block_M
,
block_N
,
block_K
=
best_config
[
"block_M"
],
best_config
[
"block_N"
],
best_config
[
"block_K"
]
"block_K"
]
print
(
f
"Best Config:
{
best_config
}
"
)
print
(
f
"Best Config:
{
best_config
}
"
)
print
(
f
"Sparsity Ratio:
{
sparsity
}
"
)
print
(
f
"Sparsity Ratio:
{
sparsity
}
"
)
...
@@ -163,7 +153,8 @@ def main():
...
@@ -163,7 +153,8 @@ def main():
block_K
=
DEFAULT_BLOCK_K
,
block_K
=
DEFAULT_BLOCK_K
,
num_stages
=
DEFAULT_NUM_STAGES
,
num_stages
=
DEFAULT_NUM_STAGES
,
thread_num
=
DEFAULT_THREAD_NUM
,
thread_num
=
DEFAULT_THREAD_NUM
,
enable_rasteration
=
DEFAULT_ENABLE_RASTERIZATION
)
enable_rasteration
=
DEFAULT_ENABLE_RASTERIZATION
,
)
block_M
,
block_N
,
block_K
=
DEFAULT_BLOCK_M
,
DEFAULT_BLOCK_N
,
DEFAULT_BLOCK_K
block_M
,
block_N
,
block_K
=
DEFAULT_BLOCK_M
,
DEFAULT_BLOCK_N
,
DEFAULT_BLOCK_K
print
(
f
"Using default kernel with block size (
{
block_M
}
,
{
block_N
}
,
{
block_K
}
)"
)
print
(
f
"Using default kernel with block size (
{
block_M
}
,
{
block_N
}
,
{
block_K
}
)"
)
# Create block mask with desired sparsity
# Create block mask with desired sparsity
...
...
examples/cast/example_group_per_split_token_cast_to_fp8.py
View file @
29051439
...
@@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
...
@@ -16,11 +16,13 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
fp8_max
=
448.0
fp8_max
=
448.0
@
T
.
prim_func
@
T
.
prim_func
def
group_per_split_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
batch_sizes
:
T
.
Tensor
(
def
group_per_split_token_cast
(
(
BG
,),
"int32"
),
X_fp8
:
T
.
Tensor
((
BG
,
M_max
,
N
),
"float8_e4m3"
),
X_amax
:
T
.
Tensor
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
(
BG
,
M_max
,
T
.
ceildiv
(
N
,
group_size
)),
accum_dtype
)):
batch_sizes
:
T
.
Tensor
((
BG
,),
"int32"
),
with
T
.
Kernel
(
X_fp8
:
T
.
Tensor
((
BG
,
M_max
,
N
),
"float8_e4m3"
),
T
.
ceildiv
(
M_max
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
BG
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
X_amax
:
T
.
Tensor
((
BG
,
M_max
,
T
.
ceildiv
(
N
,
group_size
)),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
M_max
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
BG
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
row
=
bx
row
=
bx
row_g_id
=
by
row_g_id
=
by
bg
=
bz
bg
=
bz
...
@@ -31,36 +33,32 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
...
@@ -31,36 +33,32 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
"float8_e4m3"
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
"float8_e4m3"
)
row_offset
=
T
.
alloc_fragment
((
1
,),
"int32"
)
row_offset
=
T
.
alloc_fragment
((
1
,),
"int32"
)
T
.
annotate_layout
({
T
.
annotate_layout
(
y_local
:
{
T
.
Fragment
(
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
y_local
.
shape
,
}
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
)
})
row_offset
[
0
]
=
0
row_offset
[
0
]
=
0
for
i
in
T
.
serial
(
bg
):
for
i
in
T
.
serial
(
bg
):
row_offset
[
0
]
+=
batch_sizes
[
i
]
row_offset
[
0
]
+=
batch_sizes
[
i
]
T
.
copy
(
T
.
copy
(
X
[
row_offset
[
0
]
+
row
*
blk_m
:
row_offset
[
0
]
+
(
row
+
1
)
*
blk_m
,
X
[
row_offset
[
0
]
+
row
*
blk_m
:
row_offset
[
0
]
+
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
],
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
],
y_local
)
y_local
,
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
for
i
in
T
.
Parallel
(
blk_m
):
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
y_s_local
[
i
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_s_local
[
i
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_amax_local
[
i
]
/
fp8_max
,
0
)
y_amax_local
[
i
]
/
fp8_max
,
0
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
y_q_local
[
i
,
j
]
=
T
.
clamp
(
y_local
[
i
,
j
]
/
y_s_local
[
i
],
fp8_min
,
fp8_max
)
y_q_local
[
i
,
j
]
=
T
.
clamp
(
y_local
[
i
,
j
]
/
y_s_local
[
i
],
fp8_min
,
fp8_max
)
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
y_q_local_fp8
[
i
,
j
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_q_local_fp8
[
i
,
j
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_q_local
[
i
,
j
],
0
)
y_q_local
[
i
,
j
],
0
)
for
i
in
T
.
Parallel
(
blk_m
):
for
i
in
T
.
Parallel
(
blk_m
):
X_amax
[
bg
,
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
X_amax
[
bg
,
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
T
.
copy
(
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
bg
,
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
])
y_q_local_fp8
,
X_fp8
[
bg
,
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
])
return
group_per_split_token_cast
return
group_per_split_token_cast
...
@@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
...
@@ -127,8 +125,7 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
return
x
.
squeeze
(
0
)
if
remove_dim
else
x
return
x
.
squeeze
(
0
)
if
remove_dim
else
x
# Normal layout requires transposing
# Normal layout requires transposing
aligned_x
=
torch
.
transpose
(
aligned_x
=
torch
.
transpose
(
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
aligned_x
[:,
:
m
,
:]
=
x
aligned_x
[:,
:
m
,
:]
=
x
aligned_x
=
aligned_x
[:,
:
m
,
:]
aligned_x
=
aligned_x
[:,
:
m
,
:]
return
aligned_x
.
squeeze
(
0
)
if
remove_dim
else
aligned_x
return
aligned_x
.
squeeze
(
0
)
if
remove_dim
else
aligned_x
...
@@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
...
@@ -146,15 +143,17 @@ def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
x_fp8
=
x_fp8
.
view
(
m
,
-
1
)[:,
:
n
].
contiguous
()
x_fp8
=
x_fp8
.
view
(
m
,
-
1
)[:,
:
n
].
contiguous
()
return
x_fp8
,
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
return
x_fp8
,
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
ref_program
(
x
:
torch
.
Tensor
,
batch_sizes
:
torch
.
Tensor
)
->
\
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
ref_program
(
x
:
torch
.
Tensor
,
batch_sizes
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# assert x.shape[0] == batch_sizes.sum()
# assert x.shape[0] == batch_sizes.sum()
M_max
=
ceil_div
(
batch_sizes
.
max
(),
128
)
*
128
M_max
=
ceil_div
(
batch_sizes
.
max
(),
128
)
*
128
split_x
=
torch
.
split
(
x
,
batch_sizes
.
tolist
(),
dim
=
0
)
split_x
=
torch
.
split
(
x
,
batch_sizes
.
tolist
(),
dim
=
0
)
padded_x
=
[
torch
.
nn
.
functional
.
pad
(
t
,
(
0
,
0
,
0
,
M_max
-
t
.
shape
[
0
]))
for
t
in
split_x
]
padded_x
=
[
torch
.
nn
.
functional
.
pad
(
t
,
(
0
,
0
,
0
,
M_max
-
t
.
shape
[
0
]))
for
t
in
split_x
]
num_groups
,
m
,
n
=
batch_sizes
.
shape
[
0
],
M_max
,
x
.
shape
[
1
]
num_groups
,
m
,
n
=
batch_sizes
.
shape
[
0
],
M_max
,
x
.
shape
[
1
]
x_fp8
=
(
torch
.
empty
((
num_groups
,
m
,
n
),
device
=
'cuda'
,
dtype
=
torch
.
float8_e4m3fn
),
x_fp8
=
(
torch
.
empty
((
num_groups
,
m
,
n
//
128
),
device
=
'cuda'
,
dtype
=
torch
.
float
))
torch
.
empty
((
num_groups
,
m
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
),
torch
.
empty
((
num_groups
,
m
,
n
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float
),
)
for
i
in
range
(
num_groups
):
for
i
in
range
(
num_groups
):
x_fp8
[
0
][
i
],
x_fp8
[
1
][
i
]
=
ref_per_token_cast_to_fp8
(
padded_x
[
i
])
x_fp8
[
0
][
i
],
x_fp8
[
1
][
i
]
=
ref_per_token_cast_to_fp8
(
padded_x
[
i
])
x_fp8
=
(
x_fp8
[
0
],
get_col_major_tma_aligned_tensor
(
x_fp8
[
1
]))
x_fp8
=
(
x_fp8
[
0
],
get_col_major_tma_aligned_tensor
(
x_fp8
[
1
]))
...
...
examples/cast/example_per_token_cast_to_fp8.py
View file @
29051439
...
@@ -13,8 +13,9 @@ def per_token_cast_to_fp8(M, N, blk_m):
...
@@ -13,8 +13,9 @@ def per_token_cast_to_fp8(M, N, blk_m):
fp8_max
=
448.0
fp8_max
=
448.0
@
T
.
prim_func
@
T
.
prim_func
def
per_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
X_fp8
:
T
.
Tensor
((
M
,
N
),
"float8_e4m3"
),
def
per_token_cast
(
X_amax
:
T
.
Tensor
((
M
,
T
.
ceildiv
(
N
,
group_size
)),
dtype
)):
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
X_fp8
:
T
.
Tensor
((
M
,
N
),
"float8_e4m3"
),
X_amax
:
T
.
Tensor
((
M
,
T
.
ceildiv
(
N
,
group_size
)),
dtype
)
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
threads
=
128
)
as
(
bx
,
by
):
row
=
bx
row
=
bx
row_g_id
=
by
row_g_id
=
by
...
@@ -24,16 +25,13 @@ def per_token_cast_to_fp8(M, N, blk_m):
...
@@ -24,16 +25,13 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_q_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
dtype
)
y_q_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
dtype
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
"float8_e4m3"
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
"float8_e4m3"
)
T
.
annotate_layout
({
T
.
annotate_layout
(
y_local
:
{
T
.
Fragment
(
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
y_local
.
shape
,
}
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
)
})
T
.
copy
(
T
.
copy
(
X
[
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
],
y_local
)
X
[
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
],
y_local
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
for
i
in
T
.
Parallel
(
blk_m
):
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
...
@@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
...
@@ -43,9 +41,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
for
i
in
T
.
Parallel
(
blk_m
):
for
i
in
T
.
Parallel
(
blk_m
):
X_amax
[
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
X_amax
[
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
T
.
copy
(
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
])
y_q_local_fp8
,
X_fp8
[
row
*
blk_m
:(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:(
row_g_id
+
1
)
*
group_size
])
return
per_token_cast
return
per_token_cast
...
@@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8):
...
@@ -105,8 +101,7 @@ def main(M=8192, N=8192, blk_m=8):
from
example_triton_cast_to_fp8
import
per_token_group_quant_fp8
from
example_triton_cast_to_fp8
import
per_token_group_quant_fp8
def
run_triton
():
def
run_triton
():
x_fp8_triton_
,
x_amax_triton_
=
per_token_group_quant_fp8
(
x_fp8_triton_
,
x_amax_triton_
=
per_token_group_quant_fp8
(
x
,
128
,
1e-4
,
dtype
=
torch
.
float8_e4m3fn
,
column_major_scales
=
False
)
x
,
128
,
1e-4
,
dtype
=
torch
.
float8_e4m3fn
,
column_major_scales
=
False
)
return
x_fp8_triton_
,
x_amax_triton_
return
x_fp8_triton_
,
x_amax_triton_
x_fp8_triton
,
x_amax_triton
=
run_triton
()
x_fp8_triton
,
x_amax_triton
=
run_triton
()
...
...
examples/cast/example_triton_cast_to_fp8.py
View file @
29051439
...
@@ -128,9 +128,7 @@ def per_token_group_quant_fp8(
...
@@ -128,9 +128,7 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
scaling factor for quantization.
"""
"""
assert
(
x
.
shape
[
-
1
]
%
assert
x
.
shape
[
-
1
]
%
group_size
==
0
,
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible by `group_size`
{
group_size
}
"
group_size
==
0
),
(
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible "
f
"by `group_size`
{
group_size
}
"
)
assert
x
.
stride
(
-
1
)
==
1
,
"`x` groups must be contiguous"
assert
x
.
stride
(
-
1
)
==
1
,
"`x` groups must be contiguous"
finfo
=
torch
.
finfo
(
dtype
)
finfo
=
torch
.
finfo
(
dtype
)
...
...
examples/cast/test_example_cast.py
View file @
29051439
...
@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8
...
@@ -4,8 +4,7 @@ import example_per_token_cast_to_fp8
def
test_example_group_per_split_token_cast_to_fp8
():
def
test_example_group_per_split_token_cast_to_fp8
():
example_group_per_split_token_cast_to_fp8
.
main
(
example_group_per_split_token_cast_to_fp8
.
main
(
M
=
1024
,
N
=
1024
,
BG
=
2
,
blk_m
=
4
,
batch_sizes
=
[
128
,
896
])
M
=
1024
,
N
=
1024
,
BG
=
2
,
blk_m
=
4
,
batch_sizes
=
[
128
,
896
])
def
test_example_per_token_cast_to_fp8
():
def
test_example_per_token_cast_to_fp8
():
...
...
examples/compile_flags/usecase.py
View file @
29051439
...
@@ -4,12 +4,11 @@ import tilelang.language as T
...
@@ -4,12 +4,11 @@ import tilelang.language as T
# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
# Initialize Kernel Context
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
...
@@ -36,8 +35,7 @@ block_K = 32
...
@@ -36,8 +35,7 @@ block_K = 32
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
)
func
=
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
)
jit_kernel
=
tilelang
.
compile
(
jit_kernel
=
tilelang
.
compile
(
func
,
out_idx
=
[
2
],
target
=
"cuda"
,
compile_flags
=
"-O3 --use_fast_math --expt-relaxed-constexpr"
)
func
,
out_idx
=
[
2
],
target
=
"cuda"
,
compile_flags
=
"-O3 --use_fast_math --expt-relaxed-constexpr"
)
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"])
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"])
...
...
examples/conftest.py
View file @
29051439
...
@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
...
@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
"warnings"
,
"warnings"
,
"error"
,
"error"
,
}
}
if
(
sum
(
if
sum
(
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
:
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
):
terminalreporter
.
write_sep
(
terminalreporter
.
write_sep
(
"!"
,
"!"
,
(
f
"Error: No tests were collected. "
(
f
"Error: No tests were collected.
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
f
"
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
)
)
pytest
.
exit
(
"No tests were collected."
,
returncode
=
5
)
pytest
.
exit
(
"No tests were collected."
,
returncode
=
5
)
examples/convolution/example_convolution.py
View file @
29051439
...
@@ -14,7 +14,6 @@ def check_hopper():
...
@@ -14,7 +14,6 @@ def check_hopper():
def
ref_program
(
stride
,
padding
,
dilation
):
def
ref_program
(
stride
,
padding
,
dilation
):
def
main
(
A
,
B
):
def
main
(
A
,
B
):
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
...
@@ -26,22 +25,7 @@ def ref_program(stride, padding, dilation):
...
@@ -26,22 +25,7 @@ def ref_program(stride, padding, dilation):
@
tilelang
.
jit
(
out_idx
=
[
2
])
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
convolution
(
N
,
def
convolution
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
threads
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
KH
,
KW
=
K
,
K
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
...
@@ -51,13 +35,11 @@ def convolution(N,
...
@@ -51,13 +35,11 @@ def convolution(N,
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -66,11 +48,13 @@ def convolution(N,
...
@@ -66,11 +48,13 @@ def convolution(N,
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
kernel_flat
=
T
.
Tensor
((
KH
*
KW
*
C
,
F
),
dtype
,
kernel
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
T
.
annotate_layout
({
T
.
annotate_layout
(
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
{
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
data_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
data_shared
),
})
kernel_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
kernel_shared
),
}
)
T
.
clear
(
out_local
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
@@ -82,10 +66,8 @@ def convolution(N,
...
@@ -82,10 +66,8 @@ def convolution(N,
m
=
by
*
block_M
+
i
m
=
by
*
block_M
+
i
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
in_bound
=
((
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
in_bound
=
(
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
)
(
access_w
<
W
))
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
...
@@ -97,15 +79,15 @@ def convolution(N,
...
@@ -97,15 +79,15 @@ def convolution(N,
def
main
(
argv
=
None
):
def
main
(
argv
=
None
):
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--n
'
,
type
=
int
,
default
=
128
,
help
=
'n'
)
parser
.
add_argument
(
"
--n
"
,
type
=
int
,
default
=
128
,
help
=
"n"
)
parser
.
add_argument
(
'
--c
'
,
type
=
int
,
default
=
128
,
help
=
'c'
)
parser
.
add_argument
(
"
--c
"
,
type
=
int
,
default
=
128
,
help
=
"c"
)
parser
.
add_argument
(
'
--h
'
,
type
=
int
,
default
=
64
,
help
=
'h'
)
parser
.
add_argument
(
"
--h
"
,
type
=
int
,
default
=
64
,
help
=
"h"
)
parser
.
add_argument
(
'
--w
'
,
type
=
int
,
default
=
64
,
help
=
'w'
)
parser
.
add_argument
(
"
--w
"
,
type
=
int
,
default
=
64
,
help
=
"w"
)
parser
.
add_argument
(
'
--f
'
,
type
=
int
,
default
=
128
,
help
=
'f'
)
parser
.
add_argument
(
"
--f
"
,
type
=
int
,
default
=
128
,
help
=
"f"
)
parser
.
add_argument
(
'
--k
'
,
type
=
int
,
default
=
3
,
help
=
'k'
)
parser
.
add_argument
(
"
--k
"
,
type
=
int
,
default
=
3
,
help
=
"k"
)
parser
.
add_argument
(
'
--s
'
,
type
=
int
,
default
=
1
,
help
=
's'
)
parser
.
add_argument
(
"
--s
"
,
type
=
int
,
default
=
1
,
help
=
"s"
)
parser
.
add_argument
(
'
--d
'
,
type
=
int
,
default
=
1
,
help
=
'd'
)
parser
.
add_argument
(
"
--d
"
,
type
=
int
,
default
=
1
,
help
=
"d"
)
parser
.
add_argument
(
'
--p
'
,
type
=
int
,
default
=
1
,
help
=
'p'
)
parser
.
add_argument
(
"
--p
"
,
type
=
int
,
default
=
1
,
help
=
"p"
)
args
=
parser
.
parse_args
(
argv
)
args
=
parser
.
parse_args
(
argv
)
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
...
...
examples/convolution/example_convolution_autotune.py
View file @
29051439
...
@@ -14,7 +14,6 @@ def check_hopper():
...
@@ -14,7 +14,6 @@ def check_hopper():
def
ref_program
(
stride
,
padding
,
dilation
):
def
ref_program
(
stride
,
padding
,
dilation
):
def
main
(
A
,
B
):
def
main
(
A
,
B
):
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
A
=
A
.
permute
(
0
,
3
,
1
,
2
)
# N, H, W, C -> N, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
B
=
B
.
permute
(
3
,
2
,
0
,
1
)
# H, W, C, F -> F, C, H, W
...
@@ -40,7 +39,8 @@ def get_configs():
...
@@ -40,7 +39,8 @@ def get_configs():
num_stages
,
num_stages
,
thread_num
,
thread_num
,
enable_rasterization
,
enable_rasterization
,
))
)
)
configs
=
[
configs
=
[
{
{
...
@@ -50,7 +50,8 @@ def get_configs():
...
@@ -50,7 +50,8 @@ def get_configs():
"num_stages"
:
c
[
3
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
]
return
configs
return
configs
...
@@ -64,53 +65,18 @@ def get_heuristic_config() -> dict:
...
@@ -64,53 +65,18 @@ def get_heuristic_config() -> dict:
sm_version
=
sm_major
*
10
+
sm_minor
sm_version
=
sm_major
*
10
+
sm_minor
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
print
(
f
"CUDA device capability:
{
sm_version
}
"
)
if
sm_version
in
{
80
}:
if
sm_version
in
{
80
}:
return
{
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
2
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
elif
sm_version
in
{
90
}:
elif
sm_version
in
{
90
}:
return
{
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
64
,
"num_stages"
:
3
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
}
else
:
else
:
return
{
return
{
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
"block_M"
:
128
,
"block_N"
:
256
,
"block_K"
:
32
,
"num_stages"
:
0
,
"thread_num"
:
128
,
"enable_rasteration"
:
True
}
@
tilelang
.
autotune
(
configs
=
get_configs
())
@
tilelang
.
autotune
(
configs
=
get_configs
())
@
tilelang
.
jit
(
out_idx
=
[
2
])
@
tilelang
.
jit
(
out_idx
=
[
2
])
def
convolution
(
N
,
def
convolution
(
C
,
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
H
,
):
W
,
F
,
K
,
S
,
D
,
P
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
KH
,
KW
=
K
,
K
KH
,
KW
=
K
,
K
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OH
=
(
H
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
OW
=
(
W
+
2
*
P
-
D
*
(
K
-
1
)
-
1
)
//
S
+
1
...
@@ -120,13 +86,11 @@ def convolution(N,
...
@@ -120,13 +86,11 @@ def convolution(N,
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
data
:
T
.
Tensor
((
N
,
H
,
W
,
C
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
kernel
:
T
.
Tensor
((
KH
,
KW
,
C
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
out
:
T
.
Tensor
((
N
,
OH
,
OW
,
F
),
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
T
.
ceildiv
(
F
,
block_N
),
T
.
ceildiv
(
N
*
OH
*
OW
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
data_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
kernel_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
out_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -136,9 +100,11 @@ def convolution(N,
...
@@ -136,9 +100,11 @@ def convolution(N,
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
out_flat
=
T
.
Tensor
((
N
*
OH
*
OW
,
F
),
dtype
,
out
.
data
)
if
is_hopper
:
if
is_hopper
:
T
.
annotate_layout
({
T
.
annotate_layout
(
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
{
})
out_shared
:
tilelang
.
layout
.
make_swizzled_layout
(
out_shared
),
}
)
T
.
clear
(
out_local
)
T
.
clear
(
out_local
)
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
for
k_iter
in
T
.
Pipelined
(
T
.
ceildiv
(
KH
*
KW
*
C
,
block_K
),
num_stages
=
num_stages
):
...
@@ -150,10 +116,8 @@ def convolution(N,
...
@@ -150,10 +116,8 @@ def convolution(N,
m
=
by
*
block_M
+
i
m
=
by
*
block_M
+
i
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_h
=
m
%
(
OH
*
OW
)
//
OW
*
S
+
k
//
(
KW
*
C
)
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
access_w
=
m
%
OW
*
S
+
k
//
C
%
KW
*
D
-
P
in_bound
=
((
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
in_bound
=
(
access_h
>=
0
)
and
(
access_w
>=
0
)
and
(
access_h
<
H
)
and
(
access_w
<
W
)
(
access_w
<
W
))
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
data_shared
[
i
,
j
]
=
T
.
if_then_else
(
in_bound
,
data
[
m
//
(
OH
*
OW
),
access_h
,
access_w
,
k
%
C
],
0
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
copy
(
kernel_flat
[
k_iter
*
block_K
,
bx
*
block_N
],
kernel_shared
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
T
.
gemm
(
data_shared
,
kernel_shared
,
out_local
)
...
@@ -166,17 +130,19 @@ def convolution(N,
...
@@ -166,17 +130,19 @@ def convolution(N,
return
main
return
main
def
main
(
n
:
int
=
128
,
def
main
(
c
:
int
=
128
,
n
:
int
=
128
,
h
:
int
=
64
,
c
:
int
=
128
,
w
:
int
=
64
,
h
:
int
=
64
,
f
:
int
=
128
,
w
:
int
=
64
,
k
:
int
=
3
,
f
:
int
=
128
,
s
:
int
=
1
,
k
:
int
=
3
,
d
:
int
=
1
,
s
:
int
=
1
,
p
:
int
=
1
,
d
:
int
=
1
,
use_autotune
:
bool
=
False
,
p
:
int
=
1
,
with_roller
:
bool
=
True
):
use_autotune
:
bool
=
False
,
with_roller
:
bool
=
True
,
):
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
n
,
c
,
h
,
w
,
f
,
k
,
s
,
d
,
p
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
n
,
c
,
h
,
w
,
f
,
k
,
s
,
d
,
p
ref_prog
=
ref_program
(
S
,
P
,
D
)
ref_prog
=
ref_program
(
S
,
P
,
D
)
...
@@ -196,25 +162,16 @@ def main(n: int = 128,
...
@@ -196,25 +162,16 @@ def main(n: int = 128,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned MatMul Benchmark"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned MatMul Benchmark"
)
parser
.
add_argument
(
'--n'
,
type
=
int
,
default
=
128
,
help
=
'n'
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
128
,
help
=
"n"
)
parser
.
add_argument
(
'--c'
,
type
=
int
,
default
=
128
,
help
=
'c'
)
parser
.
add_argument
(
"--c"
,
type
=
int
,
default
=
128
,
help
=
"c"
)
parser
.
add_argument
(
'--h'
,
type
=
int
,
default
=
64
,
help
=
'h'
)
parser
.
add_argument
(
"--h"
,
type
=
int
,
default
=
64
,
help
=
"h"
)
parser
.
add_argument
(
'--w'
,
type
=
int
,
default
=
64
,
help
=
'w'
)
parser
.
add_argument
(
"--w"
,
type
=
int
,
default
=
64
,
help
=
"w"
)
parser
.
add_argument
(
'--f'
,
type
=
int
,
default
=
128
,
help
=
'f'
)
parser
.
add_argument
(
"--f"
,
type
=
int
,
default
=
128
,
help
=
"f"
)
parser
.
add_argument
(
'--k'
,
type
=
int
,
default
=
3
,
help
=
'k'
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
3
,
help
=
"k"
)
parser
.
add_argument
(
'--s'
,
type
=
int
,
default
=
1
,
help
=
's'
)
parser
.
add_argument
(
"--s"
,
type
=
int
,
default
=
1
,
help
=
"s"
)
parser
.
add_argument
(
'--d'
,
type
=
int
,
default
=
1
,
help
=
'd'
)
parser
.
add_argument
(
"--d"
,
type
=
int
,
default
=
1
,
help
=
"d"
)
parser
.
add_argument
(
'--p'
,
type
=
int
,
default
=
1
,
help
=
'p'
)
parser
.
add_argument
(
"--p"
,
type
=
int
,
default
=
1
,
help
=
"p"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
"--use_autotune"
,
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
True
,
help
=
"Whether to enable BitBLAS roller for search space"
)
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune for matmul configs"
)
parser
.
add_argument
(
"--with_roller"
,
action
=
"store_true"
,
default
=
True
,
help
=
"Whether to enable BitBLAS roller for search space"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
,
args
.
use_autotune
,
main
(
args
.
n
,
args
.
c
,
args
.
h
,
args
.
w
,
args
.
f
,
args
.
k
,
args
.
s
,
args
.
d
,
args
.
p
,
args
.
use_autotune
,
args
.
with_roller
)
args
.
with_roller
)
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
View file @
29051439
...
@@ -41,14 +41,13 @@ def tl_gemm(
...
@@ -41,14 +41,13 @@ def tl_gemm(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
scales_a
:
T
.
Tensor
(
Scales_A_shape
,
"float32"
),
scales_a
:
T
.
Tensor
(
Scales_A_shape
,
"float32"
),
scales_b
:
T
.
Tensor
(
Scales_B_shape
,
"float32"
),
scales_b
:
T
.
Tensor
(
Scales_B_shape
,
"float32"
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
)
...
@@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
...
@@ -93,21 +92,18 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
m
,
n
=
x
.
shape
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
x_padded
=
torch
.
zeros
(
ceildiv
(
m
,
128
)
*
128
,
ceildiv
(
n
,
128
)
*
128
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
ceildiv
(
m
,
128
)
*
128
,
ceildiv
(
n
,
128
)
*
128
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
x_view
.
size
(
0
),
x_view
.
size
(
2
))
def
ref_deepgemm_fp8
(
A_fp8
,
B_fp8
,
A_scale
,
B_scale
,
out_dtype
):
def
ref_deepgemm_fp8
(
A_fp8
,
B_fp8
,
A_scale
,
B_scale
,
out_dtype
):
...
@@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
...
@@ -127,13 +123,14 @@ def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
c_acc
.
zero_
()
c_acc
.
zero_
()
for
k
in
range
(
ceildiv
(
K
,
128
)):
for
k
in
range
(
ceildiv
(
K
,
128
)):
c
=
torch
.
_scaled_mm
(
c
=
torch
.
_scaled_mm
(
A_fp8
[
i
*
128
:
(
i
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
],
A_fp8
[
i
*
128
:
(
i
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
],
B_fp8
[
j
*
128
:
(
j
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
].
T
,
B_fp8
[
j
*
128
:
(
j
+
1
)
*
128
,
k
*
128
:
(
k
+
1
)
*
128
].
T
,
scale_a
=
A_scales
[
i
,
k
].
view
(
128
,
1
).
contiguous
(),
scale_a
=
A_scales
[
i
,
k
].
view
(
128
,
1
).
contiguous
(),
scale_b
=
B_scales
[
j
,
k
].
view
(
1
,
128
).
contiguous
(),
scale_b
=
B_scales
[
j
,
k
].
view
(
1
,
128
).
contiguous
(),
out_dtype
=
torch
.
bfloat16
)
out_dtype
=
torch
.
bfloat16
,
)
c_acc
+=
c
.
to
(
torch
.
float32
)
c_acc
+=
c
.
to
(
torch
.
float32
)
C
[
i
*
128
:
(
i
+
1
)
*
128
,
j
*
128
:
(
j
+
1
)
*
128
]
=
c_acc
.
to
(
out_dtype
)
C
[
i
*
128
:
(
i
+
1
)
*
128
,
j
*
128
:
(
j
+
1
)
*
128
]
=
c_acc
.
to
(
out_dtype
)
return
C
return
C
...
...
Prev
1
2
3
4
5
6
7
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment