Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
cf6e11c9
Commit
cf6e11c9
authored
Feb 05, 2026
by
qisan
Browse files
feat: merge dcu branch features
parents
3f27f85a
d0436b7b
Pipeline
#3369
failed with stages
in 0 seconds
Changes
266
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4762 additions
and
0 deletions
+4762
-0
examples/bitnet-1.58b/requirements.txt
examples/bitnet-1.58b/requirements.txt
+3
-0
examples/bitnet-1.58b/tokenization_bitnet.py
examples/bitnet-1.58b/tokenization_bitnet.py
+468
-0
examples/bitnet-1.58b/utils_quant.py
examples/bitnet-1.58b/utils_quant.py
+230
-0
examples/bitnet-1.58b/vllm_workspace/conftest.py
examples/bitnet-1.58b/vllm_workspace/conftest.py
+587
-0
examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py
...et-1.58b/vllm_workspace/inference_with_compress_format.py
+45
-0
examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py
...tnet-1.58b/vllm_workspace/inference_with_native_format.py
+47
-0
examples/bitnet-1.58b/vllm_workspace/utils.py
examples/bitnet-1.58b/vllm_workspace/utils.py
+45
-0
examples/blocksparse_attention/README.md
examples/blocksparse_attention/README.md
+6
-0
examples/blocksparse_attention/block_sparse_attn_triton.py
examples/blocksparse_attention/block_sparse_attn_triton.py
+361
-0
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
...ocksparse_attention/example_tilelang_block_sparse_attn.py
+221
-0
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
...rse_attention/example_tilelang_sparse_gqa_decode_paged.py
+551
-0
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
...ntion/example_tilelang_sparse_gqa_decode_varlen_indice.py
+435
-0
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
...tention/example_tilelang_sparse_gqa_decode_varlen_mask.py
+420
-0
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
...tention/example_triton_sparse_gqa_decode_varlen_indice.py
+433
-0
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
...attention/example_triton_sparse_gqa_decode_varlen_mask.py
+419
-0
examples/blocksparse_attention/heuristic.py
examples/blocksparse_attention/heuristic.py
+54
-0
examples/blocksparse_attention/test_example_blocksparse_attention.py
...ocksparse_attention/test_example_blocksparse_attention.py
+39
-0
examples/blocksparse_gemm/example_blocksparse_gemm.py
examples/blocksparse_gemm/example_blocksparse_gemm.py
+179
-0
examples/blocksparse_gemm/test_example_blocksparse_gemm.py
examples/blocksparse_gemm/test_example_blocksparse_gemm.py
+10
-0
examples/cast/example_group_per_split_token_cast_to_fp8.py
examples/cast/example_group_per_split_token_cast_to_fp8.py
+209
-0
No files found.
Too many changes to show.
To preserve performance only
266 of 266+
files are displayed.
Plain diff
Email patch
examples/bitnet-1.58b/requirements.txt
0 → 100644
View file @
cf6e11c9
lm_eval==0.3.0
flash_attn
transformers==4.53.0
examples/bitnet-1.58b/tokenization_bitnet.py
0 → 100644
View file @
cf6e11c9
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LLaMA."""
import
os
from
shutil
import
copyfile
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
import
sentencepiece
as
spm
from
transformers.convert_slow_tokenizer
import
import_protobuf
from
transformers.tokenization_utils
import
AddedToken
,
PreTrainedTokenizer
from
transformers.utils
import
logging
if
TYPE_CHECKING
:
from
transformers.tokenization_utils_base
import
TextInput
logger
=
logging
.
get_logger
(
__name__
)
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"tokenizer.model"
}
PRETRAINED_VOCAB_FILES_MAP
=
{
"vocab_file"
:
{
"hf-internal-testing/llama-tokenizer"
:
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model"
,
},
"tokenizer_file"
:
{
"hf-internal-testing/llama-tokenizer"
:
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json"
,
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
"hf-internal-testing/llama-tokenizer"
:
2048
,
}
SPIECE_UNDERLINE
=
"▁"
B_INST
,
E_INST
=
"[INST]"
,
"[/INST]"
B_SYS
,
E_SYS
=
"<<SYS>>
\n
"
,
"
\n
<</SYS>>
\n\n
"
# fmt: off
DEFAULT_SYSTEM_PROMPT
=
"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your
\
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure
\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not
\
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
class
BitnetTokenizer
(
PreTrainedTokenizer
):
"""
Construct a Bitnet tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
no padding token in the original model.
Args:
vocab_file (`str`):
Path to the vocabulary file.
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
The end of sequence token.
pad_token (`str` or `tokenizers.AddedToken`, *optional*):
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
attention mechanisms or loss computation.
sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
to set:
- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
add_bos_token (`bool`, *optional*, defaults to `True`):
Whether or not to add an `bos_token` at the start of sequences.
add_eos_token (`bool`, *optional*, defaults to `False`):
Whether or not to add an `eos_token` at the end of sequences.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
extra spaces.
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
Whether or not the default system prompt for Bitnet should be used.
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to add spaces between special tokens.
legacy (`bool`, *optional*):
Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple
example:
- `legacy=True`:
```python
>>> from transformers import T5Tokenizer
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True)
>>> tokenizer.encode("Hello <extra_id_0>.")
[8774, 32099, 3, 5, 1]
```
- `legacy=False`:
```python
>>> from transformers import T5Tokenizer
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False)
>>> tokenizer.encode("Hello <extra_id_0>.") # the extra space `[3]` is no longer here
[8774, 32099, 5, 1]
```
Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
add_prefix_space (`bool`, *optional*, defaults to `True`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word.
"""
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
def
__init__
(
self
,
vocab_file
,
unk_token
=
"<unk>"
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
pad_token
=
None
,
sp_model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
add_bos_token
=
True
,
add_eos_token
=
False
,
clean_up_tokenization_spaces
=
False
,
use_default_system_prompt
=
False
,
spaces_between_special_tokens
=
False
,
legacy
=
None
,
add_prefix_space
=
True
,
**
kwargs
,
):
self
.
sp_model_kwargs
=
{}
if
sp_model_kwargs
is
None
else
sp_model_kwargs
bos_token
=
AddedToken
(
bos_token
,
normalized
=
False
,
special
=
True
)
if
isinstance
(
bos_token
,
str
)
else
bos_token
eos_token
=
AddedToken
(
eos_token
,
normalized
=
False
,
special
=
True
)
if
isinstance
(
eos_token
,
str
)
else
eos_token
unk_token
=
AddedToken
(
unk_token
,
normalized
=
False
,
special
=
True
)
if
isinstance
(
unk_token
,
str
)
else
unk_token
pad_token
=
AddedToken
(
pad_token
,
normalized
=
False
,
special
=
True
)
if
isinstance
(
pad_token
,
str
)
else
pad_token
if
legacy
is
None
:
logger
.
warning_once
(
f
"You are using the default legacy behavior of the
{
self
.
__class__
}
. This is"
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
" If you want to use the new behavior, set `legacy=False`. This should only be set if you understand what it"
" means, and thoroughly read the reason why this was added as explained in"
" https://github.com/huggingface/transformers/pull/24565"
)
legacy
=
True
self
.
legacy
=
legacy
self
.
vocab_file
=
vocab_file
self
.
add_bos_token
=
add_bos_token
self
.
add_eos_token
=
add_eos_token
self
.
use_default_system_prompt
=
use_default_system_prompt
self
.
sp_model
=
self
.
get_spm_processor
(
kwargs
.
pop
(
"from_slow"
,
False
))
self
.
add_prefix_space
=
add_prefix_space
super
().
__init__
(
bos_token
=
bos_token
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
pad_token
=
pad_token
,
add_bos_token
=
add_bos_token
,
add_eos_token
=
add_eos_token
,
sp_model_kwargs
=
self
.
sp_model_kwargs
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
,
use_default_system_prompt
=
use_default_system_prompt
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
legacy
=
legacy
,
add_prefix_space
=
add_prefix_space
,
**
kwargs
,
)
@
property
def
unk_token_length
(
self
):
return
len
(
self
.
sp_model
.
encode
(
str
(
self
.
unk_token
)))
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
def
get_spm_processor
(
self
,
from_slow
=
False
):
tokenizer
=
spm
.
SentencePieceProcessor
(
**
self
.
sp_model_kwargs
)
if
self
.
legacy
or
from_slow
:
# no dependency on protobuf
tokenizer
.
Load
(
self
.
vocab_file
)
return
tokenizer
with
open
(
self
.
vocab_file
,
"rb"
)
as
f
:
sp_model
=
f
.
read
()
model_pb2
=
import_protobuf
(
f
"The new behavior of
{
self
.
__class__
.
__name__
}
(with `self.legacy = False`)"
)
model
=
model_pb2
.
ModelProto
.
FromString
(
sp_model
)
normalizer_spec
=
model_pb2
.
NormalizerSpec
()
normalizer_spec
.
add_dummy_prefix
=
False
model
.
normalizer_spec
.
MergeFrom
(
normalizer_spec
)
sp_model
=
model
.
SerializeToString
()
tokenizer
.
LoadFromSerializedProto
(
sp_model
)
return
tokenizer
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
[
"sp_model"
]
=
None
state
[
"sp_model_proto"
]
=
self
.
sp_model
.
serialized_model_proto
()
return
state
def
__setstate__
(
self
,
d
):
self
.
__dict__
=
d
self
.
sp_model
=
spm
.
SentencePieceProcessor
(
**
self
.
sp_model_kwargs
)
self
.
sp_model
.
LoadFromSerializedProto
(
self
.
sp_model_proto
)
@
property
def
vocab_size
(
self
):
"""Returns vocab size"""
return
self
.
sp_model
.
get_piece_size
()
def
get_vocab
(
self
):
"""Returns vocab as a dict"""
vocab
=
{
self
.
convert_ids_to_tokens
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)}
vocab
.
update
(
self
.
added_tokens_encoder
)
return
vocab
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
def
tokenize
(
self
,
text
:
"TextInput"
,
**
kwargs
)
->
List
[
str
]:
"""
Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
first token is special.
"""
if
self
.
legacy
or
len
(
text
)
==
0
:
return
super
().
tokenize
(
text
,
**
kwargs
)
text
=
text
.
replace
(
SPIECE_UNDERLINE
,
" "
)
if
self
.
add_prefix_space
:
text
=
SPIECE_UNDERLINE
+
text
tokens
=
super
().
tokenize
(
text
,
**
kwargs
)
if
len
(
tokens
)
>
1
and
tokens
[
0
]
==
SPIECE_UNDERLINE
and
tokens
[
1
]
in
self
.
all_special_tokens
:
tokens
=
tokens
[
1
:]
return
tokens
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
def
_tokenize
(
self
,
text
,
**
kwargs
):
"""
Returns a tokenized string.
We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
`['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
"""
tokens
=
self
.
sp_model
.
encode
(
text
,
out_type
=
str
)
if
self
.
legacy
or
not
text
.
startswith
((
SPIECE_UNDERLINE
,
" "
)):
return
tokens
# 1. Encode string + prefix ex: "<unk> Hey"
tokens
=
self
.
sp_model
.
encode
(
self
.
unk_token
+
text
,
out_type
=
str
)
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
return
tokens
[
self
.
unk_token_length
:]
if
len
(
tokens
)
>=
self
.
unk_token_length
else
tokens
def
_convert_token_to_id
(
self
,
token
):
"""Converts a token (str) in an id using the vocab."""
return
self
.
sp_model
.
piece_to_id
(
token
)
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str) using the vocab."""
token
=
self
.
sp_model
.
IdToPiece
(
index
)
return
token
def
convert_tokens_to_string
(
self
,
tokens
):
"""Converts a sequence of tokens (string) in a single string."""
# since we manually add the prefix space, we have to remove it when decoding
if
tokens
[
0
].
startswith
(
SPIECE_UNDERLINE
)
and
self
.
add_prefix_space
:
tokens
[
0
]
=
tokens
[
0
][
1
:]
current_sub_tokens
=
[]
out_string
=
""
prev_is_special
=
False
for
i
,
token
in
enumerate
(
tokens
):
# make sure that special tokens are not decoded using sentencepiece model
if
token
in
self
.
all_special_tokens
:
if
not
prev_is_special
and
i
!=
0
and
self
.
legacy
:
out_string
+=
" "
out_string
+=
self
.
sp_model
.
decode
(
current_sub_tokens
)
+
token
prev_is_special
=
True
current_sub_tokens
=
[]
else
:
current_sub_tokens
.
append
(
token
)
prev_is_special
=
False
out_string
+=
self
.
sp_model
.
decode
(
current_sub_tokens
)
return
out_string
def
save_vocabulary
(
self
,
save_directory
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
])
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
)
and
os
.
path
.
isfile
(
self
.
vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
elif
not
os
.
path
.
isfile
(
self
.
vocab_file
):
with
open
(
out_vocab_file
,
"wb"
)
as
fi
:
content_spiece_model
=
self
.
sp_model
.
serialized_model_proto
()
fi
.
write
(
content_spiece_model
)
return
(
out_vocab_file
,)
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
,
token_ids_1
=
None
):
bos_token_id
=
[
self
.
bos_token_id
]
if
self
.
add_bos_token
else
[]
eos_token_id
=
[
self
.
eos_token_id
]
if
self
.
add_eos_token
else
[]
output
=
bos_token_id
+
token_ids_0
+
eos_token_id
if
token_ids_1
is
not
None
:
output
=
output
+
bos_token_id
+
token_ids_1
+
eos_token_id
return
output
def
get_special_tokens_mask
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
,
already_has_special_tokens
:
bool
=
False
)
->
List
[
int
]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if
already_has_special_tokens
:
return
super
().
get_special_tokens_mask
(
token_ids_0
=
token_ids_0
,
token_ids_1
=
token_ids_1
,
already_has_special_tokens
=
True
)
bos_token_id
=
[
1
]
if
self
.
add_bos_token
else
[]
eos_token_id
=
[
1
]
if
self
.
add_eos_token
else
[]
if
token_ids_1
is
None
:
return
bos_token_id
+
([
0
]
*
len
(
token_ids_0
))
+
eos_token_id
return
bos_token_id
+
([
0
]
*
len
(
token_ids_0
))
+
eos_token_id
+
bos_token_id
+
([
0
]
*
len
(
token_ids_1
))
+
eos_token_id
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
sequence pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
if token_ids_1 is None, only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of ids.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
bos_token_id
=
[
self
.
bos_token_id
]
if
self
.
add_bos_token
else
[]
eos_token_id
=
[
self
.
eos_token_id
]
if
self
.
add_eos_token
else
[]
output
=
[
0
]
*
len
(
bos_token_id
+
token_ids_0
+
eos_token_id
)
if
token_ids_1
is
not
None
:
output
+=
[
1
]
*
len
(
bos_token_id
+
token_ids_1
+
eos_token_id
)
return
output
@
property
def
default_chat_template
(
self
):
"""
LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
results in an unusual token ordering when it is present. This template should definitely be changed if you wish
to fine-tune a model with more flexible role ordering!
The output should look something like:
<bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos><bos>[INST] Prompt [/INST] Answer <eos>
<bos>[INST] Prompt [/INST]
The reference for this chat template is [this code
snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362)
in the original repository.
"""
logger
.
warning_once
(
"
\n
No chat template is defined for this tokenizer - using the default template "
f
"for the
{
self
.
__class__
.
__name__
}
class. If the default is not appropriate for "
"your model, please set `tokenizer.chat_template` to an appropriate template. "
"See https://huggingface.co/docs/transformers/main/chat_templating for more information.
\n
"
)
template
=
(
"{% if messages[0]['role'] == 'system' %}"
"{% set loop_messages = messages[1:] %}"
# Extract system message if it's present
"{% set system_message = messages[0]['content'] %}"
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
"{% set loop_messages = messages %}"
# Or use the default system message if the flag is set
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
"{% else %}"
"{% set loop_messages = messages %}"
"{% set system_message = false %}"
"{% endif %}"
"{% for message in loop_messages %}"
# Loop over all non-system messages
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
"{% endif %}"
"{% if loop.index0 == 0 and system_message != false %}"
# Embed system message in first message
"{% set content = '<<SYS>>
\\
n' + system_message + '
\\
n<</SYS>>
\\
n
\\
n' + message['content'] %}"
"{% else %}"
"{% set content = message['content'] %}"
"{% endif %}"
"{% if message['role'] == 'user' %}"
# After all of that, handle messages/roles in a fairly normal way
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
"{% elif message['role'] == 'system' %}"
"{{ '<<SYS>>
\\
n' + content.strip() + '
\\
n<</SYS>>
\\
n
\\
n' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ ' ' + content.strip() + ' ' + eos_token }}"
"{% endif %}"
"{% endfor %}"
)
template
=
template
.
replace
(
"USE_DEFAULT_PROMPT"
,
"true"
if
self
.
use_default_system_prompt
else
"false"
)
default_message
=
DEFAULT_SYSTEM_PROMPT
.
replace
(
"
\n
"
,
"
\\
n"
).
replace
(
"'"
,
"
\\
'"
)
template
=
template
.
replace
(
"DEFAULT_SYSTEM_MESSAGE"
,
default_message
)
return
template
examples/bitnet-1.58b/utils_quant.py
0 → 100644
View file @
cf6e11c9
# pylint: disable=missing-docstring, invalid-name
"""This is modified from https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py to work with BitBLAS."""
import
torch
from
torch
import
nn
from
bitblas.cache
import
global_operator_cache
,
get_database_path
from
bitblas
import
Matmul
,
MatmulConfig
from
bitblas
import
auto_detect_nvidia_target
from
logging
import
getLogger
logger
=
getLogger
(
__name__
)
BITBLAS_TARGET
=
auto_detect_nvidia_target
()
BITBLAS_DATABASE_PATH
=
get_database_path
()
def
weight_quant
(
weight
,
num_bits
=
1
):
dtype
=
weight
.
dtype
weight
=
weight
.
float
()
s
=
1
/
weight
.
abs
().
mean
().
clamp
(
min
=
1e-5
)
result
=
(
weight
*
s
).
round
().
clamp
(
-
1
,
1
)
/
s
return
result
.
type
(
dtype
)
def
activation_quant
(
x
,
num_bits
=
8
):
dtype
=
x
.
dtype
x
=
x
.
float
()
Qn
=
-
(
2
**
(
num_bits
-
1
))
Qp
=
2
**
(
num_bits
-
1
)
-
1
s
=
Qp
/
x
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
).
values
.
clamp
(
min
=
1e-5
)
result
=
(
x
*
s
).
round
().
clamp
(
Qn
,
Qp
)
/
s
return
result
.
type
(
dtype
)
class
BitLinearBitBLAS
(
nn
.
Module
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
weight_bits
=
1
,
input_bits
=
8
,
**
kwargs
,
):
super
().
__init__
()
"""
RMSNorm is placed outside BitLinear
"""
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
weight_bits
=
weight_bits
self
.
input_bits
=
input_bits
matmul_config
=
MatmulConfig
(
N
=
self
.
out_features
,
# N dimension
K
=
self
.
in_features
,
# K dimension
A_dtype
=
"int8"
,
# activation A dtype
W_dtype
=
"int2"
,
# weight W dtype
accum_dtype
=
"int32"
,
# accumulation dtype
out_dtype
=
"float32"
,
# output dtype
layout
=
"nt"
,
# matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias
=
False
,
# bias
# configs for weight only quantization
group_size
=
None
,
# setting for grouped quantization
with_scaling
=
False
,
# setting for scaling factor
with_zeros
=
False
,
# setting for zeros
zeros_mode
=
None
,
# setting for how to calculating zeros
)
ENABLE_TUNING
=
True
self
.
bitblas_matmul
=
self
.
_get_or_create_bitblas_operator
(
matmul_config
,
ENABLE_TUNING
)
self
.
format
=
"bitnet"
self
.
Qp
=
2
**
(
self
.
input_bits
-
1
)
-
1
def
_get_or_create_bitblas_operator
(
self
,
config
,
enable_tuning
):
if
global_operator_cache
.
size
()
==
0
:
global_operator_cache
.
load_from_database
(
BITBLAS_DATABASE_PATH
,
BITBLAS_TARGET
)
logger
.
info
(
f
"Loaded
{
global_operator_cache
.
size
()
}
operators from database."
)
bitblas_matmul
=
global_operator_cache
.
get
(
config
)
if
bitblas_matmul
is
None
:
# should disable tuning for the first time because we may require loading bitblas operator from database.
bitblas_matmul
=
Matmul
(
config
,
target
=
BITBLAS_TARGET
,
enable_tuning
=
False
)
if
enable_tuning
:
bitblas_matmul
.
hardware_aware_finetune
(
topk
=
20
)
global_operator_cache
.
add
(
config
,
bitblas_matmul
)
global_operator_cache
.
save_into_database
(
BITBLAS_DATABASE_PATH
,
BITBLAS_TARGET
)
print
(
"BitBLAS Tuning done, appended operator to global_operator_cache."
)
else
:
print
(
"BitBLAS Operator created."
)
else
:
print
(
"BitBLAS Operator found in global_operator_cache."
)
return
bitblas_matmul
def
replace_weight_param_with_qweight
(
self
):
if
hasattr
(
self
,
"weight"
):
del
self
.
weight
quant_weight
=
torch
.
empty
(
self
.
bitblas_matmul
.
retrieve_weight_shape
())
self
.
qweight
=
nn
.
Parameter
(
quant_weight
,
requires_grad
=
False
)
self
.
format
=
"bitblas"
@
classmethod
def
from_bit_linear
(
cls
,
bitlinear
,
weight_group
=
1
):
bitblas_linear
=
cls
(
bitlinear
.
in_features
,
bitlinear
.
out_features
,
weight_bits
=
1
,
input_bits
=
8
)
sw
,
qweight
=
bitblas_linear
.
create_bitblas_weights
(
bitlinear
.
weight
,
weight_group
)
bitblas_linear
.
register_buffer
(
"qweight"
,
qweight
)
bitblas_linear
.
register_buffer
(
"sw"
,
sw
)
if
bitlinear
.
bias
is
not
None
:
bitblas_linear
.
register_buffer
(
"bias"
,
bitlinear
.
bias
)
else
:
bitblas_linear
.
bias
=
None
return
bitblas_linear
def
create_bitblas_weights
(
self
,
weight
,
weight_group
=
1
):
if
weight_group
:
hidden_size
=
weight
.
size
(
0
)
group_size
=
hidden_size
//
weight_group
sw_list
=
[]
qweight_list
=
[]
for
i
in
range
(
weight_group
):
start_idx
=
i
*
group_size
end_idx
=
(
i
+
1
)
*
group_size
sw
=
1
/
weight
[
start_idx
:
end_idx
].
abs
().
mean
().
clamp
(
min
=
1e-5
)
sw_list
.
append
(
sw
.
repeat
(
group_size
))
qweight
=
self
.
weight_quant
(
weight
[
start_idx
:
end_idx
]).
detach
()
qweight_list
.
append
(
qweight
)
sw
=
torch
.
cat
(
sw_list
,
dim
=
0
)
qweight
=
torch
.
cat
(
qweight_list
,
dim
=
0
)
else
:
sw
=
1
/
weight
.
abs
().
mean
().
clamp
(
min
=
1e-5
)
qweight
=
self
.
weight_quant
(
weight
).
detach
()
qweight
=
self
.
bitblas_matmul
.
transform_weight
(
qweight
)
qweight
=
nn
.
Parameter
(
qweight
,
requires_grad
=
False
)
return
sw
,
qweight
def
post_process_weights
(
self
):
sw
=
1
/
self
.
weight
.
abs
().
mean
().
clamp
(
min
=
1e-5
)
self
.
sw
=
sw
quant_weight
=
self
.
weight_quant
(
self
.
weight
).
detach
()
quant_weight
=
self
.
bitblas_matmul
.
transform_weight
(
quant_weight
)
# remove self.weight and replace it with quant_weight
if
hasattr
(
self
,
"weight"
):
del
self
.
weight
self
.
qweight
=
nn
.
Parameter
(
quant_weight
,
requires_grad
=
False
)
self
.
format
=
"bitblas"
@
staticmethod
def
weight_quant
(
weight
):
weight
=
weight
.
float
()
s
=
1
/
weight
.
abs
().
mean
().
clamp
(
min
=
1e-5
)
result
=
(
weight
*
s
).
round
().
clamp
(
-
1
,
1
)
return
result
.
type
(
torch
.
int8
)
@
torch
.
compile
def
activation_quant
(
self
,
x
,
num_bits
=
8
):
x
=
x
.
float
()
Qn
=
-
(
2
**
(
num_bits
-
1
))
Qp
=
2
**
(
num_bits
-
1
)
-
1
s
=
Qp
/
x
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
).
values
.
clamp
(
min
=
1e-5
)
result
=
(
x
*
s
).
round
().
clamp
(
Qn
,
Qp
)
return
result
.
type
(
torch
.
int8
),
s
@
torch
.
compile
def
post_quant_process
(
self
,
input
,
si
,
sw
):
out
=
input
/
si
out
=
out
/
sw
out
=
out
.
half
()
return
out
# for the correctness evaluation.
def
native_forward
(
self
,
input
):
quant_input
=
input
+
(
activation_quant
(
input
,
self
.
input_bits
)
-
input
).
detach
()
quant_weight
=
self
.
weight
+
(
weight_quant
(
self
.
weight
,
self
.
weight_bits
)
-
self
.
weight
).
detach
()
out
=
nn
.
functional
.
linear
(
quant_input
,
quant_weight
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
.
view
(
1
,
-
1
).
expand_as
(
out
)
return
out
def
forward_fp32_simulated
(
self
,
input
):
quant_input
,
si
=
self
.
activation_quant
(
input
,
self
.
input_bits
).
detach
()
quant_weight
=
self
.
weight_quant
(
self
.
weight
).
detach
()
fp32_simulated_input
=
quant_input
.
float
()
fp32_simulated_weight
=
quant_weight
.
float
()
fp32_simulated_out
=
nn
.
functional
.
linear
(
fp32_simulated_input
,
fp32_simulated_weight
)
sw
=
1
/
self
.
weight
.
abs
().
mean
().
clamp
(
min
=
1e-5
)
# if / (si * sw) it will inf in some cases
out
=
fp32_simulated_out
/
si
out
=
out
/
sw
out
=
out
.
half
()
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
.
view
(
1
,
-
1
).
expand_as
(
out
)
return
out
def
forward
(
self
,
input
):
# return self.forward_fp32_simulated(input)
quant_input
,
si
=
self
.
activation_quant
(
input
,
self
.
input_bits
)
fp32_out
=
self
.
bitblas_matmul
(
quant_input
,
self
.
qweight
)
sw
=
self
.
sw
# if / (si * sw) it will inf in some cases
out
=
self
.
post_quant_process
(
fp32_out
,
si
,
sw
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
.
view
(
1
,
-
1
).
expand_as
(
out
)
return
out
# Naive BitLinear from HuggingFace
class
BitLinear
(
nn
.
Linear
):
def
__init__
(
self
,
*
kargs
,
weight_bits
=
1
,
input_bits
=
8
,
**
kwargs
):
super
(
BitLinear
,
self
).
__init__
(
*
kargs
,
**
kwargs
)
"""
RMSNorm is placed outside BitLinear
"""
self
.
weight_bits
=
weight_bits
self
.
input_bits
=
input_bits
def
forward
(
self
,
input
):
quant_input
=
input
+
(
activation_quant
(
input
,
self
.
input_bits
)
-
input
).
detach
()
quant_weight
=
self
.
weight
+
(
weight_quant
(
self
.
weight
,
self
.
weight_bits
)
-
self
.
weight
).
detach
()
out
=
nn
.
functional
.
linear
(
quant_input
,
quant_weight
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
.
view
(
1
,
-
1
).
expand_as
(
out
)
return
out
examples/bitnet-1.58b/vllm_workspace/conftest.py
0 → 100644
View file @
cf6e11c9
import
contextlib
import
gc
import
os
import
sys
from
collections
import
UserList
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
TypedDict
,
TypeVar
import
pytest
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
PIL
import
Image
from
transformers
import
(
AutoModelForCausalLM
,
AutoModelForVision2Seq
,
AutoTokenizer
,
BatchEncoding
,
)
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
TokenizerPoolConfig
from
vllm.distributed
import
destroy_distributed_environment
,
destroy_model_parallel
from
vllm.inputs
import
TextPrompt
from
vllm.logger
import
init_logger
from
vllm.sequence
import
SampleLogprobs
from
vllm.utils
import
cuda_device_count_stateless
,
is_cpu
logger
=
init_logger
(
__name__
)
_TEST_DIR
=
os
.
path
.
dirname
(
__file__
)
_TEST_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"example.txt"
)]
_LONG_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"summary.txt"
)]
def
_read_prompts
(
filename
:
str
)
->
List
[
str
]:
with
open
(
filename
,
"r"
)
as
f
:
prompts
=
f
.
readlines
()
return
prompts
class
_ImageAssetPrompts
(
TypedDict
):
stop_sign
:
str
cherry_blossom
:
str
if
sys
.
version_info
<
(
3
,
9
):
# UserList cannot be subscripted
class
_ImageAssetsBase
(
UserList
):
pass
else
:
class
_ImageAssetsBase
(
UserList
[
ImageAsset
]):
pass
class
_ImageAssets
(
_ImageAssetsBase
):
def
__init__
(
self
)
->
None
:
super
().
__init__
(
[
ImageAsset
(
"stop_sign"
),
ImageAsset
(
"cherry_blossom"
),
]
)
def
prompts
(
self
,
prompts
:
_ImageAssetPrompts
)
->
List
[
str
]:
"""
Convenience method to define the prompt for each test image.
The order of the returned prompts matches the order of the
assets when iterating through this object.
"""
return
[
prompts
[
"stop_sign"
],
prompts
[
"cherry_blossom"
]]
IMAGE_ASSETS
=
_ImageAssets
()
"""Singleton instance of :class:`_ImageAssets`."""
def
cleanup
():
destroy_model_parallel
()
destroy_distributed_environment
()
with
contextlib
.
suppress
(
AssertionError
):
torch
.
distributed
.
destroy_process_group
()
gc
.
collect
()
if
not
is_cpu
():
torch
.
cuda
.
empty_cache
()
@
pytest
.
fixture
()
def
should_do_global_cleanup_after_test
(
request
)
->
bool
:
"""Allow subdirectories to skip global cleanup by overriding this fixture.
This can provide a ~10x speedup for non-GPU unit tests since they don't need
to initialize torch.
"""
if
not
request
.
node
.
get_closest_marker
(
"skip_global_cleanup"
):
return
False
@
pytest
.
fixture
(
autouse
=
True
)
def
cleanup_fixture
(
should_do_global_cleanup_after_test
:
bool
):
yield
if
should_do_global_cleanup_after_test
:
cleanup
()
@
pytest
.
fixture
def
example_prompts
()
->
List
[
str
]:
prompts
=
[]
for
filename
in
_TEST_PROMPTS
:
prompts
+=
_read_prompts
(
filename
)
return
prompts
@
pytest
.
fixture
def
example_long_prompts
()
->
List
[
str
]:
prompts
=
[]
for
filename
in
_LONG_PROMPTS
:
prompts
+=
_read_prompts
(
filename
)
return
prompts
@
pytest
.
fixture
(
scope
=
"session"
)
def
image_assets
()
->
_ImageAssets
:
return
IMAGE_ASSETS
_STR_DTYPE_TO_TORCH_DTYPE
=
{
"half"
:
torch
.
half
,
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
}
_T
=
TypeVar
(
"_T"
,
nn
.
Module
,
torch
.
Tensor
,
BatchEncoding
)
class
HfRunner
:
def
wrap_device
(
self
,
input
:
_T
)
->
_T
:
if
not
is_cpu
():
return
input
.
to
(
"cuda"
)
else
:
return
input
.
to
(
"cpu"
)
def
__init__
(
self
,
model_name
:
str
,
dtype
:
str
=
"half"
,
*
,
model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_embedding_model
:
bool
=
False
,
is_vision_model
:
bool
=
False
,
is_sparseml_model
:
bool
=
False
,
)
->
None
:
assert
dtype
in
_STR_DTYPE_TO_TORCH_DTYPE
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
self
.
model_name
=
model_name
if
is_embedding_model
:
# Lazy init required for AMD CI
from
sentence_transformers
import
SentenceTransformer
self
.
model
=
self
.
wrap_device
(
SentenceTransformer
(
model_name
,
device
=
"cpu"
,
).
to
(
dtype
=
torch_dtype
)
)
else
:
if
is_vision_model
:
auto_cls
=
AutoModelForVision2Seq
elif
is_sparseml_model
:
from
sparseml.transformers
import
SparseAutoModelForCausalLM
auto_cls
=
SparseAutoModelForCausalLM
else
:
auto_cls
=
AutoModelForCausalLM
model_kwargs
=
model_kwargs
if
model_kwargs
is
not
None
else
{}
self
.
model
=
self
.
wrap_device
(
auto_cls
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
**
model_kwargs
,
)
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
)
try
:
# don't put this import at the top level
# it will call torch.cuda.device_count()
from
transformers
import
AutoProcessor
# noqa: F401
self
.
processor
=
AutoProcessor
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
)
except
Exception
:
logger
.
warning
(
"Unable to auto-load processor from HuggingFace for model %s. Using tokenizer instead."
,
model_name
,
)
self
.
processor
=
self
.
tokenizer
def
generate
(
self
,
prompts
:
List
[
str
],
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
if
images
:
assert
len
(
prompts
)
==
len
(
images
)
outputs
:
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]
=
[]
for
i
,
prompt
in
enumerate
(
prompts
):
processor_kwargs
:
Dict
[
str
,
Any
]
=
{
"text"
:
prompt
,
"return_tensors"
:
"pt"
,
}
if
images
is
not
None
and
images
[
i
]
is
not
None
:
processor_kwargs
[
"images"
]
=
images
[
i
]
inputs
=
self
.
processor
(
**
processor_kwargs
)
output_ids
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
use_cache
=
True
,
**
kwargs
,
)
output_str
=
self
.
processor
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
,
)
output_ids
=
output_ids
.
cpu
().
tolist
()
outputs
.
append
((
output_ids
,
output_str
))
return
outputs
def
generate_greedy
(
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
outputs
=
self
.
generate
(
prompts
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
,
images
=
images
,
**
kwargs
,
)
return
[(
output_ids
[
0
],
output_str
[
0
])
for
output_ids
,
output_str
in
outputs
]
def
generate_beam_search
(
self
,
prompts
:
List
[
str
],
beam_width
:
int
,
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
outputs
=
self
.
generate
(
prompts
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
,
num_beams
=
beam_width
,
num_return_sequences
=
beam_width
,
)
for
i
in
range
(
len
(
outputs
)):
output_ids
,
output_str
=
outputs
[
i
]
for
j
in
range
(
len
(
output_ids
)):
output_ids
[
j
]
=
[
x
for
x
in
output_ids
[
j
]
if
x
!=
self
.
tokenizer
.
pad_token_id
]
outputs
[
i
]
=
(
output_ids
,
output_str
)
return
outputs
def
generate_greedy_logprobs
(
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
List
[
torch
.
Tensor
]]:
all_logprobs
:
List
[
List
[
torch
.
Tensor
]]
=
[]
for
i
,
prompt
in
enumerate
(
prompts
):
processor_kwargs
:
Dict
[
str
,
Any
]
=
{
"text"
:
prompt
,
"return_tensors"
:
"pt"
,
}
if
images
is
not
None
and
images
[
i
]
is
not
None
:
processor_kwargs
[
"images"
]
=
images
[
i
]
inputs
=
self
.
processor
(
**
processor_kwargs
)
output
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
use_cache
=
True
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
,
output_hidden_states
=
True
,
return_dict_in_generate
=
True
,
**
kwargs
,
)
seq_logprobs
:
List
[
torch
.
Tensor
]
=
[]
for
hidden_states
in
output
.
hidden_states
:
last_hidden_states
=
hidden_states
[
-
1
][
0
]
logits
=
torch
.
matmul
(
last_hidden_states
,
self
.
model
.
get_output_embeddings
().
weight
.
t
(),
)
if
self
.
model
.
get_output_embeddings
().
bias
is
not
None
:
logits
+=
self
.
model
.
get_output_embeddings
().
bias
.
unsqueeze
(
0
)
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
seq_logprobs
.
append
(
logprobs
)
all_logprobs
.
append
(
seq_logprobs
)
return
all_logprobs
def
generate_greedy_logprobs_limit
(
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
num_logprobs
:
int
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
List
[
Dict
[
int
,
float
]]]]:
all_logprobs
:
List
[
List
[
Dict
[
int
,
float
]]]
=
[]
all_output_ids
:
List
[
List
[
int
]]
=
[]
all_output_strs
:
List
[
str
]
=
[]
for
i
,
prompt
in
enumerate
(
prompts
):
processor_kwargs
:
Dict
[
str
,
Any
]
=
{
"text"
:
prompt
,
"return_tensors"
:
"pt"
,
}
if
images
is
not
None
and
images
[
i
]
is
not
None
:
processor_kwargs
[
"images"
]
=
images
[
i
]
inputs
=
self
.
processor
(
**
processor_kwargs
)
input_ids
=
inputs
.
input_ids
output
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
use_cache
=
True
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
,
output_hidden_states
=
True
,
return_dict_in_generate
=
True
,
**
kwargs
,
)
seq_logprobs
:
List
[
torch
.
Tensor
]
=
[]
for
_
,
hidden_states
in
enumerate
(
output
.
hidden_states
):
last_hidden_states
=
hidden_states
[
-
1
][
0
]
logits
=
torch
.
matmul
(
last_hidden_states
,
self
.
model
.
get_output_embeddings
().
weight
.
t
(),
)
if
getattr
(
self
.
model
.
get_output_embeddings
(),
"bias"
,
None
)
is
not
None
:
logits
+=
self
.
model
.
get_output_embeddings
().
bias
.
unsqueeze
(
0
)
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
seq_logprobs
.
append
(
logprobs
)
# convert to dict
seq_logprobs_lst
:
List
[
Dict
[
int
,
float
]]
=
[]
for
tok_idx
,
tok_logprobs
in
enumerate
(
seq_logprobs
):
# drop prompt logprobs
if
tok_idx
==
0
:
tok_logprobs
=
tok_logprobs
[
-
1
,
:].
reshape
(
1
,
-
1
)
topk
=
tok_logprobs
.
topk
(
num_logprobs
)
tok_logprobs_dct
=
{}
for
token_id
,
logprob
in
zip
(
topk
.
indices
[
0
],
topk
.
values
[
0
]):
tok_logprobs_dct
[
token_id
.
item
()]
=
logprob
.
item
()
seq_logprobs_lst
.
append
(
tok_logprobs_dct
)
all_logprobs
.
append
(
seq_logprobs_lst
)
seq_ids
=
output
.
sequences
[
0
]
output_len
=
seq_ids
.
shape
[
0
]
-
input_ids
.
shape
[
1
]
output_ids
=
seq_ids
[
-
output_len
:]
all_output_ids
.
append
(
output_ids
.
tolist
())
all_output_strs
.
append
(
self
.
tokenizer
.
decode
(
output_ids
))
outputs
=
zip
(
all_output_ids
,
all_output_strs
,
all_logprobs
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
def
encode
(
self
,
prompts
:
List
[
str
])
->
List
[
List
[
torch
.
Tensor
]]:
return
self
.
model
.
encode
(
prompts
)
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
del
self
.
model
cleanup
()
@
pytest
.
fixture
(
scope
=
"session"
)
def
hf_runner
():
return
HfRunner
class
VllmRunner
:
def
__init__
(
self
,
model_name
:
str
,
tokenizer_name
:
Optional
[
str
]
=
None
,
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
max_model_len
:
int
=
1024
,
dtype
:
str
=
"half"
,
disable_log_stats
:
bool
=
True
,
tensor_parallel_size
:
int
=
1
,
block_size
:
int
=
16
,
enable_chunked_prefill
:
bool
=
False
,
swap_space
:
int
=
4
,
enforce_eager
:
bool
=
False
,
**
kwargs
,
)
->
None
:
self
.
model
=
LLM
(
model
=
model_name
,
tokenizer
=
tokenizer_name
,
trust_remote_code
=
True
,
dtype
=
dtype
,
swap_space
=
swap_space
,
enforce_eager
=
enforce_eager
,
disable_log_stats
=
disable_log_stats
,
tensor_parallel_size
=
tensor_parallel_size
,
max_model_len
=
max_model_len
,
block_size
=
block_size
,
enable_chunked_prefill
=
enable_chunked_prefill
,
**
kwargs
,
)
def
generate
(
self
,
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
if
images
is
not
None
:
assert
len
(
prompts
)
==
len
(
images
)
inputs
=
[
TextPrompt
(
prompt
=
prompt
)
for
prompt
in
prompts
]
if
images
is
not
None
:
for
i
,
image
in
enumerate
(
images
):
inputs
[
i
][
"multi_modal_data"
]
=
{
"image"
:
image
}
req_outputs
=
self
.
model
.
generate
(
inputs
,
sampling_params
=
sampling_params
)
outputs
:
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]
=
[]
for
req_output
in
req_outputs
:
prompt_str
=
req_output
.
prompt
prompt_ids
=
req_output
.
prompt_token_ids
req_sample_output_ids
:
List
[
List
[
int
]]
=
[]
req_sample_output_strs
:
List
[
str
]
=
[]
for
sample
in
req_output
.
outputs
:
output_str
=
sample
.
text
output_ids
=
list
(
sample
.
token_ids
)
req_sample_output_ids
.
append
(
prompt_ids
+
output_ids
)
req_sample_output_strs
.
append
(
prompt_str
+
output_str
)
outputs
.
append
((
req_sample_output_ids
,
req_sample_output_strs
))
return
outputs
def
generate_w_logprobs
(
self
,
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
assert
sampling_params
.
logprobs
is
not
None
if
images
is
not
None
:
assert
len
(
prompts
)
==
len
(
images
)
inputs
=
[
TextPrompt
(
prompt
=
prompt
)
for
prompt
in
prompts
]
if
images
is
not
None
:
for
i
,
image
in
enumerate
(
images
):
inputs
[
i
][
"multi_modal_data"
]
=
{
"image"
:
image
}
req_outputs
=
self
.
model
.
generate
(
inputs
,
sampling_params
=
sampling_params
)
outputs
:
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]
=
[]
for
req_output
in
req_outputs
:
for
sample
in
req_output
.
outputs
:
output_str
=
sample
.
text
output_ids
=
sample
.
token_ids
output_logprobs
=
sample
.
logprobs
outputs
.
append
((
output_ids
,
output_str
,
output_logprobs
))
return
outputs
def
generate_greedy
(
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
greedy_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
outputs
=
self
.
generate
(
prompts
,
greedy_params
,
images
=
images
)
return
[(
output_ids
[
0
],
output_str
[
0
])
for
output_ids
,
output_str
in
outputs
]
def
generate_greedy_logprobs
(
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
num_logprobs
:
int
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
,
logprobs
=
num_logprobs
)
outputs
=
self
.
generate_w_logprobs
(
prompts
,
greedy_logprobs_params
,
images
=
images
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
def
generate_beam_search
(
self
,
prompts
:
List
[
str
],
beam_width
:
int
,
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
beam_search_params
=
SamplingParams
(
n
=
beam_width
,
use_beam_search
=
True
,
temperature
=
0.0
,
max_tokens
=
max_tokens
,
)
outputs
=
self
.
generate
(
prompts
,
beam_search_params
)
return
outputs
def
encode
(
self
,
prompts
:
List
[
str
])
->
List
[
List
[
float
]]:
req_outputs
=
self
.
model
.
encode
(
prompts
)
outputs
=
[]
for
req_output
in
req_outputs
:
embedding
=
req_output
.
outputs
.
embedding
outputs
.
append
(
embedding
)
return
outputs
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
del
self
.
model
cleanup
()
@
pytest
.
fixture
(
scope
=
"session"
)
def
vllm_runner
():
return
VllmRunner
def
get_tokenizer_pool_config
(
tokenizer_group_type
):
if
tokenizer_group_type
is
None
:
return
None
if
tokenizer_group_type
==
"ray"
:
return
TokenizerPoolConfig
(
pool_size
=
1
,
pool_type
=
"ray"
,
extra_config
=
{})
raise
ValueError
(
f
"Unknown tokenizer_group_type:
{
tokenizer_group_type
}
"
)
@
pytest
.
fixture
()
def
temporary_enable_log_propagate
():
import
logging
logger
=
logging
.
getLogger
(
"vllm"
)
logger
.
propagate
=
True
yield
logger
.
propagate
=
False
@
pytest
.
fixture
()
def
caplog_vllm
(
temporary_enable_log_propagate
,
caplog
):
# To capture vllm log, we should enable propagate=True temporarily
# because caplog depends on logs propagated to the root logger.
yield
caplog
@
pytest
.
fixture
(
scope
=
"session"
)
def
num_gpus_available
():
"""Get number of GPUs without initializing the CUDA context
in current process."""
return
cuda_device_count_stateless
()
examples/bitnet-1.58b/vllm_workspace/inference_with_compress_format.py
0 → 100644
View file @
cf6e11c9
"""Compare the outputs of a GPTQ model to a Marlin model.
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
Marlin/GPTQ models are in the top 3 selections of each other.
Note: Marlin internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for Marlin. As a result, we re-run the test
up to 3 times to see if we pass.
Run `pytest tests/models/test_marlin.py`.
"""
from
conftest
import
VllmRunner
import
os
import
argparse
# get the path of the current file
current_file_path
=
os
.
path
.
realpath
(
__file__
)
current_dir
=
os
.
path
.
dirname
(
current_file_path
)
ckpt_path
=
os
.
path
.
join
(
current_dir
,
"../models/ckpt_bitnet_b1_58-3B_bitblas"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Inference with BitNet"
)
parser
.
add_argument
(
"--ckpt_path"
,
type
=
str
,
default
=
ckpt_path
,
help
=
"Path to the checkpoint"
,
)
args
=
parser
.
parse_args
()
ckpt_path
=
args
.
ckpt_path
with
VllmRunner
(
ckpt_path
,
dtype
=
"half"
,
quantization
=
"bitblas"
,
# set enforce_eager = False to enable cuda graph
# set enforce_eager = True to disable cuda graph
enforce_eager
=
False
,
)
as
bitnet_model
:
bitbnet_outputs
=
bitnet_model
.
generate_greedy
([
"Hi, tell me about microsoft?"
],
max_tokens
=
1024
)
print
(
"bitnet inference:"
)
print
(
bitbnet_outputs
[
0
][
0
])
print
(
bitbnet_outputs
[
0
][
1
])
examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py
0 → 100644
View file @
cf6e11c9
"""Compare the outputs of a GPTQ model to a Marlin model.
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
Marlin/GPTQ models are in the top 3 selections of each other.
Note: Marlin internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for Marlin. As a result, we re-run the test
up to 3 times to see if we pass.
Run `pytest tests/models/test_marlin.py`.
"""
from
conftest
import
VllmRunner
import
os
import
argparse
# get the path of the current file
current_file_path
=
os
.
path
.
realpath
(
__file__
)
current_dir
=
os
.
path
.
dirname
(
current_file_path
)
ckpt_path
=
os
.
path
.
join
(
current_dir
,
"../models/ckpt_bitnet_b1_58-3B"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Inference with BitNet"
)
parser
.
add_argument
(
"--ckpt_path"
,
type
=
str
,
default
=
ckpt_path
,
help
=
"Path to the checkpoint"
,
)
args
=
parser
.
parse_args
()
ckpt_path
=
args
.
ckpt_path
with
VllmRunner
(
ckpt_path
,
dtype
=
"half"
,
quantization
=
"bitnet_bitblas"
,
gpu_memory_utilization
=
0.5
,
# set enforce_eager = False to enable cuda graph
# set enforce_eager = True to disable cuda graph
enforce_eager
=
False
,
)
as
bitnet_model
:
bitbnet_outputs
=
bitnet_model
.
generate_greedy
([
"Hi, tell me about microsoft?"
],
max_tokens
=
128
)
print
(
"bitnet inference output:"
)
print
(
bitbnet_outputs
[
0
][
0
])
print
(
bitbnet_outputs
[
0
][
1
])
examples/bitnet-1.58b/vllm_workspace/utils.py
0 → 100644
View file @
cf6e11c9
from
typing
import
Dict
,
List
,
Tuple
TokensText
=
Tuple
[
List
[
int
],
str
]
def
check_outputs_equal
(
outputs_0_lst
:
List
[
TokensText
],
outputs_1_lst
:
List
[
TokensText
],
name_0
:
str
,
name_1
:
str
):
"""
Compare the two sequences generated by different models,
which should be equal.
"""
assert
len
(
outputs_0_lst
)
==
len
(
outputs_1_lst
)
for
prompt_idx
,
(
outputs_0
,
outputs_1
)
in
enumerate
(
zip
(
outputs_0_lst
,
outputs_1_lst
)):
output_ids_0
,
output_str_0
=
outputs_0
output_ids_1
,
output_str_1
=
outputs_1
assert
output_str_0
==
output_str_1
,
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
assert
output_ids_0
==
output_ids_1
,
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
TokensTextLogprobs
=
Tuple
[
List
[
int
],
str
,
List
[
Dict
[
int
,
float
]]]
def
check_logprobs_close
(
outputs_0_lst
:
List
[
TokensTextLogprobs
],
outputs_1_lst
:
List
[
TokensTextLogprobs
],
name_0
:
str
,
name_1
:
str
):
"""
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
"""
assert
len
(
outputs_0_lst
)
==
len
(
outputs_1_lst
)
# Loop through responses to each prompt.
for
prompt_idx
,
(
outputs_0
,
outputs_1
)
in
enumerate
(
zip
(
outputs_0_lst
,
outputs_1_lst
)):
output_ids_0
,
output_str_0
,
logprobs_0
=
outputs_0
output_ids_1
,
output_str_1
,
logprobs_1
=
outputs_1
# Loop through generated tokens.
for
idx
,
(
output_id_0
,
output_id_1
)
in
enumerate
(
zip
(
output_ids_0
,
output_ids_1
)):
# If generated tokens don't match, then
if
output_id_0
!=
output_id_1
:
# Each predicted token must be in top N logprobs of the other
assert
output_id_0
in
logprobs_1
[
idx
],
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
assert
output_id_1
in
logprobs_0
[
idx
],
f
"Test
{
prompt_idx
}
:
\n
{
name_0
}
:
\t
{
output_str_0
!
r
}
\n
{
name_1
}
:
\t
{
output_str_1
!
r
}
"
# Break out since sequences will now diverge.
break
examples/blocksparse_attention/README.md
0 → 100644
View file @
cf6e11c9
# Block-Sparse Flash-Attention
Tilelang implementation of block-sparse flash-attention kernels.
The kernels have been used in
[
Rectified Sparse Attention
](
https://arxiv.org/abs/2506.04108
)
and
[
SeerAttention-R
](
https://arxiv.org/abs/2506.08889
)
.
examples/blocksparse_attention/block_sparse_attn_triton.py
0 → 100644
View file @
cf6e11c9
# ruff: noqa: E712
import
math
import
torch
import
triton
import
triton.language
as
tl
import
torch.nn.functional
as
F
def
is_hip
():
return
triton
.
runtime
.
driver
.
active
.
get_current_target
().
backend
==
"hip"
def
get_sparse_attn_mask_from_topk
(
x
,
topk
,
use_dense_for_last_block
=
False
):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
dense_mask
.
tril_
()
return
dense_mask
def
get_sparse_attn_mask_from_threshold
(
x
,
threshold
,
use_dense_for_last_block
=
False
):
dense_mask
=
x
>
threshold
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
dense_mask
.
tril_
()
return
dense_mask
@
triton
.
jit
def
_fwd_kernel_inner
(
acc
,
l_i
,
m_i
,
q
,
k_block_col_idx
,
block_mask_ptr
,
k_ptrs
,
v_ptrs
,
offs_m
,
offs_n
,
stride_kt
,
stride_vt
,
stride_bmask_n
,
sm_scale
,
seqlen_k
,
past_len
,
LAST_K_BLOCK
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
mask_val
=
tl
.
load
(
block_mask_ptr
+
k_block_col_idx
*
stride_bmask_n
)
# print
if
mask_val
==
True
:
start_n
=
k_block_col_idx
*
BLOCK_N
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if
LAST_K_BLOCK
:
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"-inf"
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
-=
m_ij
[:,
None
]
p
=
tl
.
exp
(
qk
)
l_ij
=
tl
.
sum
(
p
,
1
)
alpha
=
tl
.
exp
(
m_i
-
m_ij
)
l_i
=
l_i
*
alpha
+
l_ij
acc
=
acc
*
alpha
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
)
p
=
p
.
to
(
v
.
type
.
element_ty
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
m_i
=
m_ij
return
acc
,
l_i
,
m_i
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
block_mask_ptr
,
Out
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qd
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kd
,
stride_vz
,
stride_vh
,
stride_vn
,
stride_vd
,
stride_bmz
,
stride_bmh
,
stride_bmm
,
stride_bmn
,
stride_oz
,
stride_oh
,
stride_om
,
stride_od
,
H
,
N_CTX
,
PAST_LEN
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
):
Q_LEN
=
N_CTX
-
PAST_LEN
start_m
=
tl
.
program_id
(
0
)
off_hz
=
tl
.
program_id
(
1
)
off_h
=
off_hz
%
H
off_z
=
off_hz
//
H
Q
+=
off_z
*
stride_qz
+
off_h
*
stride_qh
K
+=
off_z
*
stride_kz
+
off_h
*
stride_kh
V
+=
off_z
*
stride_vz
+
off_h
*
stride_vh
block_mask_ptr
+=
off_z
*
stride_bmz
+
off_h
*
stride_bmh
# initialize offsets
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
off_q
=
offs_m
[:,
None
]
*
stride_qm
+
offs_d
[
None
,
:]
*
stride_qd
# off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
off_k
=
offs_n
[
None
,
:]
*
stride_kn
+
offs_d
[:,
None
]
*
stride_kd
off_v
=
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:]
*
stride_vd
# Initialize pointers to Q, K, V
q_ptrs
=
Q
+
off_q
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
mask_ptrs
=
block_mask_ptr
+
start_m
*
stride_bmm
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
q
=
tl
.
load
(
q_ptrs
,
mask
=
offs_m
[:,
None
]
<
Q_LEN
)
k_block_start
=
0
k_block_end
=
tl
.
cdiv
((
start_m
+
1
)
*
BLOCK_M
,
BLOCK_N
)
# loop over k, v and update accumulator
for
col_idx
in
range
(
k_block_start
,
k_block_end
):
acc
,
l_i
,
m_i
=
_fwd_kernel_inner
(
acc
,
l_i
,
m_i
,
q
,
col_idx
,
mask_ptrs
,
k_ptrs
,
v_ptrs
,
offs_m
,
offs_n
,
stride_kn
,
stride_vn
,
stride_bmn
,
sm_scale
,
N_CTX
,
PAST_LEN
,
col_idx
==
k_block_end
-
1
,
BLOCK_M
,
BLOCK_N
,
)
m_i
+=
tl
.
math
.
log
(
l_i
)
l_recip
=
1
/
l_i
[:,
None
]
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
Out
.
dtype
.
element_ty
)
off_o
=
off_z
*
stride_oz
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:]
*
stride_od
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
N_CTX
)
def
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_mask
,
sm_scale
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
num_warps
=
None
,
num_stages
=
1
,
out
=
None
):
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
assert
k
.
shape
[
2
]
==
v
.
shape
[
2
]
o
=
out
if
out
is
not
None
else
torch
.
empty_like
(
q
).
contiguous
()
grid
=
(
triton
.
cdiv
(
q
.
shape
[
2
],
BLOCK_M
),
q
.
shape
[
0
]
*
q
.
shape
[
1
])
assert
q
.
shape
[
-
1
]
in
[
64
,
128
]
BLOCK_DMODEL
=
q
.
shape
[
-
1
]
if
is_hip
():
num_warps
,
num_stages
=
8
,
1
else
:
num_warps
,
num_stages
=
4
,
2
N_CTX
=
k
.
shape
[
2
]
PAST_LEN
=
N_CTX
-
q
.
shape
[
2
]
H
=
q
.
shape
[
1
]
_fwd_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
block_sparse_mask
,
o
,
*
q
.
stride
(),
*
k
.
stride
(),
*
v
.
stride
(),
*
block_sparse_mask
.
stride
(),
*
o
.
stride
(),
H
,
N_CTX
,
PAST_LEN
,
BLOCK_M
,
BLOCK_N
,
BLOCK_DMODEL
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
o
class
_sparse_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
):
# shape constraints
return
_forward
(
ctx
,
q
,
k
,
v
,
block_sparse_dense
,
sm_scale
)
@
staticmethod
def
backward
(
ctx
,
do
):
# No gradient propagation.
raise
NotImplementedError
(
"It does not support gradient propagation yet"
)
return
None
,
None
,
None
,
None
,
None
block_sparse_triton_fn
=
_sparse_attention
.
apply
def
test_topk_sparse_attention
():
# Config
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
=
1
,
1
,
256
,
64
TOPK
=
2
# Keep top 8 elements per row
BLOCK
=
64
torch
.
manual_seed
(
0
)
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
print
(
"downsample_len"
,
downsample_len
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
print
(
"x_ds.shape"
,
x_ds
.
shape
)
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
# print("block_mask", block_mask)
print
(
"block_mask.shape"
,
block_mask
.
shape
)
# Run Triton kernel
triton_output
=
block_sparse_triton_fn
(
q
,
k
,
v
,
block_mask
,
sm_scale
)
# Compute reference
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"cuda"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
attn
=
torch
.
einsum
(
"bhsd,bhtd->bhst"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"-inf"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
"bhst,bhtd->bhsd"
,
attn
,
v
)
# print("ref_output", ref_output)
# print("triton_output", triton_output)
# Verify accuracy
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference"
print
(
"Pass topk sparse attention test with qlen == klen"
)
def
test_topk_sparse_attention_qlt_kl
():
BATCH
,
N_HEADS
=
2
,
4
Q_LEN
,
K_LEN
,
D_HEAD
=
128
,
256
,
64
# qlen < klen; here, past_len = 256 - 128 = 128.
TOPK
=
1
BLOCK
=
64
# block size used in downsampling
torch
.
manual_seed
(
0
)
# Create inputs.
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
Q_LEN
,
D_HEAD
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
K_LEN
,
D_HEAD
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# softmax scale
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
downsample_factor
=
BLOCK
print
(
"downsample_factor"
,
downsample_factor
)
downsample_len
=
math
.
ceil
(
K_LEN
/
downsample_factor
)
# number of blocks along one dimension
print
(
"downsample_len"
,
downsample_len
)
x_ds
=
torch
.
randn
(
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# Force the first column to be high so that the first block is always selected.
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
print
(
"block_mask"
,
block_mask
)
print
(
"block_mask.shape"
,
block_mask
.
shape
)
# Run Triton kernel.
triton_output
=
block_sparse_triton_fn
(
q
,
k
,
v
,
block_mask
,
sm_scale
)
past_len
=
K_LEN
-
Q_LEN
attn
=
torch
.
einsum
(
"bhsd,bhtd->bhst"
,
q
,
k
)
*
sm_scale
full_mask_full
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"cuda"
)).
bool
()
full_mask_full
=
full_mask_full
[...,
:
K_LEN
,
:
K_LEN
]
effective_mask
=
full_mask_full
[...,
past_len
:
K_LEN
,
:]
# shape: (B, H, Q_LEN, K_LEN)
i_global
=
torch
.
arange
(
past_len
,
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
1
)
# shape: (Q_LEN, 1)
j_global
=
torch
.
arange
(
K_LEN
,
device
=
k
.
device
).
unsqueeze
(
0
)
# shape: (1, K_LEN)
causal_mask
=
j_global
<=
i_global
# shape: (Q_LEN, K_LEN)
final_mask
=
effective_mask
&
causal_mask
# shape: (B, H, Q_LEN, K_LEN)
attn
=
attn
.
masked_fill
(
~
final_mask
,
float
(
"-inf"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
"bhst,bhtd->bhsd"
,
attn
,
v
)
# Verify accuracy.
assert
torch
.
allclose
(
triton_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
"Triton output doesn't match reference when qlen < klen"
print
(
"Pass topk sparse attention test with qlen < klen"
)
def
main
():
test_topk_sparse_attention
()
test_topk_sparse_attention_qlt_kl
()
if
__name__
==
"__main__"
:
main
()
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
0 → 100644
View file @
cf6e11c9
import
math
import
torch
import
tilelang
import
tilelang.language
as
T
import
torch.nn.functional
as
F
def
get_sparse_attn_mask_from_topk
(
x
,
topk
,
use_dense_for_last_block
=
False
):
bsz
,
num_head
,
downsample_len
,
_
=
x
.
shape
# N_CTX = downsample_len * BLOCK
sparse_index
=
torch
.
topk
(
x
,
topk
,
dim
=-
1
).
indices
dense_mask
=
torch
.
full
([
bsz
,
num_head
,
downsample_len
,
downsample_len
],
False
,
dtype
=
torch
.
bool
,
device
=
x
.
device
)
dense_mask
.
scatter_
(
-
1
,
sparse_index
,
True
)
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
dense_mask
.
tril_
()
return
dense_mask
def
get_sparse_attn_mask_from_threshold
(
x
,
threshold
,
use_dense_for_last_block
=
False
):
dense_mask
=
x
>
threshold
if
use_dense_for_last_block
:
dense_mask
[:,
:,
-
2
:,
:]
=
True
dense_mask
.
tril_
()
return
dense_mask
@
tilelang
.
jit
(
out_idx
=
[
4
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
},
)
def
blocksparse_flashattn
(
batch
,
heads
,
seq_len
,
dim
,
downsample_len
,
is_causal
):
block_M
=
64
block_N
=
64
num_stages
=
1
threads
=
128
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
shape
=
[
batch
,
heads
,
seq_len
,
dim
]
block_mask_shape
=
[
batch
,
heads
,
downsample_len
,
downsample_len
]
dtype
=
T
.
float16
accum_dtype
=
T
.
float32
block_mask_dtype
=
T
.
bool
def
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
):
@
T
.
macro
def
MMA0
(
K
:
T
.
Tensor
(
shape
,
dtype
),
Q_shared
:
T
.
SharedBuffer
([
block_M
,
dim
],
dtype
),
K_shared
:
T
.
SharedBuffer
([
block_N
,
dim
],
dtype
),
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
k
:
T
.
int32
,
bx
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
K
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
K_shared
)
if
is_causal
:
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
bx
*
block_M
+
i
>=
k
*
block_N
+
j
,
0
,
-
T
.
infinity
(
acc_s
.
dtype
))
else
:
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
MMA1
(
V
:
T
.
Tensor
(
shape
,
dtype
),
V_shared
:
T
.
SharedBuffer
([
block_M
,
dim
],
dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
k
:
T
.
int32
,
by
:
T
.
int32
,
bz
:
T
.
int32
,
):
T
.
copy
(
V
[
bz
,
by
,
k
*
block_N
:
(
k
+
1
)
*
block_N
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
@
T
.
macro
def
Softmax
(
acc_s
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
accum_dtype
),
acc_s_cast
:
T
.
FragmentBuffer
([
block_M
,
block_N
],
dtype
),
scores_max
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_max_prev
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
scores_sum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
logsum
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for
i
in
T
.
Parallel
(
block_M
):
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_M
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
@
T
.
macro
def
Rescale
(
acc_o
:
T
.
FragmentBuffer
([
block_M
,
dim
],
accum_dtype
),
scores_scale
:
T
.
FragmentBuffer
([
block_M
],
accum_dtype
),
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
@
T
.
prim_func
def
blocksparse_flashattn
(
Q
:
T
.
Tensor
(
shape
,
dtype
),
K
:
T
.
Tensor
(
shape
,
dtype
),
V
:
T
.
Tensor
(
shape
,
dtype
),
BlockSparseMask
:
T
.
Tensor
(
block_mask_shape
,
block_mask_dtype
),
Output
:
T
.
Tensor
(
shape
,
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
seq_len
,
block_M
),
heads
,
batch
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
O_shared
=
T
.
alloc_shared
([
block_M
,
dim
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_M
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_M
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_M
,
dim
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_M
],
accum_dtype
)
block_mask
=
T
.
alloc_local
([
downsample_len
],
block_mask_dtype
)
T
.
copy
(
Q
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
for
vj
in
T
.
serial
(
downsample_len
):
block_mask
[
vj
]
=
BlockSparseMask
[
bz
,
by
,
bx
,
vj
]
loop_range
=
(
T
.
min
(
T
.
ceildiv
(
seq_len
,
block_N
),
T
.
ceildiv
((
bx
+
1
)
*
block_M
,
block_N
))
if
is_causal
else
T
.
ceildiv
(
seq_len
,
block_N
)
)
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
k
]
!=
0
:
MMA0
(
K
,
Q_shared
,
K_shared
,
acc_s
,
k
,
bx
,
by
,
bz
)
Softmax
(
acc_s
,
acc_s_cast
,
scores_max
,
scores_max_prev
,
scores_scale
,
scores_sum
,
logsum
)
Rescale
(
acc_o
,
scores_scale
)
MMA1
(
V
,
V_shared
,
acc_s_cast
,
acc_o
,
k
,
by
,
bz
)
for
i
,
j
in
T
.
Parallel
(
block_M
,
dim
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
T
.
copy
(
acc_o
,
O_shared
)
T
.
copy
(
O_shared
,
Output
[
bz
,
by
,
bx
*
block_M
:
(
bx
+
1
)
*
block_M
,
:])
return
blocksparse_flashattn
return
kernel_func
(
block_M
,
block_N
,
num_stages
,
threads
)
def
test_topk_sparse_attention
():
# Config
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
=
1
,
1
,
256
,
64
TOPK
=
2
# Keep top 8 elements per row
BLOCK
=
64
torch
.
manual_seed
(
0
)
# Create inputs
q
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
k
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
v
=
torch
.
randn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
sm_scale
=
1.0
/
(
D_HEAD
**
0.5
)
# Create sparse mask (downsampled to block level)
downsample_factor
=
BLOCK
downsample_len
=
math
.
ceil
(
SEQ_LEN
/
downsample_factor
)
x_ds
=
torch
.
randn
([
BATCH
,
N_HEADS
,
downsample_len
,
downsample_len
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
x_ds
[:,
:,
:,
0
]
=
100
block_mask
=
get_sparse_attn_mask_from_topk
(
x_ds
,
topk
=
TOPK
)
# Run tilelang kernel
kernel
=
blocksparse_flashattn
(
BATCH
,
N_HEADS
,
SEQ_LEN
,
D_HEAD
,
downsample_len
,
is_causal
=
True
)
tilelang_output
=
kernel
(
q
,
k
,
v
,
block_mask
)
# Compute reference
# Expand block mask to full attention matrix
full_mask
=
torch
.
kron
(
block_mask
.
float
(),
torch
.
ones
(
BLOCK
,
BLOCK
,
device
=
"cuda"
))
full_mask
=
full_mask
[...,
:
SEQ_LEN
,
:
SEQ_LEN
].
bool
()
full_mask
=
full_mask
&
torch
.
tril
(
torch
.
ones_like
(
full_mask
))
# Apply causal
# PyTorch reference implementation
attn
=
torch
.
einsum
(
"bhsd,bhtd->bhst"
,
q
,
k
)
*
sm_scale
attn
=
attn
.
masked_fill
(
~
full_mask
,
float
(
"-inf"
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
ref_output
=
torch
.
einsum
(
"bhst,bhtd->bhsd"
,
attn
,
v
)
print
(
"ref_output"
,
ref_output
)
print
(
"tilelang_output"
,
tilelang_output
)
# Verify accuracy
torch
.
testing
.
assert_close
(
tilelang_output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
)
print
(
"Pass topk sparse attention test with qlen == klen"
)
def
main
():
test_topk_sparse_attention
()
if
__name__
==
"__main__"
:
main
()
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
0 → 100644
View file @
cf6e11c9
# ruff: noqa
import
torch
import
torch.nn.functional
as
F
import
tilelang
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
from
einops
import
rearrange
,
einsum
import
argparse
import
time
import
math
from
heuristic
import
num_splits_heuristic
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
dtype
=
T
.
float16
accum_dtype
=
T
.
float32
kv_group_num
=
heads
//
heads_kv
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
},
)
def
kernel_func
(
block_N
,
block_H
,
page_block_size
,
num_split
,
num_stages
,
threads
,
num_pages
,
max_num_blocks_per_seq
,
max_selected_blocks
):
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
num_pages
,
page_block_size
,
heads_kv
,
dim
]
shape_v
=
[
num_pages
,
page_block_size
,
heads_kv
,
dim_v
]
shape_indices
=
[
batch
,
heads_kv
,
max_selected_blocks
]
shape_block_table
=
[
batch
,
max_num_blocks_per_seq
]
shape_o
=
[
batch
,
heads
,
dim_v
]
part_shape
=
[
batch
,
heads
,
num_split
,
dim_v
]
valid_block_H
=
min
(
block_H
,
kv_group_num
)
assert
block_N
<=
page_block_size
and
page_block_size
%
block_N
==
0
block_ratio
=
page_block_size
//
block_N
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
T
.
int32
),
cache_seqlens
:
T
.
Tensor
([
batch
],
T
.
int32
),
block_table
:
T
.
Tensor
(
shape_block_table
,
T
.
int32
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
acc_s
=
T
.
alloc_fragment
([
block_H
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_H
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_H
,
dim_v
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
has_valid_block
=
T
.
alloc_var
(
"bool"
)
bid
=
bx
hid
=
by
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
num_blocks
=
max_selected_blocks
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
has_valid_block
=
False
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
logical_block_idx
=
block_indices
[
bid
,
cur_kv_head
,
start
+
k
]
if
logical_block_idx
>=
0
:
has_valid_block
=
True
block_table_idx
=
T
.
floordiv
(
logical_block_idx
,
block_ratio
)
block_tile_idx
=
T
.
floormod
(
logical_block_idx
,
block_ratio
)
physical_block_idx
=
block_table
[
bid
,
block_table_idx
]
T
.
copy
(
K
[
physical_block_idx
,
block_tile_idx
*
block_N
:
(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
k
==
0
:
# assume block_indices is sorted in reverse order, otherwise, remove this if condition
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
logical_block_idx
*
block_N
+
j
>=
cache_seqlens
[
bid
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
physical_block_idx
,
block_tile_idx
*
block_N
:
(
block_tile_idx
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_valid_block
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
for
i
in
T
.
Parallel
(
block_H
):
if
i
<
valid_block_H
:
glse
[
bid
,
hid
*
valid_block_H
+
i
,
sid
]
=
logsum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
if
i
<
valid_block_H
:
Output_partial
[
bid
,
hid
*
valid_block_H
+
i
,
sid
,
j
]
=
acc_o
[
i
,
j
]
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
o_accum_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
lse_local_split
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_logsum_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
max_split
=
T
.
alloc_local
([
1
],
T
.
int32
)
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
lse_max_local
[
0
]
=
-
T
.
infinity
(
accum_dtype
)
for
k
in
T
.
serial
(
num_split
):
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
if
lse_local_split
[
0
]
!=
0
:
max_split
[
0
]
=
k
lse_max_local
[
0
]
=
T
.
max
(
lse_max_local
[
0
],
glse
[
bz
,
by
,
k
])
for
k
in
T
.
Pipelined
(
num_split
,
num_stages
=
1
):
if
k
<=
max_split
[
0
]:
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
lse_logsum_local
[
0
]
+=
T
.
exp2
(
lse_local_split
[
0
]
-
lse_max_local
[
0
])
lse_logsum_local
[
0
]
=
T
.
log2
(
lse_logsum_local
[
0
])
+
lse_max_local
[
0
]
for
k
in
T
.
serial
(
num_split
):
if
k
<=
max_split
[
0
]:
for
i
in
T
.
Parallel
(
dim_v
):
po_local
[
i
]
=
Output_partial
[
bz
,
by
,
k
,
i
]
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
scale_local
[
0
]
=
T
.
exp2
(
lse_local_split
[
0
]
-
lse_logsum_local
[
0
])
for
i
in
T
.
Parallel
(
dim_v
):
o_accum_local
[
i
]
+=
po_local
[
i
]
*
scale_local
[
0
]
for
i
in
T
.
Parallel
(
dim_v
):
Output
[
bz
,
by
,
i
]
=
o_accum_local
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
T
.
int32
),
cache_seqlens
:
T
.
Tensor
([
batch
],
T
.
int32
),
block_table
:
T
.
Tensor
(
shape_block_table
,
T
.
int32
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
flash_attn_split
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
block_table
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
return
main
return
kernel_func
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
page_block_size
,
block_N
,
num_pages
):
super
(
SparseFlashAttn
,
self
).
__init__
()
self
.
batch
=
batch
self
.
heads
=
heads
self
.
heads_kv
=
heads_kv
self
.
dim
=
dim
self
.
dim_v
=
dim_v
self
.
block_N
=
block_N
self
.
page_block_size
=
page_block_size
self
.
num_pages
=
num_pages
self
.
block_H
=
64
self
.
kernel
=
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
)(
block_N
=
block_N
,
block_H
=
self
.
block_H
,
page_block_size
=
page_block_size
,
num_split
=
T
.
dynamic
(
"num_split"
),
num_stages
=
2
,
threads
=
128
,
num_pages
=
num_pages
,
max_num_blocks_per_seq
=
T
.
dynamic
(
"max_num_blocks_per_seq"
),
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
),
)
props
=
torch
.
cuda
.
get_device_properties
(
torch
.
device
(
"cuda:0"
))
self
.
num_sm
=
props
.
multi_processor_count
def
forward
(
self
,
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
block_table
):
batch
=
self
.
batch
heads
=
self
.
heads
heads_kv
=
self
.
heads_kv
dim_v
=
self
.
dim_v
dim
=
self
.
dim
block_size
=
self
.
block_N
max_selected_blocks
=
block_indices
.
shape
[
-
1
]
# Compute static scheduling parameters
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
self
.
block_H
-
1
)
//
self
.
block_H
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
self
.
num_sm
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output
=
self
.
kernel
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
block_table
,
glse
,
output_partial
,
)
return
output
def
ref_program_torch_paged
(
query
,
key_cache
,
value_cache
,
block_indices
,
cache_seqlens
,
block_table
,
page_block_size
,
block_size
):
"""
Paged version of sparse attention reference implementation.
Args:
query: [batch, heads, dim]
key_cache: [num_pages, page_block_size, heads_kv, dim]
value_cache: [num_pages, page_block_size, heads_kv, dim]
block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices
cache_seqlens: [batch] - actual sequence lengths
block_table: [batch, max_num_blocks_per_seq] - maps logical to physical blocks
page_block_size: size of each page block
block_size: size of attention blocks (block_N)
"""
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key_cache
.
shape
[
2
]
dim_v
=
value_cache
.
shape
[
3
]
num_head_groups
=
heads
//
heads_kv
scale
=
dim
**
0.5
# Reconstruct the full key and value tensors from paged cache
max_cache_seqlen
=
max
(
cache_seqlens
).
item
()
key_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim
),
dtype
=
key_cache
.
dtype
,
device
=
key_cache
.
device
)
value_full
=
torch
.
zeros
((
batch
,
heads_kv
,
max_cache_seqlen
,
dim_v
),
dtype
=
value_cache
.
dtype
,
device
=
value_cache
.
device
)
# Reconstruct full tensors from paged cache using block_table
for
b
in
range
(
batch
):
seq_len
=
cache_seqlens
[
b
].
item
()
num_blocks_needed
=
int
(
math
.
ceil
(
seq_len
/
page_block_size
))
for
block_idx
in
range
(
num_blocks_needed
):
physical_block_idx
=
block_table
[
b
,
block_idx
].
item
()
# Calculate the range of tokens for this block
start_token
=
block_idx
*
page_block_size
end_token
=
min
(
start_token
+
page_block_size
,
seq_len
)
actual_block_size
=
end_token
-
start_token
# Copy from paged cache to full tensors
key_full
[
b
,
:,
start_token
:
end_token
,
:]
=
key_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
value_full
[
b
,
:,
start_token
:
end_token
,
:]
=
value_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:].
transpose
(
0
,
1
)
# Reshape query for grouped attention
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
# Compute attention scores
scores
=
einsum
(
query
,
key_full
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
# Create sparse mask based on block_indices
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Apply sparse mask based on selected blocks
for
b
in
range
(
batch
):
for
h
in
range
(
heads_kv
):
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
for
idx
in
valid_indices
:
if
idx
>=
0
:
# Valid block index
start_pos
=
idx
*
block_size
end_pos
=
min
(
start_pos
+
block_size
,
max_cache_seqlen
)
sparse_mask
[
b
,
:,
h
,
start_pos
:
end_pos
]
=
1
# Apply sparse mask
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"-inf"
))
# Apply causal mask based on actual sequence lengths
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
scores
.
device
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
# Compute attention weights
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# Apply attention to values
out
=
einsum
(
attention
,
value_full
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
# Reshape output back to original format
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
kcache
,
vcache
,
cache_seqlens
,
block_table
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
kcache
,
vcache
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_table
)
output
=
output
.
squeeze
(
1
)
return
output
def
main
(
args
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
)
sparse_ratio
=
args
.
sparse_ratio
block_N
=
args
.
block_N
page_block_size
=
args
.
page_block_size
num_blocks
=
args
.
num_pages
# Use num_pages from args
# For dense case verification, set sparse_ratio to 0 to select all blocks
max_selected_blocks
=
int
(
math
.
ceil
(
max_cache_seqlen
/
block_N
))
print
(
"max_selected_blocks: "
,
max_selected_blocks
)
dtype
=
torch
.
float16
# Generate random inputs
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
cache_seqlens
=
torch
.
randint
(
max_cache_seqlen
//
2
,
max_cache_seqlen
+
1
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
print
(
"cache_seqlens: "
,
cache_seqlens
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"cuda"
)
# Create paged KV cache
K_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
V_cache
=
torch
.
zeros
((
num_blocks
,
page_block_size
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"cuda"
)
# Create block table and block indices for dense case (all blocks selected)
max_num_blocks_per_seq
=
int
(
math
.
ceil
(
max_cache_seqlen
/
page_block_size
))
print
(
"max_num_blocks_per_seq: "
,
max_num_blocks_per_seq
)
block_table
=
torch
.
zeros
((
batch
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
block_indices
=
torch
.
zeros
((
batch
,
heads_kv
,
max_selected_blocks
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# Fill block table and block indices and cache
# Create a pool of available physical blocks
total_blocks_needed
=
sum
(
int
(
math
.
ceil
(
cache_seqlens
[
seq_idx
].
item
()
/
page_block_size
))
for
seq_idx
in
range
(
batch
))
available_blocks
=
list
(
range
(
total_blocks_needed
))
import
random
random
.
seed
(
42
)
# For reproducibility
random
.
shuffle
(
available_blocks
)
# Fill block table with random physical block indices
block_assignment
=
{}
# Map (seq_idx, block_idx) -> physical_block_idx
block_idx_counter
=
0
for
seq_idx
in
range
(
batch
):
seq_len
=
cache_seqlens
[
seq_idx
].
item
()
num_blocks_needed
=
int
(
math
.
ceil
(
seq_len
/
page_block_size
))
# Assign random physical blocks for each sequence
for
block_idx
in
range
(
num_blocks_needed
):
physical_block_idx
=
available_blocks
[
block_idx_counter
]
block_table
[
seq_idx
,
block_idx
]
=
physical_block_idx
block_assignment
[(
seq_idx
,
block_idx
)]
=
physical_block_idx
block_idx_counter
+=
1
print
(
f
"Block table:
{
block_table
}
"
)
# Fill K_cache and V_cache with data from original K and V tensors using random block assignment
for
seq_idx
in
range
(
batch
):
seq_len
=
cache_seqlens
[
seq_idx
].
item
()
num_blocks_needed
=
int
(
math
.
ceil
(
seq_len
/
page_block_size
))
for
block_idx
in
range
(
num_blocks_needed
):
physical_block_idx
=
block_assignment
[(
seq_idx
,
block_idx
)]
# Calculate the range of tokens for this block
start_token
=
block_idx
*
page_block_size
end_token
=
min
(
start_token
+
page_block_size
,
seq_len
)
actual_block_size
=
end_token
-
start_token
# Copy K and V data to the paged cache
K_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
K
[
seq_idx
,
start_token
:
end_token
,
:,
:]
V_cache
[
physical_block_idx
,
:
actual_block_size
,
:,
:]
=
V
[
seq_idx
,
start_token
:
end_token
,
:,
:]
# Fill block_indices for sparse attention
# For dense case (verification), we select all blocks in reverse order
# For sparse case, we select a subset of blocks based on sparse_ratio
for
seq_idx
in
range
(
batch
):
seq_len
=
cache_seqlens
[
seq_idx
].
item
()
num_tile
=
int
(
math
.
ceil
(
seq_len
/
block_N
))
if
sparse_ratio
==
0.0
:
# Dense case: select all blocks in reverse order
selected_blocks
=
min
(
num_tile
,
max_selected_blocks
)
for
head_idx
in
range
(
heads_kv
):
for
i
in
range
(
selected_blocks
):
# Select blocks in reverse order (most recent first)
block_indices
[
seq_idx
,
head_idx
,
i
]
=
num_tile
-
1
-
i
# Fill remaining slots with -1 (invalid)
for
i
in
range
(
selected_blocks
,
max_selected_blocks
):
block_indices
[
seq_idx
,
head_idx
,
i
]
=
-
1
else
:
# Fill block_indices for all KV heads
num_selected
=
int
(
num_tile
*
(
1.0
-
sparse_ratio
))
num_selected
=
max
(
1
,
min
(
num_selected
,
max_selected_blocks
))
all_blocks
=
list
(
range
(
num_tile
))
for
head_idx
in
range
(
heads_kv
):
selected_blocks
=
[]
# Always include the most recent blocks
recent_blocks
=
1
selected_blocks
.
append
(
num_tile
-
1
)
# Randomly select some earlier blocks
if
num_selected
>
recent_blocks
:
remaining_blocks
=
[
b
for
b
in
all_blocks
if
b
not
in
selected_blocks
]
if
remaining_blocks
:
import
random
random
.
seed
(
42
)
# For reproducibility
additional_blocks
=
random
.
sample
(
remaining_blocks
,
min
(
num_selected
-
recent_blocks
,
len
(
remaining_blocks
)))
selected_blocks
.
extend
(
additional_blocks
)
# Sort selected blocks in reverse order (most recent first)
selected_blocks
.
sort
(
reverse
=
True
)
for
i
in
range
(
len
(
selected_blocks
)):
block_indices
[
seq_idx
,
head_idx
,
i
]
=
selected_blocks
[
i
]
# Fill remaining slots with -1 (invalid)
for
i
in
range
(
len
(
selected_blocks
),
max_selected_blocks
):
block_indices
[
seq_idx
,
head_idx
,
i
]
=
-
1
# Initialize sparse attention module
sparse_attn
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
page_block_size
,
block_N
,
num_blocks
)
output_sparse
=
sparse_attn
.
forward
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
)
import
flash_attn
# noqa: F401
output_ref_torch
=
ref_program_torch_paged
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
,
page_block_size
,
block_N
)
output_ref_fa
=
ref_program_fa
(
Q
,
K_cache
,
V_cache
,
cache_seqlens
,
block_table
)
# Check correctness
if
sparse_ratio
==
0.0
:
max_diff
=
torch
.
max
(
torch
.
abs
(
output_sparse
-
output_ref_fa
)).
item
()
mean_diff
=
torch
.
mean
(
torch
.
abs
(
output_sparse
-
output_ref_fa
)).
item
()
assert
torch
.
allclose
(
output_ref_fa
,
output_ref_torch
,
atol
=
1e-2
),
"Reference outputs do not match!"
else
:
max_diff
=
torch
.
max
(
torch
.
abs
(
output_sparse
-
output_ref_torch
)).
item
()
mean_diff
=
torch
.
mean
(
torch
.
abs
(
output_sparse
-
output_ref_torch
)).
item
()
print
(
f
"Max difference:
{
max_diff
:.
6
f
}
"
)
print
(
f
"Mean difference:
{
mean_diff
:.
6
f
}
"
)
if
max_diff
<
1e-2
:
print
(
"✓ Verification PASSED: Results match within tolerance"
)
else
:
print
(
"✗ Verification FAILED: Results differ significantly"
)
# Performance measurement
for
_
in
range
(
10
):
# Warm-up
sparse_attn
.
forward
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
_
in
range
(
100
):
# Run multiple times for averaging
sparse_attn
.
forward
(
Q
,
K_cache
,
V_cache
,
block_indices
,
cache_seqlens
,
block_table
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
kernel_time
=
(
end_time
-
start_time
)
/
100
*
1000
# Convert to ms
print
(
f
"Kernel execution time:
{
kernel_time
:.
2
f
}
ms"
)
# FA performance measurement
for
_
in
range
(
10
):
# Warm-up
ref_program_fa
(
Q
,
K_cache
,
V_cache
,
cache_seqlens
,
block_table
)
torch
.
cuda
.
synchronize
()
start_time_fa
=
time
.
time
()
for
_
in
range
(
100
):
# Run multiple times for averaging
ref_program_fa
(
Q
,
K_cache
,
V_cache
,
cache_seqlens
,
block_table
)
torch
.
cuda
.
synchronize
()
end_time_fa
=
time
.
time
()
kernel_time_fa
=
(
end_time_fa
-
start_time_fa
)
/
100
*
1000
# Convert to ms
print
(
f
"FA kernel execution time:
{
kernel_time_fa
:.
2
f
}
ms"
)
print
(
f
"Speedup:
{
kernel_time_fa
/
kernel_time
:.
2
f
}
x"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.0
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_N"
,
type
=
int
,
default
=
64
,
help
=
"block_N"
)
parser
.
add_argument
(
"--page_block_size"
,
type
=
int
,
default
=
256
,
help
=
"block size of pages"
)
parser
.
add_argument
(
"--num_pages"
,
type
=
int
,
default
=
1024
,
help
=
"total number of pages"
)
args
=
parser
.
parse_args
()
main
(
args
)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
0 → 100644
View file @
cf6e11c9
import
torch
import
torch.nn.functional
as
F
import
tilelang
import
tilelang.language
as
T
from
einops
import
rearrange
,
einsum
import
argparse
import
time
import
math
from
heuristic
import
num_splits_heuristic
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
dtype
=
T
.
float16
accum_dtype
=
T
.
float32
kv_group_num
=
heads
//
heads_kv
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
},
)
def
kernel_func
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
,
max_cache_seqlen
,
max_selected_blocks
):
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim
]
shape_v
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
]
shape_indices
=
[
batch
,
heads_kv
,
max_selected_blocks
]
shape_o
=
[
batch
,
heads
,
dim_v
]
part_shape
=
[
batch
,
heads
,
num_split
,
dim_v
]
valid_block_H
=
min
(
block_H
,
kv_group_num
)
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
T
.
int32
),
cache_seqlens
:
T
.
Tensor
([
batch
],
T
.
int32
),
# actual_num_blocks: T.Tensor([batch], T.int32),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
# O_shared = T.alloc_shared([valid_block_H, dim_v], dtype)
acc_s
=
T
.
alloc_fragment
([
block_H
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_H
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_H
,
dim_v
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
has_valid_block
=
T
.
alloc_var
(
"bool"
)
bid
=
bx
hid
=
by
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
num_blocks
=
max_selected_blocks
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
has_valid_block
=
False
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
i_s
=
block_indices
[
bid
,
cur_kv_head
,
start
+
k
]
if
i_s
>=
0
:
has_valid_block
=
True
T
.
copy
(
K
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
k
==
0
:
# assume block_indices is sorted in reverse order, otherwise, remove this if condition
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
i_s
*
block_N
+
j
>=
cache_seqlens
[
bid
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
])
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
i_s
*
block_N
:
(
i_s
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_valid_block
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
for
i
in
T
.
Parallel
(
block_H
):
if
i
<
valid_block_H
:
glse
[
bid
,
hid
*
valid_block_H
+
i
,
sid
]
=
logsum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
if
i
<
valid_block_H
:
Output_partial
[
bid
,
hid
*
valid_block_H
+
i
,
sid
,
j
]
=
acc_o
[
i
,
j
]
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
o_accum_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
lse_local_split
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_logsum_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
max_split
=
T
.
alloc_local
([
1
],
T
.
int32
)
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
lse_max_local
[
0
]
=
-
T
.
infinity
(
accum_dtype
)
for
k
in
T
.
serial
(
num_split
):
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
if
lse_local_split
[
0
]
!=
0
:
max_split
[
0
]
=
k
lse_max_local
[
0
]
=
T
.
max
(
lse_max_local
[
0
],
glse
[
bz
,
by
,
k
])
for
k
in
T
.
Pipelined
(
num_split
,
num_stages
=
1
):
if
k
<=
max_split
[
0
]:
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
lse_logsum_local
[
0
]
+=
T
.
exp2
(
lse_local_split
[
0
]
-
lse_max_local
[
0
])
lse_logsum_local
[
0
]
=
T
.
log2
(
lse_logsum_local
[
0
])
+
lse_max_local
[
0
]
for
k
in
T
.
serial
(
num_split
):
if
k
<=
max_split
[
0
]:
for
i
in
T
.
Parallel
(
dim_v
):
po_local
[
i
]
=
Output_partial
[
bz
,
by
,
k
,
i
]
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
scale_local
[
0
]
=
T
.
exp2
(
lse_local_split
[
0
]
-
lse_logsum_local
[
0
])
for
i
in
T
.
Parallel
(
dim_v
):
o_accum_local
[
i
]
+=
po_local
[
i
]
*
scale_local
[
0
]
for
i
in
T
.
Parallel
(
dim_v
):
Output
[
bz
,
by
,
i
]
=
o_accum_local
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_indices
:
T
.
Tensor
(
shape_indices
,
T
.
int32
),
cache_seqlens
:
T
.
Tensor
([
batch
],
T
.
int32
),
# actual_num_blocks: T.Tensor([batch], T.int32),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
# flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial)
flash_attn_split
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
return
main
return
kernel_func
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
):
super
(
SparseFlashAttn
,
self
).
__init__
()
self
.
batch
=
batch
self
.
heads
=
heads
self
.
heads_kv
=
heads_kv
self
.
dim
=
dim
self
.
dim_v
=
dim_v
self
.
block_size
=
block_size
self
.
block_H
=
64
self
.
kernel
=
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
)(
block_N
=
block_size
,
block_H
=
self
.
block_H
,
num_split
=
T
.
dynamic
(
"num_split"
),
num_stages
=
2
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
),
)
props
=
torch
.
cuda
.
get_device_properties
(
torch
.
device
(
"cuda:0"
))
self
.
num_sm
=
props
.
multi_processor_count
def
forward
(
self
,
query
,
key
,
value
,
block_indices
,
cache_seqlens
):
batch
=
self
.
batch
heads
=
self
.
heads
heads_kv
=
self
.
heads_kv
dim_v
=
self
.
dim_v
dim
=
self
.
dim
block_size
=
self
.
block_size
max_selected_blocks
=
block_indices
.
shape
[
-
1
]
# Compute static scheduling parameters
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
self
.
block_H
-
1
)
//
self
.
block_H
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
self
.
num_sm
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output
=
self
.
kernel
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
glse
,
output_partial
)
return
output
def
sparse_gqa_decode_varlen_indice
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
block_size
):
"""
Args:
query: [batch, heads, dim]
key: [batch, max_cache_seqlen, heads_kv, dim]
value: [batch, max_cache_seqlen, heads_kv, dim_v]
block_indices: [batch, heads_kv, max_selected_blocks], indices of selected blocks, -1 for padding
cache_seqlens: [batch], sequence lengths of the kvcache
max_cache_seqlen: maximum sequence length of kvcache
block_size: block size
Returns:
output: [batch, heads, dim_v]
"""
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
dim_v
=
value
.
shape
[
-
1
]
max_selected_blocks
=
block_indices
.
shape
[
-
1
]
block_H
=
64
actual_num_blocks
=
torch
.
sum
(
block_indices
!=
-
1
,
dim
=-
1
).
to
(
torch
.
int32
)
actual_num_blocks
=
actual_num_blocks
[
:,
0
]
# [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
# get num_split
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
# (kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
132
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
kernel
=
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
)(
block_N
=
block_size
,
block_H
=
block_H
,
num_split
=
T
.
dynamic
(
"num_split"
),
num_stages
=
2
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
max_selected_blocks
=
T
.
dynamic
(
"max_selected_blocks"
),
)
output
=
kernel
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
glse
,
Output_partial
)
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
"b n h d -> b h n d"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"b n h d -> b h n d"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values based on block_indices
for
b
in
range
(
batch
):
for
h
in
range
(
heads_kv
):
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
for
idx
in
valid_indices
:
if
idx
>=
0
:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"-inf"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"cuda"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
return
output
def
debug
(
name
,
expect
,
actual
,
atol
=
1e-3
,
rtol
=
1e-3
):
all_close
=
torch
.
allclose
(
expect
,
actual
,
atol
=
atol
,
rtol
=
rtol
)
print
(
name
+
" all_close={}"
.
format
(
all_close
))
if
not
all_close
:
diff
=
(
expect
-
actual
).
abs
()
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
max_indices
=
torch
.
nonzero
(
diff
==
diff
.
max
().
item
())
first_index
=
tuple
(
max_indices
[
0
].
tolist
())
print
(
f
"Index:
{
first_index
}
, expect:
{
expect
[
first_index
]
}
, actual:
{
actual
[
first_index
]
}
"
)
def
main
(
batch
=
8
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
block_size
=
block_size
max_selected_blocks
=
int
(
math
.
ceil
(
max_cache_seqlen
*
(
1
-
sparse_ratio
)
/
block_size
))
print
(
"max_selected_blocks: "
,
max_selected_blocks
)
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"cuda"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# # Ensure at least one element equals cache_seqlen
# random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
# # cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
print
(
"cache_seqlens: "
,
cache_seqlens
)
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_indices with -1 (for padding blocks)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size)
# block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda')
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
if
max_valid_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
# valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks]
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
# Sort indices within each batch-group for consistency
block_indices
,
_
=
block_indices
.
sort
(
dim
=-
1
,
descending
=
True
)
# print("block_indices: ", block_indices)
actual_num_blocks
=
torch
.
sum
(
block_indices
!=
-
1
,
dim
=-
1
).
to
(
torch
.
int32
)[:,
0
]
print
(
"actual_num_blocks: "
,
actual_num_blocks
)
# print(block_indices.shape, actual_num_blocks.shape)
max_num_blocks
=
torch
.
max
(
max_valid_num_blocks
).
item
()
print
(
"max_num_blocks: "
,
max_num_blocks
)
# parity reference
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
sparse_kernel
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
)
out
=
sparse_kernel
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
)
debug
(
"output"
,
ref
,
out
,
atol
=
1e-3
,
rtol
=
1e-3
)
import
flash_attn
# noqa: F401
## latency reference
for
_
in
range
(
10
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
_
in
range
(
100
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
print
(
"dense time: "
,
(
time
.
time
()
-
start
)
/
100
*
1000
)
for
_
in
range
(
10
):
# out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size)
out
=
sparse_kernel
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
_
in
range
(
100
):
# out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size)
out
=
sparse_kernel
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
)
torch
.
cuda
.
synchronize
()
print
(
"sparse time: "
,
(
time
.
time
()
-
start
)
/
100
*
1000
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
0 → 100644
View file @
cf6e11c9
import
torch
import
torch.nn.functional
as
F
import
tilelang
from
tilelang.autotuner
import
*
import
tilelang.language
as
T
from
einops
import
rearrange
,
einsum
import
argparse
import
time
import
math
from
heuristic
import
num_splits_heuristic
def
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
):
scale
=
(
1.0
/
dim
)
**
0.5
*
1.44269504
# log2(e)
dtype
=
T
.
float16
accum_dtype
=
T
.
float32
kv_group_num
=
heads
//
heads_kv
@
tilelang
.
jit
(
out_idx
=
[
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
,
},
)
def
kernel_func
(
block_N
,
block_H
,
num_split
,
num_stages
,
threads
,
max_cache_seqlen
,
num_blocks
):
shape_q
=
[
batch
,
heads
,
dim
]
shape_k
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim
]
shape_v
=
[
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
]
shape_mask
=
[
batch
,
heads_kv
,
num_blocks
]
shape_o
=
[
batch
,
heads
,
dim_v
]
part_shape
=
[
batch
,
heads
,
num_split
,
dim_v
]
valid_block_H
=
min
(
block_H
,
kv_group_num
)
@
T
.
macro
def
flash_attn_split
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_mask
:
T
.
Tensor
(
shape_mask
,
T
.
bool
),
cache_seqlens
:
T
.
Tensor
([
batch
],
T
.
int32
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
):
with
T
.
Kernel
(
batch
,
heads
//
valid_block_H
,
num_split
,
threads
=
threads
)
as
(
bx
,
by
,
bz
):
Q_shared
=
T
.
alloc_shared
([
block_H
,
dim
],
dtype
)
K_shared
=
T
.
alloc_shared
([
block_N
,
dim
],
dtype
)
V_shared
=
T
.
alloc_shared
([
block_N
,
dim_v
],
dtype
)
# O_shared = T.alloc_shared([valid_block_H, dim_v], dtype)
acc_s
=
T
.
alloc_fragment
([
block_H
,
block_N
],
accum_dtype
)
acc_s_cast
=
T
.
alloc_fragment
([
block_H
,
block_N
],
dtype
)
acc_o
=
T
.
alloc_fragment
([
block_H
,
dim_v
],
accum_dtype
)
scores_max
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_max_prev
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_scale
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
scores_sum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
logsum
=
T
.
alloc_fragment
([
block_H
],
accum_dtype
)
has_valid_block
=
T
.
alloc_var
(
"bool"
)
bid
=
bx
hid
=
by
sid
=
bz
cur_kv_head
=
hid
//
(
kv_group_num
//
valid_block_H
)
T
.
copy
(
Q
[
bid
,
hid
*
valid_block_H
:
hid
*
valid_block_H
+
block_H
,
:],
Q_shared
)
T
.
fill
(
acc_o
,
0
)
T
.
fill
(
logsum
,
0
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
blocks_per_split
=
T
.
floordiv
(
num_blocks
,
num_split
)
remaining_blocks
=
T
.
floormod
(
num_blocks
,
num_split
)
loop_range
=
blocks_per_split
+
T
.
if_then_else
(
sid
<
remaining_blocks
,
1
,
0
)
start
=
blocks_per_split
*
sid
+
T
.
min
(
sid
,
remaining_blocks
)
has_valid_block
=
False
for
k
in
T
.
Pipelined
(
loop_range
,
num_stages
=
num_stages
):
if
block_mask
[
bid
,
hid
,
start
+
k
]:
has_valid_block
=
True
T
.
copy
(
K
[
bid
,
(
start
+
k
)
*
block_N
:
(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
K_shared
)
T
.
clear
(
acc_s
)
T
.
gemm
(
Q_shared
,
K_shared
,
acc_s
,
transpose_B
=
True
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
if_then_else
(
(
start
+
k
)
*
block_N
+
j
>=
cache_seqlens
[
bx
],
-
T
.
infinity
(
accum_dtype
),
acc_s
[
i
,
j
]
)
T
.
copy
(
scores_max
,
scores_max_prev
)
T
.
fill
(
scores_max
,
-
T
.
infinity
(
accum_dtype
))
T
.
reduce_max
(
acc_s
,
scores_max
,
dim
=
1
,
clear
=
False
)
for
i
in
T
.
Parallel
(
block_H
):
scores_max
[
i
]
=
T
.
max
(
scores_max
[
i
],
scores_max_prev
[
i
])
scores_scale
[
i
]
=
T
.
exp2
(
scores_max_prev
[
i
]
*
scale
-
scores_max
[
i
]
*
scale
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
block_N
):
acc_s
[
i
,
j
]
=
T
.
exp2
(
acc_s
[
i
,
j
]
*
scale
-
scores_max
[
i
]
*
scale
)
T
.
reduce_sum
(
acc_s
,
scores_sum
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
logsum
[
i
]
*
scores_scale
[
i
]
+
scores_sum
[
i
]
T
.
copy
(
acc_s
,
acc_s_cast
)
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
*=
scores_scale
[
i
]
T
.
copy
(
V
[
bid
,
(
start
+
k
)
*
block_N
:
(
start
+
k
+
1
)
*
block_N
,
cur_kv_head
,
:],
V_shared
)
T
.
gemm
(
acc_s_cast
,
V_shared
,
acc_o
,
policy
=
T
.
GemmWarpPolicy
.
FullRow
)
if
has_valid_block
:
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
acc_o
[
i
,
j
]
/=
logsum
[
i
]
for
i
in
T
.
Parallel
(
block_H
):
logsum
[
i
]
=
T
.
log2
(
logsum
[
i
])
+
scores_max
[
i
]
*
scale
for
i
in
T
.
Parallel
(
block_H
):
if
i
<
valid_block_H
:
glse
[
bid
,
hid
*
valid_block_H
+
i
,
sid
]
=
logsum
[
i
]
for
i
,
j
in
T
.
Parallel
(
block_H
,
dim_v
):
if
i
<
valid_block_H
:
Output_partial
[
bid
,
hid
*
valid_block_H
+
i
,
sid
,
j
]
=
acc_o
[
i
,
j
]
@
T
.
macro
def
combine
(
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
with
T
.
Kernel
(
heads
,
batch
,
threads
=
128
)
as
(
by
,
bz
):
po_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
o_accum_local
=
T
.
alloc_fragment
([
dim_v
],
accum_dtype
)
lse_local_split
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_logsum_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
lse_max_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
scale_local
=
T
.
alloc_local
([
1
],
accum_dtype
)
T
.
annotate_layout
(
{
lse_logsum_local
:
T
.
Fragment
(
lse_logsum_local
.
shape
,
forward_thread_fn
=
lambda
i
:
i
),
}
)
T
.
clear
(
lse_logsum_local
)
T
.
clear
(
o_accum_local
)
lse_max_local
[
0
]
=
-
T
.
infinity
(
accum_dtype
)
for
k
in
T
.
serial
(
num_split
):
lse_max_local
[
0
]
=
T
.
max
(
lse_max_local
[
0
],
glse
[
bz
,
by
,
k
])
for
k
in
T
.
Pipelined
(
num_split
,
num_stages
=
1
):
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
lse_logsum_local
[
0
]
+=
T
.
exp2
(
lse_local_split
[
0
]
-
lse_max_local
[
0
])
lse_logsum_local
[
0
]
=
T
.
log2
(
lse_logsum_local
[
0
])
+
lse_max_local
[
0
]
for
k
in
T
.
serial
(
num_split
):
for
i
in
T
.
Parallel
(
dim_v
):
po_local
[
i
]
=
Output_partial
[
bz
,
by
,
k
,
i
]
lse_local_split
[
0
]
=
glse
[
bz
,
by
,
k
]
scale_local
[
0
]
=
T
.
exp2
(
lse_local_split
[
0
]
-
lse_logsum_local
[
0
])
for
i
in
T
.
Parallel
(
dim_v
):
o_accum_local
[
i
]
+=
po_local
[
i
]
*
scale_local
[
0
]
for
i
in
T
.
Parallel
(
dim_v
):
Output
[
bz
,
by
,
i
]
=
o_accum_local
[
i
]
@
T
.
prim_func
def
main
(
Q
:
T
.
Tensor
(
shape_q
,
dtype
),
K
:
T
.
Tensor
(
shape_k
,
dtype
),
V
:
T
.
Tensor
(
shape_v
,
dtype
),
block_mask
:
T
.
Tensor
(
shape_mask
,
T
.
bool
),
cache_seqlens
:
T
.
Tensor
([
batch
],
T
.
int32
),
glse
:
T
.
Tensor
([
batch
,
heads
,
num_split
],
accum_dtype
),
Output_partial
:
T
.
Tensor
(
part_shape
,
accum_dtype
),
Output
:
T
.
Tensor
(
shape_o
,
dtype
),
):
flash_attn_split
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
combine
(
glse
,
Output_partial
,
Output
)
return
main
return
kernel_func
class
SparseFlashAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
):
super
(
SparseFlashAttn
,
self
).
__init__
()
self
.
batch
=
batch
self
.
heads
=
heads
self
.
heads_kv
=
heads_kv
self
.
dim
=
dim
self
.
dim_v
=
dim_v
self
.
block_size
=
block_size
self
.
block_H
=
64
self
.
kernel
=
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
)(
block_N
=
block_size
,
block_H
=
self
.
block_H
,
num_split
=
T
.
dynamic
(
"num_split"
),
num_stages
=
2
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
num_blocks
=
T
.
dynamic
(
"num_blocks"
),
)
props
=
torch
.
cuda
.
get_device_properties
(
torch
.
device
(
"cuda:0"
))
self
.
num_sm
=
props
.
multi_processor_count
def
forward
(
self
,
query
,
key
,
value
,
block_mask
,
cache_seqlens
):
batch
=
self
.
batch
heads
=
self
.
heads
heads_kv
=
self
.
heads_kv
dim_v
=
self
.
dim_v
dim
=
self
.
dim
block_size
=
self
.
block_size
block_H
=
self
.
block_H
max_cache_seqlen
=
key
.
shape
[
1
]
# get num_split
max_selected_blocks
=
(
max_cache_seqlen
+
block_size
-
1
)
//
block_size
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
# num_sm = 132
num_sm
=
self
.
num_sm
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_split: ", num_split)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output
=
self
.
kernel
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
return
output
def
sparse_gqa_decode_varlen_mask
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
block_size
):
"""
Args:
query: [batch, heads, dim]
key: [batch, max_cache_seqlen, heads_kv, dim]
value: [batch, max_cache_seqlen, heads_kv, dim_v]
block_mask: [batch, heads_kv, num_blocks], mask for valid blocks
cache_seqlens: [batch], sequence lengths of the kvcache
block_size: block size
Returns:
output: [batch, heads, dim_v]
"""
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
dim_v
=
value
.
shape
[
-
1
]
block_H
=
64
actual_num_blocks
=
torch
.
sum
(
block_mask
,
dim
=-
1
).
to
(
torch
.
int32
)
actual_num_blocks
=
actual_num_blocks
[
:,
0
]
# [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
max_selected_blocks
=
actual_num_blocks
.
max
().
item
()
# get num_split
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
# (kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
132
num_split
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
kernel
=
flashattn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
)(
block_N
=
block_size
,
block_H
=
block_H
,
num_split
=
T
.
dynamic
(
"num_split"
),
num_stages
=
2
,
threads
=
128
,
max_cache_seqlen
=
T
.
dynamic
(
"max_cache_seqlen"
),
num_blocks
=
T
.
dynamic
(
"num_blocks"
),
)
glse
=
torch
.
empty
((
batch
,
heads
,
num_split
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
Output_partial
=
torch
.
empty
((
batch
,
heads
,
num_split
,
dim_v
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
# print(kernel.get_kernel_source())
output
=
kernel
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
glse
,
Output_partial
)
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
"b n h d -> b h n d"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"b n h d -> b h n d"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values
for
b
in
range
(
batch
):
for
h
in
range
(
heads_kv
):
for
idx
in
range
(
num_blocks
):
if
block_mask
[
b
,
h
,
idx
]:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"-inf"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"cuda"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
return
output
def
debug
(
name
,
expect
,
actual
,
atol
=
1e-3
,
rtol
=
1e-3
):
all_close
=
torch
.
allclose
(
expect
,
actual
,
atol
=
atol
,
rtol
=
rtol
)
print
(
name
+
" all_close={}"
.
format
(
all_close
))
if
not
all_close
:
# print(expect[3, 28])
# print(actual[3, 28])
diff
=
(
expect
-
actual
).
abs
()
print
(
"all_close={}, max={}, min={}, mean={}"
.
format
(
all_close
,
diff
.
max
().
item
(),
diff
.
min
().
item
(),
diff
.
mean
().
item
()))
max_indices
=
torch
.
nonzero
(
diff
==
diff
.
max
().
item
())
first_index
=
tuple
(
max_indices
[
0
].
tolist
())
print
(
f
"Index:
{
first_index
}
, expect:
{
expect
[
first_index
]
}
, actual:
{
actual
[
first_index
]
}
"
)
def
main
(
batch
=
8
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
block_size
=
block_size
max_selected_blocks
=
int
(
math
.
ceil
(
max_cache_seqlen
*
(
1
-
sparse_ratio
)
/
block_size
))
print
(
"max_selected_blocks: "
,
max_selected_blocks
)
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"cuda"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
print
(
"cache_seqlens: "
,
cache_seqlens
)
num_blocks
=
(
max_cache_seqlen
+
block_size
-
1
)
//
block_size
valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
*
(
1
-
sparse_ratio
)
/
block_size
).
int
()
print
(
"valid_num_blocks: "
,
valid_num_blocks
)
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_mask with false (for padding blocks)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
valid_num_block
=
valid_num_blocks
[
b
].
item
()
# Valid blocks for this batch
if
valid_num_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
"cuda"
)[:
valid_num_block
]
block_mask
[
b
,
h
,
perm
]
=
True
# print("block_mask: ", block_mask)
# parity reference
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
model
=
SparseFlashAttn
(
batch
,
heads
,
heads_kv
,
dim
,
dim_v
,
block_size
)
out
=
model
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
)
debug
(
"output"
,
ref
,
out
,
atol
=
1e-3
,
rtol
=
1e-3
)
import
flash_attn
# noqa: F401
## latency reference
for
_
in
range
(
10
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
_
in
range
(
100
):
ref
=
ref_program_fa
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
torch
.
cuda
.
synchronize
()
print
(
"dense time: "
,
(
time
.
time
()
-
start
)
/
100
*
1000
)
for
_
in
range
(
10
):
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
out
=
model
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
_
in
range
(
100
):
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
out
=
model
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
)
torch
.
cuda
.
synchronize
()
print
(
"sparse time: "
,
(
time
.
time
()
-
start
)
/
100
*
1000
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
8
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
0 → 100644
View file @
cf6e11c9
# ruff: noqa
import
torch
import
triton
import
triton.language
as
tl
import
argparse
from
einops
import
rearrange
,
einsum
import
torch.nn.functional
as
F
import
math
import
time
from
heuristic
import
num_splits_heuristic
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_H"
,
"BLOCK_N"
,
"BLOCK_D"
],
)
@
triton
.
jit
def
_split_kernel
(
q_ptr
,
k_cache_ptr
,
v_cache_ptr
,
cache_seqlens_ptr
,
o_partial_ptr
,
lse_partial_ptr
,
mask_ptr
,
sm_scale
,
num_splits
,
gqa_group_size
,
max_selected_blocks
,
stride_q_b
,
stride_q_h
,
stride_q_d
,
stride_k_b
,
stride_k_s
,
stride_k_h
,
stride_k_d
,
stride_v_b
,
stride_v_s
,
stride_v_h
,
stride_v_d
,
stride_o_b
,
stride_o_h
,
stride_o_split
,
stride_o_d
,
stride_lse_b
,
stride_lse_h
,
stride_lse_split
,
stride_mask_b
,
stride_mask_h
,
stride_mask_s
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
)
head_idx_kv
=
tl
.
program_id
(
1
)
split_idx
=
tl
.
program_id
(
2
)
head_idx_q
=
head_idx_kv
*
gqa_group_size
offs_h
=
tl
.
arange
(
0
,
BLOCK_H
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
m_i
=
tl
.
full
([
BLOCK_H
],
float
(
"-inf"
),
dtype
=
tl
.
float32
)
l_i
=
tl
.
full
([
BLOCK_H
],
1.0
,
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_D
],
dtype
=
tl
.
float32
)
cache_seqlens
=
tl
.
load
(
cache_seqlens_ptr
+
batch_idx
)
num_blocks
=
max_selected_blocks
blocks_per_split
=
tl
.
floor
(
num_blocks
/
num_splits
).
to
(
tl
.
int32
)
remaining_blocks
=
num_blocks
%
num_splits
if
split_idx
<
remaining_blocks
:
loop_range
=
blocks_per_split
+
1
else
:
loop_range
=
blocks_per_split
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
for
i
in
range
(
loop_range
):
block_idx
=
tl
.
load
(
mask_ptr
+
(
start
+
i
)
*
stride_mask_s
)
if
block_idx
>=
0
:
start_n
=
block_idx
*
BLOCK_N
k_ptr
=
k_cache_ptr
+
start_n
*
stride_k_s
v_ptr
=
v_cache_ptr
+
start_n
*
stride_v_s
k
=
tl
.
load
(
k_ptr
,
mask
=
start_n
+
offs_n
[
None
,
:]
<
cache_seqlens
,
other
=
0.0
)
v
=
tl
.
load
(
v_ptr
,
mask
=
start_n
+
offs_n
[:,
None
]
<
cache_seqlens
,
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
)
qk
=
qk
*
sm_scale
qk
=
tl
.
where
(
start_n
+
offs_n
[
None
,
:]
<
cache_seqlens
,
qk
,
float
(
"-inf"
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
-=
m_ij
[:,
None
]
p
=
tl
.
exp
(
qk
)
l_ij
=
tl
.
sum
(
p
,
1
)
alpha
=
tl
.
exp
(
m_i
-
m_ij
)
l_i
=
l_i
*
alpha
+
l_ij
acc
=
acc
*
alpha
[:,
None
]
p
=
p
.
to
(
v
.
type
.
element_ty
)
acc
+=
tl
.
dot
(
p
,
v
)
m_i
=
m_ij
m_i
+=
tl
.
math
.
log
(
l_i
)
l_recip
=
1
/
l_i
[:,
None
]
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
o_partial_ptr
+=
(
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_D"
],
)
@
triton
.
jit
def
_merge_kernel
(
o_partial_ptr
,
lse_partial_ptr
,
o_ptr
,
lse_partial_stride_b
,
lse_partial_stride_h
,
lse_partial_stride_split
,
o_partial_stride_b
,
o_partial_stride_h
,
o_partial_stride_split
,
o_partial_stride_d
,
o_stride_b
,
o_stride_h
,
o_stride_d
,
BLOCK_D
:
tl
.
constexpr
,
num_splits
:
tl
.
constexpr
,
num_splits_pow2
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
)
head_idx
=
tl
.
program_id
(
1
)
offs_splits
=
tl
.
arange
(
0
,
num_splits_pow2
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_max
=
tl
.
max
(
lse
)
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_partial
=
tl
.
load
(
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
,
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
acc
=
numerator_normalized
/
sumexp_normalized
acc
=
acc
.
to
(
o_ptr
.
dtype
.
element_ty
)
o_ptr
+=
batch_idx
*
o_stride_b
+
head_idx
*
o_stride_h
tl
.
store
(
o_ptr
+
offs_d
*
o_stride_d
,
acc
)
def
block_sparse_flash_decode_gqa_indice_triton
(
q
,
k_cache
,
v_cache
,
cache_seqlens
,
max_cache_seqlen
,
max_selected_blocks
,
block_indices
,
block_size
,
sm_scale
=
None
,
):
batch
,
heads
,
dim
=
q
.
shape
if
sm_scale
is
None
:
sm_scale
=
1
/
math
.
sqrt
(
dim
)
_
,
max_cache_seqlen_cache
,
heads_kv
,
dim_v
=
v_cache
.
shape
assert
max_cache_seqlen
==
max_cache_seqlen_cache
,
"max_cache_seqlen mismatch"
group_size
=
heads
//
heads_kv
block_H
=
16
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
64
# num_sm = self.num_sm
num_splits
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
num_splits_pow2
=
triton
.
next_power_of_2
(
num_splits
)
o_partial
=
torch
.
empty
((
batch
,
heads
,
num_splits
,
dim_v
),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
lse_partial
=
torch
.
empty
((
batch
,
heads
,
num_splits
),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
BLOCK_D
=
dim
BLOCK_H
=
group_size
if
group_size
>
16
else
16
grid
=
(
batch
,
heads_kv
,
num_splits
)
_split_kernel
[
grid
](
q
,
k_cache
,
v_cache
,
cache_seqlens
,
o_partial
,
lse_partial
,
block_indices
,
sm_scale
,
num_splits
,
group_size
,
max_selected_blocks
,
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
k_cache
.
stride
(
0
),
k_cache
.
stride
(
1
),
k_cache
.
stride
(
2
),
k_cache
.
stride
(
3
),
v_cache
.
stride
(
0
),
v_cache
.
stride
(
1
),
v_cache
.
stride
(
2
),
v_cache
.
stride
(
3
),
o_partial
.
stride
(
0
),
o_partial
.
stride
(
1
),
o_partial
.
stride
(
2
),
o_partial
.
stride
(
3
),
lse_partial
.
stride
(
0
),
lse_partial
.
stride
(
1
),
lse_partial
.
stride
(
2
),
block_indices
.
stride
(
0
),
block_indices
.
stride
(
1
),
block_indices
.
stride
(
2
),
BLOCK_H
=
BLOCK_H
,
BLOCK_N
=
block_size
,
BLOCK_D
=
BLOCK_D
,
)
output
=
torch
.
zeros
((
batch
,
heads
,
dim_v
),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
grid
=
(
batch
,
heads
)
_merge_kernel
[
grid
](
o_partial
,
lse_partial
,
output
,
lse_partial
.
stride
(
0
),
lse_partial
.
stride
(
1
),
lse_partial
.
stride
(
2
),
o_partial
.
stride
(
0
),
o_partial
.
stride
(
1
),
o_partial
.
stride
(
2
),
o_partial
.
stride
(
3
),
output
.
stride
(
0
),
output
.
stride
(
1
),
output
.
stride
(
2
),
BLOCK_D
=
dim_v
,
num_splits
=
num_splits
,
num_splits_pow2
=
num_splits_pow2
,
)
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
dim_v
=
value
.
shape
[
-
1
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
"b n h d -> b h n d"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"b n h d -> b h n d"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values based on block_indices
for
b
in
range
(
batch
):
for
h
in
range
(
heads_kv
):
valid_indices
=
block_indices
[
b
,
h
]
# Extract indices for this batch and head
for
idx
in
valid_indices
:
if
idx
>=
0
:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"-inf"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"cuda"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
return
output
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
sparse_ratio
=
sparse_ratio
block_size
=
block_size
qk_flops
=
2
*
batch
*
heads
*
max_cache_seqlen
*
dim
pv_flops
=
2
*
batch
*
heads
*
max_cache_seqlen
*
dim_v
total_flops
=
qk_flops
+
pv_flops
max_selected_blocks
=
int
(
math
.
ceil
(
max_cache_seqlen
*
(
1
-
sparse_ratio
)
/
block_size
))
print
(
"max_selected_blocks: "
,
max_selected_blocks
)
dtype
=
torch
.
float16
block_H
=
64
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"cuda"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
print
(
"cache_seqlens: "
,
cache_seqlens
)
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_indices with -1 (for padding blocks)
block_indices
=
torch
.
full
((
batch
,
heads_kv
,
max_selected_blocks
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
if
max_valid_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
valid_indices
=
torch
.
randperm
(
max_valid_block
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
max_selected_blocks
]
block_indices
[
b
,
h
,
:
len
(
valid_indices
)]
=
valid_indices
# Sort indices within each batch-group for consistency
block_indices
,
_
=
block_indices
.
sort
(
dim
=-
1
,
descending
=
True
)
# print("block_indices: ", block_indices)
actual_num_blocks
=
torch
.
sum
(
block_indices
!=
-
1
,
dim
=-
1
).
to
(
torch
.
int32
)[:,
0
]
print
(
"actual_num_blocks: "
,
actual_num_blocks
)
# print(block_indices.shape, actual_num_blocks.shape)
max_num_blocks
=
torch
.
max
(
max_valid_num_blocks
).
item
()
print
(
"max_num_blocks: "
,
max_num_blocks
)
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_indices
,
cache_seqlens
,
max_cache_seqlen
,
max_num_blocks
,
block_size
)
triton_out
=
block_sparse_flash_decode_gqa_indice_triton
(
Q
,
K
,
V
,
cache_seqlens
,
max_cache_seqlen
,
max_selected_blocks
,
block_indices
,
block_size
,
)
print
(
"max difference: "
,
torch
.
max
(
torch
.
abs
(
ref
-
triton_out
)))
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
print
(
"Passed the ref test!"
)
# Measure performance
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
_
in
range
(
1000
):
block_sparse_flash_decode_gqa_indice_triton
(
Q
,
K
,
V
,
cache_seqlens
,
max_cache_seqlen
,
max_selected_blocks
,
block_indices
,
block_size
,
)
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
elapsed_time
=
end
-
start
avg_time
=
elapsed_time
/
1000
avg_flops
=
total_flops
/
avg_time
print
(
f
"Average time:
{
avg_time
:.
6
f
}
seconds"
)
# Measure performance of reference implementation
import
flash_attn
# noqa: F401
start
=
time
.
time
()
for
_
in
range
(
1000
):
ref_program_fa
(
Q
,
K
,
V
,
cache_seqlens
)
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
elapsed_time_ref
=
end
-
start
avg_time_ref
=
elapsed_time_ref
/
1000
avg_flops_ref
=
total_flops
/
avg_time_ref
print
(
f
"Average time of ref:
{
avg_time_ref
:.
6
f
}
seconds"
)
print
(
f
"Speedup:
{
avg_time_ref
/
avg_time
:.
2
f
}
x"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
64
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
0 → 100644
View file @
cf6e11c9
import
torch
import
triton
import
triton.language
as
tl
import
argparse
from
einops
import
rearrange
,
einsum
import
torch.nn.functional
as
F
import
math
import
time
from
heuristic
import
num_splits_heuristic
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_H"
,
"BLOCK_N"
,
"BLOCK_D"
],
)
@
triton
.
jit
def
_split_kernel
(
q_ptr
,
k_cache_ptr
,
v_cache_ptr
,
cache_seqlens_ptr
,
o_partial_ptr
,
lse_partial_ptr
,
mask_ptr
,
sm_scale
,
num_splits
,
gqa_group_size
,
stride_q_b
,
stride_q_h
,
stride_q_d
,
stride_k_b
,
stride_k_s
,
stride_k_h
,
stride_k_d
,
stride_v_b
,
stride_v_s
,
stride_v_h
,
stride_v_d
,
stride_o_b
,
stride_o_h
,
stride_o_split
,
stride_o_d
,
stride_lse_b
,
stride_lse_h
,
stride_lse_split
,
stride_mask_b
,
stride_mask_h
,
stride_mask_s
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
)
head_idx_kv
=
tl
.
program_id
(
1
)
split_idx
=
tl
.
program_id
(
2
)
head_idx_q
=
head_idx_kv
*
gqa_group_size
offs_h
=
tl
.
arange
(
0
,
BLOCK_H
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
m_i
=
tl
.
full
([
BLOCK_H
],
float
(
"-inf"
),
dtype
=
tl
.
float32
)
l_i
=
tl
.
full
([
BLOCK_H
],
1.0
,
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_D
],
dtype
=
tl
.
float32
)
cache_seqlens
=
tl
.
load
(
cache_seqlens_ptr
+
batch_idx
)
num_blocks
=
(
cache_seqlens
+
BLOCK_N
-
1
)
//
BLOCK_N
blocks_per_split
=
tl
.
floor
(
num_blocks
/
num_splits
).
to
(
tl
.
int32
)
remaining_blocks
=
num_blocks
%
num_splits
if
split_idx
<
remaining_blocks
:
loop_range
=
blocks_per_split
+
1
else
:
loop_range
=
blocks_per_split
q_ptr
+=
batch_idx
*
stride_q_b
+
head_idx_q
*
stride_q_h
k_cache_ptr
+=
batch_idx
*
stride_k_b
+
head_idx_kv
*
stride_k_h
+
offs_n
[
None
,
:]
*
stride_k_s
+
offs_d
[:,
None
]
*
stride_k_d
v_cache_ptr
+=
batch_idx
*
stride_v_b
+
head_idx_kv
*
stride_v_h
+
offs_n
[:,
None
]
*
stride_v_s
+
offs_d
[
None
,
:]
*
stride_v_d
mask_ptr
+=
batch_idx
*
stride_mask_b
+
head_idx_kv
*
stride_mask_h
q
=
tl
.
load
(
q_ptr
+
offs_h
[:,
None
]
*
stride_q_h
+
offs_d
[
None
,
:]
*
stride_q_d
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
start
=
blocks_per_split
*
split_idx
+
tl
.
minimum
(
split_idx
,
remaining_blocks
)
for
block_idx
in
range
(
loop_range
):
start_n
=
(
start
+
block_idx
)
*
BLOCK_N
mask_val
=
tl
.
load
(
mask_ptr
+
(
start
+
block_idx
)
*
stride_mask_s
)
if
mask_val
==
1
:
k_ptr
=
k_cache_ptr
+
start_n
*
stride_k_s
v_ptr
=
v_cache_ptr
+
start_n
*
stride_v_s
k
=
tl
.
load
(
k_ptr
,
mask
=
start_n
+
offs_n
[
None
,
:]
<
cache_seqlens
,
other
=
0.0
)
v
=
tl
.
load
(
v_ptr
,
mask
=
start_n
+
offs_n
[:,
None
]
<
cache_seqlens
,
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
)
qk
=
qk
*
sm_scale
qk
=
tl
.
where
(
start_n
+
offs_n
[
None
,
:]
<
cache_seqlens
,
qk
,
float
(
"-inf"
))
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
-=
m_ij
[:,
None
]
p
=
tl
.
exp
(
qk
)
l_ij
=
tl
.
sum
(
p
,
1
)
alpha
=
tl
.
exp
(
m_i
-
m_ij
)
l_i
=
l_i
*
alpha
+
l_ij
acc
=
acc
*
alpha
[:,
None
]
p
=
p
.
to
(
v
.
type
.
element_ty
)
acc
+=
tl
.
dot
(
p
,
v
)
m_i
=
m_ij
m_i
+=
tl
.
math
.
log
(
l_i
)
l_recip
=
1
/
l_i
[:,
None
]
acc
=
acc
*
l_recip
acc
=
acc
.
to
(
o_partial_ptr
.
dtype
.
element_ty
)
lse_partial_ptr
+=
batch_idx
*
stride_lse_b
+
(
head_idx_q
+
offs_h
)
*
stride_lse_h
+
split_idx
*
stride_lse_split
tl
.
store
(
lse_partial_ptr
,
m_i
,
mask
=
offs_h
<
gqa_group_size
)
o_partial_ptr
+=
(
batch_idx
*
stride_o_b
+
(
head_idx_q
+
offs_h
[:,
None
])
*
stride_o_h
+
split_idx
*
stride_o_split
+
offs_d
[
None
,
:]
*
stride_o_d
)
tl
.
store
(
o_partial_ptr
,
acc
,
mask
=
offs_h
[:,
None
]
<
gqa_group_size
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
]
for
num_stages
in
[
1
,
2
,
3
,
4
,
7
]],
key
=
[
"BLOCK_D"
],
)
@
triton
.
jit
def
_merge_kernel
(
o_partial_ptr
,
lse_partial_ptr
,
o_ptr
,
lse_partial_stride_b
,
lse_partial_stride_h
,
lse_partial_stride_split
,
o_partial_stride_b
,
o_partial_stride_h
,
o_partial_stride_split
,
o_partial_stride_d
,
o_stride_b
,
o_stride_h
,
o_stride_d
,
BLOCK_D
:
tl
.
constexpr
,
num_splits
:
tl
.
constexpr
,
num_splits_pow2
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
)
head_idx
=
tl
.
program_id
(
1
)
offs_splits
=
tl
.
arange
(
0
,
num_splits_pow2
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
lse_offsets
=
lse_partial_ptr
+
batch_idx
*
lse_partial_stride_b
+
head_idx
*
lse_partial_stride_h
lse
=
tl
.
load
(
lse_offsets
+
offs_splits
*
lse_partial_stride_split
,
mask
=
offs_splits
<
num_splits
,
other
=
float
(
"-inf"
))
lse_max
=
tl
.
max
(
lse
)
o_offsets
=
o_partial_ptr
+
batch_idx
*
o_partial_stride_b
+
head_idx
*
o_partial_stride_h
o_partial
=
tl
.
load
(
o_offsets
+
offs_splits
[:,
None
]
*
o_partial_stride_split
+
offs_d
[
None
,
:]
*
o_partial_stride_d
,
mask
=
offs_splits
[:,
None
]
<
num_splits
,
)
sumexp_normalized_splitk
=
tl
.
exp
(
lse
-
lse_max
)
sumexp_normalized
=
tl
.
sum
(
sumexp_normalized_splitk
,
axis
=
0
)
numerator_normalized
=
tl
.
sum
(
o_partial
*
sumexp_normalized_splitk
[:,
None
],
axis
=
0
)
acc
=
numerator_normalized
/
sumexp_normalized
acc
=
acc
.
to
(
o_ptr
.
dtype
.
element_ty
)
o_ptr
+=
batch_idx
*
o_stride_b
+
head_idx
*
o_stride_h
tl
.
store
(
o_ptr
+
offs_d
*
o_stride_d
,
acc
)
def
block_sparse_flash_decode_gqa_mask_triton
(
q
,
k_cache
,
v_cache
,
cache_seqlens
,
max_cache_seqlen
,
block_mask
,
block_size
,
sm_scale
=
None
,
):
batch
,
heads
,
dim
=
q
.
shape
if
sm_scale
is
None
:
sm_scale
=
1
/
math
.
sqrt
(
dim
)
_
,
max_cache_seqlen_cache
,
heads_kv
,
dim_v
=
v_cache
.
shape
assert
max_cache_seqlen
==
max_cache_seqlen_cache
,
"max_cache_seqlen mismatch"
group_size
=
heads
//
heads_kv
block_H
=
16
max_selected_blocks
=
(
max_cache_seqlen
+
block_size
-
1
)
//
block_size
num_m_blocks
=
1
*
(
heads
//
heads_kv
+
block_H
-
1
)
//
block_H
num_n_blocks
=
max_selected_blocks
size_one_kv_head
=
max_selected_blocks
*
block_size
*
(
dim
+
dim_v
)
*
2
# kv_seqlen * (dim + dim_v) * 2
total_mblocks
=
batch
*
heads_kv
*
num_m_blocks
num_sm
=
64
# num_sm = self.num_sm
num_splits
=
num_splits_heuristic
(
total_mblocks
,
num_sm
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
=
True
,
max_splits
=
128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
num_splits_pow2
=
triton
.
next_power_of_2
(
num_splits
)
o_partial
=
torch
.
empty
((
batch
,
heads
,
num_splits
,
dim_v
),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
lse_partial
=
torch
.
empty
((
batch
,
heads
,
num_splits
),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
BLOCK_D
=
dim
BLOCK_H
=
group_size
if
group_size
>
16
else
16
grid
=
(
batch
,
heads_kv
,
num_splits
)
_split_kernel
[
grid
](
q
,
k_cache
,
v_cache
,
cache_seqlens
,
o_partial
,
lse_partial
,
block_mask
,
sm_scale
,
num_splits
,
group_size
,
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
k_cache
.
stride
(
0
),
k_cache
.
stride
(
1
),
k_cache
.
stride
(
2
),
k_cache
.
stride
(
3
),
v_cache
.
stride
(
0
),
v_cache
.
stride
(
1
),
v_cache
.
stride
(
2
),
v_cache
.
stride
(
3
),
o_partial
.
stride
(
0
),
o_partial
.
stride
(
1
),
o_partial
.
stride
(
2
),
o_partial
.
stride
(
3
),
lse_partial
.
stride
(
0
),
lse_partial
.
stride
(
1
),
lse_partial
.
stride
(
2
),
block_mask
.
stride
(
0
),
block_mask
.
stride
(
1
),
block_mask
.
stride
(
2
),
BLOCK_H
=
BLOCK_H
,
BLOCK_N
=
block_size
,
BLOCK_D
=
BLOCK_D
,
)
output
=
torch
.
zeros
((
batch
,
heads
,
dim_v
),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
grid
=
(
batch
,
heads
)
_merge_kernel
[
grid
](
o_partial
,
lse_partial
,
output
,
lse_partial
.
stride
(
0
),
lse_partial
.
stride
(
1
),
lse_partial
.
stride
(
2
),
o_partial
.
stride
(
0
),
o_partial
.
stride
(
1
),
o_partial
.
stride
(
2
),
o_partial
.
stride
(
3
),
output
.
stride
(
0
),
output
.
stride
(
1
),
output
.
stride
(
2
),
BLOCK_D
=
dim_v
,
num_splits
=
num_splits
,
num_splits_pow2
=
num_splits_pow2
,
)
return
output
def
ref_program_torch
(
query
,
key
,
value
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
):
batch
,
heads
,
dim
=
query
.
shape
heads_kv
=
key
.
shape
[
2
]
num_head_groups
=
query
.
shape
[
1
]
//
key
.
shape
[
2
]
scale
=
dim
**
0.5
key
=
rearrange
(
key
,
"b n h d -> b h n d"
)
# [batch_size, heads_kv, seqlen_kv, dim]
value
=
rearrange
(
value
,
"b n h d -> b h n d"
)
# [batch_size, heads_kv, seqlen_kv, dim]
query
=
rearrange
(
query
,
"b (h g) d -> b g h d"
,
g
=
num_head_groups
)
# [batch_size, num_head_groups, heads_kv, dim]
scores
=
einsum
(
query
,
key
,
"b g h d, b h s d -> b g h s"
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask
=
torch
.
zeros_like
(
scores
)
# Assign mask values
for
b
in
range
(
batch
):
for
h
in
range
(
heads_kv
):
for
idx
in
range
(
num_blocks
):
if
block_mask
[
b
,
h
,
idx
]:
sparse_mask
[
b
,
:,
h
,
idx
*
block_size
:
(
idx
+
1
)
*
block_size
]
=
1
scores
=
scores
.
masked_fill
(
sparse_mask
==
0
,
float
(
"-inf"
))
range_len
=
torch
.
arange
(
scores
.
shape
[
-
1
],
device
=
"cuda"
).
unsqueeze
(
0
)
cache_seqlens_expanded
=
cache_seqlens
.
unsqueeze
(
1
)
pad_mask
=
range_len
>=
cache_seqlens_expanded
pad_mask
=
pad_mask
[:,
None
,
None
,
:]
scores
=
scores
.
masked_fill
(
pad_mask
,
float
(
"-inf"
))
attention
=
F
.
softmax
(
scores
/
scale
,
dim
=-
1
)
# [batch_size, num_head_groups, heads_kv, seqlen_kv]
out
=
einsum
(
attention
,
value
,
"b g h s, b h s d -> b g h d"
)
# [batch_size, num_head_groups, heads_kv, dim]
out
=
rearrange
(
out
,
"b g h d -> b (h g) d"
)
# [batch_size, heads, dim]
return
out
def
ref_program_fa
(
query
,
key
,
value
,
cache_seqlens
):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from
flash_attn
import
flash_attn_with_kvcache
# fa2
query
=
query
.
unsqueeze
(
1
)
output
=
flash_attn_with_kvcache
(
query
,
key
,
value
,
cache_seqlens
=
cache_seqlens
)
output
=
output
.
squeeze
(
1
)
return
output
def
main
(
batch
=
64
,
heads
=
32
,
heads_kv
=
8
,
max_cache_seqlen
=
8192
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
):
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
=
batch
,
heads
,
heads_kv
,
max_cache_seqlen
,
dim
,
dim_v
block_size
=
block_size
sparse_ratio
=
sparse_ratio
qk_flops
=
2
*
batch
*
heads
*
max_cache_seqlen
*
dim
pv_flops
=
2
*
batch
*
heads
*
max_cache_seqlen
*
dim_v
total_flops
=
qk_flops
+
pv_flops
dtype
=
torch
.
float16
Q
=
torch
.
randn
((
batch
,
heads
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
K
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim
),
dtype
=
dtype
,
device
=
"cuda"
)
V
=
torch
.
randn
((
batch
,
max_cache_seqlen
,
heads_kv
,
dim_v
),
dtype
=
dtype
,
device
=
"cuda"
)
cache_seqlens
=
torch
.
randint
(
1
,
max_cache_seqlen
,
(
batch
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# Ensure at least one element equals cache_seqlen
random_index
=
torch
.
randint
(
0
,
batch
,
(
1
,),
device
=
"cuda"
).
item
()
# Select a random index
cache_seqlens
[
random_index
]
=
max_cache_seqlen
# Assign cache_seqlen to ensure at least one occurrence
num_blocks
=
(
max_cache_seqlen
+
block_size
-
1
)
//
block_size
valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
*
(
1
-
sparse_ratio
)
/
block_size
).
int
()
print
(
"valid_num_blocks: "
,
valid_num_blocks
)
max_valid_num_blocks
=
torch
.
ceil
(
cache_seqlens
/
block_size
).
int
()
print
(
"max_valid_num_blocks: "
,
max_valid_num_blocks
)
# Initialize block_mask with false (for padding blocks)
block_mask
=
torch
.
zeros
((
batch
,
heads_kv
,
num_blocks
),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
# Assign valid indices while ensuring no duplicates within each batch-group
for
b
in
range
(
batch
):
max_valid_block
=
max_valid_num_blocks
[
b
].
item
()
# Max valid blocks for this batch
valid_num_block
=
valid_num_blocks
[
b
].
item
()
# Valid blocks for this batch
if
valid_num_block
>
0
:
# Ensure there's at least one valid block
for
h
in
range
(
heads_kv
):
perm
=
torch
.
randperm
(
max_valid_block
,
device
=
"cuda"
)[:
valid_num_block
]
block_mask
[
b
,
h
,
perm
]
=
True
ref
=
ref_program_torch
(
Q
,
K
,
V
,
block_mask
,
cache_seqlens
,
max_cache_seqlen
,
num_blocks
,
block_size
)
triton_out
=
block_sparse_flash_decode_gqa_mask_triton
(
Q
,
K
,
V
,
cache_seqlens
,
max_cache_seqlen
,
block_mask
,
block_size
,
)
# print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert
torch
.
allclose
(
ref
,
triton_out
,
atol
=
1e-2
),
"Output mismatch between Triton and reference implementation"
print
(
"Passed the ref test!"
)
# Measure performance
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
_
in
range
(
1000
):
block_sparse_flash_decode_gqa_mask_triton
(
Q
,
K
,
V
,
cache_seqlens
,
max_cache_seqlen
,
block_mask
,
block_size
,
)
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
elapsed_time
=
end
-
start
avg_time
=
elapsed_time
/
1000
avg_flops
=
total_flops
/
avg_time
print
(
f
"Average time:
{
avg_time
:.
6
f
}
seconds"
)
print
(
f
"Average flops:
{
avg_flops
:.
2
f
}
GFLOPS"
)
import
flash_attn
# noqa: F401
start
=
time
.
time
()
for
_
in
range
(
1000
):
ref_program_fa
(
Q
,
K
,
V
,
cache_seqlens
)
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
elapsed_time_ref
=
end
-
start
avg_time_ref
=
elapsed_time_ref
/
1000
avg_flops_ref
=
total_flops
/
avg_time_ref
print
(
f
"Average time of ref:
{
avg_time_ref
:.
6
f
}
seconds"
)
print
(
f
"Average flops of ref:
{
avg_flops_ref
:.
2
f
}
GFLOPS"
)
print
(
f
"Speedup:
{
avg_time_ref
/
avg_time
:.
2
f
}
x"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch"
,
type
=
int
,
default
=
64
,
help
=
"batch size"
)
parser
.
add_argument
(
"--heads"
,
type
=
int
,
default
=
32
,
help
=
"heads"
)
parser
.
add_argument
(
"--heads_kv"
,
type
=
int
,
default
=
8
,
help
=
"heads_kv"
)
parser
.
add_argument
(
"--max_cache_seqlen"
,
type
=
int
,
default
=
8192
,
help
=
"kvcache sequence length"
)
parser
.
add_argument
(
"--dim"
,
type
=
int
,
default
=
128
,
help
=
"dim"
)
parser
.
add_argument
(
"--dim_v"
,
type
=
int
,
default
=
128
,
help
=
"dim_v"
)
parser
.
add_argument
(
"--sparse_ratio"
,
type
=
float
,
default
=
0.8
,
help
=
"sparse ratio"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
32
,
help
=
"block_size"
)
args
=
parser
.
parse_args
()
main
(
args
.
batch
,
args
.
heads
,
args
.
heads_kv
,
args
.
max_cache_seqlen
,
args
.
dim
,
args
.
dim_v
,
args
.
sparse_ratio
,
args
.
block_size
)
examples/blocksparse_attention/heuristic.py
0 → 100644
View file @
cf6e11c9
import
math
def
num_splits_heuristic
(
total_mblocks
,
num_SMs
,
num_n_blocks
,
num_m_blocks
,
size_one_kv_head
,
is_causal_or_local
,
max_splits
):
"""
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
Parameters:
- total_mblocks (int): Total number of m_blocks.
- num_SMs (int): Number of Streaming Multiprocessors (SMs) in the GPU.
- num_n_blocks (int): Number of n_blocks.
- num_m_blocks (int): Number of m_blocks.
- size_one_kv_head (int): Size of one KV head in bytes.
- is_causal_or_local (bool): Indicates whether the operation is causal or local.
- max_splits (int): Maximum number of allowed splits.
Returns:
- int: The optimal number of splits.
"""
# If we have enough m_blocks to almost fill the SMs, prefer 1 split unless memory constraints apply.
if
total_mblocks
>=
0.8
*
num_SMs
:
size_l2
=
50
*
1024
*
1024
# L2 cache size assumption (50MB)
# Only split if each KV head is too large for L2 and there are enough m_blocks
if
size_one_kv_head
>
size_l2
and
num_m_blocks
>=
num_SMs
*
2
and
not
is_causal_or_local
:
return
min
((
size_one_kv_head
+
size_l2
-
1
)
//
size_l2
,
max_splits
)
else
:
return
1
# If num_n_blocks is too small, we don't split
if
num_n_blocks
<=
4
:
return
1
# Limit max_splits to a reasonable range
max_splits
=
min
(
max_splits
,
num_SMs
,
num_n_blocks
)
max_efficiency
=
0.0
efficiency
=
[]
# Compute efficiency for different splits
for
num_splits
in
range
(
1
,
max_splits
+
1
):
n_waves
=
(
total_mblocks
*
num_splits
)
/
num_SMs
eff
=
n_waves
/
math
.
ceil
(
n_waves
)
# Track max efficiency
if
eff
>
max_efficiency
:
max_efficiency
=
eff
efficiency
.
append
(
eff
)
# Find the smallest number of splits that achieves at least 85% of max efficiency
for
num_splits
in
range
(
1
,
max_splits
+
1
):
if
efficiency
[
num_splits
-
1
]
>=
0.85
*
max_efficiency
:
return
num_splits
return
1
examples/blocksparse_attention/test_example_blocksparse_attention.py
0 → 100644
View file @
cf6e11c9
import
tilelang.testing
import
block_sparse_attn_triton
import
example_tilelang_block_sparse_attn
import
example_tilelang_sparse_gqa_decode_varlen_indice
import
example_tilelang_sparse_gqa_decode_varlen_mask
import
example_triton_sparse_gqa_decode_varlen_indice
import
example_triton_sparse_gqa_decode_varlen_mask
def
test_block_sparse_attn_triton
():
block_sparse_attn_triton
.
main
()
def
test_example_tilelang_block_sparse_attn
():
example_tilelang_block_sparse_attn
.
main
()
def
test_example_tilelang_sparse_gqa_decode_varlen_indice
():
example_tilelang_sparse_gqa_decode_varlen_indice
.
main
(
batch
=
1
,
max_cache_seqlen
=
2048
)
def
test_example_tilelang_sparse_gqa_decode_varlen_mask
():
example_tilelang_sparse_gqa_decode_varlen_mask
.
main
(
batch
=
1
,
max_cache_seqlen
=
2048
)
def
test_example_triton_sparse_gqa_decode_varlen_indice
():
example_triton_sparse_gqa_decode_varlen_indice
.
main
(
batch
=
8
,
heads
=
8
,
heads_kv
=
4
,
max_cache_seqlen
=
2048
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
def
test_example_triton_sparse_gqa_decode_varlen_mask
():
example_triton_sparse_gqa_decode_varlen_mask
.
main
(
batch
=
16
,
heads
=
16
,
heads_kv
=
8
,
max_cache_seqlen
=
1024
,
dim
=
128
,
dim_v
=
128
,
sparse_ratio
=
0.8
,
block_size
=
32
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
examples/blocksparse_gemm/example_blocksparse_gemm.py
0 → 100644
View file @
cf6e11c9
import
argparse
import
itertools
import
tilelang
import
tilelang.language
as
T
from
tilelang.engine.param
import
KernelParam
from
tilelang.utils.tensor
import
get_tensor_supply
,
TensorSupplyType
import
torch
from
typing
import
List
DEFAULT_BLOCK_M
=
128
DEFAULT_BLOCK_N
=
128
DEFAULT_BLOCK_K
=
32
DEFAULT_NUM_STAGES
=
2
DEFAULT_THREAD_NUM
=
128
DEFAULT_ENABLE_RASTERIZATION
=
True
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotuned BlockSparse MatMul Benchmark"
)
parser
.
add_argument
(
"--m"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension M"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension K"
)
parser
.
add_argument
(
"--sparsity"
,
type
=
float
,
default
=
0.5
,
help
=
"Sparsity ratio (0-1)"
)
parser
.
add_argument
(
"--use_autotune"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use autotune"
)
args
,
_
=
parser
.
parse_known_args
()
M
,
N
,
K
=
args
.
m
,
args
.
n
,
args
.
k
sparsity
=
args
.
sparsity
use_autotune
=
args
.
use_autotune
default_tensor_supply
=
get_tensor_supply
(
TensorSupplyType
.
Auto
)
print
(
f
"Running BlockSparse MatMul Benchmark for M=
{
M
}
, N=
{
N
}
, K=
{
K
}
"
)
print
(
f
"Target Block Sparsity:
{
sparsity
}
"
)
print
(
f
"Using Autotuner:
{
use_autotune
}
\n
"
)
def
get_configs
():
block_M
=
[
64
,
128
,
256
]
block_N
=
[
64
,
128
,
256
]
block_K
=
[
32
,
64
]
num_stages
=
[
1
,
2
,
3
]
thread_num
=
[
128
,
256
]
enable_rasterization
=
[
True
,
False
]
_configs
=
list
(
itertools
.
product
(
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasterization
))
return
[
{
"block_M"
:
c
[
0
],
"block_N"
:
c
[
1
],
"block_K"
:
c
[
2
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
}
for
c
in
_configs
]
def
ref_program
(
A
,
B
,
BlockMask
,
block_M
,
block_N
,
block_K
):
ref_c
=
torch
.
zeros
((
M
,
N
),
dtype
=
torch
.
float16
,
device
=
A
.
device
)
for
i
in
range
(
M
//
block_M
):
for
j
in
range
(
N
//
block_N
):
accu
=
torch
.
zeros
((
block_M
,
block_N
),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
for
k
in
range
(
K
//
block_K
):
if
BlockMask
[
i
,
j
,
k
]:
accu
+=
A
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
k
*
block_K
:
(
k
+
1
)
*
block_K
].
to
(
torch
.
float32
)
@
B
[
k
*
block_K
:
(
k
+
1
)
*
block_K
,
j
*
block_N
:
(
j
+
1
)
*
block_N
].
to
(
torch
.
float32
)
ref_c
[
i
*
block_M
:
(
i
+
1
)
*
block_M
,
j
*
block_N
:
(
j
+
1
)
*
block_N
]
=
accu
.
to
(
torch
.
float16
)
return
ref_c
def
supply_program
(
params
:
List
[
KernelParam
]):
input_tensors
=
[]
for
p
in
params
:
# Check if the kernel parameter is BlockMask tensor.
# Here, BlockMask is uniquely identified by having 3 dimensions.
if
len
(
p
.
shape
)
!=
3
:
# For non-BlockMask tensors, use the default tensor generation logic.
input_tensors
.
append
(
default_tensor_supply
(
p
))
else
:
# For BlockMask tensor, randomly set elements to True based on desired
# sparsity level.
block_mask
=
torch
.
zeros
(
p
.
shape
,
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
())
block_mask
[:,
:,
:]
=
torch
.
rand
(
p
.
shape
)
>
sparsity
input_tensors
.
append
(
block_mask
)
return
input_tensors
@
tilelang
.
autotune
(
configs
=
get_configs
(),
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
blocksparse_matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
num_stages
,
thread_num
,
enable_rasteration
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
block_mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
@
T
.
prim_func
def
block_sparse_matmul
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
BlockMask
:
T
.
Tensor
(
block_mask_shape
,
"bool"
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_shared
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
use_swizzle
(
panel_size
=
10
,
enable
=
enable_rasteration
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
BlockMask
[
by
,
bx
,
k
]:
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
T
.
copy
(
C_local
,
C_shared
)
T
.
copy
(
C_shared
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
block_sparse_matmul
def
main
():
# Initialize input matrices A and B on the GPU with half precision
a
=
torch
.
randn
(
M
,
K
).
cuda
().
half
()
b
=
torch
.
randn
(
K
,
N
).
cuda
().
half
()
if
args
.
use_autotune
:
# Run the autotuner to find the best kernel configuration and performance
# get_best_config is expected to return an object containing the compiled kernel,
# the best configuration found, latency, and reference latency.
kernel
=
blocksparse_matmul
(
M
,
N
,
K
)
best_config
=
kernel
.
config
best_latency
=
kernel
.
latency
block_M
,
block_N
,
block_K
=
best_config
[
"block_M"
],
best_config
[
"block_N"
],
best_config
[
"block_K"
]
print
(
f
"Best Config:
{
best_config
}
"
)
print
(
f
"Sparsity Ratio:
{
sparsity
}
"
)
print
(
f
"Best Kernel Latency:
{
best_latency
:.
6
f
}
ms"
)
else
:
kernel
=
blocksparse_matmul
(
M
,
N
,
K
,
block_M
=
DEFAULT_BLOCK_M
,
block_N
=
DEFAULT_BLOCK_N
,
block_K
=
DEFAULT_BLOCK_K
,
num_stages
=
DEFAULT_NUM_STAGES
,
thread_num
=
DEFAULT_THREAD_NUM
,
enable_rasteration
=
DEFAULT_ENABLE_RASTERIZATION
,
)
block_M
,
block_N
,
block_K
=
DEFAULT_BLOCK_M
,
DEFAULT_BLOCK_N
,
DEFAULT_BLOCK_K
print
(
f
"Using default kernel with block size (
{
block_M
}
,
{
block_N
}
,
{
block_K
}
)"
)
# Create block mask with desired sparsity
mask_shape
=
(
M
//
block_M
,
N
//
block_N
,
K
//
block_K
)
block_mask
=
torch
.
rand
(
mask_shape
).
cuda
()
>
sparsity
# Run the compiled kernel (either tuned or default) with the inputs
c
=
kernel
(
a
,
b
,
block_mask
)
# Compute the reference result using the naive PyTorch implementation
ref_c
=
ref_program
(
a
,
b
,
block_mask
,
block_M
,
block_N
,
block_K
)
try
:
torch
.
testing
.
assert_close
(
c
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
print
(
"✅ Results are close! Verification successful."
)
except
AssertionError
as
e
:
print
(
"❌ Verification FAILED: Results differ significantly."
)
print
(
e
)
if
__name__
==
"__main__"
:
main
()
examples/blocksparse_gemm/test_example_blocksparse_gemm.py
0 → 100644
View file @
cf6e11c9
import
tilelang.testing
import
example_blocksparse_gemm
def
test_example_blocksparse_gemm
():
example_blocksparse_gemm
.
main
()
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
examples/cast/example_group_per_split_token_cast_to_fp8.py
0 → 100644
View file @
cf6e11c9
import
torch
import
tilelang
import
tilelang.language
as
T
from
typing
import
Tuple
from
tilelang.utils.tensor
import
torch_assert_close
# support bfloat16, float, float16
dtype
=
T
.
bfloat16
accum_dtype
=
T
.
float32
@
tilelang
.
jit
(
out_idx
=
[
2
,
3
])
def
group_per_split_token_cast_to_fp8
(
M
,
M_max
,
N
,
BG
,
blk_m
):
group_size
=
128
fp8_min
=
-
448.0
fp8_max
=
448.0
@
T
.
prim_func
def
group_per_split_token_cast
(
X
:
T
.
Tensor
((
M
,
N
),
dtype
),
batch_sizes
:
T
.
Tensor
((
BG
,),
T
.
int32
),
X_fp8
:
T
.
Tensor
((
BG
,
M_max
,
N
),
T
.
float8_e4m3fn
),
X_amax
:
T
.
Tensor
((
BG
,
M_max
,
T
.
ceildiv
(
N
,
group_size
)),
accum_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
M_max
,
blk_m
),
T
.
ceildiv
(
N
,
group_size
),
BG
,
threads
=
128
)
as
(
bx
,
by
,
bz
):
row
=
bx
row_g_id
=
by
bg
=
bz
y_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
accum_dtype
)
y_amax_local
=
T
.
alloc_fragment
((
blk_m
,),
accum_dtype
)
y_s_local
=
T
.
alloc_fragment
((
blk_m
,),
accum_dtype
)
y_q_local
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
accum_dtype
)
y_q_local_fp8
=
T
.
alloc_fragment
((
blk_m
,
group_size
),
T
.
float8_e4m3fn
)
row_offset
=
T
.
alloc_fragment
((
1
,),
T
.
int32
)
T
.
annotate_layout
(
{
y_local
:
T
.
Fragment
(
y_local
.
shape
,
forward_thread_fn
=
lambda
i
,
j
:
(
i
//
(
blk_m
//
4
))
*
32
+
j
%
32
),
}
)
row_offset
[
0
]
=
0
for
i
in
T
.
serial
(
bg
):
row_offset
[
0
]
+=
batch_sizes
[
i
]
T
.
copy
(
X
[
row_offset
[
0
]
+
row
*
blk_m
:
row_offset
[
0
]
+
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
],
y_local
,
)
T
.
reduce_absmax
(
y_local
,
y_amax_local
,
dim
=
1
)
for
i
in
T
.
Parallel
(
blk_m
):
y_amax_local
[
i
]
=
T
.
max
(
y_amax_local
[
i
],
1e-4
)
y_s_local
[
i
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_amax_local
[
i
]
/
fp8_max
,
0
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
y_q_local
[
i
,
j
]
=
T
.
clamp
(
y_local
[
i
,
j
]
/
y_s_local
[
i
],
fp8_min
,
fp8_max
)
T
.
copy
(
y_q_local
,
y_q_local_fp8
)
for
i
,
j
in
T
.
Parallel
(
blk_m
,
group_size
):
y_q_local_fp8
[
i
,
j
]
=
T
.
if_then_else
(
row
*
blk_m
+
i
<
batch_sizes
[
bg
],
y_q_local
[
i
,
j
],
0
)
for
i
in
T
.
Parallel
(
blk_m
):
X_amax
[
bg
,
row
*
blk_m
+
i
,
row_g_id
]
=
y_s_local
[
i
]
T
.
copy
(
y_q_local_fp8
,
X_fp8
[
bg
,
row
*
blk_m
:
(
row
+
1
)
*
blk_m
,
row_g_id
*
group_size
:
(
row_g_id
+
1
)
*
group_size
])
return
group_per_split_token_cast
def
ceil_div
(
x
:
int
,
y
:
int
)
->
int
:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return
(
x
+
y
-
1
)
//
y
def
get_tma_aligned_size
(
x
:
int
,
element_size
:
int
)
->
int
:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes
=
16
assert
tma_alignment_bytes
%
element_size
==
0
alignment
=
tma_alignment_bytes
//
element_size
return
ceil_div
(
x
,
alignment
)
*
alignment
def
get_col_major_tma_aligned_tensor
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along the M axis
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
assert
x
.
dim
()
in
(
2
,
3
)
remove_dim
=
False
m
,
n
=
x
.
shape
[
-
2
],
x
.
shape
[
-
1
]
aligned_m
=
get_tma_aligned_size
(
m
,
x
.
element_size
())
if
x
.
dim
()
==
2
:
if
x
.
stride
(
0
)
==
1
and
x
.
stride
(
1
)
==
aligned_m
:
return
x
x
,
remove_dim
=
x
.
unsqueeze
(
0
),
True
b
=
x
.
shape
[
0
]
# The last kernel gives a column-major TMA aligned layout
if
x
.
stride
(
0
)
==
aligned_m
*
n
and
x
.
stride
(
1
)
==
1
and
x
.
stride
(
2
)
==
aligned_m
:
return
x
.
squeeze
(
0
)
if
remove_dim
else
x
# Normal layout requires transposing
aligned_x
=
torch
.
transpose
(
torch
.
empty
((
b
,
n
,
aligned_m
),
device
=
x
.
device
,
dtype
=
x
.
dtype
),
1
,
2
)
aligned_x
[:,
:
m
,
:]
=
x
aligned_x
=
aligned_x
[:,
:
m
,
:]
return
aligned_x
.
squeeze
(
0
)
if
remove_dim
else
aligned_x
def
ref_per_token_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# this function don't support cpu tensor
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
new_n
=
ceil_div
(
n
,
128
)
*
128
x_padded
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
new_n
-
n
))
x_view
=
x_padded
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
x_fp8
=
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
)
x_fp8
=
x_fp8
.
view
(
m
,
-
1
)[:,
:
n
].
contiguous
()
return
x_fp8
,
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
ref_program
(
x
:
torch
.
Tensor
,
batch_sizes
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# assert x.shape[0] == batch_sizes.sum()
M_max
=
ceil_div
(
batch_sizes
.
max
(),
128
)
*
128
split_x
=
torch
.
split
(
x
,
batch_sizes
.
tolist
(),
dim
=
0
)
padded_x
=
[
torch
.
nn
.
functional
.
pad
(
t
,
(
0
,
0
,
0
,
M_max
-
t
.
shape
[
0
]))
for
t
in
split_x
]
num_groups
,
m
,
n
=
batch_sizes
.
shape
[
0
],
M_max
,
x
.
shape
[
1
]
x_fp8
=
(
torch
.
empty
((
num_groups
,
m
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
),
torch
.
empty
((
num_groups
,
m
,
n
//
128
),
device
=
"cuda"
,
dtype
=
torch
.
float
),
)
for
i
in
range
(
num_groups
):
x_fp8
[
0
][
i
],
x_fp8
[
1
][
i
]
=
ref_per_token_cast_to_fp8
(
padded_x
[
i
])
x_fp8
=
(
x_fp8
[
0
],
get_col_major_tma_aligned_tensor
(
x_fp8
[
1
]))
return
x_fp8
def
main
(
M
=
8192
,
N
=
8192
,
BG
=
2
,
blk_m
=
8
,
batch_sizes
=
None
):
if
batch_sizes
is
None
:
batch_sizes
=
[
2048
,
6144
]
if
dtype
==
T
.
float
:
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
elif
dtype
==
T
.
float16
:
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
elif
dtype
==
T
.
bfloat16
:
x
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
else
:
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
batch_sizes
=
torch
.
tensor
(
batch_sizes
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
M_max
=
int
(
ceil_div
(
batch_sizes
.
max
(),
128
)
*
128
)
print
(
"batch_sizes:"
,
batch_sizes
)
print
(
"M_max:"
,
M_max
)
kernel
=
group_per_split_token_cast_to_fp8
(
M
,
M_max
,
N
,
BG
,
blk_m
)
print
(
kernel
.
get_kernel_source
())
# profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
x_fp8
,
x_amax
=
kernel
(
x
,
batch_sizes
)
x_fp8_ref
,
x_amax_ref
=
ref_program
(
x
,
batch_sizes
)
torch_assert_close
(
x_fp8
.
to
(
torch
.
float32
),
x_fp8_ref
.
to
(
torch
.
float32
),
rtol
=
0.01
,
atol
=
0.01
)
torch_assert_close
(
x_amax
,
x_amax_ref
,
rtol
=
0.01
,
atol
=
0.01
)
print
(
"All checks pass."
)
from
tilelang.profiler
import
do_bench
def
run_tilelang
():
x_fp8_tilelang_
,
x_amax_tilelang_
=
kernel
(
x
,
batch_sizes
)
return
x_fp8_tilelang_
,
x_amax_tilelang_
def
run_torch
():
x_fp8_torch_
,
x_amax_torch_
=
ref_program
(
x
,
batch_sizes
)
return
x_fp8_torch_
,
x_amax_torch_
latency
=
do_bench
(
run_tilelang
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
latency
=
do_bench
(
run_torch
)
print
(
"Torch: {:.2f} ms"
.
format
(
latency
))
if
__name__
==
"__main__"
:
main
()
Prev
1
…
4
5
6
7
8
9
10
11
12
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment