Unverified Commit 9ef98d52 authored by Gerald's avatar Gerald Committed by GitHub
Browse files

[Model][MiniMaxText01] Support MiniMaxText01 model inference (#13454)


Signed-off-by: default avatarqscqesze <475517977@qq.com>
Co-authored-by: default avatarqingjun <qingjun@minimaxi.com>
Co-authored-by: default avatarqscqesze <475517977@qq.com>
parent 93491aef
...@@ -503,6 +503,11 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -503,6 +503,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. * `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `MiniMaxText01ForCausalLM`
* MiniMax-Text
* `MiniMaxAI/MiniMax-Text-01`, etc.
*
* ✅︎
- * `Zamba2ForCausalLM` - * `Zamba2ForCausalLM`
* Zamba2 * Zamba2
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. * `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
......
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from vllm.model_executor.layers.lightning_attn import (
linear_decode_forward_triton)
from vllm.platforms import current_platform
NUM_HEADS = [4, 8]
HEAD_SIZES = [64]
BATCH_SIZES = [1, 2]
SEQ_LENGTHS = [16]
DTYPES = [torch.float32]
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
"""Reference implementation of lightning attention core algorithm
The difference from the main implementation is that this processes
each step sequentially, instead of using parallelized triton kernels
"""
B, H, S, D = q.shape
E = v.shape[-1]
dtype = q.dtype
output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device)
# Use clone() to ensure an independent copy
if kv_history is None:
kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device)
else:
kv_cache = kv_history.clone()
# More efficient implementation
# Convert decay factors to matrix form
if ed.dim() == 1:
decay = torch.exp(-ed).view(1, -1, 1, 1)
else:
decay = torch.exp(-ed)
for b in range(B):
for step in range(S):
# Process all heads at once for this position
q_bs = q[b, :, step] # [H, D]
k_bs = k[b, :, step] # [H, D]
v_bs = v[b, :, step] # [H, E]
# Calculate KV outer products for all heads
for h in range(H):
# Calculate KV outer product
kv_outer = torch.outer(k_bs[h], v_bs[h])
# Update KV cache with decay
# Note: Using the same order as in the Triton kernel
kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer
# Calculate attention output
output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h])
# Match the shape returned by the actual implementation
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
# where dimension 2 contains both KV and KV history
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
dim=2) # [B, H, 2, D, E]
return output, final_kv_cache
def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
"""Reference implementation: linear attention decode function"""
B, H, _, D = q.shape
output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device)
# Calculate decay factors once (more efficient)
decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1]
# Process each batch
for b in range(B):
slot_id = slot_idx[b].item()
# Skip padding positions
if slot_id == -1:
continue
# Process all heads at once for this batch
q_b = q[b, :, 0] # [H, D]
k_b = k[b, :, 0] # [H, D]
v_b = v[b, :, 0] # [H, D]
# Process each attention head
for h in range(H):
# Get current query, key and value
q_bh = q_b[h]
k_bh = k_b[h]
v_bh = v_b[h]
# Get cache
kv_cache_old = kv_caches[b, h]
# Calculate new key-value outer product
kv_outer = torch.outer(k_bh, v_bh)
# Apply decay and update cache
kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old
# Calculate output
out_h = torch.matmul(q_bh, kv_new)
# Update output and cache
output[b, h * D:(h + 1) * D] = out_h
kv_caches[b, h] = kv_new
return output
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton(
batch_size: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)
base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_caches_copy = kv_caches.clone()
slope_rate = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
slot_idx = torch.arange(batch_size, device="cuda")
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)
torch.testing.assert_close(triton_output,
reference_output,
rtol=1e-1,
atol=1e-1)
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
assert triton_output.shape == (batch_size, num_heads * head_size)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton_with_padding(
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)
batch_size = 4
base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_caches_copy = kv_caches.clone()
slope_rate = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)
padding_mask = (slot_idx
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)
triton_masked = triton_output[padding_mask]
reference_masked = reference_output[padding_mask]
atol, rtol = 1.5e-1, 1.5e-1
valid_indices = slot_idx != -1
for i in range(batch_size):
if valid_indices[i] > 0:
torch.testing.assert_close(kv_caches[i],
kv_caches_copy[i],
rtol=rtol,
atol=atol)
torch.testing.assert_close(triton_masked,
reference_masked,
rtol=rtol,
atol=atol)
assert triton_output.shape == (batch_size, num_heads * head_size)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_lightning_attention_reference(
batch_size: int,
num_heads: int,
head_size: int,
seq_len: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)
base = 0.01
q = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
k = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
v = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
ed = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
ed[h] = 0.1 * (h + 1)
kv_history = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_history_clone = kv_history.clone()
ref_output, ref_kv_cache = reference_lightning_attention(
q, k, v, ed, 256, kv_history)
from vllm.model_executor.layers.lightning_attn import lightning_attention
actual_output, actual_kv_cache = lightning_attention(
q, k, v, ed, 256, kv_history_clone)
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
torch.testing.assert_close(ref_kv_cache,
actual_kv_cache,
rtol=rtol,
atol=atol)
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
assert ref_kv_cache.shape == actual_kv_cache.shape
...@@ -176,6 +176,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -176,6 +176,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True), trust_remote_code=True),
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B", "MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
trust_remote_code=True), trust_remote_code=True),
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
trust_remote_code=True),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501 "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
......
...@@ -971,26 +971,34 @@ class ModelConfig: ...@@ -971,26 +971,34 @@ class ModelConfig:
return sum(not bc.attention.no_op return sum(not bc.attention.no_op
for bc in block_configs[start:end]) for bc in block_configs[start:end])
else: else:
# Hybrid model # Hybrid model Jamba
layers_block_type_value = getattr(self.hf_config, layers_block_type_value = getattr(self.hf_config,
"layers_block_type", None) "layers_block_type", None)
if layers_block_type_value is None: if layers_block_type_value is not None:
raise ValueError("The model is an hybrid without a " if hasattr(self.hf_text_config,
"layers_block_type in the hf_config, " "model_type") and (self.hf_text_config.model_type
"cannot determine the num of " == "zamba2"):
f"{block_type.value} layers") if attn_block_type:
return sum(t == "hybrid"
if hasattr(self.hf_text_config, for t in layers_block_type_value[start:end])
"model_type") and (self.hf_text_config.model_type else:
== "zamba2"): return self.get_num_layers(parallel_config)
if attn_block_type: return sum(t == block_type.value
return sum(t == "hybrid" for t in layers_block_type_value[start:end])
for t in layers_block_type_value[start:end])
else: # Hybrid model Minimax
return self.get_num_layers(parallel_config) attn_type_list = getattr(self.hf_config, "attn_type_list", None)
if attn_type_list:
return sum(t == 1 for t in attn_type_list[start:end])
if layers_block_type_value is None and attn_type_list is None:
raise ValueError(
"The model is an hybrid without a"
"layers_block_type or an attn_type_list in the hf_config,"
"cannot determine the num of "
f"{block_type.value} layers")
return sum(t == block_type.value return sum(t == 1 for t in attn_type_list[start:end])
for t in layers_block_type_value[start:end])
def get_multimodal_config(self) -> "MultiModalConfig": def get_multimodal_config(self) -> "MultiModalConfig":
""" """
......
...@@ -303,8 +303,11 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -303,8 +303,11 @@ class _AsyncLLMEngine(LLMEngine):
ctx.seq_group_metadata_list = seq_group_metadata_list ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs ctx.scheduler_outputs = scheduler_outputs
finished_requests_ids = self.scheduler[ if not scheduler_outputs.is_empty():
virtual_engine].get_and_reset_finished_requests_ids() # this will cause mamba_cache/minimax_cache failed
# to release finished_requests_ids of the last steps
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
# Maybe switch from async mode to sync mode # Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0: if not allow_async_output_proc and len(ctx.output_queue) > 0:
......
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from einops import rearrange
@triton.jit
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
NUM_BLOCK, CBLOCK: tl.constexpr):
# This kernel computes the diagonal blocks of the attention matrix
# Each diagonal block represents attention
# where queries attend to keys in the same block
off = tl.program_id(0)
off_bh = off // NUM_BLOCK # batch-head index
off_block = off % NUM_BLOCK # block index within the sequence
off_cblock = tl.program_id(1) # sub-block index within a block
off_h = off_bh % h # head index
# Calculate base offsets for the current batch and head
qk_offset = off_bh * n * d
v_offset = off_bh * n * e
o_offset = off_bh * n * e
# Calculate offsets for the current block
block_offset = off_block * BLOCK
qk_block_offset = block_offset * d
v_block_offset = block_offset * e
o_block_offset = block_offset * e
# Calculate offsets for the current sub-block
cblock_offset = off_cblock * CBLOCK
q_cblock_offset = cblock_offset * d
o_cblock_offset = cblock_offset * e
# Calculate pointers to the query, key, value, and output tensors
Q_block_ptr = (Q + qk_offset + qk_block_offset + q_cblock_offset +
tl.arange(0, CBLOCK)[:, None] * d +
tl.arange(0, d)[None, :])
K_trans_block_ptr = (K + qk_offset + qk_block_offset +
tl.arange(0, CBLOCK)[None, :] * d +
tl.arange(0, d)[:, None])
V_block_ptr = (V + v_offset + v_block_offset +
tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, e)[None, :])
O_block_ptr = (Out + o_offset + o_block_offset + o_cblock_offset +
tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, e)[None, :])
# Load the decay rate for the current head
S_block_ptr = S + off_h
s = tl.load(S_block_ptr)
i = off_cblock
q_index = tl.arange(0, CBLOCK) + i * CBLOCK
# Load query values
q = tl.load(Q_block_ptr,
mask=block_offset + q_index[:, None] < n,
other=0.0).to(tl.float32)
# Initialize output accumulator
qkv = tl.zeros([CBLOCK, e], dtype=tl.float32)
# Process all sub-blocks up to and
# including the current one (causal attention)
for j in range(i + 1):
kv_index = tl.arange(0, CBLOCK) + j * CBLOCK
diff = q_index[:, None] - kv_index[None, :]
s_index = s * diff
# Apply causal mask: only attend to positions before the current one
s_index = tl.where(diff >= 0, -s_index, float("-inf"))
decay = tl.exp(s_index)
# Load key and value
k_trans = tl.load(
K_trans_block_ptr,
mask=block_offset + kv_index[None, :] < n,
other=0.0,
).to(tl.float32)
v = tl.load(
V_block_ptr,
mask=block_offset + kv_index[:, None] < n,
other=0.0,
).to(tl.float32)
# Compute attention scores and apply decay
qk = tl.dot(q, k_trans) * decay
# Compute weighted values and accumulate
qkv += tl.dot(qk, v)
# Move to the next sub-block
K_trans_block_ptr += CBLOCK * d
V_block_ptr += CBLOCK * e
# Store the result
tl.store(
O_block_ptr,
qkv.to(O_block_ptr.dtype.element_ty),
mask=block_offset + q_index[:, None] < n,
)
@triton.jit
def _fwd_kv_parallel(
K,
V,
K_decay,
KV,
b: tl.constexpr,
h: tl.constexpr,
n,
d: tl.constexpr,
e: tl.constexpr,
BLOCK: tl.constexpr,
NUM_BLOCK,
D_FBLOCK: tl.constexpr,
E_FBLOCK: tl.constexpr,
NUM_FBLOCK: tl.constexpr,
CBLOCK: tl.constexpr,
NUM_CBLOCK: tl.constexpr,
):
# This kernel computes the key-value outer
# products for each block in parallel
off_bh = tl.program_id(0) # batch-head index
off_block = tl.program_id(1) # block index
off_h = off_bh % h # head index
block_offset = off_block * BLOCK
# Calculate offsets for the current block
k_block_offset = block_offset * d
v_block_offset = block_offset * e
kv_block_offset = off_block * d * e
# Calculate base offsets for the current batch and head
k_offset = off_bh * n * d
v_offset = off_bh * n * e
kv_offset = off_bh * NUM_BLOCK * d * e
# Calculate pointers to the key, value, and key-value tensors
K_trans_block_ptr = (K + k_offset + k_block_offset +
tl.arange(0, CBLOCK)[None, :] * d +
tl.arange(0, D_FBLOCK)[:, None])
V_block_ptr = (V + v_offset + v_block_offset +
tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
KV_block_ptr = (KV + kv_offset + kv_block_offset +
tl.arange(0, D_FBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the decay factors for the current head and block
k_decay_ptr = (K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :])
kv_index = tl.arange(0, CBLOCK)
# Initialize the key-value outer product accumulator
kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)
# Handle the last block which might be smaller than BLOCK
if off_block == NUM_BLOCK - 1:
split_n = n - (NUM_BLOCK - 1) * BLOCK
else:
split_n = BLOCK
left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n
num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK)
k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK
# Process all sub-blocks in the current block
for j in range(num_blocks):
left_bound = (1 - j) * left_shift
# Load key and value, handling boundary conditions
k_trans = tl.load(K_trans_block_ptr - left_shift * d,
mask=kv_index[None, :] >= left_bound,
other=0.0)
v = tl.load(V_block_ptr - left_shift * e,
mask=kv_index[:, None] >= left_bound,
other=0.0)
# Load decay factor and compute weighted key-value outer product
k_decay = tl.load(k_decay_ptr)
kv += tl.dot(k_trans * k_decay, v)
# Move to the next sub-block
K_trans_block_ptr += CBLOCK * d
V_block_ptr += CBLOCK * e
k_decay_ptr += CBLOCK
# Store the result
tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))
@triton.jit
def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n,
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
NUM_BLOCK, D_FBLOCK: tl.constexpr, E_FBLOCK: tl.constexpr):
# This kernel reduces the key-value outer products
# across blocks and updates the KV history
off_bh = tl.program_id(0) # batch-head index
off_h = off_bh % h # head index
kv_offset = off_bh * NUM_BLOCK * d * e
# Calculate pointer to the key-value tensor
KV_block_ptr = (KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the decay rate for the current head
s_ptrs = S + off_h
s = tl.load(s_ptrs)
# Calculate pointer to the key-value history tensor
kv_history_offset = off_bh * d * e
KV_HISTORY_block_ptr = (KV_HISTORY + kv_history_offset +
tl.arange(0, D_FBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the previous key-value history
kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32)
# Process all blocks in reverse order to compute the prefix sum
for i in range(NUM_BLOCK):
block_size = min(n - i * BLOCK, BLOCK)
# Compute decay factor for the current block
block_decay = tl.exp(-s.to(tl.float32) * block_size)
# Load the current key-value outer product
kv_cur = tl.load(KV_block_ptr).to(tl.float32)
# Store the previous key-value history to the current block
tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty))
# Update the key-value history with the current block
kv_pre = block_decay * kv_pre + kv_cur
KV_block_ptr += d * e
# Store the updated key-value history
tl.store(KV_HISTORY_block_ptr, kv_pre)
@triton.jit
def _fwd_none_diag_kernel(
Q,
Out,
S,
KV,
b: tl.constexpr,
h: tl.constexpr,
n,
d: tl.constexpr,
e: tl.constexpr,
BLOCK: tl.constexpr,
NUM_BLOCK,
E_FBLOCK: tl.constexpr,
CBLOCK: tl.constexpr,
NUM_CBLOCK: tl.constexpr,
):
# This kernel computes the non-diagonal blocks of the attention matrix
# Each non-diagonal block represents attention
# where queries attend to keys in different blocks
off_bh = tl.program_id(0) # batch-head index
off_h = off_bh % h # head index
off_nc = tl.program_id(1)
off_n = off_nc // NUM_CBLOCK # block index
off_c = off_nc % NUM_CBLOCK # sub-block index
off_e = tl.program_id(2) # output feature block index
n_offset = off_n * BLOCK
c_offset = off_c * CBLOCK
e_offset = off_e * E_FBLOCK
block_offset = n_offset + c_offset
# Calculate offsets for the current batch, head, and block
q_offset = off_bh * n * d + (n_offset + c_offset) * d
o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset
kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset
# Calculate pointers to the query, output, and key-value tensors
Q_block_ptr = (Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d +
tl.arange(0, d)[None, :])
O_block_ptr = (Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
KV_block_ptr = (KV + kv_offset + tl.arange(0, d)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the decay rate for the current head
S_block_ptr = S + off_h
s = tl.load(S_block_ptr)
c_array = tl.arange(0, CBLOCK)
# Load the key-value outer product for the current block
kv = tl.load(KV_block_ptr).to(tl.float32)
q_index = block_offset + tl.arange(0, CBLOCK)
# Load query values
q = tl.load(Q_block_ptr, mask=q_index[:, None] < n,
other=0.).to(tl.float32)
# Compute decay factors for the current sub-block
q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None]))
# Compute non-diagonal attention output
qkv_none_diag = tl.dot(q, kv) * q_decay
# Load diagonal attention output (computed by _fwd_diag_kernel)
qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n,
other=0.).to(tl.float32)
# Combine diagonal and non-diagonal attention outputs
qkv = qkv_diag + qkv_none_diag
# Store the result
tl.store(O_block_ptr,
qkv.to(O_block_ptr.dtype.element_ty),
mask=q_index[:, None] < n)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, s, kv_history):
# Forward pass of the lightning attention algorithm
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
s = s.contiguous()
# Check CUDA compute capability
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
raise RuntimeError("Flash attention currently only supported",
"for compute capability >= 80")
# Get input dimensions
b, h, n, d = q.shape
e = v.shape[-1]
# Initialize output tensor
o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
# Set block sizes
BLOCK = 256
NUM_BLOCK = triton.cdiv(n, BLOCK)
CBLOCK = 32
NUM_CBLOCK = BLOCK // CBLOCK
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
# Compute decay factors for keys
array = torch.arange(0, BLOCK, device=q.device) + 1
k_decay = torch.exp(-s * (BLOCK - array.reshape(1, -1)))
# Step 1: Compute diagonal blocks of attention
grid = (b * h * NUM_BLOCK, NUM_CBLOCK)
_fwd_diag_kernel[grid](q,
k,
v,
o,
s,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
CBLOCK=CBLOCK)
# Set feature block sizes
NUM_FBLOCK = 1
D_FBLOCK = d // NUM_FBLOCK
assert d % NUM_FBLOCK == 0
E_FBLOCK = e // NUM_FBLOCK
assert e % NUM_FBLOCK == 0
CBLOCK = 64
NUM_CBLOCK = BLOCK // CBLOCK
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
# Step 2: Compute key-value outer products for each block in parallel
kv = torch.empty((b, h, NUM_BLOCK, d, e),
dtype=torch.float32,
device=q.device)
grid = (b * h, NUM_BLOCK)
_fwd_kv_parallel[grid](
k,
v,
k_decay,
kv,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
D_FBLOCK=D_FBLOCK,
E_FBLOCK=E_FBLOCK,
NUM_FBLOCK=NUM_FBLOCK,
CBLOCK=CBLOCK,
NUM_CBLOCK=NUM_CBLOCK,
)
# Step 3: Reduce key-value outer products
# across blocks and update KV history
grid = (b * h, NUM_FBLOCK)
_fwd_kv_reduce[grid](s,
kv,
kv_history,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
D_FBLOCK=D_FBLOCK,
E_FBLOCK=E_FBLOCK)
# Step 4: Compute non-diagonal blocks of attention
grid = (b * h, NUM_BLOCK * NUM_CBLOCK)
_fwd_none_diag_kernel[grid](
q,
o,
s,
kv,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
E_FBLOCK=E_FBLOCK,
CBLOCK=CBLOCK,
NUM_CBLOCK=NUM_CBLOCK,
)
# Save tensors for backward pass
ctx.save_for_backward(q, k, v, s, kv)
ctx.BLOCK = BLOCK
return o, torch.cat([kv, kv_history.unsqueeze(2)], dim=2)
# Apply the lightning attention function
lightning_attention_ = _attention.apply
def lightning_attention(q, k, v, ed, block_size=256, kv_history=None):
"""
Apply lightning attention algorithm
to compute attention efficiently.
Args:
q: Query tensor of shape [batch, heads, seq_len, dim]
k: Key tensor of shape [batch, heads, seq_len, dim]
v: Value tensor of shape [batch, heads, seq_len, dim_v]
ed: Decay rate tensor of shape [heads]
block_size: Size of blocks for block-sparse attention
kv_history: Optional key-value history from previous computations
Returns:
output: Attention output
kv: Updated key-value history
"""
d = q.shape[-1]
e = v.shape[-1]
if ed.dim() == 1:
ed = ed.view(1, -1, 1, 1)
# Split the computation into chunks for better parallelism
m = 128 if d >= 128 else 64
assert d % m == 0, f"Dimension d ({d}) must be divisible by m ({m})"
arr = [m * i for i in range(d // m + 1)]
if arr[-1] != d:
arr.append(d)
n = len(arr)
output = 0
# Initialize or clone key-value history
if kv_history is None:
kv_history = torch.zeros((q.shape[0], q.shape[1], d, e),
dtype=torch.float32,
device=q.device)
else:
kv_history = kv_history.clone().contiguous()
# Process each chunk and accumulate results
for i in range(n - 1):
s = arr[i]
e = arr[i + 1]
q1 = q[..., s:e]
k1 = k[..., s:e]
o, kv = lightning_attention_(q1, k1, v, ed, kv_history)
output = output + o
return output, kv
@triton.jit
def _linear_attn_decode_kernel(
q_ptr,
k_ptr,
v_ptr,
kv_cache_ptr,
slope_rate,
slot_idx,
output_ptr,
D: tl.constexpr,
qkv_b_stride,
qkv_h_stride,
cache_b_stride,
cache_h_stride,
cache_d0_stride,
cache_d1_stride,
BLOCK_SIZE: tl.constexpr,
):
"""
Kernel for linear attention decoding with KV cache.
This kernel computes attention for a single token using the KV cache.
"""
pid_b = tl.program_id(0) # batch index
pid_h = tl.program_id(1) # head index
pid_d = tl.program_id(2) # dimension block index
# Load slot index for the current batch
slot_id = tl.load(slot_idx + pid_b)
# Skip if slot_id is -1 (padding)
if slot_id == -1:
return
batch_id = pid_b
head_id = pid_h
# Load decay rate for the current head
ratio = tl.load(slope_rate + pid_h)
# Calculate offsets for dimensions
qk_d_offsets = tl.arange(0, D)
v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE
cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[
None, :] * cache_d1_stride
# Calculate offsets for the current batch and head
q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride
# Create masks for loading tensors
qk_mask = qk_d_offsets < D
v_mask = v_d_offsets < D
# Load query, key, and value tensors
q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0)
k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0)
v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0)
# Compute key-value outer product
kv_outer = k[:, None] * v[None, :]
kv_mask = qk_mask[:, None] & v_mask[None, :]
# Apply decay to previous KV cache
ratio = tl.exp(-ratio)
kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets
kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0)
kv_outer = kv_outer + ratio * kv_cache_old
# Compute attention output
output = q[:, None].to(tl.float32) * kv_outer
output = tl.sum(output, axis=0)
# Update KV cache and store output
tl.store(kv_ptr, kv_outer, mask=kv_mask)
tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask)
def linear_decode_forward_triton(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
slot_idx: torch.Tensor,
BLOCK_SIZE: int = 32,
) -> torch.Tensor:
"""
Perform linear attention decoding using Triton kernels.
Args:
q: Query tensor of shape [B, H, 1, D]
k: Key tensor of shape [B, H, 1, D]
v: Value tensor of shape [B, H, 1, D]
kv_caches: Key-value cache tensor
slope_rate: Decay rate tensor
slot_idx: Slot indices for batches
BLOCK_SIZE: Size of blocks for processing
Returns:
output: Attention output tensor
"""
B, H, _, D = q.shape
assert k.shape == (B, H, 1, D)
assert v.shape == (B, H, 1, D)
# Initialize output tensor
output = torch.empty_like(q)
# Set grid dimensions for the kernel
grid = (B, H, D // BLOCK_SIZE)
# Calculate strides for tensors
qkv_b_stride = q.stride(0)
qkv_h_stride = q.stride(1)
cache_b_stride = kv_caches.stride(0)
cache_h_stride = kv_caches.stride(1)
cache_d0_stride = kv_caches.stride(2)
cache_d1_stride = kv_caches.stride(3)
# Launch the kernel
_linear_attn_decode_kernel[grid](
q,
k,
v,
kv_caches,
slope_rate,
slot_idx,
output,
D,
qkv_b_stride,
qkv_h_stride,
cache_b_stride,
cache_h_stride,
cache_d0_stride,
cache_d1_stride,
BLOCK_SIZE=BLOCK_SIZE,
)
# Reshape output and return
output = rearrange(output, "b h n d -> b n (h d)")
return output.squeeze(1).contiguous()
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
class ConstantSizeCache(ABC):
"""
Abstract base class for managing constant size caches
like Mamba and Minimax.
"""
def __init__(self, max_batch_size: int):
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the cache
self.cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
@property
@abstractmethod
def cache(self) -> Any:
"""Return the underlying cache tensor(s)"""
pass
@abstractmethod
def _copy_cache(self, from_index: int, to_index: int):
"""Copy cache data from one index to another"""
pass
def current_run_tensors(self, **kwargs) -> Tuple:
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
cache_tensors = self.cache
else:
# CUDA graph capturing runs
cache_tensors, state_indices_tensor = kwargs[
"seqlen_agnostic_capture_inputs"]
return (cache_tensors, state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Cache during the CUDA graph replay
runs.
"""
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.cache, state_indices_tensor)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_cache(from_index=index_exists,
to_index=destination_index)
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
return destination_index
else:
return self.cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.cache_indices_mapping:
for seq_id in self.cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.cache_indices_mapping[req_id][seq_id])
self.cache_indices_mapping.pop(req_id)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Tuple from typing import Tuple
import torch import torch
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass @dataclass
...@@ -21,7 +22,7 @@ class MambaCacheParams: ...@@ -21,7 +22,7 @@ class MambaCacheParams:
self.state_indices_tensor) self.state_indices_tensor)
class MambaCacheManager: class MambaCacheManager(ConstantSizeCache):
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype, def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
num_mamba_layers: int, conv_state_shape: Tuple[int, int], num_mamba_layers: int, conv_state_shape: Tuple[int, int],
...@@ -32,6 +33,9 @@ class MambaCacheManager: ...@@ -32,6 +33,9 @@ class MambaCacheManager:
if not vllm_config.model_config.enforce_eager: if not vllm_config.model_config.enforce_eager:
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size) max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
# Initialize parent class
super().__init__(max_batch_size)
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
conv_state_shape, conv_state_shape,
dtype=dtype, dtype=dtype,
...@@ -41,126 +45,32 @@ class MambaCacheManager: ...@@ -41,126 +45,32 @@ class MambaCacheManager:
dtype=dtype, dtype=dtype,
device="cuda") device="cuda")
self.mamba_cache = (conv_state, temporal_state) self._mamba_cache = (conv_state, temporal_state)
@property
def cache(self):
return self._mamba_cache
# Maps between the request id and a dict that maps between the seq_id def _copy_cache(self, from_index: int, to_index: int):
# and its index inside the self.mamba_cache for cache_t in self.cache:
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} cache_t[:, to_index].copy_(cache_t[:, from_index],
self.free_cache_indices = list(range(max_batch_size)) non_blocking=True)
def current_run_tensors(self, **kwargs) -> MambaCacheParams: def current_run_tensors(self, **kwargs) -> MambaCacheParams:
""" """
Return the tensors for the current run's conv and ssm state. Return the tensors for the current run's conv and ssm state.
""" """
if "seqlen_agnostic_capture_inputs" not in kwargs: cache_tensors, state_indices_tensor = super().current_run_tensors(
# We get here only on Prefill/Eager mode runs **kwargs)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] return MambaCacheParams(cache_tensors[0], cache_tensors[1],
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
mamba_cache_tensors = self.mamba_cache
else:
# CUDA graph capturing runs
(mamba_cache_tensors,
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1],
state_indices_tensor) state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
""" """
Provide the CUDA graph capture runs with a buffer in adjusted size. Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs. replay runs.
""" """
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size, return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32, dtype=torch.int32,
device="cuda") device="cuda")
return (self.mamba_cache, state_indices_tensor)
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.mamba_cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
return destination_index
else:
# already exists
return self.mamba_cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
for seq_id in self.mamba_cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.mamba_cache_indices_mapping[req_id][seq_id])
self.mamba_cache_indices_mapping.pop(req_id)
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
class MinimaxCacheParams:
minimax_cache: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MinimaxCacheParams(self.minimax_cache[layer_idx, ...],
self.state_indices_tensor)
class MinimaxCacheManager(ConstantSizeCache):
def __init__(self, dtype, cache_shape):
super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1]
self._minimax_cache = torch.empty(size=cache_shape,
dtype=dtype,
device="cuda")
@property
def cache(self):
return self._minimax_cache
def _copy_cache(self, from_index: int, to_index: int):
assert len(self.cache) > 0
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
# SPDX-License-Identifier: Apache-2.0
"""Inference-only MiniMaxText01 model."""
import copy
import math
import re
from typing import Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.distributed
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.lightning_attn import (
lightning_attention, linear_decode_forward_triton)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
def replace_weight_name(name: str,
key: str = None,
to: str = None,
count: int = None,
prefix: str = None) -> str:
name = name.replace(key, to) if count is None else \
name.replace(key, to, count)
return name
def weight_loader_with_alias(alias: str):
def wrapper(func: callable):
def inner_func(param: torch.Tensor,
loaded_weight: torch.Tensor,
*args,
prefix: str = None,
**kwargs):
value = func(param, loaded_weight, *args, **kwargs)
return value
return inner_func
return wrapper
class MiniMaxText01RMSNormTP(CustomOp):
name = "MiniMaxText01RMSNormTP"
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.tp_world = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.weight = nn.Parameter(torch.ones(int(hidden_size /
self.tp_world)))
self.weight.weight_loader = self.weight_loader
self.variance_epsilon = eps
return
@staticmethod
def weight_loader(
param: nn.Parameter,
loaded_weight: torch.Tensor,
) -> None:
tp_world = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
shard_size = loaded_weight.shape[0] // tp_world
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
param.data.copy_(loaded_weight[shard])
return
def _forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
if self.tp_world > 1:
variance = tensor_model_parallel_all_reduce(
variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
return x
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x)
class MiniMaxText01RotaryEmbedding(CustomOp):
name = "MiniMaxText01RotaryEmbedding"
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool,
cache_dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position
self.base = base
self.is_neox_style = is_neox_style
self.cache_dtype = cache_dtype
cache = self._compute_cos_sin_cache().to(cache_dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(
self,
base: Union[int, float],
) -> torch.Tensor:
"""Compute the inverse frequency."""
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
query_cast = query.to(self.cache_dtype)
key_cast = key.to(self.cache_dtype)
ops.rotary_embedding(positions, query_cast, key_cast, self.head_size,
self.cos_sin_cache, self.is_neox_style)
query = query_cast.to(query.dtype)
key = key_cast.to(key.dtype)
return query, key
class MiniMaxText01MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
layer_idx: int = None,
prefix: str = "mlp",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = SiluAndMul()
return
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class MiniMaxText01MoE(nn.Module):
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
layer_idx: int = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "moe",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
self.quant_config = quant_config
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.gate = ReplicatedLinear(
self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=torch.float32,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.gate.weight.weight_loader = MiniMaxText01MoE.gate_weight_loader
self.experts = FusedMoE(
num_experts=self.num_total_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size * self.tp_size,
params_dtype=self.params_dtype,
reduce_results=True,
renormalize=True,
quant_config=self.quant_config,
tp_size=self.tp_size,
prefix=f"{prefix}.experts",
)
return
@staticmethod
def gate_weight_loader(param: nn.Parameter,
loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight.to(torch.float32))
return
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
router_logits_fp32, _ = self.gate(hidden_states.to(torch.float32))
final_hidden_states = self.experts(
hidden_states, router_logits_fp32.to(hidden_states.dtype))
final_hidden = final_hidden_states.view(num_tokens, hidden_size)
return final_hidden
class MiniMaxText01LinearKernel:
@staticmethod
def jit_linear_forward_prefix(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
block_size: int,
layer_idx: int = None,
**kwargs) -> torch.Tensor:
slope_rate = slope_rate.to(torch.float32)
should_pad_dim = q.dim() == 3
if should_pad_dim:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
b, h, n, d = q.shape
e = d
kv_history = kv_caches.reshape(1, h, d, e).contiguous()
output, kv_history = lightning_attention(q,
k,
v,
slope_rate,
block_size=block_size,
kv_history=kv_history)
kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
assert output.shape[0] == 1, "batch size must be 1"
return rearrange(output.squeeze(0), "h n d -> n (h d)")
class MiniMaxText01LinearAttention(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_inner_size: int,
num_heads: int,
head_dim: int,
max_position: int,
block_size: int,
num_hidden_layer: int,
quant_config: Optional[QuantizationConfig] = None,
layer_idx: int = 0,
linear_layer_idx: int = 0,
prefix: str = "linear_attn",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.BLOCK = block_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.total_num_heads = num_heads
self.hidden_inner_size = hidden_inner_size
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
assert self.total_num_heads % self.tp_size == 0
self.tp_heads = self.total_num_heads // self.tp_size
self.qkv_size = self.num_heads * self.head_dim
self.tp_hidden = self.head_dim * self.tp_heads
self.qkv_proj = ColumnParallelLinear(
hidden_size,
self.hidden_inner_size * 3,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.output_gate = ColumnParallelLinear(
hidden_size,
self.hidden_inner_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.output_gate",
)
self.out_proj = RowParallelLinear(
self.hidden_inner_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.norm = MiniMaxText01RMSNormTP(
self.hidden_inner_size,
eps=1e-5,
)
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
self.num_heads)
if num_hidden_layer <= 1:
self.slope_rate = slope_rate * (1 + 1e-5)
else:
self.slope_rate = slope_rate * (1 - layer_idx /
(num_hidden_layer - 1) + 1e-5)
self.tp_slope = self.slope_rate[self.tp_rank *
self.tp_heads:(self.tp_rank + 1) *
self.tp_heads].contiguous()
@staticmethod
def weight_direct_load(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
return
@staticmethod
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
slopes = torch.tensor(get_slopes(n_attention_heads),
dtype=torch.float32).reshape(
n_attention_heads, 1, 1)
return slopes
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
_start = attn_metadata.query_start_loc[_prefill_idx]
_end = attn_metadata.query_start_loc[_prefill_idx + 1]
slot_id = state_indices_tensor[_prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slot_id = state_indices_tensor[_prefill_idx]
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
qs,
ks,
vs,
slice_layer_cache,
self.tp_slope,
self.BLOCK,
layer_idx=self.layer_idx)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden.append(
self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
attn_metadata))
hidden = torch.concat(hidden, dim=0).contiguous()
return hidden
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0
):]
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
slot_id, 32)
return hidden
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
qkv32 = qkv.to(torch.float32)
qkvact = torch.nn.functional.silu(qkv32)
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if not decode_only:
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
else:
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor, attn_metadata)
hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states)
hidden = F.sigmoid(gate) * hidden
hidden = hidden.to(hidden_states.dtype)
hidden, _ = self.out_proj(hidden)
return hidden
class MiniMaxText01Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
head_dim: int,
num_kv_heads: int,
rotary_dim: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
sliding_window: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
layer_idx: int = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = "mha",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
return
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
**kwargs) -> torch.Tensor:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = attn_metadata.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class MiniMaxText01DecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
expert_num: int = 1,
layer_id: int = None,
linear_layer_id: Optional[int] = None,
prefix: str = "decoder",
) -> None:
self._ilayer = layer_id
self._irank = get_tensor_model_parallel_rank()
super().__init__()
self.hidden_size = config.hidden_size
self.expert_num = expert_num
rope_theta = getattr(config, "rope_theta", 10000)
head_dim = getattr(config, "head_dim",
config.hidden_size // config.num_attention_heads)
if hasattr(config, "max_model_len") and isinstance(
config.max_model_len, int):
max_position_embeddings = min(config.max_position_embeddings,
config.max_model_len)
if config.attention_type == 0:
use_headxdim = True
hidden_inner = (head_dim * config.num_attention_heads
if use_headxdim else config.hidden_size)
self.self_attn = MiniMaxText01LinearAttention(
hidden_size=self.hidden_size,
hidden_inner_size=hidden_inner,
num_heads=config.num_attention_heads,
head_dim=head_dim,
max_position=max_position_embeddings,
block_size=config.block if hasattr(config, "block") else 256,
num_hidden_layer=config.num_hidden_layers,
quant_config=quant_config,
layer_idx=self._ilayer,
linear_layer_idx=linear_layer_id,
prefix=prefix)
elif config.attention_type == 1:
self.self_attn = MiniMaxText01Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
head_dim=head_dim,
rotary_dim=config.rotary_dim
if hasattr(config, "rotary_dim") else head_dim,
num_kv_heads=config.num_key_value_heads,
max_position=max_position_embeddings,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
quant_config=quant_config,
layer_idx=self._ilayer,
cache_config=cache_config,
prefix=prefix)
else:
raise ValueError(
f"Unsupported attention type: {self.config.attention_type}")
if expert_num == 1:
self.mlp = MiniMaxText01MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
layer_idx=self._ilayer,
prefix=prefix)
else:
self.block_sparse_moe = MiniMaxText01MoE(
num_experts=expert_num,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
layer_idx=self._ilayer,
quant_config=quant_config,
prefix=prefix)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
if config.attention_type == 0:
self.layernorm_attention_alpha = getattr(
config, 'layernorm_linear_attention_alpha', 1)
self.layernorm_attention_beta = getattr(
config, 'layernorm_linear_attention_beta', 1)
else:
self.layernorm_attention_alpha = getattr(
config, 'layernorm_full_attention_alpha', 1)
self.layernorm_attention_beta = getattr(
config, 'layernorm_full_attention_beta', 1)
self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1)
self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1)
self.postnorm = getattr(config, 'postnorm', False)
self.shared_moe = False
shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
if shared_intermediate > 0:
self.shared_moe = True
self.shared_mlp = MiniMaxText01MLP(
hidden_size=self.hidden_size,
intermediate_size=shared_intermediate,
quant_config=quant_config,
layer_idx=self._ilayer,
prefix=prefix)
self.coefficient = ReplicatedLinear(
self.hidden_size,
1,
bias=False,
quant_config=quant_config,
params_dtype=torch.float32,
)
self.coefficient.weight.weight_loader = (
self.shared_moe_coefficient_loader)
self.shared_moe_mode = getattr(config, 'shared_moe_mode',
'softmax')
return
def forward(self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
kv_caches: Union[List[Dict], Optional[torch.Tensor]],
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
is_warmup: bool = False,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
layernorm_input = hidden_states
layernorm_output = self.input_layernorm(layernorm_input)
residual = layernorm_output if self.postnorm else layernorm_input
self_attention_output = self.self_attn(
hidden_states=layernorm_output,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
residual = residual * self.layernorm_attention_alpha
self_attention_output = (self_attention_output *
self.layernorm_attention_beta)
layernorm_input = residual + self_attention_output
layernorm_output = self.post_attention_layernorm(layernorm_input)
residual = layernorm_output if self.postnorm else layernorm_input
if self.expert_num == 1:
hidden_states = self.mlp(layernorm_output)
else:
moe_hidden_states = self.block_sparse_moe(
copy.deepcopy(layernorm_output))
if self.shared_moe:
before_moe_dtype = layernorm_output.dtype
moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
output_mlp = self.shared_mlp(layernorm_output).to(
torch.float32)
coef, _ = self.coefficient(layernorm_output.to(torch.float32))
if self.shared_moe_mode == 'softmax':
coef = torch.nn.functional.softmax(coef, dim=-1)
hidden_states = moe_hidden_fp32 * (
1 - coef) + output_mlp * coef
elif self.shared_moe_mode == 'sigmoid':
coef = torch.nn.functional.sigmoid(coef)
hidden_states = moe_hidden_fp32 * (
1 - coef) + output_mlp * coef
hidden_states = hidden_states.to(before_moe_dtype)
else:
hidden_states = moe_hidden_states
residual = residual * self.layernorm_mlp_alpha
hidden_states = hidden_states * self.layernorm_mlp_beta
hidden_states = residual + hidden_states
return hidden_states, None
@staticmethod
def shared_moe_coefficient_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight.to(torch.float32))
return
class MiniMaxText01Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
scheduler_config=None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.decoder_attention_types = getattr(
config, "attn_type_list", False) or getattr(
config, "decoder_attention_types", False)
if not self.decoder_attention_types:
self.decoder_attention_types = [1] * config.num_hidden_layers
self.num_layers = config.num_hidden_layers
self._layer_barrier = False
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=self.vocab_size,
)
else:
self.embed_tokens = PPMissingLayer()
def layer_fn(prefix):
layer_idx = int(prefix.split('.')[-1])
layer_config = config
layer_config.attention_type = self.decoder_attention_types[
layer_idx]
layer_config.layer_idx = layer_idx
decoder_kwargs = {
"quant_config": quant_config,
"layer_id": layer_idx,
"cache_config": cache_config
}
if layer_config.attention_type == 0:
decoder_kwargs["linear_layer_id"] = sum(
1 for i in range(layer_idx)
if self.decoder_attention_types[i] == 0)
else:
decoder_kwargs["linear_layer_id"] = None
if hasattr(config, "num_local_experts") and isinstance(
config.num_local_experts, list):
decoder_kwargs["expert_num"] = config.num_local_experts[
layer_idx]
elif hasattr(config, "num_local_experts") and isinstance(
config.num_local_experts, int):
decoder_kwargs["expert_num"] = config.num_local_experts
else:
decoder_kwargs["expert_num"] = 1
return MiniMaxText01DecoderLayer(layer_config,
**decoder_kwargs,
prefix=prefix)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers")
linear_layer_nums = sum(1 for i in range(config.num_hidden_layers)
if self.decoder_attention_types[i] == 0)
max_slots_number = scheduler_config.max_num_seqs
self.cache_shape = (linear_layer_nums, max_slots_number,
config.num_attention_heads //
get_tensor_model_parallel_world_size(),
config.head_dim, config.head_dim)
_dummy = torch.zeros(1)
self._dtype = _dummy.dtype
del _dummy
self.minimax_cache = MinimaxCacheManager(dtype=self._dtype,
cache_shape=self.cache_shape)
rope_theta = getattr(config, "rope_theta", 10000)
head_dim = getattr(config, "head_dim",
config.hidden_size // config.num_attention_heads)
if hasattr(config, "max_model_len") and isinstance(
config.max_model_len, int):
max_position_embeddings = min(config.max_position_embeddings,
config.max_model_len)
self.rotary_emb = MiniMaxText01RotaryEmbedding(
head_dim,
rotary_dim=config.rotary_dim
if hasattr(config, "rotary_dim") else head_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
is_neox_style=True,
cache_dtype=torch.float32,
)
norm_kwargs = {}
if hasattr(config, "rms_norm_eps"):
norm_kwargs["eps"] = config.rms_norm_eps
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, **norm_kwargs)
else:
self.norm = PPMissingLayer()
self.embed_scale = 1.0
return
def _clear_prefill_cache(self, attn_metadata,
minimax_cache_tensors: torch.Tensor, **kwargs):
seq_to_slot_maps = {}
seq_id_map = sum(list(kwargs["request_ids_to_seq_ids"].values()), [])
for _, seq_to_slot_map in (
self.minimax_cache.cache_indices_mapping.items()):
seq_to_slot_maps.update(seq_to_slot_map)
slots_to_clear = []
for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)):
seq_id = seq_id_map[_prefill_id]
if attn_metadata.context_lens_tensor[
_prefill_id] == 0 and seq_id in seq_to_slot_maps:
slots_to_clear.append(seq_to_slot_maps[seq_id])
if slots_to_clear:
slots_tensor = torch.tensor(slots_to_clear,
device=minimax_cache_tensors.device,
dtype=torch.long)
minimax_cache_tensors[:, slots_tensor, ...] = 0
def forward(self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
intermediate_tensors=None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return None
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []
(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)
if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
if get_pp_group().is_first_rank:
if inputs_embeds is None:
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
else:
hidden_states = inputs_embeds
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
kv_cache_index = 0
minimax_cache_index = 0
attn_metadata.rotary_emb = self.rotary_emb
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
_caches = None
if isinstance(layer.self_attn, MiniMaxText01Attention):
_caches = kv_caches[kv_cache_index]
kv_cache_index += 1
if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
current_state_layer = minimax_cache_index
_caches = minimax_cache_params.at_layer_idx(
current_state_layer)
minimax_cache_index += 1
hidden_states, residual = layer(
hidden_states=hidden_states,
positions=positions,
kv_caches=_caches,
attn_metadata=attn_metadata,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states)
return hidden_states
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
if not hasattr(config, "sliding_window"):
config.sliding_window = None
self.CONCAT_FFN = True
self.unpadded_vocab_size = self.config.vocab_size
if hasattr(vllm_config.model_config, "max_model_len"):
self.config.max_model_len = vllm_config.model_config.max_model_len
self.model = MiniMaxText01Model(
self.config,
quant_config,
cache_config=vllm_config.cache_config,
scheduler_config=vllm_config.scheduler_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
self.config.hidden_size,
org_num_embeddings=self.config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.config.vocab_size)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()
flash_layer_count = sum(1 for attn_type in self.config.attn_type_list
if attn_type == 1)
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
return
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.model.minimax_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
batch_size)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, self.kv_cache,
intermediate_tensors, inputs_embeds,
**kwargs)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> None:
params_dict = dict(self.named_parameters())
def which_layer(name: str) -> int:
if "layers" in name:
after_layer = name.split("layers")[-1]
return int(after_layer.split(".")[1])
return None
def is_linear_attn_layer(layer_idx: int) -> bool:
if layer_idx is None or not hasattr(self.config, "attn_type_list"):
return False
return self.config.attn_type_list[layer_idx] == 0
def is_moe_weight(name: str) -> bool:
return "block_sparse_moe" in name and not name.endswith(".bias")
def get_expert_id(param_name):
pattern = r'model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\.'
match = re.search(pattern, param_name)
if match:
return match.group(1)
return None
def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor,
self) -> None:
if isinstance(self.config.num_local_experts, list):
expert_params_mapping = [
("w13_weight"
if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(max(self.config.num_local_experts))
for weight_name in ["w1", "w2", "w3"]
]
else:
expert_params_mapping = [
("w13_scale" if weight_name in ["w1", "w3"] else
"w2_scale", f"{expert_id}.{weight_name}.weight_scale",
expert_id, weight_name)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [("w13_weight" if weight_name in ["w1", "w3"] else
"w2_weight", f"{expert_id}.{weight_name}.weight",
expert_id, weight_name)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]]
for (param_name, weight_name, expert_id,
shard_id) in expert_params_mapping:
name_expert_id = get_expert_id(name)
if name_expert_id is not None and int(name_expert_id) != int(
expert_id):
continue
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id,
shard_id=shard_id)
break
else:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
return
def is_shared_mlp_weight(name: str) -> bool:
return "shared_mlp" in name and not name.endswith(".bias")
def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor,
self) -> None:
if not self.CONCAT_FFN:
if "gate_proj" in name:
name = name.replace("gate_proj", "w1", 1)
elif "up_proj" in name:
name = name.replace("up_proj", "w3", 1)
elif "down_proj" in name:
name = name.replace("down_proj", "w2", 1)
else:
if "gate_proj" in name:
name = name.replace("gate_proj", "gate_up_proj", 1)
loaded_shard_id = 0
elif "up_proj" in name:
name = name.replace("up_proj", "gate_up_proj", 1)
loaded_shard_id = 1
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
if not self.CONCAT_FFN:
weight_loader(param, loaded_weight)
else:
if "gate_up_proj" in name:
weight_loader(param, loaded_weight, loaded_shard_id)
elif "down_proj" in name:
weight_loader(param, loaded_weight)
else:
raise AssertionError(
"MLP weight not in [gate_up_proj, down_proj]")
return
def is_mha_weight(name: str) -> bool:
return "self_attn" in name and not name.endswith(".bias")
def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor,
self) -> None:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader",
MiniMaxText01LinearAttention.weight_direct_load)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
return
def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
self) -> None:
flash_mha_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
for (param_name, weight_name,
shard_id) in flash_mha_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight, shard_id)
break
else:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
return
def is_layer_norm_weight(name: str) -> bool:
return "norm" in name and not name.endswith(
".bias") and name in params_dict
def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor,
self) -> None:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
return
def load_basic_weight(name: str, loaded_weight: torch.Tensor,
self) -> None:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
return
for name, loaded_weight in weights:
weight_at_layer = which_layer(name)
if weight_at_layer and weight_at_layer >= len(
self.config.attn_type_list):
continue
if is_layer_norm_weight(name):
load_layer_norm_weight(name, loaded_weight, self)
continue
if is_mha_weight(name):
if is_linear_attn_layer(weight_at_layer):
load_linear_attn_weight(name, loaded_weight, self)
else:
load_flash_attn_weight(name, loaded_weight, self)
continue
if is_moe_weight(name):
load_sparse_moe_weight(name, loaded_weight, self)
continue
if is_shared_mlp_weight(name):
load_shared_mlp_weight(name, loaded_weight, self)
continue
if "rotary_emb.inv_freq" in name:
continue
load_basic_weight(name, loaded_weight, self)
return
...@@ -35,6 +35,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -35,6 +35,7 @@ _TEXT_GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
# baichuan-7b, upper case 'C' in the class name # baichuan-7b, upper case 'C' in the class name
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
# baichuan-13b, lower case 'c' in the class name # baichuan-13b, lower case 'c' in the class name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment