Commit cf6e11c9 authored by qisan's avatar qisan
Browse files

feat: merge dcu branch features

parents 3f27f85a d0436b7b
Pipeline #3369 failed with stages
in 0 seconds
lm_eval==0.3.0
flash_attn
transformers==4.53.0
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LLaMA."""
import os
from shutil import copyfile
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from transformers.convert_slow_tokenizer import import_protobuf
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
from transformers.utils import logging
if TYPE_CHECKING:
from transformers.tokenization_utils_base import TextInput
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
},
"tokenizer_file": {
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"hf-internal-testing/llama-tokenizer": 2048,
}
SPIECE_UNDERLINE = "▁"
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
# fmt: off
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
class BitnetTokenizer(PreTrainedTokenizer):
"""
Construct a Bitnet tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
no padding token in the original model.
Args:
vocab_file (`str`):
Path to the vocabulary file.
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
The end of sequence token.
pad_token (`str` or `tokenizers.AddedToken`, *optional*):
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
attention mechanisms or loss computation.
sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
to set:
- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
add_bos_token (`bool`, *optional*, defaults to `True`):
Whether or not to add an `bos_token` at the start of sequences.
add_eos_token (`bool`, *optional*, defaults to `False`):
Whether or not to add an `eos_token` at the end of sequences.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
extra spaces.
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
Whether or not the default system prompt for Bitnet should be used.
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to add spaces between special tokens.
legacy (`bool`, *optional*):
Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple
example:
- `legacy=True`:
```python
>>> from transformers import T5Tokenizer
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True)
>>> tokenizer.encode("Hello <extra_id_0>.")
[8774, 32099, 3, 5, 1]
```
- `legacy=False`:
```python
>>> from transformers import T5Tokenizer
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False)
>>> tokenizer.encode("Hello <extra_id_0>.") # the extra space `[3]` is no longer here
[8774, 32099, 5, 1]
```
Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
add_prefix_space (`bool`, *optional*, defaults to `True`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
pad_token=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
clean_up_tokenization_spaces=False,
use_default_system_prompt=False,
spaces_between_special_tokens=False,
legacy=None,
add_prefix_space=True,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
if legacy is None:
logger.warning_once(
f"You are using the default legacy behavior of the {self.__class__}. This is"
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
" If you want to use the new behavior, set `legacy=False`. This should only be set if you understand what it"
" means, and thoroughly read the reason why this was added as explained in"
" https://github.com/huggingface/transformers/pull/24565"
)
legacy = True
self.legacy = legacy
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
self.add_prefix_space = add_prefix_space
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
sp_model_kwargs=self.sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
use_default_system_prompt=use_default_system_prompt,
spaces_between_special_tokens=spaces_between_special_tokens,
legacy=legacy,
add_prefix_space=add_prefix_space,
**kwargs,
)
@property
def unk_token_length(self):
return len(self.sp_model.encode(str(self.unk_token)))
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
def get_spm_processor(self, from_slow=False):
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
if self.legacy or from_slow: # no dependency on protobuf
tokenizer.Load(self.vocab_file)
return tokenizer
with open(self.vocab_file, "rb") as f:
sp_model = f.read()
model_pb2 = import_protobuf(f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)")
model = model_pb2.ModelProto.FromString(sp_model)
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
sp_model = model.SerializeToString()
tokenizer.LoadFromSerializedProto(sp_model)
return tokenizer
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
return state
def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
@property
def vocab_size(self):
"""Returns vocab size"""
return self.sp_model.get_piece_size()
def get_vocab(self):
"""Returns vocab as a dict"""
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
"""
Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
first token is special.
"""
if self.legacy or len(text) == 0:
return super().tokenize(text, **kwargs)
text = text.replace(SPIECE_UNDERLINE, " ")
if self.add_prefix_space:
text = SPIECE_UNDERLINE + text
tokens = super().tokenize(text, **kwargs)
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
tokens = tokens[1:]
return tokens
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
def _tokenize(self, text, **kwargs):
"""
Returns a tokenized string.
We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
`['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
"""
tokens = self.sp_model.encode(text, out_type=str)
if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
return tokens
# 1. Encode string + prefix ex: "<unk> Hey"
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index)
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
# since we manually add the prefix space, we have to remove it when decoding
if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
tokens[0] = tokens[0][1:]
current_sub_tokens = []
out_string = ""
prev_is_special = False
for i, token in enumerate(tokens):
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special and i != 0 and self.legacy:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = bos_token_id + token_ids_0 + eos_token_id
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id
return output
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True)
bos_token_id = [1] if self.add_bos_token else []
eos_token_id = [1] if self.add_eos_token else []
if token_ids_1 is None:
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id
def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
sequence pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
if token_ids_1 is None, only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of ids.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
if token_ids_1 is not None:
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
return output
@property
def default_chat_template(self):
"""
LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
results in an unusual token ordering when it is present. This template should definitely be changed if you wish
to fine-tune a model with more flexible role ordering!
The output should look something like:
<bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos><bos>[INST] Prompt [/INST] Answer <eos>
<bos>[INST] Prompt [/INST]
The reference for this chat template is [this code
snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362)
in the original repository.
"""
logger.warning_once(
"\nNo chat template is defined for this tokenizer - using the default template "
f"for the {self.__class__.__name__} class. If the default is not appropriate for "
"your model, please set `tokenizer.chat_template` to an appropriate template. "
"See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
)
template = (
"{% if messages[0]['role'] == 'system' %}"
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
"{% set system_message = messages[0]['content'] %}"
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
"{% else %}"
"{% set loop_messages = messages %}"
"{% set system_message = false %}"
"{% endif %}"
"{% for message in loop_messages %}" # Loop over all non-system messages
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
"{% endif %}"
"{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
"{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
"{% else %}"
"{% set content = message['content'] %}"
"{% endif %}"
"{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
"{% elif message['role'] == 'system' %}"
"{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ ' ' + content.strip() + ' ' + eos_token }}"
"{% endif %}"
"{% endfor %}"
)
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
return template
# pylint: disable=missing-docstring, invalid-name
"""This is modified from https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py to work with BitBLAS."""
import torch
from torch import nn
from bitblas.cache import global_operator_cache, get_database_path
from bitblas import Matmul, MatmulConfig
from bitblas import auto_detect_nvidia_target
from logging import getLogger
logger = getLogger(__name__)
BITBLAS_TARGET = auto_detect_nvidia_target()
BITBLAS_DATABASE_PATH = get_database_path()
def weight_quant(weight, num_bits=1):
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(-1, 1) / s
return result.type(dtype)
def activation_quant(x, num_bits=8):
dtype = x.dtype
x = x.float()
Qn = -(2 ** (num_bits - 1))
Qp = 2 ** (num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp) / s
return result.type(dtype)
class BitLinearBitBLAS(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
weight_bits=1,
input_bits=8,
**kwargs,
):
super().__init__()
"""
RMSNorm is placed outside BitLinear
"""
self.in_features = in_features
self.out_features = out_features
self.weight_bits = weight_bits
self.input_bits = input_bits
matmul_config = MatmulConfig(
N=self.out_features, # N dimension
K=self.in_features, # K dimension
A_dtype="int8", # activation A dtype
W_dtype="int2", # weight W dtype
accum_dtype="int32", # accumulation dtype
out_dtype="float32", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
)
ENABLE_TUNING = True
self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, ENABLE_TUNING)
self.format = "bitnet"
self.Qp = 2 ** (self.input_bits - 1) - 1
def _get_or_create_bitblas_operator(self, config, enable_tuning):
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
logger.info(f"Loaded {global_operator_cache.size()} operators from database.")
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
# should disable tuning for the first time because we may require loading bitblas operator from database.
bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False)
if enable_tuning:
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
print("BitBLAS Tuning done, appended operator to global_operator_cache.")
else:
print("BitBLAS Operator created.")
else:
print("BitBLAS Operator found in global_operator_cache.")
return bitblas_matmul
def replace_weight_param_with_qweight(self):
if hasattr(self, "weight"):
del self.weight
quant_weight = torch.empty(self.bitblas_matmul.retrieve_weight_shape())
self.qweight = nn.Parameter(quant_weight, requires_grad=False)
self.format = "bitblas"
@classmethod
def from_bit_linear(cls, bitlinear, weight_group=1):
bitblas_linear = cls(bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8)
sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight, weight_group)
bitblas_linear.register_buffer("qweight", qweight)
bitblas_linear.register_buffer("sw", sw)
if bitlinear.bias is not None:
bitblas_linear.register_buffer("bias", bitlinear.bias)
else:
bitblas_linear.bias = None
return bitblas_linear
def create_bitblas_weights(self, weight, weight_group=1):
if weight_group:
hidden_size = weight.size(0)
group_size = hidden_size // weight_group
sw_list = []
qweight_list = []
for i in range(weight_group):
start_idx = i * group_size
end_idx = (i + 1) * group_size
sw = 1 / weight[start_idx:end_idx].abs().mean().clamp(min=1e-5)
sw_list.append(sw.repeat(group_size))
qweight = self.weight_quant(weight[start_idx:end_idx]).detach()
qweight_list.append(qweight)
sw = torch.cat(sw_list, dim=0)
qweight = torch.cat(qweight_list, dim=0)
else:
sw = 1 / weight.abs().mean().clamp(min=1e-5)
qweight = self.weight_quant(weight).detach()
qweight = self.bitblas_matmul.transform_weight(qweight)
qweight = nn.Parameter(qweight, requires_grad=False)
return sw, qweight
def post_process_weights(self):
sw = 1 / self.weight.abs().mean().clamp(min=1e-5)
self.sw = sw
quant_weight = self.weight_quant(self.weight).detach()
quant_weight = self.bitblas_matmul.transform_weight(quant_weight)
# remove self.weight and replace it with quant_weight
if hasattr(self, "weight"):
del self.weight
self.qweight = nn.Parameter(quant_weight, requires_grad=False)
self.format = "bitblas"
@staticmethod
def weight_quant(weight):
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(-1, 1)
return result.type(torch.int8)
@torch.compile
def activation_quant(self, x, num_bits=8):
x = x.float()
Qn = -(2 ** (num_bits - 1))
Qp = 2 ** (num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp)
return result.type(torch.int8), s
@torch.compile
def post_quant_process(self, input, si, sw):
out = input / si
out = out / sw
out = out.half()
return out
# for the correctness evaluation.
def native_forward(self, input):
quant_input = input + (activation_quant(input, self.input_bits) - input).detach()
quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach()
out = nn.functional.linear(quant_input, quant_weight)
if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
def forward_fp32_simulated(self, input):
quant_input, si = self.activation_quant(input, self.input_bits).detach()
quant_weight = self.weight_quant(self.weight).detach()
fp32_simulated_input = quant_input.float()
fp32_simulated_weight = quant_weight.float()
fp32_simulated_out = nn.functional.linear(fp32_simulated_input, fp32_simulated_weight)
sw = 1 / self.weight.abs().mean().clamp(min=1e-5)
# if / (si * sw) it will inf in some cases
out = fp32_simulated_out / si
out = out / sw
out = out.half()
if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
def forward(self, input):
# return self.forward_fp32_simulated(input)
quant_input, si = self.activation_quant(input, self.input_bits)
fp32_out = self.bitblas_matmul(quant_input, self.qweight)
sw = self.sw
# if / (si * sw) it will inf in some cases
out = self.post_quant_process(fp32_out, si, sw)
if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
# Naive BitLinear from HuggingFace
class BitLinear(nn.Linear):
def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs):
super(BitLinear, self).__init__(*kargs, **kwargs)
"""
RMSNorm is placed outside BitLinear
"""
self.weight_bits = weight_bits
self.input_bits = input_bits
def forward(self, input):
quant_input = input + (activation_quant(input, self.input_bits) - input).detach()
quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach()
out = nn.functional.linear(quant_input, quant_weight)
if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
import contextlib
import gc
import os
import sys
from collections import UserList
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import (
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoTokenizer,
BatchEncoding,
)
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig
from vllm.distributed import destroy_distributed_environment, destroy_model_parallel
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu
logger = init_logger(__name__)
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
def _read_prompts(filename: str) -> List[str]:
with open(filename, "r") as f:
prompts = f.readlines()
return prompts
class _ImageAssetPrompts(TypedDict):
stop_sign: str
cherry_blossom: str
if sys.version_info < (3, 9):
# UserList cannot be subscripted
class _ImageAssetsBase(UserList):
pass
else:
class _ImageAssetsBase(UserList[ImageAsset]):
pass
class _ImageAssets(_ImageAssetsBase):
def __init__(self) -> None:
super().__init__(
[
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]
)
def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
"""
Convenience method to define the prompt for each test image.
The order of the returned prompts matches the order of the
assets when iterating through this object.
"""
return [prompts["stop_sign"], prompts["cherry_blossom"]]
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
if not is_cpu():
torch.cuda.empty_cache()
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
"""Allow subdirectories to skip global cleanup by overriding this fixture.
This can provide a ~10x speedup for non-GPU unit tests since they don't need
to initialize torch.
"""
if not request.node.get_closest_marker("skip_global_cleanup"):
return False
@pytest.fixture(autouse=True)
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
yield
if should_do_global_cleanup_after_test:
cleanup()
@pytest.fixture
def example_prompts() -> List[str]:
prompts = []
for filename in _TEST_PROMPTS:
prompts += _read_prompts(filename)
return prompts
@pytest.fixture
def example_long_prompts() -> List[str]:
prompts = []
for filename in _LONG_PROMPTS:
prompts += _read_prompts(filename)
return prompts
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
return IMAGE_ASSETS
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
}
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
class HfRunner:
def wrap_device(self, input: _T) -> _T:
if not is_cpu():
return input.to("cuda")
else:
return input.to("cpu")
def __init__(
self,
model_name: str,
dtype: str = "half",
*,
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_vision_model: bool = False,
is_sparseml_model: bool = False,
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name
if is_embedding_model:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = self.wrap_device(
SentenceTransformer(
model_name,
device="cpu",
).to(dtype=torch_dtype)
)
else:
if is_vision_model:
auto_cls = AutoModelForVision2Seq
elif is_sparseml_model:
from sparseml.transformers import SparseAutoModelForCausalLM
auto_cls = SparseAutoModelForCausalLM
else:
auto_cls = AutoModelForCausalLM
model_kwargs = model_kwargs if model_kwargs is not None else {}
self.model = self.wrap_device(
auto_cls.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
**model_kwargs,
)
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
try:
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401
self.processor = AutoProcessor.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
except Exception:
logger.warning(
"Unable to auto-load processor from HuggingFace for model %s. Using tokenizer instead.",
model_name,
)
self.processor = self.tokenizer
def generate(
self,
prompts: List[str],
images: Optional[List[Image.Image]] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
if images:
assert len(prompts) == len(images)
outputs: List[Tuple[List[List[int]], List[str]]] = []
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
output_ids = self.model.generate(
**self.wrap_device(inputs),
use_cache=True,
**kwargs,
)
output_str = self.processor.batch_decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
output_ids = output_ids.cpu().tolist()
outputs.append((output_ids, output_str))
return outputs
def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(
prompts,
do_sample=False,
max_new_tokens=max_tokens,
images=images,
**kwargs,
)
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
outputs = self.generate(
prompts,
do_sample=False,
max_new_tokens=max_tokens,
num_beams=beam_width,
num_return_sequences=beam_width,
)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
for j in range(len(output_ids)):
output_ids[j] = [x for x in output_ids[j] if x != self.tokenizer.pad_token_id]
outputs[i] = (output_ids, output_str)
return outputs
def generate_greedy_logprobs(
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
**kwargs: Any,
) -> List[List[torch.Tensor]]:
all_logprobs: List[List[torch.Tensor]] = []
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
output = self.model.generate(
**self.wrap_device(inputs),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
**kwargs,
)
seq_logprobs: List[torch.Tensor] = []
for hidden_states in output.hidden_states:
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if self.model.get_output_embeddings().bias is not None:
logits += self.model.get_output_embeddings().bias.unsqueeze(0)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
all_logprobs.append(seq_logprobs)
return all_logprobs
def generate_greedy_logprobs_limit(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
input_ids = inputs.input_ids
output = self.model.generate(
**self.wrap_device(inputs),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
**kwargs,
)
seq_logprobs: List[torch.Tensor] = []
for _, hidden_states in enumerate(output.hidden_states):
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if getattr(self.model.get_output_embeddings(), "bias", None) is not None:
logits += self.model.get_output_embeddings().bias.unsqueeze(0)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
# convert to dict
seq_logprobs_lst: List[Dict[int, float]] = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs
if tok_idx == 0:
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
topk = tok_logprobs.topk(num_logprobs)
tok_logprobs_dct = {}
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
tok_logprobs_dct[token_id.item()] = logprob.item()
seq_logprobs_lst.append(tok_logprobs_dct)
all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
output_len = seq_ids.shape[0] - input_ids.shape[1]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs]
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
return self.model.encode(prompts)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup()
@pytest.fixture(scope="session")
def hf_runner():
return HfRunner
class VllmRunner:
def __init__(
self,
model_name: str,
tokenizer_name: Optional[str] = None,
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
max_model_len: int = 1024,
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: bool = False,
swap_space: int = 4,
enforce_eager: bool = False,
**kwargs,
) -> None:
self.model = LLM(
model=model_name,
tokenizer=tokenizer_name,
trust_remote_code=True,
dtype=dtype,
swap_space=swap_space,
enforce_eager=enforce_eager,
disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
block_size=block_size,
enable_chunked_prefill=enable_chunked_prefill,
**kwargs,
)
def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image}
req_outputs = self.model.generate(inputs, sampling_params=sampling_params)
outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids: List[List[int]] = []
req_sample_output_strs: List[str] = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
def generate_w_logprobs(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None
if images is not None:
assert len(prompts) == len(images)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image}
req_outputs = self.model.generate(inputs, sampling_params=sampling_params)
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
for req_output in req_outputs:
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs))
return outputs
def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images)
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
def generate_greedy_logprobs(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs)
outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params, images=images)
return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs]
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
beam_search_params = SamplingParams(
n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens,
)
outputs = self.generate(prompts, beam_search_params)
return outputs
def encode(self, prompts: List[str]) -> List[List[float]]:
req_outputs = self.model.encode(prompts)
outputs = []
for req_output in req_outputs:
embedding = req_output.outputs.embedding
outputs.append(embedding)
return outputs
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup()
@pytest.fixture(scope="session")
def vllm_runner():
return VllmRunner
def get_tokenizer_pool_config(tokenizer_group_type):
if tokenizer_group_type is None:
return None
if tokenizer_group_type == "ray":
return TokenizerPoolConfig(pool_size=1, pool_type="ray", extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
@pytest.fixture()
def temporary_enable_log_propagate():
import logging
logger = logging.getLogger("vllm")
logger.propagate = True
yield
logger.propagate = False
@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
# To capture vllm log, we should enable propagate=True temporarily
# because caplog depends on logs propagated to the root logger.
yield caplog
@pytest.fixture(scope="session")
def num_gpus_available():
"""Get number of GPUs without initializing the CUDA context
in current process."""
return cuda_device_count_stateless()
"""Compare the outputs of a GPTQ model to a Marlin model.
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
Marlin/GPTQ models are in the top 3 selections of each other.
Note: Marlin internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for Marlin. As a result, we re-run the test
up to 3 times to see if we pass.
Run `pytest tests/models/test_marlin.py`.
"""
from conftest import VllmRunner
import os
import argparse
# get the path of the current file
current_file_path = os.path.realpath(__file__)
current_dir = os.path.dirname(current_file_path)
ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B_bitblas")
parser = argparse.ArgumentParser(description="Inference with BitNet")
parser.add_argument(
"--ckpt_path",
type=str,
default=ckpt_path,
help="Path to the checkpoint",
)
args = parser.parse_args()
ckpt_path = args.ckpt_path
with VllmRunner(
ckpt_path,
dtype="half",
quantization="bitblas",
# set enforce_eager = False to enable cuda graph
# set enforce_eager = True to disable cuda graph
enforce_eager=False,
) as bitnet_model:
bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=1024)
print("bitnet inference:")
print(bitbnet_outputs[0][0])
print(bitbnet_outputs[0][1])
"""Compare the outputs of a GPTQ model to a Marlin model.
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
Marlin/GPTQ models are in the top 3 selections of each other.
Note: Marlin internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for Marlin. As a result, we re-run the test
up to 3 times to see if we pass.
Run `pytest tests/models/test_marlin.py`.
"""
from conftest import VllmRunner
import os
import argparse
# get the path of the current file
current_file_path = os.path.realpath(__file__)
current_dir = os.path.dirname(current_file_path)
ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B")
parser = argparse.ArgumentParser(description="Inference with BitNet")
parser.add_argument(
"--ckpt_path",
type=str,
default=ckpt_path,
help="Path to the checkpoint",
)
args = parser.parse_args()
ckpt_path = args.ckpt_path
with VllmRunner(
ckpt_path,
dtype="half",
quantization="bitnet_bitblas",
gpu_memory_utilization=0.5,
# set enforce_eager = False to enable cuda graph
# set enforce_eager = True to disable cuda graph
enforce_eager=False,
) as bitnet_model:
bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=128)
print("bitnet inference output:")
print(bitbnet_outputs[0][0])
print(bitbnet_outputs[0][1])
from typing import Dict, List, Tuple
TokensText = Tuple[List[int], str]
def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], name_0: str, name_1: str):
"""
Compare the two sequences generated by different models,
which should be equal.
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)):
output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1
assert output_str_0 == output_str_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
assert output_ids_0 == output_ids_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str):
"""
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
# Loop through responses to each prompt.
for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)):
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
# Loop through generated tokens.
for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
# If generated tokens don't match, then
if output_id_0 != output_id_1:
# Each predicted token must be in top N logprobs of the other
assert output_id_0 in logprobs_1[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
assert output_id_1 in logprobs_0[idx], f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
# Break out since sequences will now diverge.
break
# Block-Sparse Flash-Attention
Tilelang implementation of block-sparse flash-attention kernels.
The kernels have been used in [Rectified Sparse Attention](https://arxiv.org/abs/2506.04108) and [SeerAttention-R](https://arxiv.org/abs/2506.08889).
# ruff: noqa: E712
import math
import torch
import triton
import triton.language as tl
import torch.nn.functional as F
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
dense_mask = x > threshold
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
@triton.jit
def _fwd_kernel_inner(
acc,
l_i,
m_i,
q,
k_block_col_idx,
block_mask_ptr,
k_ptrs,
v_ptrs,
offs_m,
offs_n,
stride_kt,
stride_vt,
stride_bmask_n,
sm_scale,
seqlen_k,
past_len,
LAST_K_BLOCK: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
# print
if mask_val == True:
start_n = k_block_col_idx * BLOCK_N
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kt)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK:
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.exp(qk)
l_ij = tl.sum(p, 1)
alpha = tl.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vt)
p = p.to(v.type.element_ty)
acc += tl.dot(p, v)
# update m_i and l_i
m_i = m_ij
return acc, l_i, m_i
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
block_mask_ptr,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qd,
stride_kz,
stride_kh,
stride_kn,
stride_kd,
stride_vz,
stride_vh,
stride_vn,
stride_vd,
stride_bmz,
stride_bmh,
stride_bmm,
stride_bmn,
stride_oz,
stride_oh,
stride_om,
stride_od,
H,
N_CTX,
PAST_LEN,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
):
Q_LEN = N_CTX - PAST_LEN
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_h = off_hz % H
off_z = off_hz // H
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
block_mask_ptr += off_z * stride_bmz + off_h * stride_bmh
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd
# off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd
off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
mask_ptrs = block_mask_ptr + start_m * stride_bmm
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN)
k_block_start = 0
k_block_end = tl.cdiv((start_m + 1) * BLOCK_M, BLOCK_N)
# loop over k, v and update accumulator
for col_idx in range(k_block_start, k_block_end):
acc, l_i, m_i = _fwd_kernel_inner(
acc,
l_i,
m_i,
q,
col_idx,
mask_ptrs,
k_ptrs,
v_ptrs,
offs_m,
offs_n,
stride_kn,
stride_vn,
stride_bmn,
sm_scale,
N_CTX,
PAST_LEN,
col_idx == k_block_end - 1,
BLOCK_M,
BLOCK_N,
)
m_i += tl.math.log(l_i)
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
acc = acc.to(Out.dtype.element_ty)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None):
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2]
o = out if out is not None else torch.empty_like(q).contiguous()
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
assert q.shape[-1] in [64, 128]
BLOCK_DMODEL = q.shape[-1]
if is_hip():
num_warps, num_stages = 8, 1
else:
num_warps, num_stages = 4, 2
N_CTX = k.shape[2]
PAST_LEN = N_CTX - q.shape[2]
H = q.shape[1]
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
block_sparse_mask,
o,
*q.stride(),
*k.stride(),
*v.stride(),
*block_sparse_mask.stride(),
*o.stride(),
H,
N_CTX,
PAST_LEN,
BLOCK_M,
BLOCK_N,
BLOCK_DMODEL,
num_warps=num_warps,
num_stages=num_stages,
)
return o
class _sparse_attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
# shape constraints
return _forward(ctx, q, k, v, block_sparse_dense, sm_scale)
@staticmethod
def backward(ctx, do):
# No gradient propagation.
raise NotImplementedError("It does not support gradient propagation yet")
return None, None, None, None, None
block_sparse_triton_fn = _sparse_attention.apply
def test_topk_sparse_attention():
# Config
BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64
TOPK = 2 # Keep top 8 elements per row
BLOCK = 64
torch.manual_seed(0)
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
print("downsample_len", downsample_len)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
print("x_ds.shape", x_ds.shape)
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# print("block_mask", block_mask)
print("block_mask.shape", block_mask.shape)
# Run Triton kernel
triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation
attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
# print("ref_output", ref_output)
# print("triton_output", triton_output)
# Verify accuracy
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen")
def test_topk_sparse_attention_qlt_kl():
BATCH, N_HEADS = 2, 4
Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128.
TOPK = 1
BLOCK = 64 # block size used in downsampling
torch.manual_seed(0)
# Create inputs.
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device="cuda", dtype=torch.bfloat16)
# softmax scale
sm_scale = 1.0 / (D_HEAD**0.5)
downsample_factor = BLOCK
print("downsample_factor", downsample_factor)
downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
print("downsample_len", downsample_len)
x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, device="cuda", dtype=torch.bfloat16)
# Force the first column to be high so that the first block is always selected.
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
print("block_mask", block_mask)
print("block_mask.shape", block_mask.shape)
# Run Triton kernel.
triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
past_len = K_LEN - Q_LEN
attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")).bool()
full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN)
i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1)
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
causal_mask = j_global <= i_global # shape: (Q_LEN, K_LEN)
final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN)
attn = attn.masked_fill(~final_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
# Verify accuracy.
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), "Triton output doesn't match reference when qlen < klen"
print("Pass topk sparse attention test with qlen < klen")
def main():
test_topk_sparse_attention()
test_topk_sparse_attention_qlt_kl()
if __name__ == "__main__":
main()
import math
import torch
import tilelang
import tilelang.language as T
import torch.nn.functional as F
def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
dense_mask = x > threshold
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
@tilelang.jit(
out_idx=[4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64
block_N = 64
num_stages = 1
threads = 128
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len]
dtype = T.float16
accum_dtype = T.float32
block_mask_dtype = T.bool
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def blocksparse_flashattn(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for vj in T.serial(downsample_len):
block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k] != 0:
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return blocksparse_flashattn
return kernel_func(block_M, block_N, num_stages, threads)
def test_topk_sparse_attention():
# Config
BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64
TOPK = 2 # Keep top 8 elements per row
BLOCK = 64
torch.manual_seed(0)
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# Run tilelang kernel
kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
tilelang_output = kernel(q, k, v, block_mask)
# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation
attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
print("ref_output", ref_output)
print("tilelang_output", tilelang_output)
# Verify accuracy
torch.testing.assert_close(tilelang_output, ref_output, atol=1e-2, rtol=1e-2)
print("Pass topk sparse attention test with qlen == klen")
def main():
test_topk_sparse_attention()
if __name__ == "__main__":
main()
# ruff: noqa
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
import time
import math
from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // heads_kv
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def kernel_func(
block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, max_num_blocks_per_seq, max_selected_blocks
):
shape_q = [batch, heads, dim]
shape_k = [num_pages, page_block_size, heads_kv, dim]
shape_v = [num_pages, page_block_size, heads_kv, dim_v]
shape_indices = [batch, heads_kv, max_selected_blocks]
shape_block_table = [batch, max_num_blocks_per_seq]
shape_o = [batch, heads, dim_v]
part_shape = [batch, heads, num_split, dim_v]
valid_block_H = min(block_H, kv_group_num)
assert block_N <= page_block_size and page_block_size % block_N == 0
block_ratio = page_block_size // block_N
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, T.int32),
cache_seqlens: T.Tensor([batch], T.int32),
block_table: T.Tensor(shape_block_table, T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
has_valid_block = T.alloc_var("bool")
bid = bx
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
num_blocks = max_selected_blocks
blocks_per_split = T.floordiv(num_blocks, num_split)
remaining_blocks = T.floormod(num_blocks, num_split)
loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)
start = blocks_per_split * sid + T.min(sid, remaining_blocks)
has_valid_block = False
for k in T.Pipelined(loop_range, num_stages=num_stages):
logical_block_idx = block_indices[bid, cur_kv_head, start + k]
if logical_block_idx >= 0:
has_valid_block = True
block_table_idx = T.floordiv(logical_block_idx, block_ratio)
block_tile_idx = T.floormod(logical_block_idx, block_ratio)
physical_block_idx = block_table[bid, block_table_idx]
T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(
logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]
)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] *= scores_scale[i]
T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_valid_block:
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
for i in T.Parallel(block_H):
if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i]
for i, j in T.Parallel(block_H, dim_v):
if i < valid_block_H:
Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j]
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype)
o_accum_local = T.alloc_fragment([dim_v], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], T.int32)
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_local_split[0] = glse[bz, by, k]
if lse_local_split[0] != 0:
max_split[0] = k
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
if k <= max_split[0]:
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
if k <= max_split[0]:
for i in T.Parallel(dim_v):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim_v):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim_v):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, T.int32),
cache_seqlens: T.Tensor([batch], T.int32),
block_table: T.Tensor(shape_block_table, T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse, Output_partial)
combine(glse, Output_partial, Output)
return main
return kernel_func
class SparseFlashAttn(torch.nn.Module):
def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages):
super(SparseFlashAttn, self).__init__()
self.batch = batch
self.heads = heads
self.heads_kv = heads_kv
self.dim = dim
self.dim_v = dim_v
self.block_N = block_N
self.page_block_size = page_block_size
self.num_pages = num_pages
self.block_H = 64
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_N,
block_H=self.block_H,
page_block_size=page_block_size,
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
num_pages=num_pages,
max_num_blocks_per_seq=T.dynamic("max_num_blocks_per_seq"),
max_selected_blocks=T.dynamic("max_selected_blocks"),
)
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
def forward(self, query, key, value, block_indices, cache_seqlens, block_table):
batch = self.batch
heads = self.heads
heads_kv = self.heads_kv
dim_v = self.dim_v
dim = self.dim
block_size = self.block_N
max_selected_blocks = block_indices.shape[-1]
# Compute static scheduling parameters
num_m_blocks = 1 * (heads // heads_kv + self.block_H - 1) // self.block_H
num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = self.num_sm
num_split = num_splits_heuristic(
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
output = self.kernel(
query,
key,
value,
block_indices,
cache_seqlens,
block_table,
glse,
output_partial,
)
return output
def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens, block_table, page_block_size, block_size):
"""
Paged version of sparse attention reference implementation.
Args:
query: [batch, heads, dim]
key_cache: [num_pages, page_block_size, heads_kv, dim]
value_cache: [num_pages, page_block_size, heads_kv, dim]
block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices
cache_seqlens: [batch] - actual sequence lengths
block_table: [batch, max_num_blocks_per_seq] - maps logical to physical blocks
page_block_size: size of each page block
block_size: size of attention blocks (block_N)
"""
batch, heads, dim = query.shape
heads_kv = key_cache.shape[2]
dim_v = value_cache.shape[3]
num_head_groups = heads // heads_kv
scale = dim**0.5
# Reconstruct the full key and value tensors from paged cache
max_cache_seqlen = max(cache_seqlens).item()
key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim), dtype=key_cache.dtype, device=key_cache.device)
value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v), dtype=value_cache.dtype, device=value_cache.device)
# Reconstruct full tensors from paged cache using block_table
for b in range(batch):
seq_len = cache_seqlens[b].item()
num_blocks_needed = int(math.ceil(seq_len / page_block_size))
for block_idx in range(num_blocks_needed):
physical_block_idx = block_table[b, block_idx].item()
# Calculate the range of tokens for this block
start_token = block_idx * page_block_size
end_token = min(start_token + page_block_size, seq_len)
actual_block_size = end_token - start_token
# Copy from paged cache to full tensors
key_full[b, :, start_token:end_token, :] = key_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1)
value_full[b, :, start_token:end_token, :] = value_cache[physical_block_idx, :actual_block_size, :, :].transpose(0, 1)
# Reshape query for grouped attention
query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
# Compute attention scores
scores = einsum(query, key_full, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
# Create sparse mask based on block_indices
sparse_mask = torch.zeros_like(scores)
# Apply sparse mask based on selected blocks
for b in range(batch):
for h in range(heads_kv):
valid_indices = block_indices[b, h] # Extract indices for this batch and head
for idx in valid_indices:
if idx >= 0: # Valid block index
start_pos = idx * block_size
end_pos = min(start_pos + block_size, max_cache_seqlen)
sparse_mask[b, :, h, start_pos:end_pos] = 1
# Apply sparse mask
scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
# Apply causal mask based on actual sequence lengths
range_len = torch.arange(scores.shape[-1], device=scores.device).unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float("-inf"))
# Compute attention weights
attention = F.softmax(scores / scale, dim=-1)
# Apply attention to values
out = einsum(attention, value_full, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
# Reshape output back to original format
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def ref_program_fa(query, kcache, vcache, cache_seqlens, block_table):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table)
output = output.squeeze(1)
return output
def main(args):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = (
args.batch,
args.heads,
args.heads_kv,
args.max_cache_seqlen,
args.dim,
args.dim_v,
)
sparse_ratio = args.sparse_ratio
block_N = args.block_N
page_block_size = args.page_block_size
num_blocks = args.num_pages # Use num_pages from args
# For dense case verification, set sparse_ratio to 0 to select all blocks
max_selected_blocks = int(math.ceil(max_cache_seqlen / block_N))
print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16
# Generate random inputs
Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda")
print("cache_seqlens: ", cache_seqlens)
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
# Create paged KV cache
K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda")
V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda")
# Create block table and block indices for dense case (all blocks selected)
max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size))
print("max_num_blocks_per_seq: ", max_num_blocks_per_seq)
block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda")
block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda")
# Fill block table and block indices and cache
# Create a pool of available physical blocks
total_blocks_needed = sum(int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch))
available_blocks = list(range(total_blocks_needed))
import random
random.seed(42) # For reproducibility
random.shuffle(available_blocks)
# Fill block table with random physical block indices
block_assignment = {} # Map (seq_idx, block_idx) -> physical_block_idx
block_idx_counter = 0
for seq_idx in range(batch):
seq_len = cache_seqlens[seq_idx].item()
num_blocks_needed = int(math.ceil(seq_len / page_block_size))
# Assign random physical blocks for each sequence
for block_idx in range(num_blocks_needed):
physical_block_idx = available_blocks[block_idx_counter]
block_table[seq_idx, block_idx] = physical_block_idx
block_assignment[(seq_idx, block_idx)] = physical_block_idx
block_idx_counter += 1
print(f"Block table: {block_table}")
# Fill K_cache and V_cache with data from original K and V tensors using random block assignment
for seq_idx in range(batch):
seq_len = cache_seqlens[seq_idx].item()
num_blocks_needed = int(math.ceil(seq_len / page_block_size))
for block_idx in range(num_blocks_needed):
physical_block_idx = block_assignment[(seq_idx, block_idx)]
# Calculate the range of tokens for this block
start_token = block_idx * page_block_size
end_token = min(start_token + page_block_size, seq_len)
actual_block_size = end_token - start_token
# Copy K and V data to the paged cache
K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, start_token:end_token, :, :]
V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, start_token:end_token, :, :]
# Fill block_indices for sparse attention
# For dense case (verification), we select all blocks in reverse order
# For sparse case, we select a subset of blocks based on sparse_ratio
for seq_idx in range(batch):
seq_len = cache_seqlens[seq_idx].item()
num_tile = int(math.ceil(seq_len / block_N))
if sparse_ratio == 0.0:
# Dense case: select all blocks in reverse order
selected_blocks = min(num_tile, max_selected_blocks)
for head_idx in range(heads_kv):
for i in range(selected_blocks):
# Select blocks in reverse order (most recent first)
block_indices[seq_idx, head_idx, i] = num_tile - 1 - i
# Fill remaining slots with -1 (invalid)
for i in range(selected_blocks, max_selected_blocks):
block_indices[seq_idx, head_idx, i] = -1
else:
# Fill block_indices for all KV heads
num_selected = int(num_tile * (1.0 - sparse_ratio))
num_selected = max(1, min(num_selected, max_selected_blocks))
all_blocks = list(range(num_tile))
for head_idx in range(heads_kv):
selected_blocks = []
# Always include the most recent blocks
recent_blocks = 1
selected_blocks.append(num_tile - 1)
# Randomly select some earlier blocks
if num_selected > recent_blocks:
remaining_blocks = [b for b in all_blocks if b not in selected_blocks]
if remaining_blocks:
import random
random.seed(42) # For reproducibility
additional_blocks = random.sample(remaining_blocks, min(num_selected - recent_blocks, len(remaining_blocks)))
selected_blocks.extend(additional_blocks)
# Sort selected blocks in reverse order (most recent first)
selected_blocks.sort(reverse=True)
for i in range(len(selected_blocks)):
block_indices[seq_idx, head_idx, i] = selected_blocks[i]
# Fill remaining slots with -1 (invalid)
for i in range(len(selected_blocks), max_selected_blocks):
block_indices[seq_idx, head_idx, i] = -1
# Initialize sparse attention module
sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks)
output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table)
import flash_attn # noqa: F401
output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table, page_block_size, block_N)
output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table)
# Check correctness
if sparse_ratio == 0.0:
max_diff = torch.max(torch.abs(output_sparse - output_ref_fa)).item()
mean_diff = torch.mean(torch.abs(output_sparse - output_ref_fa)).item()
assert torch.allclose(output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!"
else:
max_diff = torch.max(torch.abs(output_sparse - output_ref_torch)).item()
mean_diff = torch.mean(torch.abs(output_sparse - output_ref_torch)).item()
print(f"Max difference: {max_diff:.6f}")
print(f"Mean difference: {mean_diff:.6f}")
if max_diff < 1e-2:
print("✓ Verification PASSED: Results match within tolerance")
else:
print("✗ Verification FAILED: Results differ significantly")
# Performance measurement
for _ in range(10): # Warm-up
sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table)
torch.cuda.synchronize()
start_time = time.time()
for _ in range(100): # Run multiple times for averaging
sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table)
torch.cuda.synchronize()
end_time = time.time()
kernel_time = (end_time - start_time) / 100 * 1000 # Convert to ms
print(f"Kernel execution time: {kernel_time:.2f} ms")
# FA performance measurement
for _ in range(10): # Warm-up
ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table)
torch.cuda.synchronize()
start_time_fa = time.time()
for _ in range(100): # Run multiple times for averaging
ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table)
torch.cuda.synchronize()
end_time_fa = time.time()
kernel_time_fa = (end_time_fa - start_time_fa) / 100 * 1000 # Convert to ms
print(f"FA kernel execution time: {kernel_time_fa:.2f} ms")
print(f"Speedup: {kernel_time_fa / kernel_time:.2f}x")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument("--sparse_ratio", type=float, default=0.0, help="sparse ratio")
parser.add_argument("--block_N", type=int, default=64, help="block_N")
parser.add_argument("--page_block_size", type=int, default=256, help="block size of pages")
parser.add_argument("--num_pages", type=int, default=1024, help="total number of pages")
args = parser.parse_args()
main(args)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from einops import rearrange, einsum
import argparse
import time
import math
from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // heads_kv
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, max_selected_blocks):
shape_q = [batch, heads, dim]
shape_k = [batch, max_cache_seqlen, heads_kv, dim]
shape_v = [batch, max_cache_seqlen, heads_kv, dim_v]
shape_indices = [batch, heads_kv, max_selected_blocks]
shape_o = [batch, heads, dim_v]
part_shape = [batch, heads, num_split, dim_v]
valid_block_H = min(block_H, kv_group_num)
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, T.int32),
cache_seqlens: T.Tensor([batch], T.int32),
# actual_num_blocks: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
# O_shared = T.alloc_shared([valid_block_H, dim_v], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
has_valid_block = T.alloc_var("bool")
bid = bx
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
num_blocks = max_selected_blocks
blocks_per_split = T.floordiv(num_blocks, num_split)
remaining_blocks = T.floormod(num_blocks, num_split)
loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)
start = blocks_per_split * sid + T.min(sid, remaining_blocks)
has_valid_block = False
for k in T.Pipelined(loop_range, num_stages=num_stages):
i_s = block_indices[bid, cur_kv_head, start + k]
if i_s >= 0:
has_valid_block = True
T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j])
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] *= scores_scale[i]
T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_valid_block:
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
for i in T.Parallel(block_H):
if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i]
for i, j in T.Parallel(block_H, dim_v):
if i < valid_block_H:
Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j]
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype)
o_accum_local = T.alloc_fragment([dim_v], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], T.int32)
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_local_split[0] = glse[bz, by, k]
if lse_local_split[0] != 0:
max_split[0] = k
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
if k <= max_split[0]:
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
if k <= max_split[0]:
for i in T.Parallel(dim_v):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim_v):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim_v):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, T.int32),
cache_seqlens: T.Tensor([batch], T.int32),
# actual_num_blocks: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
# flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial)
flash_attn_split(Q, K, V, block_indices, cache_seqlens, glse, Output_partial)
combine(glse, Output_partial, Output)
return main
return kernel_func
class SparseFlashAttn(torch.nn.Module):
def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
super(SparseFlashAttn, self).__init__()
self.batch = batch
self.heads = heads
self.heads_kv = heads_kv
self.dim = dim
self.dim_v = dim_v
self.block_size = block_size
self.block_H = 64
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=self.block_H,
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.dynamic("max_selected_blocks"),
)
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
def forward(self, query, key, value, block_indices, cache_seqlens):
batch = self.batch
heads = self.heads
heads_kv = self.heads_kv
dim_v = self.dim_v
dim = self.dim
block_size = self.block_size
max_selected_blocks = block_indices.shape[-1]
# Compute static scheduling parameters
num_m_blocks = 1 * (heads // heads_kv + self.block_H - 1) // self.block_H
num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = self.num_sm
num_split = num_splits_heuristic(
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial)
return output
def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, block_size):
"""
Args:
query: [batch, heads, dim]
key: [batch, max_cache_seqlen, heads_kv, dim]
value: [batch, max_cache_seqlen, heads_kv, dim_v]
block_indices: [batch, heads_kv, max_selected_blocks], indices of selected blocks, -1 for padding
cache_seqlens: [batch], sequence lengths of the kvcache
max_cache_seqlen: maximum sequence length of kvcache
block_size: block size
Returns:
output: [batch, heads, dim_v]
"""
batch, heads, dim = query.shape
heads_kv = key.shape[2]
dim_v = value.shape[-1]
max_selected_blocks = block_indices.shape[-1]
block_H = 64
actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)
actual_num_blocks = actual_num_blocks[
:, 0
] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
# get num_split
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 132
num_split = num_splits_heuristic(
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=block_H,
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.dynamic("max_selected_blocks"),
)
output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial)
return output
def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
batch, heads, dim = query.shape
heads_kv = key.shape[2]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores)
# Assign mask values based on block_indices
for b in range(batch):
for h in range(heads_kv):
valid_indices = block_indices[b, h] # Extract indices for this batch and head
for idx in valid_indices:
if idx >= 0:
sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1)
return output
def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol)
print(name + " all_close={}".format(all_close))
if not all_close:
diff = (expect - actual).abs()
print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item()))
max_indices = torch.nonzero(diff == diff.max().item())
first_index = tuple(max_indices[0].tolist())
print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
sparse_ratio = sparse_ratio
block_size = block_size
max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size))
print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16
Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# # Ensure at least one element equals cache_seqlen
# random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
# # cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
print("cache_seqlens: ", cache_seqlens)
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_indices with -1 (for padding blocks)
block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda")
# max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size)
# block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda')
# Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch):
max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch
if max_valid_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv):
valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks]
# valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks]
block_indices[b, h, : len(valid_indices)] = valid_indices
# Sort indices within each batch-group for consistency
block_indices, _ = block_indices.sort(dim=-1, descending=True)
# print("block_indices: ", block_indices)
actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0]
print("actual_num_blocks: ", actual_num_blocks)
# print(block_indices.shape, actual_num_blocks.shape)
max_num_blocks = torch.max(max_valid_num_blocks).item()
print("max_num_blocks: ", max_num_blocks)
# parity reference
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)
sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size)
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
debug("output", ref, out, atol=1e-3, rtol=1e-3)
import flash_attn # noqa: F401
## latency reference
for _ in range(10):
ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)
torch.cuda.synchronize()
print("dense time: ", (time.time() - start) / 100 * 1000)
for _ in range(10):
# out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size)
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
# out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size)
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
torch.cuda.synchronize()
print("sparse time: ", (time.time() - start) / 100 * 1000)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument("--block_size", type=int, default=32, help="block_size")
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
import time
import math
from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // heads_kv
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks):
shape_q = [batch, heads, dim]
shape_k = [batch, max_cache_seqlen, heads_kv, dim]
shape_v = [batch, max_cache_seqlen, heads_kv, dim_v]
shape_mask = [batch, heads_kv, num_blocks]
shape_o = [batch, heads, dim_v]
part_shape = [batch, heads, num_split, dim_v]
valid_block_H = min(block_H, kv_group_num)
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, T.bool),
cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
# O_shared = T.alloc_shared([valid_block_H, dim_v], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
has_valid_block = T.alloc_var("bool")
bid = bx
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
blocks_per_split = T.floordiv(num_blocks, num_split)
remaining_blocks = T.floormod(num_blocks, num_split)
loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0)
start = blocks_per_split * sid + T.min(sid, remaining_blocks)
has_valid_block = False
for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[bid, hid, start + k]:
has_valid_block = True
T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(
(start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j]
)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] *= scores_scale[i]
T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_valid_block:
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
for i in T.Parallel(block_H):
if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i]
for i, j in T.Parallel(block_H, dim_v):
if i < valid_block_H:
Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j]
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype)
o_accum_local = T.alloc_fragment([dim_v], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim_v):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim_v):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim_v):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, T.bool),
cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn_split(Q, K, V, block_mask, cache_seqlens, glse, Output_partial)
combine(glse, Output_partial, Output)
return main
return kernel_func
class SparseFlashAttn(torch.nn.Module):
def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
super(SparseFlashAttn, self).__init__()
self.batch = batch
self.heads = heads
self.heads_kv = heads_kv
self.dim = dim
self.dim_v = dim_v
self.block_size = block_size
self.block_H = 64
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=self.block_H,
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.dynamic("num_blocks"),
)
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
def forward(self, query, key, value, block_mask, cache_seqlens):
batch = self.batch
heads = self.heads
heads_kv = self.heads_kv
dim_v = self.dim_v
dim = self.dim
block_size = self.block_size
block_H = self.block_H
max_cache_seqlen = key.shape[1]
# get num_split
max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
# num_sm = 132
num_sm = self.num_sm
num_split = num_splits_heuristic(
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
# print("num_split: ", num_split)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial)
return output
def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, block_size):
"""
Args:
query: [batch, heads, dim]
key: [batch, max_cache_seqlen, heads_kv, dim]
value: [batch, max_cache_seqlen, heads_kv, dim_v]
block_mask: [batch, heads_kv, num_blocks], mask for valid blocks
cache_seqlens: [batch], sequence lengths of the kvcache
block_size: block size
Returns:
output: [batch, heads, dim_v]
"""
batch, heads, dim = query.shape
heads_kv = key.shape[2]
dim_v = value.shape[-1]
block_H = 64
actual_num_blocks = torch.sum(block_mask, dim=-1).to(torch.int32)
actual_num_blocks = actual_num_blocks[
:, 0
] # [batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
max_selected_blocks = actual_num_blocks.max().item()
# get num_split
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks # (kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 132
num_split = num_splits_heuristic(
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=block_H,
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.dynamic("num_blocks"),
)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda")
Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda")
# print(kernel.get_kernel_source())
output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial)
return output
def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
batch, heads, dim = query.shape
heads_kv = key.shape[2]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores)
# Assign mask values
for b in range(batch):
for h in range(heads_kv):
for idx in range(num_blocks):
if block_mask[b, h, idx]:
sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1)
return output
def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol)
print(name + " all_close={}".format(all_close))
if not all_close:
# print(expect[3, 28])
# print(actual[3, 28])
diff = (expect - actual).abs()
print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item()))
max_indices = torch.nonzero(diff == diff.max().item())
first_index = tuple(max_indices[0].tolist())
print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
sparse_ratio = sparse_ratio
block_size = block_size
max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size))
print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16
Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index
cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
print("cache_seqlens: ", cache_seqlens)
num_blocks = (max_cache_seqlen + block_size - 1) // block_size
valid_num_blocks = torch.ceil(cache_seqlens * (1 - sparse_ratio) / block_size).int()
print("valid_num_blocks: ", valid_num_blocks)
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_mask with false (for padding blocks)
block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda")
# Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch):
max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch
valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch
if valid_num_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv):
perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block]
block_mask[b, h, perm] = True
# print("block_mask: ", block_mask)
# parity reference
ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size)
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size)
out = model(Q, K, V, block_mask, cache_seqlens)
debug("output", ref, out, atol=1e-3, rtol=1e-3)
import flash_attn # noqa: F401
## latency reference
for _ in range(10):
ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size)
torch.cuda.synchronize()
print("dense time: ", (time.time() - start) / 100 * 1000)
for _ in range(10):
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
out = model(Q, K, V, block_mask, cache_seqlens)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
out = model(Q, K, V, block_mask, cache_seqlens)
torch.cuda.synchronize()
print("sparse time: ", (time.time() - start) / 100 * 1000)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument("--block_size", type=int, default=32, help="block_size")
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
# ruff: noqa
import torch
import triton
import triton.language as tl
import argparse
from einops import rearrange, einsum
import torch.nn.functional as F
import math
import time
from heuristic import num_splits_heuristic
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
key=["BLOCK_H", "BLOCK_N", "BLOCK_D"],
)
@triton.jit
def _split_kernel(
q_ptr,
k_cache_ptr,
v_cache_ptr,
cache_seqlens_ptr,
o_partial_ptr,
lse_partial_ptr,
mask_ptr,
sm_scale,
num_splits,
gqa_group_size,
max_selected_blocks,
stride_q_b,
stride_q_h,
stride_q_d,
stride_k_b,
stride_k_s,
stride_k_h,
stride_k_d,
stride_v_b,
stride_v_s,
stride_v_h,
stride_v_d,
stride_o_b,
stride_o_h,
stride_o_split,
stride_o_d,
stride_lse_b,
stride_lse_h,
stride_lse_split,
stride_mask_b,
stride_mask_h,
stride_mask_s,
BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
batch_idx = tl.program_id(0)
head_idx_kv = tl.program_id(1)
split_idx = tl.program_id(2)
head_idx_q = head_idx_kv * gqa_group_size
offs_h = tl.arange(0, BLOCK_H)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_D)
m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32)
cache_seqlens = tl.load(cache_seqlens_ptr + batch_idx)
num_blocks = max_selected_blocks
blocks_per_split = tl.floor(num_blocks / num_splits).to(tl.int32)
remaining_blocks = num_blocks % num_splits
if split_idx < remaining_blocks:
loop_range = blocks_per_split + 1
else:
loop_range = blocks_per_split
q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h
k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d
v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d
mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h
q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size)
start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks)
for i in range(loop_range):
block_idx = tl.load(mask_ptr + (start + i) * stride_mask_s)
if block_idx >= 0:
start_n = block_idx * BLOCK_N
k_ptr = k_cache_ptr + start_n * stride_k_s
v_ptr = v_cache_ptr + start_n * stride_v_s
k = tl.load(k_ptr, mask=start_n + offs_n[None, :] < cache_seqlens, other=0.0)
v = tl.load(v_ptr, mask=start_n + offs_n[:, None] < cache_seqlens, other=0.0)
qk = tl.dot(q, k)
qk = qk * sm_scale
qk = tl.where(start_n + offs_n[None, :] < cache_seqlens, qk, float("-inf"))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.exp(qk)
l_ij = tl.sum(p, 1)
alpha = tl.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
p = p.to(v.type.element_ty)
acc += tl.dot(p, v)
m_i = m_ij
m_i += tl.math.log(l_i)
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
acc = acc.to(o_partial_ptr.dtype.element_ty)
lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size)
o_partial_ptr += (
batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d
)
tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
key=["BLOCK_D"],
)
@triton.jit
def _merge_kernel(
o_partial_ptr,
lse_partial_ptr,
o_ptr,
lse_partial_stride_b,
lse_partial_stride_h,
lse_partial_stride_split,
o_partial_stride_b,
o_partial_stride_h,
o_partial_stride_split,
o_partial_stride_d,
o_stride_b,
o_stride_h,
o_stride_d,
BLOCK_D: tl.constexpr,
num_splits: tl.constexpr,
num_splits_pow2: tl.constexpr,
):
batch_idx = tl.program_id(0)
head_idx = tl.program_id(1)
offs_splits = tl.arange(0, num_splits_pow2)
offs_d = tl.arange(0, BLOCK_D)
lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h
lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf"))
lse_max = tl.max(lse)
o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h
o_partial = tl.load(
o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d,
mask=offs_splits[:, None] < num_splits,
)
sumexp_normalized_splitk = tl.exp(lse - lse_max)
sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0)
numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0)
acc = numerator_normalized / sumexp_normalized
acc = acc.to(o_ptr.dtype.element_ty)
o_ptr += batch_idx * o_stride_b + head_idx * o_stride_h
tl.store(o_ptr + offs_d * o_stride_d, acc)
def block_sparse_flash_decode_gqa_indice_triton(
q,
k_cache,
v_cache,
cache_seqlens,
max_cache_seqlen,
max_selected_blocks,
block_indices,
block_size,
sm_scale=None,
):
batch, heads, dim = q.shape
if sm_scale is None:
sm_scale = 1 / math.sqrt(dim)
_, max_cache_seqlen_cache, heads_kv, dim_v = v_cache.shape
assert max_cache_seqlen == max_cache_seqlen_cache, "max_cache_seqlen mismatch"
group_size = heads // heads_kv
block_H = 16
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 64
# num_sm = self.num_sm
num_splits = num_splits_heuristic(
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
num_splits_pow2 = triton.next_power_of_2(num_splits)
o_partial = torch.empty((batch, heads, num_splits, dim_v), device=q.device, dtype=q.dtype)
lse_partial = torch.empty((batch, heads, num_splits), device=q.device, dtype=torch.float32)
BLOCK_D = dim
BLOCK_H = group_size if group_size > 16 else 16
grid = (batch, heads_kv, num_splits)
_split_kernel[grid](
q,
k_cache,
v_cache,
cache_seqlens,
o_partial,
lse_partial,
block_indices,
sm_scale,
num_splits,
group_size,
max_selected_blocks,
q.stride(0),
q.stride(1),
q.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3),
o_partial.stride(0),
o_partial.stride(1),
o_partial.stride(2),
o_partial.stride(3),
lse_partial.stride(0),
lse_partial.stride(1),
lse_partial.stride(2),
block_indices.stride(0),
block_indices.stride(1),
block_indices.stride(2),
BLOCK_H=BLOCK_H,
BLOCK_N=block_size,
BLOCK_D=BLOCK_D,
)
output = torch.zeros((batch, heads, dim_v), device=q.device, dtype=q.dtype)
grid = (batch, heads)
_merge_kernel[grid](
o_partial,
lse_partial,
output,
lse_partial.stride(0),
lse_partial.stride(1),
lse_partial.stride(2),
o_partial.stride(0),
o_partial.stride(1),
o_partial.stride(2),
o_partial.stride(3),
output.stride(0),
output.stride(1),
output.stride(2),
BLOCK_D=dim_v,
num_splits=num_splits,
num_splits_pow2=num_splits_pow2,
)
return output
def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
batch, heads, dim = query.shape
heads_kv = key.shape[2]
dim_v = value.shape[-1]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores)
# Assign mask values based on block_indices
for b in range(batch):
for h in range(heads_kv):
valid_indices = block_indices[b, h] # Extract indices for this batch and head
for idx in valid_indices:
if idx >= 0:
sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def ref_program_fa(query, key, value, cache_seqlens):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1)
return output
def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
sparse_ratio = sparse_ratio
block_size = block_size
qk_flops = 2 * batch * heads * max_cache_seqlen * dim
pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v
total_flops = qk_flops + pv_flops
max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size))
print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16
block_H = 64
Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index
cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
print("cache_seqlens: ", cache_seqlens)
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_indices with -1 (for padding blocks)
block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda")
# Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch):
max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch
if max_valid_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv):
valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks]
block_indices[b, h, : len(valid_indices)] = valid_indices
# Sort indices within each batch-group for consistency
block_indices, _ = block_indices.sort(dim=-1, descending=True)
# print("block_indices: ", block_indices)
actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0]
print("actual_num_blocks: ", actual_num_blocks)
# print(block_indices.shape, actual_num_blocks.shape)
max_num_blocks = torch.max(max_valid_num_blocks).item()
print("max_num_blocks: ", max_num_blocks)
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size)
triton_out = block_sparse_flash_decode_gqa_indice_triton(
Q,
K,
V,
cache_seqlens,
max_cache_seqlen,
max_selected_blocks,
block_indices,
block_size,
)
print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
print("Passed the ref test!")
# Measure performance
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
block_sparse_flash_decode_gqa_indice_triton(
Q,
K,
V,
cache_seqlens,
max_cache_seqlen,
max_selected_blocks,
block_indices,
block_size,
)
torch.cuda.synchronize()
end = time.time()
elapsed_time = end - start
avg_time = elapsed_time / 1000
avg_flops = total_flops / avg_time
print(f"Average time: {avg_time:.6f} seconds")
# Measure performance of reference implementation
import flash_attn # noqa: F401
start = time.time()
for _ in range(1000):
ref_program_fa(Q, K, V, cache_seqlens)
torch.cuda.synchronize()
end = time.time()
elapsed_time_ref = end - start
avg_time_ref = elapsed_time_ref / 1000
avg_flops_ref = total_flops / avg_time_ref
print(f"Average time of ref: {avg_time_ref:.6f} seconds")
print(f"Speedup: {avg_time_ref / avg_time:.2f}x")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=64, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument("--block_size", type=int, default=32, help="block_size")
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
import torch
import triton
import triton.language as tl
import argparse
from einops import rearrange, einsum
import torch.nn.functional as F
import math
import time
from heuristic import num_splits_heuristic
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
key=["BLOCK_H", "BLOCK_N", "BLOCK_D"],
)
@triton.jit
def _split_kernel(
q_ptr,
k_cache_ptr,
v_cache_ptr,
cache_seqlens_ptr,
o_partial_ptr,
lse_partial_ptr,
mask_ptr,
sm_scale,
num_splits,
gqa_group_size,
stride_q_b,
stride_q_h,
stride_q_d,
stride_k_b,
stride_k_s,
stride_k_h,
stride_k_d,
stride_v_b,
stride_v_s,
stride_v_h,
stride_v_d,
stride_o_b,
stride_o_h,
stride_o_split,
stride_o_d,
stride_lse_b,
stride_lse_h,
stride_lse_split,
stride_mask_b,
stride_mask_h,
stride_mask_s,
BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
batch_idx = tl.program_id(0)
head_idx_kv = tl.program_id(1)
split_idx = tl.program_id(2)
head_idx_q = head_idx_kv * gqa_group_size
offs_h = tl.arange(0, BLOCK_H)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_D)
m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32)
cache_seqlens = tl.load(cache_seqlens_ptr + batch_idx)
num_blocks = (cache_seqlens + BLOCK_N - 1) // BLOCK_N
blocks_per_split = tl.floor(num_blocks / num_splits).to(tl.int32)
remaining_blocks = num_blocks % num_splits
if split_idx < remaining_blocks:
loop_range = blocks_per_split + 1
else:
loop_range = blocks_per_split
q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h
k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[None, :] * stride_k_s + offs_d[:, None] * stride_k_d
v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:, None] * stride_v_s + offs_d[None, :] * stride_v_d
mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h
q = tl.load(q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d, mask=offs_h[:, None] < gqa_group_size)
start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks)
for block_idx in range(loop_range):
start_n = (start + block_idx) * BLOCK_N
mask_val = tl.load(mask_ptr + (start + block_idx) * stride_mask_s)
if mask_val == 1:
k_ptr = k_cache_ptr + start_n * stride_k_s
v_ptr = v_cache_ptr + start_n * stride_v_s
k = tl.load(k_ptr, mask=start_n + offs_n[None, :] < cache_seqlens, other=0.0)
v = tl.load(v_ptr, mask=start_n + offs_n[:, None] < cache_seqlens, other=0.0)
qk = tl.dot(q, k)
qk = qk * sm_scale
qk = tl.where(start_n + offs_n[None, :] < cache_seqlens, qk, float("-inf"))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.exp(qk)
l_ij = tl.sum(p, 1)
alpha = tl.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
p = p.to(v.type.element_ty)
acc += tl.dot(p, v)
m_i = m_ij
m_i += tl.math.log(l_i)
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
acc = acc.to(o_partial_ptr.dtype.element_ty)
lse_partial_ptr += batch_idx * stride_lse_b + (head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size)
o_partial_ptr += (
batch_idx * stride_o_b + (head_idx_q + offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d
)
tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4] for num_stages in [1, 2, 3, 4, 7]],
key=["BLOCK_D"],
)
@triton.jit
def _merge_kernel(
o_partial_ptr,
lse_partial_ptr,
o_ptr,
lse_partial_stride_b,
lse_partial_stride_h,
lse_partial_stride_split,
o_partial_stride_b,
o_partial_stride_h,
o_partial_stride_split,
o_partial_stride_d,
o_stride_b,
o_stride_h,
o_stride_d,
BLOCK_D: tl.constexpr,
num_splits: tl.constexpr,
num_splits_pow2: tl.constexpr,
):
batch_idx = tl.program_id(0)
head_idx = tl.program_id(1)
offs_splits = tl.arange(0, num_splits_pow2)
offs_d = tl.arange(0, BLOCK_D)
lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h
lse = tl.load(lse_offsets + offs_splits * lse_partial_stride_split, mask=offs_splits < num_splits, other=float("-inf"))
lse_max = tl.max(lse)
o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h
o_partial = tl.load(
o_offsets + offs_splits[:, None] * o_partial_stride_split + offs_d[None, :] * o_partial_stride_d,
mask=offs_splits[:, None] < num_splits,
)
sumexp_normalized_splitk = tl.exp(lse - lse_max)
sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0)
numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0)
acc = numerator_normalized / sumexp_normalized
acc = acc.to(o_ptr.dtype.element_ty)
o_ptr += batch_idx * o_stride_b + head_idx * o_stride_h
tl.store(o_ptr + offs_d * o_stride_d, acc)
def block_sparse_flash_decode_gqa_mask_triton(
q,
k_cache,
v_cache,
cache_seqlens,
max_cache_seqlen,
block_mask,
block_size,
sm_scale=None,
):
batch, heads, dim = q.shape
if sm_scale is None:
sm_scale = 1 / math.sqrt(dim)
_, max_cache_seqlen_cache, heads_kv, dim_v = v_cache.shape
assert max_cache_seqlen == max_cache_seqlen_cache, "max_cache_seqlen mismatch"
group_size = heads // heads_kv
block_H = 16
max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 64
# num_sm = self.num_sm
num_splits = num_splits_heuristic(
total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128
)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
num_splits_pow2 = triton.next_power_of_2(num_splits)
o_partial = torch.empty((batch, heads, num_splits, dim_v), device=q.device, dtype=q.dtype)
lse_partial = torch.empty((batch, heads, num_splits), device=q.device, dtype=torch.float32)
BLOCK_D = dim
BLOCK_H = group_size if group_size > 16 else 16
grid = (batch, heads_kv, num_splits)
_split_kernel[grid](
q,
k_cache,
v_cache,
cache_seqlens,
o_partial,
lse_partial,
block_mask,
sm_scale,
num_splits,
group_size,
q.stride(0),
q.stride(1),
q.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3),
o_partial.stride(0),
o_partial.stride(1),
o_partial.stride(2),
o_partial.stride(3),
lse_partial.stride(0),
lse_partial.stride(1),
lse_partial.stride(2),
block_mask.stride(0),
block_mask.stride(1),
block_mask.stride(2),
BLOCK_H=BLOCK_H,
BLOCK_N=block_size,
BLOCK_D=BLOCK_D,
)
output = torch.zeros((batch, heads, dim_v), device=q.device, dtype=q.dtype)
grid = (batch, heads)
_merge_kernel[grid](
o_partial,
lse_partial,
output,
lse_partial.stride(0),
lse_partial.stride(1),
lse_partial.stride(2),
o_partial.stride(0),
o_partial.stride(1),
o_partial.stride(2),
o_partial.stride(3),
output.stride(0),
output.stride(1),
output.stride(2),
BLOCK_D=dim_v,
num_splits=num_splits,
num_splits_pow2=num_splits_pow2,
)
return output
def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
batch, heads, dim = query.shape
heads_kv = key.shape[2]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, "b n h d -> b h n d") # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores)
# Assign mask values
for b in range(batch):
for h in range(heads_kv):
for idx in range(num_blocks):
if block_mask[b, h, idx]:
sparse_mask[b, :, h, idx * block_size : (idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float("-inf"))
range_len = torch.arange(scores.shape[-1], device="cuda").unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float("-inf"))
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def ref_program_fa(query, key, value, cache_seqlens):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache # fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1)
return output
def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
block_size = block_size
sparse_ratio = sparse_ratio
qk_flops = 2 * batch * heads * max_cache_seqlen * dim
pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v
total_flops = qk_flops + pv_flops
dtype = torch.float16
Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda")
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda")
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda")
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda")
# Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index
cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
num_blocks = (max_cache_seqlen + block_size - 1) // block_size
valid_num_blocks = torch.ceil(cache_seqlens * (1 - sparse_ratio) / block_size).int()
print("valid_num_blocks: ", valid_num_blocks)
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_mask with false (for padding blocks)
block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device="cuda")
# Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch):
max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch
valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch
if valid_num_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv):
perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block]
block_mask[b, h, perm] = True
ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size)
triton_out = block_sparse_flash_decode_gqa_mask_triton(
Q,
K,
V,
cache_seqlens,
max_cache_seqlen,
block_mask,
block_size,
)
# print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert torch.allclose(ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
print("Passed the ref test!")
# Measure performance
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
block_sparse_flash_decode_gqa_mask_triton(
Q,
K,
V,
cache_seqlens,
max_cache_seqlen,
block_mask,
block_size,
)
torch.cuda.synchronize()
end = time.time()
elapsed_time = end - start
avg_time = elapsed_time / 1000
avg_flops = total_flops / avg_time
print(f"Average time: {avg_time:.6f} seconds")
print(f"Average flops: {avg_flops:.2f} GFLOPS")
import flash_attn # noqa: F401
start = time.time()
for _ in range(1000):
ref_program_fa(Q, K, V, cache_seqlens)
torch.cuda.synchronize()
end = time.time()
elapsed_time_ref = end - start
avg_time_ref = elapsed_time_ref / 1000
avg_flops_ref = total_flops / avg_time_ref
print(f"Average time of ref: {avg_time_ref:.6f} seconds")
print(f"Average flops of ref: {avg_flops_ref:.2f} GFLOPS")
print(f"Speedup: {avg_time_ref / avg_time:.2f}x")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=64, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--heads_kv", type=int, default=8, help="heads_kv")
parser.add_argument("--max_cache_seqlen", type=int, default=8192, help="kvcache sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--dim_v", type=int, default=128, help="dim_v")
parser.add_argument("--sparse_ratio", type=float, default=0.8, help="sparse ratio")
parser.add_argument("--block_size", type=int, default=32, help="block_size")
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, args.sparse_ratio, args.block_size)
import math
def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local, max_splits):
"""
Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency.
Parameters:
- total_mblocks (int): Total number of m_blocks.
- num_SMs (int): Number of Streaming Multiprocessors (SMs) in the GPU.
- num_n_blocks (int): Number of n_blocks.
- num_m_blocks (int): Number of m_blocks.
- size_one_kv_head (int): Size of one KV head in bytes.
- is_causal_or_local (bool): Indicates whether the operation is causal or local.
- max_splits (int): Maximum number of allowed splits.
Returns:
- int: The optimal number of splits.
"""
# If we have enough m_blocks to almost fill the SMs, prefer 1 split unless memory constraints apply.
if total_mblocks >= 0.8 * num_SMs:
size_l2 = 50 * 1024 * 1024 # L2 cache size assumption (50MB)
# Only split if each KV head is too large for L2 and there are enough m_blocks
if size_one_kv_head > size_l2 and num_m_blocks >= num_SMs * 2 and not is_causal_or_local:
return min((size_one_kv_head + size_l2 - 1) // size_l2, max_splits)
else:
return 1
# If num_n_blocks is too small, we don't split
if num_n_blocks <= 4:
return 1
# Limit max_splits to a reasonable range
max_splits = min(max_splits, num_SMs, num_n_blocks)
max_efficiency = 0.0
efficiency = []
# Compute efficiency for different splits
for num_splits in range(1, max_splits + 1):
n_waves = (total_mblocks * num_splits) / num_SMs
eff = n_waves / math.ceil(n_waves)
# Track max efficiency
if eff > max_efficiency:
max_efficiency = eff
efficiency.append(eff)
# Find the smallest number of splits that achieves at least 85% of max efficiency
for num_splits in range(1, max_splits + 1):
if efficiency[num_splits - 1] >= 0.85 * max_efficiency:
return num_splits
return 1
import tilelang.testing
import block_sparse_attn_triton
import example_tilelang_block_sparse_attn
import example_tilelang_sparse_gqa_decode_varlen_indice
import example_tilelang_sparse_gqa_decode_varlen_mask
import example_triton_sparse_gqa_decode_varlen_indice
import example_triton_sparse_gqa_decode_varlen_mask
def test_block_sparse_attn_triton():
block_sparse_attn_triton.main()
def test_example_tilelang_block_sparse_attn():
example_tilelang_block_sparse_attn.main()
def test_example_tilelang_sparse_gqa_decode_varlen_indice():
example_tilelang_sparse_gqa_decode_varlen_indice.main(batch=1, max_cache_seqlen=2048)
def test_example_tilelang_sparse_gqa_decode_varlen_mask():
example_tilelang_sparse_gqa_decode_varlen_mask.main(batch=1, max_cache_seqlen=2048)
def test_example_triton_sparse_gqa_decode_varlen_indice():
example_triton_sparse_gqa_decode_varlen_indice.main(
batch=8, heads=8, heads_kv=4, max_cache_seqlen=2048, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32
)
def test_example_triton_sparse_gqa_decode_varlen_mask():
example_triton_sparse_gqa_decode_varlen_mask.main(
batch=16, heads=16, heads_kv=8, max_cache_seqlen=1024, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32
)
if __name__ == "__main__":
tilelang.testing.main()
import argparse
import itertools
import tilelang
import tilelang.language as T
from tilelang.engine.param import KernelParam
from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType
import torch
from typing import List
DEFAULT_BLOCK_M = 128
DEFAULT_BLOCK_N = 128
DEFAULT_BLOCK_K = 32
DEFAULT_NUM_STAGES = 2
DEFAULT_THREAD_NUM = 128
DEFAULT_ENABLE_RASTERIZATION = True
parser = argparse.ArgumentParser(description="Autotuned BlockSparse MatMul Benchmark")
parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)")
parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune")
args, _ = parser.parse_known_args()
M, N, K = args.m, args.n, args.k
sparsity = args.sparsity
use_autotune = args.use_autotune
default_tensor_supply = get_tensor_supply(TensorSupplyType.Auto)
print(f"Running BlockSparse MatMul Benchmark for M={M}, N={N}, K={K}")
print(f"Target Block Sparsity: {sparsity}")
print(f"Using Autotuner: {use_autotune}\n")
def get_configs():
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization))
return [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5],
}
for c in _configs
]
def ref_program(A, B, BlockMask, block_M, block_N, block_K):
ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device)
for i in range(M // block_M):
for j in range(N // block_N):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K):
if BlockMask[i, j, k]:
accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[
k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N
].to(torch.float32)
ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16)
return ref_c
def supply_program(params: List[KernelParam]):
input_tensors = []
for p in params:
# Check if the kernel parameter is BlockMask tensor.
# Here, BlockMask is uniquely identified by having 3 dimensions.
if len(p.shape) != 3:
# For non-BlockMask tensors, use the default tensor generation logic.
input_tensors.append(default_tensor_supply(p))
else:
# For BlockMask tensor, randomly set elements to True based on desired
# sparsity level.
block_mask = torch.zeros(p.shape, dtype=torch.bool, device=torch.cuda.current_device())
block_mask[:, :, :] = torch.rand(p.shape) > sparsity
input_tensors.append(block_mask)
return input_tensors
@tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(out_idx=[-1])
def blocksparse_matmul(
M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32
):
block_mask_shape = (M // block_M, N // block_N, K // block_K)
@T.prim_func
def block_sparse_matmul(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if BlockMask[by, bx, k]:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return block_sparse_matmul
def main():
# Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
if args.use_autotune:
# Run the autotuner to find the best kernel configuration and performance
# get_best_config is expected to return an object containing the compiled kernel,
# the best configuration found, latency, and reference latency.
kernel = blocksparse_matmul(M, N, K)
best_config = kernel.config
best_latency = kernel.latency
block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config["block_K"]
print(f"Best Config: {best_config}")
print(f"Sparsity Ratio: {sparsity}")
print(f"Best Kernel Latency: {best_latency:.6f} ms")
else:
kernel = blocksparse_matmul(
M,
N,
K,
block_M=DEFAULT_BLOCK_M,
block_N=DEFAULT_BLOCK_N,
block_K=DEFAULT_BLOCK_K,
num_stages=DEFAULT_NUM_STAGES,
thread_num=DEFAULT_THREAD_NUM,
enable_rasteration=DEFAULT_ENABLE_RASTERIZATION,
)
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
# Run the compiled kernel (either tuned or default) with the inputs
c = kernel(a, b, block_mask)
# Compute the reference result using the naive PyTorch implementation
ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K)
try:
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("✅ Results are close! Verification successful.")
except AssertionError as e:
print("❌ Verification FAILED: Results differ significantly.")
print(e)
if __name__ == "__main__":
main()
import tilelang.testing
import example_blocksparse_gemm
def test_example_blocksparse_gemm():
example_blocksparse_gemm.main()
if __name__ == "__main__":
tilelang.testing.main()
import torch
import tilelang
import tilelang.language as T
from typing import Tuple
from tilelang.utils.tensor import torch_assert_close
# support bfloat16, float, float16
dtype = T.bfloat16
accum_dtype = T.float32
@tilelang.jit(out_idx=[2, 3])
def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
group_size = 128
fp8_min = -448.0
fp8_max = 448.0
@T.prim_func
def group_per_split_token_cast(
X: T.Tensor((M, N), dtype),
batch_sizes: T.Tensor((BG,), T.int32),
X_fp8: T.Tensor((BG, M_max, N), T.float8_e4m3fn),
X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype),
):
with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz):
row = bx
row_g_id = by
bg = bz
y_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_amax_local = T.alloc_fragment((blk_m,), accum_dtype)
y_s_local = T.alloc_fragment((blk_m,), accum_dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn)
row_offset = T.alloc_fragment((1,), T.int32)
T.annotate_layout(
{
y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
}
)
row_offset[0] = 0
for i in T.serial(bg):
row_offset[0] += batch_sizes[i]
T.copy(
X[row_offset[0] + row * blk_m : row_offset[0] + (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size],
y_local,
)
T.reduce_absmax(y_local, y_amax_local, dim=1)
for i in T.Parallel(blk_m):
y_amax_local[i] = T.max(y_amax_local[i], 1e-4)
y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_amax_local[i] / fp8_max, 0)
for i, j in T.Parallel(blk_m, group_size):
y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max)
T.copy(y_q_local, y_q_local_fp8)
for i, j in T.Parallel(blk_m, group_size):
y_q_local_fp8[i, j] = T.if_then_else(row * blk_m + i < batch_sizes[bg], y_q_local[i, j], 0)
for i in T.Parallel(blk_m):
X_amax[bg, row * blk_m + i, row_g_id] = y_s_local[i]
T.copy(y_q_local_fp8, X_fp8[bg, row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size])
return group_per_split_token_cast
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return ceil_div(x, alignment) * alignment
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
"""
Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along the M axis
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
assert x.dim() in (2, 3)
remove_dim = False
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
if x.dim() == 2:
if x.stride(0) == 1 and x.stride(1) == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
# The last kernel gives a column-major TMA aligned layout
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# this function don't support cpu tensor
assert x.dim() == 2
m, n = x.shape
new_n = ceil_div(n, 128) * 128
x_padded = torch.nn.functional.pad(x, (0, new_n - n))
x_view = x_padded.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
x_fp8 = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous()
return x_fp8, (x_amax / 448.0).view(m, -1)
def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# assert x.shape[0] == batch_sizes.sum()
M_max = ceil_div(batch_sizes.max(), 128) * 128
split_x = torch.split(x, batch_sizes.tolist(), dim=0)
padded_x = [torch.nn.functional.pad(t, (0, 0, 0, M_max - t.shape[0])) for t in split_x]
num_groups, m, n = batch_sizes.shape[0], M_max, x.shape[1]
x_fp8 = (
torch.empty((num_groups, m, n), device="cuda", dtype=torch.float8_e4m3fn),
torch.empty((num_groups, m, n // 128), device="cuda", dtype=torch.float),
)
for i in range(num_groups):
x_fp8[0][i], x_fp8[1][i] = ref_per_token_cast_to_fp8(padded_x[i])
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
return x_fp8
def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None):
if batch_sizes is None:
batch_sizes = [2048, 6144]
if dtype == T.float:
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
elif dtype == T.float16:
x = torch.randn(M, N, device="cuda", dtype=torch.float16)
elif dtype == T.bfloat16:
x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
else:
raise ValueError(f"Unsupported dtype: {dtype}")
batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32)
M_max = int(ceil_div(batch_sizes.max(), 128) * 128)
print("batch_sizes:", batch_sizes)
print("M_max:", M_max)
kernel = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m)
print(kernel.get_kernel_source())
# profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
x_fp8, x_amax = kernel(x, batch_sizes)
x_fp8_ref, x_amax_ref = ref_program(x, batch_sizes)
torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01)
torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01)
print("All checks pass.")
from tilelang.profiler import do_bench
def run_tilelang():
x_fp8_tilelang_, x_amax_tilelang_ = kernel(x, batch_sizes)
return x_fp8_tilelang_, x_amax_tilelang_
def run_torch():
x_fp8_torch_, x_amax_torch_ = ref_program(x, batch_sizes)
return x_fp8_torch_, x_amax_torch_
latency = do_bench(run_tilelang)
print("Tile-lang: {:.2f} ms".format(latency))
latency = do_bench(run_torch)
print("Torch: {:.2f} ms".format(latency))
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment