Commit bc2d5632 authored by root's avatar root
Browse files

init

parents
Pipeline #3222 failed with stages
in 0 seconds
# retrieve the native model input and saved model directory
MODEL_DIR=$1
SAVED_MODEL_DIR=$2
# check if the model directory exists
if [ ! -d "$MODEL_DIR" ]; then
echo "Model directory does not exist!"
exit 1
fi
# if the saved model directory does not exist, create it
# if SAVED_MODEL_DIR is not provided, we do not pass it to the script
if [ -z "$SAVED_MODEL_DIR" ]; then
python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR
else
if [ ! -d "$SAVED_MODEL_DIR" ]; then
mkdir -p $SAVED_MODEL_DIR
fi
python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR --saved_model_path $SAVED_MODEL_DIR
fi
# get the realpath of the saved model directory
SAVED_MODEL_DIR=$(realpath $SAVED_MODEL_DIR)
# cp files
cp $MODEL_DIR/quantize_config.json $SAVED_MODEL_DIR/
cp $MODEL_DIR/tokenizer.json $SAVED_MODEL_DIR/
cp $MODEL_DIR/tokenizer.model $SAVED_MODEL_DIR/
cp $MODEL_DIR/tokenizer_config.json $SAVED_MODEL_DIR/
echo "Model has been converted and save to $SAVED_MODEL_DIR"
# require git lfs
if ! command -v git-lfs &> /dev/null; then
echo "Please install git-lfs first by running 'sudo apt install git-lfs'"
exit 1
fi
mkdir -p models
cd models
# download the model
git clone https://huggingface.co/1bitLLM/bitnet_b1_58-3B ckpt_bitnet_b1_58-3B --depth 1
# copy quantized config into the model directory
cp ../maint/quantize_config.json ckpt_bitnet_b1_58-3B
# copy README.md into the model directory
cp ../maint/README.md ckpt_bitnet_b1_58-3B
# get the realpath of the model directory
MODEL_DIR=$(realpath ckpt_bitnet_b1_58-3B)
cd ..
echo "Model has been converted and save to $MODEL_DIR"
{
"bits": 2,
"desc_act": false,
"static_groups": false,
"sym": true,
"lm_head": false,
"model_name_or_path": "1bitLLM/bitnet_b1_58-3B",
"quant_method": "bitnet",
"checkpoint_format": "bitnet"
}
\ No newline at end of file
MODEL_DIR=$1
REMOTE_DIR=$2
if [ ! -d "$MODEL_DIR" ]; then
echo "Model directory does not exist!"
exit 1
fi
cd $MODEL_DIR
if [ ! -d ".git" ]; then
rm -rf .git
fi
git init
git checkout -b main
git lfs install
git lfs track *.bin
git lfs track *.safetensors
git add .
git commit -m "Initial commit"
git remote add origin $REMOTE_DIR
huggingface-cli lfs-enable-largefiles .
git fetch origin
git push -f --set-upstream origin main
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LLaMA model."""
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from configuration_bitnet import BitnetConfig
from utils_quant import BitLinear, BitLinearBitBLAS
from transformers.utils.hub import cached_file
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401
def find_layers(module, layers=None, name=""):
if not layers:
layers = [nn.Linear]
for layer in layers:
if isinstance(module, layer):
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(
find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1))
return res
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "BitnetConfig"
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
class BitnetRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
BitnetRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
ALL_LAYERNORM_LAYERS.append(BitnetRMSNorm)
class BitnetRotaryEmbedding(nn.Module):
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (
self.base
**(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
"_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
self.register_buffer(
"_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
@property
def sin_cached(self):
logger.warning_once(
"The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead. It is not used in the `BitnetAttention` class"
)
return self._sin_cached
@property
def cos_cached(self):
logger.warning_once(
"The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
"the forward method of RoPE from now on instead. It is not used in the `BitnetAttention` class"
)
return self._cos_cached
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :,
None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type,
str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class BitnetMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = BitLinear(
self.hidden_size,
self.intermediate_size,
bias=False,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self.up_proj = BitLinear(
self.hidden_size,
self.intermediate_size,
bias=False,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self.down_proj = BitLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self.act_fn = ACT2FN[config.hidden_act]
self.ffn_layernorm = BitnetRMSNorm(self.intermediate_size, eps=config.rms_norm_eps)
def forward(self, x):
x = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
x = self.ffn_layernorm(x)
x = self.down_proj(x)
return x
class BitnetMLPFuseGateUp(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = BitLinear(
self.hidden_size,
self.intermediate_size * 2,
bias=False,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self.down_proj = BitLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self.act_fn = ACT2FN[config.hidden_act]
self.ffn_layernorm = BitnetRMSNorm(self.intermediate_size, eps=config.rms_norm_eps)
@classmethod
def from_bit_mlp(cls, bit_mlp: BitnetMLP):
module = cls(bit_mlp.config)
# assign the weights
module.gate_up_proj.weight = nn.Parameter(
torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0))
module.down_proj = bit_mlp.down_proj
module.ffn_layernorm = bit_mlp.ffn_layernorm
return module
def forward(self, x):
gate_up = self.gate_up_proj(x)
gate, up = torch.chunk(gate_up, chunks=2, dim=-1)
x = self.act_fn(gate) * up
x = self.ffn_layernorm(x)
x = self.down_proj(x)
return x
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen,
head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class BitnetAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class.")
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
self.q_proj = BitLinear(
self.hidden_size,
self.num_heads * self.head_dim,
bias=config.attention_bias,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self.k_proj = BitLinear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self.v_proj = BitLinear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self.o_proj = BitLinear(
self.hidden_size,
self.hidden_size,
bias=config.attention_bias,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self._init_rope()
self.inner_attn_ln = BitnetRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = BitnetRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
raise NotImplementedError
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.inner_attn_ln(attn_output)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class BitnetAttentionQKVFused(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class.")
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
self.qkv_proj = BitLinear(
self.hidden_size,
self.num_heads * self.head_dim + (self.num_key_value_heads * self.head_dim) * 2,
bias=config.attention_bias,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self.o_proj = BitLinear(
self.hidden_size,
self.hidden_size,
bias=config.attention_bias,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self._init_rope()
self.inner_attn_ln = BitnetRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = BitnetRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
raise NotImplementedError
@classmethod
def from_bit_attention(cls, bit_attention: BitnetAttention):
module = cls(bit_attention.config, bit_attention.layer_idx)
# assign the weights
module.qkv_proj.weight = nn.Parameter(
torch.cat([
bit_attention.q_proj.weight, bit_attention.k_proj.weight,
bit_attention.v_proj.weight
],
dim=0))
if bit_attention.q_proj.bias is not None and bit_attention.k_proj.bias is not None and bit_attention.v_proj.bias is not None:
module.qkv_proj.bias = nn.Parameter(
torch.cat([
bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias
],
dim=0))
module.o_proj = bit_attention.o_proj
module.inner_attn_ln = bit_attention.inner_attn_ln
if bit_attention.config.rope_scaling is None:
module.rotary_emb = bit_attention.rotary_emb
return module
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states)
query_states, key_states, value_states = torch.split(
qkv_states, [
self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim,
self.num_key_value_heads * self.head_dim
],
dim=-1)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.inner_attn_ln(attn_output)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class BitnetFlashAttention2(BitnetAttention):
"""
Bitnet flash attention module. This module inherits from `BitnetAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (BitnetRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}.")
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.inner_attn_ln(attn_output)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def _flash_attention_forward(self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in BitnetFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32,
device=query_layer.device) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
LLAMA_ATTENTION_CLASSES = {
"eager": BitnetAttention,
"flash_attention_2": BitnetFlashAttention2,
}
class BitnetDecoderLayer(nn.Module):
def __init__(self, config: BitnetConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
config=config, layer_idx=layer_idx)
self.mlp = BitnetMLP(config)
self.input_layernorm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`",
stacklevel=2)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
LLAMA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`BitnetConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class BitnetPreTrainedModel(PreTrainedModel):
config_class = BitnetConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["BitnetDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = False
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
for layer in self.model.layers:
device = layer.input_layernorm.weight.device
if hasattr(self.config, "_pre_quantization_dtype"):
dtype = self.config._pre_quantization_dtype
else:
dtype = layer.self_attn.o_proj.weight.dtype
layer.self_attn.past_key_value = cache_cls(
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype)
def _reset_cache(self):
for layer in self.model.layers:
layer.self_attn.past_key_value = None
LLAMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`BitnetTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`BitnetTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class BitnetModel(BitnetPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BitnetDecoderLayer`]
Args:
config: BitnetConfig
"""
def __init__(self, config: BitnetConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([
BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
])
self.norm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0
if use_cache and not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
# embed positions
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, Cache) else next_decoder_cache)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1)
causal_mask = torch.full((sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(
0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
padding_mask, min_dtype)
elif attention_mask.dim() == 4:
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
# cache. In that case, the 4D attention mask attends to the newest tokens only.
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
offset = cache_position[0]
else:
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[:mask_shape[0], :mask_shape[1],
offset:mask_shape[2] + offset, :mask_shape[3]] = mask_slice
return causal_mask
class BitnetForCausalLM(BitnetPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = BitnetModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.quantized = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import LlamaTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Bitnet-2-7b-hf")
>>> tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Bitnet-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
**kwargs):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
has_static_cache = past_key_values is not None
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[
0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None else None)
cache_length = past_length if max_cache_length is None else torch.min(
max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (max_cache_length is not None and attention_mask is not None and
cache_length + input_ids.shape[1] > max_cache_length):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids")
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1]:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + input_length, device=input_ids.device)
else:
cache_position = cache_position[-input_length:]
if has_static_cache:
past_key_values = None
model_inputs.update({
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
})
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past),)
return reordered_past
@staticmethod
def recursive_set(model, name, attr):
'''
set layers.25.mlp.up_proj to attr
'''
names = name.split('.')
obj = model
for n in names[:-1]:
obj = getattr(obj, n)
setattr(obj, names[-1], attr)
def quantize(self, fuse_qkv=True, fuse_gateup=True):
for name, module in self.model.named_modules():
# if is bitnet layer
if fuse_qkv and isinstance(module, BitnetAttention):
# create quantized version of the layer
print("Replacing BitnetAttention", name)
bitnet_attenion_qkv_fused = BitnetAttentionQKVFused.from_bit_attention(module)
self.recursive_set(self.model, name, bitnet_attenion_qkv_fused)
if fuse_gateup and isinstance(module, BitnetMLP):
# create quantized version of the layer
print("Replacing BitnetMLP", name)
bitnet_mlp_fused = BitnetMLPFuseGateUp.from_bit_mlp(module)
self.recursive_set(self.model, name, bitnet_mlp_fused)
for name, module in self.model.named_modules():
# if is bitnet layer
if isinstance(module, BitLinear):
# create quantized version of the layer
print("Quantizing module", name)
if name.endswith(".qkv_proj"):
bitblas_linear = BitLinearBitBLAS.from_bit_linear(module, weight_group=3)
elif name.endswith(".gate_up_proj"):
bitblas_linear = BitLinearBitBLAS.from_bit_linear(module, weight_group=2)
else:
bitblas_linear = BitLinearBitBLAS.from_bit_linear(module)
print("Replacing module", name, "with a quantized version")
self.recursive_set(self.model, name, bitblas_linear)
self.quantized = True
def _post_process_weights(self):
for name, module in self.model.named_modules():
if hasattr(module, "post_process_weights"):
print("Post processing weights for module", name)
module.post_process_weights()
def _replace_weight_param_with_qweight(self):
for name, module in self.model.named_modules():
if hasattr(module, "replace_weight_param_with_qweight"):
print("Replacing weight param with qweight for module", name)
module.replace_weight_param_with_qweight()
@classmethod
def from_quantized(
cls,
model_name_or_path: Optional[str],
trust_remote_code: bool = False,
**kwargs,
):
"""load quantized model from local disk"""
# Parameters related to loading from Hugging Face Hub
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
# == step1: prepare configs and file names == #
config: BitnetConfig = BitnetConfig.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
**cached_file_kwargs,
)
# load quantize config
quantize_file = cached_file(model_name_or_path, "quantize_config.json")
assert quantize_file is not None, "quantize config file not found"
import json
# get quantize format
with open(quantize_file, "r") as f:
quant_config = json.load(f)
checkpoint_format = quant_config["checkpoint_format"]
assert checkpoint_format in ["bitblas"], "quantize format not supported"
fuse_qkv = quant_config.get("fuse_qkv", True)
fuse_gateup = quant_config.get("fuse_gateup", True)
import accelerate
if checkpoint_format == "bitblas":
model = cls(config)
for name, module in model.named_modules():
# if is bitnet layer
if fuse_qkv and isinstance(module, BitnetAttention):
# create quantized version of the layer
print("Replacing BitnetAttention", name)
bitnet_attenion_qkv_fused = BitnetAttentionQKVFused.from_bit_attention(module)
model.recursive_set(model, name, bitnet_attenion_qkv_fused)
if fuse_gateup and isinstance(module, BitnetMLP):
# create quantized version of the layer
print("Replacing BitnetMLP", name)
bitnet_mlp_fused = BitnetMLPFuseGateUp.from_bit_mlp(module)
model.recursive_set(model, name, bitnet_mlp_fused)
for name, module in model.named_modules():
if isinstance(module, BitLinear):
# create quantized version of the layer
print("Quantizing module", name)
bitblas_linear = BitLinearBitBLAS.from_bit_linear(module)
print("Replacing module", name, "with a quantized version")
model.recursive_set(model, name, bitblas_linear)
accelerate.utils.modeling.load_checkpoint_in_model(
model,
checkpoint=model_name_or_path,
offload_state_dict=True,
offload_buffers=True,
)
return model
@add_start_docstrings(
"""
The LLaMa Model transformer with a sequence classification head on top (linear layer).
[`BitnetForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
LLAMA_START_DOCSTRING,
)
class BitnetForSequenceClassification(BitnetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = BitnetModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids,
self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or
labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
The Bitnet Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
LLAMA_START_DOCSTRING,
)
class BitnetForQuestionAnswering(BitnetPreTrainedModel):
base_model_prefix = "transformer"
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Bitnet
def __init__(self, config):
super().__init__(config)
self.transformer = BitnetModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.transformer.embed_tokens
def set_input_embeddings(self, value):
self.transformer.embed_tokens = value
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labeled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labeled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
nvidia-smi --query-gpu=memory.used --format=csv -lms 500
lm_eval==0.3.0
flash_attn
transformers==4.53.0
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LLaMA."""
import os
from shutil import copyfile
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from transformers.convert_slow_tokenizer import import_protobuf
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
from transformers.utils import logging
if TYPE_CHECKING:
from transformers.tokenization_utils_base import TextInput
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"hf-internal-testing/llama-tokenizer":
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
},
"tokenizer_file": {
"hf-internal-testing/llama-tokenizer":
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"hf-internal-testing/llama-tokenizer": 2048,
}
SPIECE_UNDERLINE = "▁"
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
# fmt: off
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
class BitnetTokenizer(PreTrainedTokenizer):
"""
Construct a Bitnet tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
no padding token in the original model.
Args:
vocab_file (`str`):
Path to the vocabulary file.
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
The end of sequence token.
pad_token (`str` or `tokenizers.AddedToken`, *optional*):
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
attention mechanisms or loss computation.
sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
to set:
- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
add_bos_token (`bool`, *optional*, defaults to `True`):
Whether or not to add an `bos_token` at the start of sequences.
add_eos_token (`bool`, *optional*, defaults to `False`):
Whether or not to add an `eos_token` at the end of sequences.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
extra spaces.
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
Whether or not the default system prompt for Bitnet should be used.
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to add spaces between special tokens.
legacy (`bool`, *optional*):
Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple
example:
- `legacy=True`:
```python
>>> from transformers import T5Tokenizer
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True)
>>> tokenizer.encode("Hello <extra_id_0>.")
[8774, 32099, 3, 5, 1]
```
- `legacy=False`:
```python
>>> from transformers import T5Tokenizer
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False)
>>> tokenizer.encode("Hello <extra_id_0>.") # the extra space `[3]` is no longer here
[8774, 32099, 5, 1]
```
Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
add_prefix_space (`bool`, *optional*, defaults to `True`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
pad_token=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
clean_up_tokenization_spaces=False,
use_default_system_prompt=False,
spaces_between_special_tokens=False,
legacy=None,
add_prefix_space=True,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
bos_token = AddedToken(
bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(
eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(
unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(
pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
if legacy is None:
logger.warning_once(
f"You are using the default legacy behavior of the {self.__class__}. This is"
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
" If you want to use the new behavior, set `legacy=False`. This should only be set if you understand what it"
" means, and thoroughly read the reason why this was added as explained in"
" https://github.com/huggingface/transformers/pull/24565")
legacy = True
self.legacy = legacy
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
self.add_prefix_space = add_prefix_space
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
sp_model_kwargs=self.sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
use_default_system_prompt=use_default_system_prompt,
spaces_between_special_tokens=spaces_between_special_tokens,
legacy=legacy,
add_prefix_space=add_prefix_space,
**kwargs,
)
@property
def unk_token_length(self):
return len(self.sp_model.encode(str(self.unk_token)))
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
def get_spm_processor(self, from_slow=False):
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
if self.legacy or from_slow: # no dependency on protobuf
tokenizer.Load(self.vocab_file)
return tokenizer
with open(self.vocab_file, "rb") as f:
sp_model = f.read()
model_pb2 = import_protobuf(
f"The new behavior of {self.__class__.__name__} (with `self.legacy = False`)")
model = model_pb2.ModelProto.FromString(sp_model)
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
sp_model = model.SerializeToString()
tokenizer.LoadFromSerializedProto(sp_model)
return tokenizer
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
return state
def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
@property
def vocab_size(self):
"""Returns vocab size"""
return self.sp_model.get_piece_size()
def get_vocab(self):
"""Returns vocab as a dict"""
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
"""
Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
first token is special.
"""
if self.legacy or len(text) == 0:
return super().tokenize(text, **kwargs)
text = text.replace(SPIECE_UNDERLINE, " ")
if self.add_prefix_space:
text = SPIECE_UNDERLINE + text
tokens = super().tokenize(text, **kwargs)
if len(tokens
) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
tokens = tokens[1:]
return tokens
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
def _tokenize(self, text, **kwargs):
"""
Returns a tokenized string.
We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
`['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
"""
tokens = self.sp_model.encode(text, out_type=str)
if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
return tokens
# 1. Encode string + prefix ex: "<unk> Hey"
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
return tokens[self.unk_token_length:] if len(tokens) >= self.unk_token_length else tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index)
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
# since we manually add the prefix space, we have to remove it when decoding
if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
tokens[0] = tokens[0][1:]
current_sub_tokens = []
out_string = ""
prev_is_special = False
for i, token in enumerate(tokens):
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special and i != 0 and self.legacy:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(save_directory,
(filename_prefix + "-" if filename_prefix else "") +
VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(
self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = bos_token_id + token_ids_0 + eos_token_id
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id
return output
def get_special_tokens_mask(self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True)
bos_token_id = [1] if self.add_bos_token else []
eos_token_id = [1] if self.add_eos_token else []
if token_ids_1 is None:
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id +
([0] * len(token_ids_1)) + eos_token_id)
def create_token_type_ids_from_sequences(self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
sequence pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
if token_ids_1 is None, only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of ids.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
if token_ids_1 is not None:
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
return output
@property
def default_chat_template(self):
"""
LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
results in an unusual token ordering when it is present. This template should definitely be changed if you wish
to fine-tune a model with more flexible role ordering!
The output should look something like:
<bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos><bos>[INST] Prompt [/INST] Answer <eos>
<bos>[INST] Prompt [/INST]
The reference for this chat template is [this code
snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362)
in the original repository.
"""
logger.warning_once(
"\nNo chat template is defined for this tokenizer - using the default template "
f"for the {self.__class__.__name__} class. If the default is not appropriate for "
"your model, please set `tokenizer.chat_template` to an appropriate template. "
"See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
)
template = (
"{% if messages[0]['role'] == 'system' %}"
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
"{% set system_message = messages[0]['content'] %}"
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
"{% else %}"
"{% set loop_messages = messages %}"
"{% set system_message = false %}"
"{% endif %}"
"{% for message in loop_messages %}" # Loop over all non-system messages
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
"{% endif %}"
"{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
"{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
"{% else %}"
"{% set content = message['content'] %}"
"{% endif %}"
"{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
"{% elif message['role'] == 'system' %}"
"{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ ' ' + content.strip() + ' ' + eos_token }}"
"{% endif %}"
"{% endfor %}")
template = template.replace("USE_DEFAULT_PROMPT",
"true" if self.use_default_system_prompt else "false")
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
return template
# pylint: disable=missing-docstring, invalid-name
"""This is modified from https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py to work with BitBLAS."""
import torch
from torch import nn
from bitblas.cache import global_operator_cache, get_database_path
from bitblas import Matmul, MatmulConfig
from bitblas import auto_detect_nvidia_target
from logging import getLogger
logger = getLogger(__name__)
BITBLAS_TARGET = auto_detect_nvidia_target()
BITBLAS_DATABASE_PATH = get_database_path()
def weight_quant(weight, num_bits=1):
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(-1, 1) / s
return result.type(dtype)
def activation_quant(x, num_bits=8):
dtype = x.dtype
x = x.float()
Qn = -(2**(num_bits - 1))
Qp = 2**(num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp) / s
return result.type(dtype)
class BitLinearBitBLAS(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
weight_bits=1,
input_bits=8,
**kwargs,
):
super().__init__()
"""
RMSNorm is placed outside BitLinear
"""
self.in_features = in_features
self.out_features = out_features
self.weight_bits = weight_bits
self.input_bits = input_bits
matmul_config = MatmulConfig(
N=self.out_features, # N dimension
K=self.in_features, # K dimension
A_dtype="int8", # activation A dtype
W_dtype="int2", # weight W dtype
accum_dtype="int32", # accumulation dtype
out_dtype="float32", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
)
ENABLE_TUNING = True
self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, ENABLE_TUNING)
self.format = "bitnet"
self.Qp = 2**(self.input_bits - 1) - 1
def _get_or_create_bitblas_operator(self, config, enable_tuning):
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
logger.info(f"Loaded {global_operator_cache.size()} operators from database.")
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
# should disable tuning for the first time because we may require loading bitblas operator from database.
bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False)
if enable_tuning:
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
print("BitBLAS Tuning done, appended operator to global_operator_cache.")
else:
print("BitBLAS Operator created.")
else:
print("BitBLAS Operator found in global_operator_cache.")
return bitblas_matmul
def replace_weight_param_with_qweight(self):
if hasattr(self, "weight"):
del self.weight
quant_weight = torch.empty(self.bitblas_matmul.retrieve_weight_shape())
self.qweight = nn.Parameter(quant_weight, requires_grad=False)
self.format = "bitblas"
@classmethod
def from_bit_linear(cls, bitlinear, weight_group=1):
bitblas_linear = cls(
bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8)
sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight, weight_group)
bitblas_linear.register_buffer("qweight", qweight)
bitblas_linear.register_buffer("sw", sw)
if bitlinear.bias is not None:
bitblas_linear.register_buffer("bias", bitlinear.bias)
else:
bitblas_linear.bias = None
return bitblas_linear
def create_bitblas_weights(self, weight, weight_group=1):
if weight_group:
hidden_size = weight.size(0)
group_size = hidden_size // weight_group
sw_list = []
qweight_list = []
for i in range(weight_group):
start_idx = i * group_size
end_idx = (i + 1) * group_size
sw = 1 / weight[start_idx:end_idx].abs().mean().clamp(min=1e-5)
sw_list.append(sw.repeat(group_size))
qweight = self.weight_quant(weight[start_idx:end_idx]).detach()
qweight_list.append(qweight)
sw = torch.cat(sw_list, dim=0)
qweight = torch.cat(qweight_list, dim=0)
else:
sw = 1 / weight.abs().mean().clamp(min=1e-5)
qweight = self.weight_quant(weight).detach()
qweight = self.bitblas_matmul.transform_weight(qweight)
qweight = nn.Parameter(qweight, requires_grad=False)
return sw, qweight
def post_process_weights(self):
sw = 1 / self.weight.abs().mean().clamp(min=1e-5)
self.sw = sw
quant_weight = self.weight_quant(self.weight).detach()
quant_weight = self.bitblas_matmul.transform_weight(quant_weight)
# remove self.weight and replace it with quant_weight
if hasattr(self, "weight"):
del self.weight
self.qweight = nn.Parameter(quant_weight, requires_grad=False)
self.format = "bitblas"
@staticmethod
def weight_quant(weight):
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(-1, 1)
return result.type(torch.int8)
@torch.compile
def activation_quant(self, x, num_bits=8):
x = x.float()
Qn = -(2**(num_bits - 1))
Qp = 2**(num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp)
return result.type(torch.int8), s
@torch.compile
def post_quant_process(self, input, si, sw):
out = input / si
out = out / sw
out = out.half()
return out
# for the correctness evaluation.
def native_forward(self, input):
quant_input = (input + (activation_quant(input, self.input_bits) - input).detach())
quant_weight = (
self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach())
out = nn.functional.linear(quant_input, quant_weight)
if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
def forward_fp32_simulated(self, input):
quant_input, si = self.activation_quant(input, self.input_bits).detach()
quant_weight = self.weight_quant(self.weight).detach()
fp32_simulated_input = quant_input.float()
fp32_simulated_weight = quant_weight.float()
fp32_simulated_out = nn.functional.linear(fp32_simulated_input, fp32_simulated_weight)
sw = 1 / self.weight.abs().mean().clamp(min=1e-5)
# if / (si * sw) it will inf in some cases
out = fp32_simulated_out / si
out = out / sw
out = out.half()
if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
def forward(self, input):
# return self.forward_fp32_simulated(input)
quant_input, si = self.activation_quant(input, self.input_bits)
fp32_out = self.bitblas_matmul(quant_input, self.qweight)
sw = self.sw
# if / (si * sw) it will inf in some cases
out = self.post_quant_process(fp32_out, si, sw)
if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
# Naive BitLinear from HuggingFace
class BitLinear(nn.Linear):
def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs):
super(BitLinear, self).__init__(*kargs, **kwargs)
"""
RMSNorm is placed outside BitLinear
"""
self.weight_bits = weight_bits
self.input_bits = input_bits
def forward(self, input):
quant_input = input + (activation_quant(input, self.input_bits) - input).detach()
quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) -
self.weight).detach()
out = nn.functional.linear(quant_input, quant_weight)
if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
import contextlib
import gc
import os
import sys
from collections import UserList
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import (
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoTokenizer,
BatchEncoding,
)
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig
from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu
logger = init_logger(__name__)
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
def _read_prompts(filename: str) -> List[str]:
with open(filename, "r") as f:
prompts = f.readlines()
return prompts
class _ImageAssetPrompts(TypedDict):
stop_sign: str
cherry_blossom: str
if sys.version_info < (3, 9):
# UserList cannot be subscripted
class _ImageAssetsBase(UserList):
pass
else:
class _ImageAssetsBase(UserList[ImageAsset]):
pass
class _ImageAssets(_ImageAssetsBase):
def __init__(self) -> None:
super().__init__([
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
])
def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
"""
Convenience method to define the prompt for each test image.
The order of the returned prompts matches the order of the
assets when iterating through this object.
"""
return [prompts["stop_sign"], prompts["cherry_blossom"]]
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
if not is_cpu():
torch.cuda.empty_cache()
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
"""Allow subdirectories to skip global cleanup by overriding this fixture.
This can provide a ~10x speedup for non-GPU unit tests since they don't need
to initialize torch.
"""
if not request.node.get_closest_marker("skip_global_cleanup"):
return False
@pytest.fixture(autouse=True)
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
yield
if should_do_global_cleanup_after_test:
cleanup()
@pytest.fixture
def example_prompts() -> List[str]:
prompts = []
for filename in _TEST_PROMPTS:
prompts += _read_prompts(filename)
return prompts
@pytest.fixture
def example_long_prompts() -> List[str]:
prompts = []
for filename in _LONG_PROMPTS:
prompts += _read_prompts(filename)
return prompts
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
return IMAGE_ASSETS
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
}
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
class HfRunner:
def wrap_device(self, input: _T) -> _T:
if not is_cpu():
return input.to("cuda")
else:
return input.to("cpu")
def __init__(
self,
model_name: str,
dtype: str = "half",
*,
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_vision_model: bool = False,
is_sparseml_model: bool = False,
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name
if is_embedding_model:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = self.wrap_device(
SentenceTransformer(
model_name,
device="cpu",
).to(dtype=torch_dtype))
else:
if is_vision_model:
auto_cls = AutoModelForVision2Seq
elif is_sparseml_model:
from sparseml.transformers import SparseAutoModelForCausalLM
auto_cls = SparseAutoModelForCausalLM
else:
auto_cls = AutoModelForCausalLM
model_kwargs = model_kwargs if model_kwargs is not None else {}
self.model = self.wrap_device(
auto_cls.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
**model_kwargs,
))
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
try:
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401
self.processor = AutoProcessor.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
except Exception:
logger.warning(
"Unable to auto-load processor from HuggingFace for "
"model %s. Using tokenizer instead.",
model_name,
)
self.processor = self.tokenizer
def generate(
self,
prompts: List[str],
images: Optional[List[Image.Image]] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
if images:
assert len(prompts) == len(images)
outputs: List[Tuple[List[List[int]], List[str]]] = []
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
output_ids = self.model.generate(
**self.wrap_device(inputs),
use_cache=True,
**kwargs,
)
output_str = self.processor.batch_decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
output_ids = output_ids.cpu().tolist()
outputs.append((output_ids, output_str))
return outputs
def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(
prompts,
do_sample=False,
max_new_tokens=max_tokens,
images=images,
**kwargs,
)
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
outputs = self.generate(
prompts,
do_sample=False,
max_new_tokens=max_tokens,
num_beams=beam_width,
num_return_sequences=beam_width,
)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
for j in range(len(output_ids)):
output_ids[j] = [x for x in output_ids[j] if x != self.tokenizer.pad_token_id]
outputs[i] = (output_ids, output_str)
return outputs
def generate_greedy_logprobs(
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
**kwargs: Any,
) -> List[List[torch.Tensor]]:
all_logprobs: List[List[torch.Tensor]] = []
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
output = self.model.generate(
**self.wrap_device(inputs),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
**kwargs,
)
seq_logprobs: List[torch.Tensor] = []
for hidden_states in output.hidden_states:
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if self.model.get_output_embeddings().bias is not None:
logits += self.model.get_output_embeddings().bias.unsqueeze(0)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
all_logprobs.append(seq_logprobs)
return all_logprobs
def generate_greedy_logprobs_limit(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
input_ids = inputs.input_ids
output = self.model.generate(
**self.wrap_device(inputs),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
**kwargs,
)
seq_logprobs: List[torch.Tensor] = []
for _, hidden_states in enumerate(output.hidden_states):
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if (getattr(self.model.get_output_embeddings(), "bias", None) is not None):
logits += self.model.get_output_embeddings().bias.unsqueeze(0)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
# convert to dict
seq_logprobs_lst: List[Dict[int, float]] = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs
if tok_idx == 0:
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
topk = tok_logprobs.topk(num_logprobs)
tok_logprobs_dct = {}
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
tok_logprobs_dct[token_id.item()] = logprob.item()
seq_logprobs_lst.append(tok_logprobs_dct)
all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
output_len = seq_ids.shape[0] - input_ids.shape[1]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
return self.model.encode(prompts)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup()
@pytest.fixture(scope="session")
def hf_runner():
return HfRunner
class VllmRunner:
def __init__(
self,
model_name: str,
tokenizer_name: Optional[str] = None,
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
max_model_len: int = 1024,
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: bool = False,
swap_space: int = 4,
enforce_eager: bool = False,
**kwargs,
) -> None:
self.model = LLM(
model=model_name,
tokenizer=tokenizer_name,
trust_remote_code=True,
dtype=dtype,
swap_space=swap_space,
enforce_eager=enforce_eager,
disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
block_size=block_size,
enable_chunked_prefill=enable_chunked_prefill,
**kwargs,
)
def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image}
req_outputs = self.model.generate(inputs, sampling_params=sampling_params)
outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids: List[List[int]] = []
req_sample_output_strs: List[str] = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
def generate_w_logprobs(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None
if images is not None:
assert len(prompts) == len(images)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image}
req_outputs = self.model.generate(inputs, sampling_params=sampling_params)
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
for req_output in req_outputs:
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs))
return outputs
def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images)
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
def generate_greedy_logprobs(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs)
outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params, images=images)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
beam_search_params = SamplingParams(
n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens,
)
outputs = self.generate(prompts, beam_search_params)
return outputs
def encode(self, prompts: List[str]) -> List[List[float]]:
req_outputs = self.model.encode(prompts)
outputs = []
for req_output in req_outputs:
embedding = req_output.outputs.embedding
outputs.append(embedding)
return outputs
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup()
@pytest.fixture(scope="session")
def vllm_runner():
return VllmRunner
def get_tokenizer_pool_config(tokenizer_group_type):
if tokenizer_group_type is None:
return None
if tokenizer_group_type == "ray":
return TokenizerPoolConfig(pool_size=1, pool_type="ray", extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
@pytest.fixture()
def temporary_enable_log_propagate():
import logging
logger = logging.getLogger("vllm")
logger.propagate = True
yield
logger.propagate = False
@pytest.fixture()
def caplog_vllm(temporary_enable_log_propagate, caplog):
# To capture vllm log, we should enable propagate=True temporarily
# because caplog depends on logs propagated to the root logger.
yield caplog
@pytest.fixture(scope="session")
def num_gpus_available():
"""Get number of GPUs without initializing the CUDA context
in current process."""
return cuda_device_count_stateless()
"""Compare the outputs of a GPTQ model to a Marlin model.
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
Marlin/GPTQ models are in the top 3 selections of each other.
Note: Marlin internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for Marlin. As a result, we re-run the test
up to 3 times to see if we pass.
Run `pytest tests/models/test_marlin.py`.
"""
from conftest import VllmRunner
import os
import argparse
# get the path of the current file
current_file_path = os.path.realpath(__file__)
current_dir = os.path.dirname(current_file_path)
ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B_bitblas")
parser = argparse.ArgumentParser(description="Inference with BitNet")
parser.add_argument(
"--ckpt_path",
type=str,
default=ckpt_path,
help="Path to the checkpoint",
)
args = parser.parse_args()
ckpt_path = args.ckpt_path
with VllmRunner(
ckpt_path,
dtype="half",
quantization="bitblas",
# set enforce_eager = False to enable cuda graph
# set enforce_eager = True to disable cuda graph
enforce_eager=False,
) as bitnet_model:
bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"],
max_tokens=1024)
print("bitnet inference:")
print(bitbnet_outputs[0][0])
print(bitbnet_outputs[0][1])
"""Compare the outputs of a GPTQ model to a Marlin model.
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
Marlin/GPTQ models are in the top 3 selections of each other.
Note: Marlin internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for Marlin. As a result, we re-run the test
up to 3 times to see if we pass.
Run `pytest tests/models/test_marlin.py`.
"""
from conftest import VllmRunner
import os
import argparse
# get the path of the current file
current_file_path = os.path.realpath(__file__)
current_dir = os.path.dirname(current_file_path)
ckpt_path = os.path.join(current_dir, "../models/ckpt_bitnet_b1_58-3B")
parser = argparse.ArgumentParser(description="Inference with BitNet")
parser.add_argument(
"--ckpt_path",
type=str,
default=ckpt_path,
help="Path to the checkpoint",
)
args = parser.parse_args()
ckpt_path = args.ckpt_path
with VllmRunner(
ckpt_path,
dtype="half",
quantization="bitnet_bitblas",
gpu_memory_utilization=0.5,
# set enforce_eager = False to enable cuda graph
# set enforce_eager = True to disable cuda graph
enforce_eager=False,
) as bitnet_model:
bitbnet_outputs = bitnet_model.generate_greedy(["Hi, tell me about microsoft?"], max_tokens=128)
print("bitnet inference output:")
print(bitbnet_outputs[0][0])
print(bitbnet_outputs[0][1])
from typing import Dict, List, Tuple
TokensText = Tuple[List[int], str]
def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText],
name_0: str, name_1: str):
"""
Compare the two sequences generated by different models,
which should be equal.
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)):
output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1
assert output_str_0 == output_str_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str):
"""
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
# Loop through responses to each prompt.
for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)):
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
# Loop through generated tokens.
for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
# If generated tokens don't match, then
if output_id_0 != output_id_1:
# Each predicted token must be in top N logprobs of the other
assert output_id_0 in logprobs_1[idx], (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_id_1 in logprobs_0[idx], (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
# Break out since sequences will now diverge.
break
# Block-Sparse Flash-Attention
Tilelang implementation of block-sparse flash-attention kernels.
The kernels have been used in [Rectified Sparse Attention](https://arxiv.org/abs/2506.04108) and [SeerAttention-R](https://arxiv.org/abs/2506.08889).
# ruff: noqa: E712
import math
import torch
import triton
import triton.language as tl
import torch.nn.functional as F
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
dense_mask = x > threshold
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
@triton.jit
def _fwd_kernel_inner(
acc,
l_i,
m_i,
q,
k_block_col_idx,
block_mask_ptr,
k_ptrs,
v_ptrs,
offs_m,
offs_n,
stride_kt,
stride_vt,
stride_bmask_n,
sm_scale,
seqlen_k,
past_len,
LAST_K_BLOCK: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
# print
if mask_val == True:
start_n = k_block_col_idx * BLOCK_N
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kt)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK:
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0,
float('-inf'))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.exp(qk)
l_ij = tl.sum(p, 1)
alpha = tl.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vt)
p = p.to(v.type.element_ty)
acc += tl.dot(p, v)
# update m_i and l_i
m_i = m_ij
return acc, l_i, m_i
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
block_mask_ptr,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qd,
stride_kz,
stride_kh,
stride_kn,
stride_kd,
stride_vz,
stride_vh,
stride_vn,
stride_vd,
stride_bmz,
stride_bmh,
stride_bmm,
stride_bmn,
stride_oz,
stride_oh,
stride_om,
stride_od,
H,
N_CTX,
PAST_LEN,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
):
Q_LEN = N_CTX - PAST_LEN
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_h = off_hz % H
off_z = off_hz // H
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
block_mask_ptr += off_z * stride_bmz + off_h * stride_bmh
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd
# off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd
off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
mask_ptrs = block_mask_ptr + start_m * stride_bmm
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN)
k_block_start = 0
k_block_end = tl.cdiv((start_m + 1) * BLOCK_M, BLOCK_N)
# loop over k, v and update accumulator
for col_idx in range(k_block_start, k_block_end):
acc, l_i, m_i = _fwd_kernel_inner(
acc,
l_i,
m_i,
q,
col_idx,
mask_ptrs,
k_ptrs,
v_ptrs,
offs_m,
offs_n,
stride_kn,
stride_vn,
stride_bmn,
sm_scale,
N_CTX,
PAST_LEN,
col_idx == k_block_end - 1,
BLOCK_M,
BLOCK_N,
)
m_i += tl.math.log(l_i)
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
acc = acc.to(Out.dtype.element_ty)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[
None, :] * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
def _forward(ctx,
q,
k,
v,
block_sparse_mask,
sm_scale,
BLOCK_M=64,
BLOCK_N=64,
num_warps=None,
num_stages=1,
out=None):
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2]
o = out if out is not None else torch.empty_like(q).contiguous()
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
assert q.shape[-1] in [64, 128]
BLOCK_DMODEL = q.shape[-1]
if is_hip():
num_warps, num_stages = 8, 1
else:
num_warps, num_stages = 4, 2
N_CTX = k.shape[2]
PAST_LEN = N_CTX - q.shape[2]
H = q.shape[1]
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
block_sparse_mask,
o,
*q.stride(),
*k.stride(),
*v.stride(),
*block_sparse_mask.stride(),
*o.stride(),
H,
N_CTX,
PAST_LEN,
BLOCK_M,
BLOCK_N,
BLOCK_DMODEL,
num_warps=num_warps,
num_stages=num_stages,
)
return o
class _sparse_attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
# shape constraints
return _forward(ctx, q, k, v, block_sparse_dense, sm_scale)
@staticmethod
def backward(ctx, do):
# No gradient propagation.
raise NotImplementedError("It does not support gradient propagation yet")
return None, None, None, None, None
block_sparse_triton_fn = _sparse_attention.apply
def test_topk_sparse_attention():
# Config
BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64
TOPK = 2 # Keep top 8 elements per row
BLOCK = 64
torch.manual_seed(0)
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
print("downsample_len", downsample_len)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
print("x_ds.shape", x_ds.shape)
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# print("block_mask", block_mask)
print("block_mask.shape", block_mask.shape)
# Run Triton kernel
triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf'))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
# print("ref_output", ref_output)
# print("triton_output", triton_output)
# Verify accuracy
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \
"Triton output doesn't match reference"
print("Pass topk sparse attention test with qlen == klen")
def test_topk_sparse_attention_qlt_kl():
BATCH, N_HEADS = 2, 4
Q_LEN, K_LEN, D_HEAD = 128, 256, 64 # qlen < klen; here, past_len = 256 - 128 = 128.
TOPK = 1
BLOCK = 64 # block size used in downsampling
torch.manual_seed(0)
# Create inputs.
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16)
# softmax scale
sm_scale = 1.0 / (D_HEAD**0.5)
downsample_factor = BLOCK
print("downsample_factor", downsample_factor)
downsample_len = math.ceil(K_LEN / downsample_factor) # number of blocks along one dimension
print("downsample_len", downsample_len)
x_ds = torch.randn(
BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.bfloat16)
# Force the first column to be high so that the first block is always selected.
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
print("block_mask", block_mask)
print("block_mask.shape", block_mask.shape)
# Run Triton kernel.
triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale)
past_len = K_LEN - Q_LEN
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool()
full_mask_full = full_mask_full[..., :K_LEN, :K_LEN]
effective_mask = full_mask_full[..., past_len:K_LEN, :] # shape: (B, H, Q_LEN, K_LEN)
i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) # shape: (Q_LEN, 1)
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) # shape: (1, K_LEN)
causal_mask = (j_global <= i_global) # shape: (Q_LEN, K_LEN)
final_mask = effective_mask & causal_mask # shape: (B, H, Q_LEN, K_LEN)
attn = attn.masked_fill(~final_mask, float('-inf'))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
# Verify accuracy.
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \
"Triton output doesn't match reference when qlen < klen"
print("Pass topk sparse attention test with qlen < klen")
def main():
test_topk_sparse_attention()
test_topk_sparse_attention_qlt_kl()
if __name__ == "__main__":
main()
import math
import torch
import tilelang
import tilelang.language as T
import torch.nn.functional as F
def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
dense_mask = x > threshold
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
dense_mask.tril_()
return dense_mask
@tilelang.jit(
out_idx=[4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64
block_N = 64
num_stages = 1
threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len]
dtype = "float16"
accum_dtype = "float"
block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def blocksparse_flashattn(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for vj in T.serial(downsample_len):
block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k] != 0:
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return blocksparse_flashattn
return kernel_func(block_M, block_N, num_stages, threads)
def test_topk_sparse_attention():
# Config
BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64
TOPK = 2 # Keep top 8 elements per row
BLOCK = 64
torch.manual_seed(0)
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
# Run tilelang kernel
kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
tilelang_output = kernel(q, k, v, block_mask)
# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf'))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
print("ref_output", ref_output)
print("tilelang_output", tilelang_output)
# Verify accuracy
torch.testing.assert_close(tilelang_output, ref_output, atol=1e-2, rtol=1e-2)
print("Pass topk sparse attention test with qlen == klen")
def main():
test_topk_sparse_attention()
if __name__ == "__main__":
main()
# ruff: noqa
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
import time
import math
from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // heads_kv
@tilelang.jit(
out_idx=[-1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def kernel_func(block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages,
max_num_blocks_per_seq, max_selected_blocks):
shape_q = [batch, heads, dim]
shape_k = [num_pages, page_block_size, heads_kv, dim]
shape_v = [num_pages, page_block_size, heads_kv, dim_v]
shape_indices = [batch, heads_kv, max_selected_blocks]
shape_block_table = [batch, max_num_blocks_per_seq]
shape_o = [batch, heads, dim_v]
part_shape = [batch, heads, num_split, dim_v]
valid_block_H = min(block_H, kv_group_num)
assert block_N <= page_block_size and page_block_size % block_N == 0
block_ratio = page_block_size // block_N
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
block_table: T.Tensor(shape_block_table, "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
):
with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
has_valid_block = T.alloc_var("bool")
bid = bx
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
num_blocks = max_selected_blocks
blocks_per_split = T.floordiv(num_blocks, num_split)
remaining_blocks = T.floormod(num_blocks, num_split)
loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0))
start = blocks_per_split * sid + T.min(sid, remaining_blocks)
has_valid_block = False
for k in T.Pipelined(loop_range, num_stages=num_stages):
logical_block_idx = block_indices[bid, cur_kv_head, start + k]
if logical_block_idx >= 0:
has_valid_block = True
block_table_idx = T.floordiv(logical_block_idx, block_ratio)
block_tile_idx = T.floormod(logical_block_idx, block_ratio)
physical_block_idx = block_table[bid, block_table_idx]
T.copy(
K[physical_block_idx,
block_tile_idx * block_N:(block_tile_idx + 1) * block_N,
cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(
logical_block_idx * block_N + j >= cache_seqlens[bid],
-T.infinity(accum_dtype), acc_s[i, j])
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.if_then_else(scores_max[i] > scores_max_prev[i],
scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] *= scores_scale[i]
T.copy(
V[physical_block_idx,
block_tile_idx * block_N:(block_tile_idx + 1) * block_N,
cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_valid_block:
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
for i in T.Parallel(block_H):
if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i]
for i, j in T.Parallel(block_H, dim_v):
if i < valid_block_H:
Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j]
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype)
o_accum_local = T.alloc_fragment([dim_v], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], "int32")
T.annotate_layout({
lse_logsum_local:
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_local_split[0] = glse[bz, by, k]
if (lse_local_split[0] != 0):
max_split[0] = k
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
if k <= max_split[0]:
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
if k <= max_split[0]:
for i in T.Parallel(dim_v):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim_v):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim_v):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
block_table: T.Tensor(shape_block_table, "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn_split(Q, K, V, block_indices, cache_seqlens, block_table, glse,
Output_partial)
combine(glse, Output_partial, Output)
return main
return kernel_func
class SparseFlashAttn(torch.nn.Module):
def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages):
super(SparseFlashAttn, self).__init__()
self.batch = batch
self.heads = heads
self.heads_kv = heads_kv
self.dim = dim
self.dim_v = dim_v
self.block_N = block_N
self.page_block_size = page_block_size
self.num_pages = num_pages
self.block_H = 64
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_N,
block_H=self.block_H,
page_block_size=page_block_size,
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
num_pages=num_pages,
max_num_blocks_per_seq=T.dynamic("max_num_blocks_per_seq"),
max_selected_blocks=T.dynamic("max_selected_blocks"),
)
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
def forward(self, query, key, value, block_indices, cache_seqlens, block_table):
batch = self.batch
heads = self.heads
heads_kv = self.heads_kv
dim_v = self.dim_v
dim = self.dim
block_size = self.block_N
max_selected_blocks = block_indices.shape[-1]
# Compute static scheduling parameters
num_m_blocks = 1 * (heads // heads_kv + self.block_H - 1) // self.block_H
num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = self.num_sm
num_split = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
output = self.kernel(
query,
key,
value,
block_indices,
cache_seqlens,
block_table,
glse,
output_partial,
)
return output
def ref_program_torch_paged(query, key_cache, value_cache, block_indices, cache_seqlens,
block_table, page_block_size, block_size):
"""
Paged version of sparse attention reference implementation.
Args:
query: [batch, heads, dim]
key_cache: [num_pages, page_block_size, heads_kv, dim]
value_cache: [num_pages, page_block_size, heads_kv, dim]
block_indices: [batch, heads_kv, max_selected_blocks] - logical block indices
cache_seqlens: [batch] - actual sequence lengths
block_table: [batch, max_num_blocks_per_seq] - maps logical to physical blocks
page_block_size: size of each page block
block_size: size of attention blocks (block_N)
"""
batch, heads, dim = query.shape
heads_kv = key_cache.shape[2]
dim_v = value_cache.shape[3]
num_head_groups = heads // heads_kv
scale = dim**0.5
# Reconstruct the full key and value tensors from paged cache
max_cache_seqlen = max(cache_seqlens).item()
key_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim),
dtype=key_cache.dtype,
device=key_cache.device)
value_full = torch.zeros((batch, heads_kv, max_cache_seqlen, dim_v),
dtype=value_cache.dtype,
device=value_cache.device)
# Reconstruct full tensors from paged cache using block_table
for b in range(batch):
seq_len = cache_seqlens[b].item()
num_blocks_needed = int(math.ceil(seq_len / page_block_size))
for block_idx in range(num_blocks_needed):
physical_block_idx = block_table[b, block_idx].item()
# Calculate the range of tokens for this block
start_token = block_idx * page_block_size
end_token = min(start_token + page_block_size, seq_len)
actual_block_size = end_token - start_token
# Copy from paged cache to full tensors
key_full[b, :, start_token:end_token, :] = key_cache[
physical_block_idx, :actual_block_size, :, :].transpose(0, 1)
value_full[b, :, start_token:end_token, :] = value_cache[
physical_block_idx, :actual_block_size, :, :].transpose(0, 1)
# Reshape query for grouped attention
query = rearrange(
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
# Compute attention scores
scores = einsum(
query, key_full,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
# Create sparse mask based on block_indices
sparse_mask = torch.zeros_like(scores)
# Apply sparse mask based on selected blocks
for b in range(batch):
for h in range(heads_kv):
valid_indices = block_indices[b, h] # Extract indices for this batch and head
for idx in valid_indices:
if idx >= 0: # Valid block index
start_pos = idx * block_size
end_pos = min(start_pos + block_size, max_cache_seqlen)
sparse_mask[b, :, h, start_pos:end_pos] = 1
# Apply sparse mask
scores = scores.masked_fill(sparse_mask == 0, float('-inf'))
# Apply causal mask based on actual sequence lengths
range_len = torch.arange(scores.shape[-1], device=scores.device).unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf'))
# Compute attention weights
attention = F.softmax(scores / scale, dim=-1)
# Apply attention to values
out = einsum(attention, value_full,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim]
# Reshape output back to original format
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
def ref_program_fa(query, kcache, vcache, cache_seqlens, block_table):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(
query, kcache, vcache, cache_seqlens=cache_seqlens, block_table=block_table)
output = output.squeeze(1)
return output
def main(args):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v
sparse_ratio = args.sparse_ratio
block_N = args.block_N
page_block_size = args.page_block_size
num_blocks = args.num_pages # Use num_pages from args
# For dense case verification, set sparse_ratio to 0 to select all blocks
max_selected_blocks = int(math.ceil(max_cache_seqlen / block_N))
print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16
# Generate random inputs
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
cache_seqlens = torch.randint(
max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device='cuda')
print("cache_seqlens: ", cache_seqlens)
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda')
# Create paged KV cache
K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device='cuda')
V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v),
dtype=dtype,
device='cuda')
# Create block table and block indices for dense case (all blocks selected)
max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size))
print("max_num_blocks_per_seq: ", max_num_blocks_per_seq)
block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device='cuda')
block_indices = torch.zeros((batch, heads_kv, max_selected_blocks),
dtype=torch.int32,
device='cuda')
# Fill block table and block indices and cache
# Create a pool of available physical blocks
total_blocks_needed = sum(
int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch))
available_blocks = list(range(total_blocks_needed))
import random
random.seed(42) # For reproducibility
random.shuffle(available_blocks)
# Fill block table with random physical block indices
block_assignment = {} # Map (seq_idx, block_idx) -> physical_block_idx
block_idx_counter = 0
for seq_idx in range(batch):
seq_len = cache_seqlens[seq_idx].item()
num_blocks_needed = int(math.ceil(seq_len / page_block_size))
# Assign random physical blocks for each sequence
for block_idx in range(num_blocks_needed):
physical_block_idx = available_blocks[block_idx_counter]
block_table[seq_idx, block_idx] = physical_block_idx
block_assignment[(seq_idx, block_idx)] = physical_block_idx
block_idx_counter += 1
print(f"Block table: {block_table}")
# Fill K_cache and V_cache with data from original K and V tensors using random block assignment
for seq_idx in range(batch):
seq_len = cache_seqlens[seq_idx].item()
num_blocks_needed = int(math.ceil(seq_len / page_block_size))
for block_idx in range(num_blocks_needed):
physical_block_idx = block_assignment[(seq_idx, block_idx)]
# Calculate the range of tokens for this block
start_token = block_idx * page_block_size
end_token = min(start_token + page_block_size, seq_len)
actual_block_size = end_token - start_token
# Copy K and V data to the paged cache
K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx,
start_token:end_token, :, :]
V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx,
start_token:end_token, :, :]
# Fill block_indices for sparse attention
# For dense case (verification), we select all blocks in reverse order
# For sparse case, we select a subset of blocks based on sparse_ratio
for seq_idx in range(batch):
seq_len = cache_seqlens[seq_idx].item()
num_tile = int(math.ceil(seq_len / block_N))
if sparse_ratio == 0.0:
# Dense case: select all blocks in reverse order
selected_blocks = min(num_tile, max_selected_blocks)
for head_idx in range(heads_kv):
for i in range(selected_blocks):
# Select blocks in reverse order (most recent first)
block_indices[seq_idx, head_idx, i] = num_tile - 1 - i
# Fill remaining slots with -1 (invalid)
for i in range(selected_blocks, max_selected_blocks):
block_indices[seq_idx, head_idx, i] = -1
else:
# Fill block_indices for all KV heads
num_selected = int(num_tile * (1.0 - sparse_ratio))
num_selected = max(1, min(num_selected, max_selected_blocks))
all_blocks = list(range(num_tile))
for head_idx in range(heads_kv):
selected_blocks = []
# Always include the most recent blocks
recent_blocks = 1
selected_blocks.append(num_tile - 1)
# Randomly select some earlier blocks
if num_selected > recent_blocks:
remaining_blocks = [b for b in all_blocks if b not in selected_blocks]
if remaining_blocks:
import random
random.seed(42) # For reproducibility
additional_blocks = random.sample(
remaining_blocks,
min(num_selected - recent_blocks, len(remaining_blocks)))
selected_blocks.extend(additional_blocks)
# Sort selected blocks in reverse order (most recent first)
selected_blocks.sort(reverse=True)
for i in range(len(selected_blocks)):
block_indices[seq_idx, head_idx, i] = selected_blocks[i]
# Fill remaining slots with -1 (invalid)
for i in range(len(selected_blocks), max_selected_blocks):
block_indices[seq_idx, head_idx, i] = -1
# Initialize sparse attention module
sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N,
num_blocks)
output_sparse = sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens,
block_table)
import flash_attn # noqa: F401
output_ref_torch = ref_program_torch_paged(Q, K_cache, V_cache, block_indices, cache_seqlens,
block_table, page_block_size, block_N)
output_ref_fa = ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table)
# Check correctness
if sparse_ratio == 0.0:
max_diff = torch.max(torch.abs(output_sparse - output_ref_fa)).item()
mean_diff = torch.mean(torch.abs(output_sparse - output_ref_fa)).item()
assert torch.allclose(
output_ref_fa, output_ref_torch, atol=1e-2), "Reference outputs do not match!"
else:
max_diff = torch.max(torch.abs(output_sparse - output_ref_torch)).item()
mean_diff = torch.mean(torch.abs(output_sparse - output_ref_torch)).item()
print(f"Max difference: {max_diff:.6f}")
print(f"Mean difference: {mean_diff:.6f}")
if max_diff < 1e-2:
print("✓ Verification PASSED: Results match within tolerance")
else:
print("✗ Verification FAILED: Results differ significantly")
# Performance measurement
for _ in range(10): # Warm-up
sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table)
torch.cuda.synchronize()
start_time = time.time()
for _ in range(100): # Run multiple times for averaging
sparse_attn.forward(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table)
torch.cuda.synchronize()
end_time = time.time()
kernel_time = (end_time - start_time) / 100 * 1000 # Convert to ms
print(f"Kernel execution time: {kernel_time:.2f} ms")
# FA performance measurement
for _ in range(10): # Warm-up
ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table)
torch.cuda.synchronize()
start_time_fa = time.time()
for _ in range(100): # Run multiple times for averaging
ref_program_fa(Q, K_cache, V_cache, cache_seqlens, block_table)
torch.cuda.synchronize()
end_time_fa = time.time()
kernel_time_fa = (end_time_fa - start_time_fa) / 100 * 1000 # Convert to ms
print(f"FA kernel execution time: {kernel_time_fa:.2f} ms")
print(f"Speedup: {kernel_time_fa / kernel_time:.2f}x")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.0, help='sparse ratio')
parser.add_argument('--block_N', type=int, default=64, help='block_N')
parser.add_argument('--page_block_size', type=int, default=256, help='block size of pages')
parser.add_argument('--num_pages', type=int, default=1024, help='total number of pages')
args = parser.parse_args()
main(args)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from einops import rearrange, einsum
import argparse
import time
import math
from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // heads_kv
@tilelang.jit(
out_idx=[-1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen,
max_selected_blocks):
shape_q = [batch, heads, dim]
shape_k = [batch, max_cache_seqlen, heads_kv, dim]
shape_v = [batch, max_cache_seqlen, heads_kv, dim_v]
shape_indices = [batch, heads_kv, max_selected_blocks]
shape_o = [batch, heads, dim_v]
part_shape = [batch, heads, num_split, dim_v]
valid_block_H = min(block_H, kv_group_num)
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
):
with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
# O_shared = T.alloc_shared([valid_block_H, dim_v], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
has_valid_block = T.alloc_var("bool")
bid = bx
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
num_blocks = max_selected_blocks
blocks_per_split = T.floordiv(num_blocks, num_split)
remaining_blocks = T.floormod(num_blocks, num_split)
loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0))
start = blocks_per_split * sid + T.min(sid, remaining_blocks)
has_valid_block = False
for k in T.Pipelined(loop_range, num_stages=num_stages):
i_s = block_indices[bid, cur_kv_head, start + k]
if i_s >= 0:
has_valid_block = True
T.copy(K[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition
for i, j in T.Parallel(block_H, block_N):
acc_s[i,
j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid],
-T.infinity(accum_dtype), acc_s[i, j])
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.if_then_else(scores_max[i] > scores_max_prev[i],
scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] *= scores_scale[i]
T.copy(V[bid, i_s * block_N:(i_s + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_valid_block:
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
for i in T.Parallel(block_H):
if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i]
for i, j in T.Parallel(block_H, dim_v):
if i < valid_block_H:
Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j]
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype)
o_accum_local = T.alloc_fragment([dim_v], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], "int32")
T.annotate_layout({
lse_logsum_local:
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_local_split[0] = glse[bz, by, k]
if (lse_local_split[0] != 0):
max_split[0] = k
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
if k <= max_split[0]:
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
if k <= max_split[0]:
for i in T.Parallel(dim_v):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim_v):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim_v):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
# actual_num_blocks: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
# flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial)
flash_attn_split(Q, K, V, block_indices, cache_seqlens, glse, Output_partial)
combine(glse, Output_partial, Output)
return main
return kernel_func
class SparseFlashAttn(torch.nn.Module):
def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
super(SparseFlashAttn, self).__init__()
self.batch = batch
self.heads = heads
self.heads_kv = heads_kv
self.dim = dim
self.dim_v = dim_v
self.block_size = block_size
self.block_H = 64
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=self.block_H,
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.dynamic("max_selected_blocks"))
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
def forward(self, query, key, value, block_indices, cache_seqlens):
batch = self.batch
heads = self.heads
heads_kv = self.heads_kv
dim_v = self.dim_v
dim = self.dim
block_size = self.block_size
max_selected_blocks = block_indices.shape[-1]
# Compute static scheduling parameters
num_m_blocks = 1 * (heads // heads_kv + self.block_H - 1) // self.block_H
num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = self.num_sm
num_split = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial)
return output
def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seqlens,
max_cache_seqlen, block_size):
"""
Args:
query: [batch, heads, dim]
key: [batch, max_cache_seqlen, heads_kv, dim]
value: [batch, max_cache_seqlen, heads_kv, dim_v]
block_indices: [batch, heads_kv, max_selected_blocks], indices of selected blocks, -1 for padding
cache_seqlens: [batch], sequence lengths of the kvcache
max_cache_seqlen: maximum sequence length of kvcache
block_size: block size
Returns:
output: [batch, heads, dim_v]
"""
batch, heads, dim = query.shape
heads_kv = key.shape[2]
dim_v = value.shape[-1]
max_selected_blocks = block_indices.shape[-1]
block_H = 64
actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)
actual_num_blocks = actual_num_blocks[:,
0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
# get num_split
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head = max_selected_blocks * block_size * (
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 132
num_split = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=block_H,
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.dynamic("max_selected_blocks"))
output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial)
return output
def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks,
block_size):
batch, heads, dim = query.shape
heads_kv = key.shape[2]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange(
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores)
# Assign mask values based on block_indices
for b in range(batch):
for h in range(heads_kv):
valid_indices = block_indices[b, h] # Extract indices for this batch and head
for idx in valid_indices:
if idx >= 0:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf'))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf'))
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks,
block_size):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1)
return output
def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol)
print(name + " all_close={}".format(all_close))
if not all_close:
diff = (expect - actual).abs()
print("all_close={}, max={}, min={}, mean={}".format(all_close,
diff.max().item(),
diff.min().item(),
diff.mean().item()))
max_indices = torch.nonzero(diff == diff.max().item())
first_index = tuple(max_indices[0].tolist())
print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
def main(batch=8,
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
sparse_ratio = sparse_ratio
block_size = block_size
max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size))
print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda')
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda')
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# # Ensure at least one element equals cache_seqlen
# random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
# # cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
print("cache_seqlens: ", cache_seqlens)
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_indices with -1 (for padding blocks)
block_indices = torch.full((batch, heads_kv, max_selected_blocks),
-1,
dtype=torch.int32,
device='cuda')
# max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size)
# block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda')
# Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch):
max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch
if max_valid_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv):
valid_indices = torch.randperm(
max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks]
# valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks]
block_indices[b, h, :len(valid_indices)] = valid_indices
# Sort indices within each batch-group for consistency
block_indices, _ = block_indices.sort(dim=-1, descending=True)
# print("block_indices: ", block_indices)
actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0]
print("actual_num_blocks: ", actual_num_blocks)
# print(block_indices.shape, actual_num_blocks.shape)
max_num_blocks = torch.max(max_valid_num_blocks).item()
print("max_num_blocks: ", max_num_blocks)
# parity reference
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks,
block_size)
sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size)
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
debug("output", ref, out, atol=1e-3, rtol=1e-3)
import flash_attn # noqa: F401
## latency reference
for _ in range(10):
ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen,
max_num_blocks, block_size)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen,
max_num_blocks, block_size)
torch.cuda.synchronize()
print("dense time: ", (time.time() - start) / 100 * 1000)
for _ in range(10):
# out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size)
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
# out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size)
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
torch.cuda.synchronize()
print("sparse time: ", (time.time() - start) / 100 * 1000)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
args.sparse_ratio, args.block_size)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
import time
import math
from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // heads_kv
@tilelang.jit(
out_idx=[-1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks):
shape_q = [batch, heads, dim]
shape_k = [batch, max_cache_seqlen, heads_kv, dim]
shape_v = [batch, max_cache_seqlen, heads_kv, dim_v]
shape_mask = [batch, heads_kv, num_blocks]
shape_o = [batch, heads, dim_v]
part_shape = [batch, heads, num_split, dim_v]
valid_block_H = min(block_H, kv_group_num)
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, "bool"),
cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
):
with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
# O_shared = T.alloc_shared([valid_block_H, dim_v], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
has_valid_block = T.alloc_var("bool")
bid = bx
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
blocks_per_split = T.floordiv(num_blocks, num_split)
remaining_blocks = T.floormod(num_blocks, num_split)
loop_range = (blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0))
start = blocks_per_split * sid + T.min(sid, remaining_blocks)
has_valid_block = False
for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[bid, hid, start + k]:
has_valid_block = True
T.copy(
K[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :],
K_shared)
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else((start + k) * block_N + j
>= cache_seqlens[bx],
-T.infinity(accum_dtype), acc_s[i, j])
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] *= scores_scale[i]
T.copy(
V[bid, (start + k) * block_N:(start + k + 1) * block_N, cur_kv_head, :],
V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_valid_block:
for i, j in T.Parallel(block_H, dim_v):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
for i in T.Parallel(block_H):
if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i]
for i, j in T.Parallel(block_H, dim_v):
if i < valid_block_H:
Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j]
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim_v], accum_dtype)
o_accum_local = T.alloc_fragment([dim_v], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local:
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim_v):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim_v):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim_v):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, "bool"),
cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn_split(Q, K, V, block_mask, cache_seqlens, glse, Output_partial)
combine(glse, Output_partial, Output)
return main
return kernel_func
class SparseFlashAttn(torch.nn.Module):
def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size):
super(SparseFlashAttn, self).__init__()
self.batch = batch
self.heads = heads
self.heads_kv = heads_kv
self.dim = dim
self.dim_v = dim_v
self.block_size = block_size
self.block_H = 64
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=self.block_H,
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.dynamic("num_blocks"))
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
def forward(self, query, key, value, block_mask, cache_seqlens):
batch = self.batch
heads = self.heads
heads_kv = self.heads_kv
dim_v = self.dim_v
dim = self.dim
block_size = self.block_size
block_H = self.block_H
max_cache_seqlen = key.shape[1]
# get num_split
max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * (
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
# num_sm = 132
num_sm = self.num_sm
num_split = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
# print("num_split: ", num_split)
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial)
return output
def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, block_size):
"""
Args:
query: [batch, heads, dim]
key: [batch, max_cache_seqlen, heads_kv, dim]
value: [batch, max_cache_seqlen, heads_kv, dim_v]
block_mask: [batch, heads_kv, num_blocks], mask for valid blocks
cache_seqlens: [batch], sequence lengths of the kvcache
block_size: block size
Returns:
output: [batch, heads, dim_v]
"""
batch, heads, dim = query.shape
heads_kv = key.shape[2]
dim_v = value.shape[-1]
block_H = 64
actual_num_blocks = torch.sum(block_mask, dim=-1).to(torch.int32)
actual_num_blocks = actual_num_blocks[:,
0] #[batch], number of valid blocks, assume all groups in the same batch have the same number of blocks
max_selected_blocks = actual_num_blocks.max().item()
# get num_split
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks #(kv_seqlen + block_size - 1 ) // block_size
# num_n_blocks = torch.sum(actual_num_blocks, dim=-1).item() * heads_kv # total number of blocks
size_one_kv_head = max_selected_blocks * block_size * (
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 132
num_split = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=block_H,
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.dynamic("num_blocks"))
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
device='cuda')
# print(kernel.get_kernel_source())
output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial)
return output
def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks,
block_size):
batch, heads, dim = query.shape
heads_kv = key.shape[2]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange(
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores)
# Assign mask values
for b in range(batch):
for h in range(heads_kv):
for idx in range(num_blocks):
if block_mask[b, h, idx]:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf'))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf'))
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks,
block_size):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1)
return output
def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol)
print(name + " all_close={}".format(all_close))
if not all_close:
# print(expect[3, 28])
# print(actual[3, 28])
diff = (expect - actual).abs()
print("all_close={}, max={}, min={}, mean={}".format(all_close,
diff.max().item(),
diff.min().item(),
diff.mean().item()))
max_indices = torch.nonzero(diff == diff.max().item())
first_index = tuple(max_indices[0].tolist())
print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
def main(batch=8,
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
sparse_ratio = sparse_ratio
block_size = block_size
max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size))
print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda')
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
cache_seqlens[
random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
print("cache_seqlens: ", cache_seqlens)
num_blocks = (max_cache_seqlen + block_size - 1) // block_size
valid_num_blocks = torch.ceil(cache_seqlens * (1 - sparse_ratio) / block_size).int()
print("valid_num_blocks: ", valid_num_blocks)
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_mask with false (for padding blocks)
block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda')
# Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch):
max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch
valid_num_block = valid_num_blocks[b].item() # Valid blocks for this batch
if valid_num_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv):
perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block]
block_mask[b, h, perm] = True
# print("block_mask: ", block_mask)
# parity reference
ref = ref_program_torch(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks,
block_size)
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size)
out = model(Q, K, V, block_mask, cache_seqlens)
debug("output", ref, out, atol=1e-3, rtol=1e-3)
import flash_attn # noqa: F401
## latency reference
for _ in range(10):
ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks,
block_size)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks,
block_size)
torch.cuda.synchronize()
print("dense time: ", (time.time() - start) / 100 * 1000)
for _ in range(10):
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
out = model(Q, K, V, block_mask, cache_seqlens)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
# out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size)
out = model(Q, K, V, block_mask, cache_seqlens)
torch.cuda.synchronize()
print("sparse time: ", (time.time() - start) / 100 * 1000)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
args.sparse_ratio, args.block_size)
# ruff: noqa
import torch
import triton
import triton.language as tl
import argparse
from einops import rearrange, einsum
import torch.nn.functional as F
import math
import time
from heuristic import num_splits_heuristic
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_H', 'BLOCK_N', 'BLOCK_D'],
)
@triton.jit
def _split_kernel(
q_ptr,
k_cache_ptr,
v_cache_ptr,
cache_seqlens_ptr,
o_partial_ptr,
lse_partial_ptr,
mask_ptr,
sm_scale,
num_splits,
gqa_group_size,
max_selected_blocks,
stride_q_b,
stride_q_h,
stride_q_d,
stride_k_b,
stride_k_s,
stride_k_h,
stride_k_d,
stride_v_b,
stride_v_s,
stride_v_h,
stride_v_d,
stride_o_b,
stride_o_h,
stride_o_split,
stride_o_d,
stride_lse_b,
stride_lse_h,
stride_lse_split,
stride_mask_b,
stride_mask_h,
stride_mask_s,
BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
batch_idx = tl.program_id(0)
head_idx_kv = tl.program_id(1)
split_idx = tl.program_id(2)
head_idx_q = head_idx_kv * gqa_group_size
offs_h = tl.arange(0, BLOCK_H)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_D)
m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32)
cache_seqlens = tl.load(cache_seqlens_ptr + batch_idx)
num_blocks = max_selected_blocks
blocks_per_split = tl.floor(num_blocks / num_splits).to(tl.int32)
remaining_blocks = num_blocks % num_splits
if split_idx < remaining_blocks:
loop_range = blocks_per_split + 1
else:
loop_range = blocks_per_split
q_ptr += batch_idx * stride_q_b + head_idx_q * stride_q_h
k_cache_ptr += batch_idx * stride_k_b + head_idx_kv * stride_k_h + offs_n[
None, :] * stride_k_s + offs_d[:, None] * stride_k_d
v_cache_ptr += batch_idx * stride_v_b + head_idx_kv * stride_v_h + offs_n[:,
None] * stride_v_s + offs_d[
None, :] * stride_v_d
mask_ptr += batch_idx * stride_mask_b + head_idx_kv * stride_mask_h
q = tl.load(
q_ptr + offs_h[:, None] * stride_q_h + offs_d[None, :] * stride_q_d,
mask=offs_h[:, None] < gqa_group_size)
start = blocks_per_split * split_idx + tl.minimum(split_idx, remaining_blocks)
for i in range(loop_range):
block_idx = tl.load(mask_ptr + (start + i) * stride_mask_s)
if block_idx >= 0:
start_n = block_idx * BLOCK_N
k_ptr = k_cache_ptr + start_n * stride_k_s
v_ptr = v_cache_ptr + start_n * stride_v_s
k = tl.load(k_ptr, mask=start_n + offs_n[None, :] < cache_seqlens, other=0.0)
v = tl.load(v_ptr, mask=start_n + offs_n[:, None] < cache_seqlens, other=0.0)
qk = tl.dot(q, k)
qk = qk * sm_scale
qk = tl.where(start_n + offs_n[None, :] < cache_seqlens, qk, float("-inf"))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.exp(qk)
l_ij = tl.sum(p, 1)
alpha = tl.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
p = p.to(v.type.element_ty)
acc += tl.dot(p, v)
m_i = m_ij
m_i += tl.math.log(l_i)
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
acc = acc.to(o_partial_ptr.dtype.element_ty)
lse_partial_ptr += batch_idx * stride_lse_b + (
head_idx_q + offs_h) * stride_lse_h + split_idx * stride_lse_split
tl.store(lse_partial_ptr, m_i, mask=offs_h < gqa_group_size)
o_partial_ptr += batch_idx * stride_o_b + (
head_idx_q +
offs_h[:, None]) * stride_o_h + split_idx * stride_o_split + offs_d[None, :] * stride_o_d
tl.store(o_partial_ptr, acc, mask=offs_h[:, None] < gqa_group_size)
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4]\
for num_stages in [1, 2, 3, 4, 7]
],
key=['BLOCK_D'],
)
@triton.jit
def _merge_kernel(
o_partial_ptr,
lse_partial_ptr,
o_ptr,
lse_partial_stride_b,
lse_partial_stride_h,
lse_partial_stride_split,
o_partial_stride_b,
o_partial_stride_h,
o_partial_stride_split,
o_partial_stride_d,
o_stride_b,
o_stride_h,
o_stride_d,
BLOCK_D: tl.constexpr,
num_splits: tl.constexpr,
num_splits_pow2: tl.constexpr,
):
batch_idx = tl.program_id(0)
head_idx = tl.program_id(1)
offs_splits = tl.arange(0, num_splits_pow2)
offs_d = tl.arange(0, BLOCK_D)
lse_offsets = lse_partial_ptr + batch_idx * lse_partial_stride_b + head_idx * lse_partial_stride_h
lse = tl.load(
lse_offsets + offs_splits * lse_partial_stride_split,
mask=offs_splits < num_splits,
other=float("-inf"))
lse_max = tl.max(lse)
o_offsets = o_partial_ptr + batch_idx * o_partial_stride_b + head_idx * o_partial_stride_h
o_partial = tl.load(
o_offsets + offs_splits[:, None] * o_partial_stride_split +
offs_d[None, :] * o_partial_stride_d,
mask=offs_splits[:, None] < num_splits)
sumexp_normalized_splitk = tl.exp(lse - lse_max)
sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0)
numerator_normalized = tl.sum(o_partial * sumexp_normalized_splitk[:, None], axis=0)
acc = numerator_normalized / sumexp_normalized
acc = acc.to(o_ptr.dtype.element_ty)
o_ptr += batch_idx * o_stride_b + head_idx * o_stride_h
tl.store(o_ptr + offs_d * o_stride_d, acc)
def block_sparse_flash_decode_gqa_indice_triton(
q,
k_cache,
v_cache,
cache_seqlens,
max_cache_seqlen,
max_selected_blocks,
block_indices,
block_size,
sm_scale=None,
):
batch, heads, dim = q.shape
if sm_scale is None:
sm_scale = 1 / math.sqrt(dim)
_, max_cache_seqlen_cache, heads_kv, dim_v = v_cache.shape
assert max_cache_seqlen == max_cache_seqlen_cache, "max_cache_seqlen mismatch"
group_size = heads // heads_kv
block_H = 16
num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H
num_n_blocks = max_selected_blocks
size_one_kv_head = max_selected_blocks * block_size * (
dim + dim_v) * 2 #kv_seqlen * (dim + dim_v) * 2
total_mblocks = batch * heads_kv * num_m_blocks
num_sm = 64
# num_sm = self.num_sm
num_splits = num_splits_heuristic(
total_mblocks,
num_sm,
num_n_blocks,
num_m_blocks,
size_one_kv_head,
is_causal_or_local=True,
max_splits=128)
# print("num_splits:", num_splits, "num_blocks:", num_n_blocks)
num_splits_pow2 = triton.next_power_of_2(num_splits)
o_partial = torch.empty((batch, heads, num_splits, dim_v), device=q.device, dtype=q.dtype)
lse_partial = torch.empty((batch, heads, num_splits), device=q.device, dtype=torch.float32)
BLOCK_D = dim
BLOCK_H = group_size if group_size > 16 else 16
grid = (batch, heads_kv, num_splits)
_split_kernel[grid](
q,
k_cache,
v_cache,
cache_seqlens,
o_partial,
lse_partial,
block_indices,
sm_scale,
num_splits,
group_size,
max_selected_blocks,
q.stride(0),
q.stride(1),
q.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3),
o_partial.stride(0),
o_partial.stride(1),
o_partial.stride(2),
o_partial.stride(3),
lse_partial.stride(0),
lse_partial.stride(1),
lse_partial.stride(2),
block_indices.stride(0),
block_indices.stride(1),
block_indices.stride(2),
BLOCK_H=BLOCK_H,
BLOCK_N=block_size,
BLOCK_D=BLOCK_D,
)
output = torch.zeros((batch, heads, dim_v), device=q.device, dtype=q.dtype)
grid = (batch, heads)
_merge_kernel[grid](
o_partial,
lse_partial,
output,
lse_partial.stride(0),
lse_partial.stride(1),
lse_partial.stride(2),
o_partial.stride(0),
o_partial.stride(1),
o_partial.stride(2),
o_partial.stride(3),
output.stride(0),
output.stride(1),
output.stride(2),
BLOCK_D=dim_v,
num_splits=num_splits,
num_splits_pow2=num_splits_pow2,
)
return output
def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks,
block_size):
batch, heads, dim = query.shape
heads_kv = key.shape[2]
dim_v = value.shape[-1]
num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, heads_kv, seqlen_kv, dim]
query = rearrange(
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, heads_kv, dim]
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, heads_kv, seqlen_kv]
sparse_mask = torch.zeros_like(scores)
# Assign mask values based on block_indices
for b in range(batch):
for h in range(heads_kv):
valid_indices = block_indices[b, h] # Extract indices for this batch and head
for idx in valid_indices:
if idx >= 0:
sparse_mask[b, :, h, idx * block_size:(idx + 1) * block_size] = 1
scores = scores.masked_fill(sparse_mask == 0, float('-inf'))
range_len = torch.arange(scores.shape[-1], device='cuda').unsqueeze(0)
cache_seqlens_expanded = cache_seqlens.unsqueeze(1)
pad_mask = range_len >= cache_seqlens_expanded
pad_mask = pad_mask[:, None, None, :]
scores = scores.masked_fill(pad_mask, float('-inf'))
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, heads_kv, seqlen_kv]
out = einsum(attention, value,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, heads_kv, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
def ref_program_fa(query, key, value, cache_seqlens):
# latency reference
# from flash_attn_interface import flash_attn_with_kvcache # fa3
from flash_attn import flash_attn_with_kvcache #fa2
query = query.unsqueeze(1)
output = flash_attn_with_kvcache(query, key, value, cache_seqlens=cache_seqlens)
output = output.squeeze(1)
return output
def main(batch=64,
heads=32,
heads_kv=8,
max_cache_seqlen=8192,
dim=128,
dim_v=128,
sparse_ratio=0.8,
block_size=32):
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
sparse_ratio = sparse_ratio
block_size = block_size
qk_flops = 2 * batch * heads * max_cache_seqlen * dim
pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v
total_flops = qk_flops + pv_flops
max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size))
print("max_selected_blocks: ", max_selected_blocks)
dtype = torch.float16
block_H = 64
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda')
cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda')
# cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda')
# Ensure at least one element equals cache_seqlen
random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index
cache_seqlens[
random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence
print("cache_seqlens: ", cache_seqlens)
max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int()
print("max_valid_num_blocks: ", max_valid_num_blocks)
# Initialize block_indices with -1 (for padding blocks)
block_indices = torch.full((batch, heads_kv, max_selected_blocks),
-1,
dtype=torch.int32,
device='cuda')
# Assign valid indices while ensuring no duplicates within each batch-group
for b in range(batch):
max_valid_block = max_valid_num_blocks[b].item() # Max valid blocks for this batch
if max_valid_block > 0: # Ensure there's at least one valid block
for h in range(heads_kv):
valid_indices = torch.randperm(
max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks]
block_indices[b, h, :len(valid_indices)] = valid_indices
# Sort indices within each batch-group for consistency
block_indices, _ = block_indices.sort(dim=-1, descending=True)
# print("block_indices: ", block_indices)
actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0]
print("actual_num_blocks: ", actual_num_blocks)
# print(block_indices.shape, actual_num_blocks.shape)
max_num_blocks = torch.max(max_valid_num_blocks).item()
print("max_num_blocks: ", max_num_blocks)
ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks,
block_size)
triton_out = block_sparse_flash_decode_gqa_indice_triton(
Q,
K,
V,
cache_seqlens,
max_cache_seqlen,
max_selected_blocks,
block_indices,
block_size,
)
print("max difference: ", torch.max(torch.abs(ref - triton_out)))
assert torch.allclose(
ref, triton_out, atol=1e-2), "Output mismatch between Triton and reference implementation"
print("Passed the ref test!")
# Measure performance
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
block_sparse_flash_decode_gqa_indice_triton(
Q,
K,
V,
cache_seqlens,
max_cache_seqlen,
max_selected_blocks,
block_indices,
block_size,
)
torch.cuda.synchronize()
end = time.time()
elapsed_time = end - start
avg_time = elapsed_time / 1000
avg_flops = total_flops / avg_time
print(f"Average time: {avg_time:.6f} seconds")
# Measure performance of reference implementation
import flash_attn # noqa: F401
start = time.time()
for _ in range(1000):
ref_program_fa(Q, K, V, cache_seqlens)
torch.cuda.synchronize()
end = time.time()
elapsed_time_ref = end - start
avg_time_ref = elapsed_time_ref / 1000
avg_flops_ref = total_flops / avg_time_ref
print(f"Average time of ref: {avg_time_ref:.6f} seconds")
print(f"Speedup: {avg_time_ref / avg_time:.2f}x")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=64, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
parser.add_argument(
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
parser.add_argument('--block_size', type=int, default=32, help='block_size')
args = parser.parse_args()
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
args.sparse_ratio, args.block_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment