Unverified Commit 9ef98d52 authored by Gerald's avatar Gerald Committed by GitHub
Browse files

[Model][MiniMaxText01] Support MiniMaxText01 model inference (#13454)


Signed-off-by: default avatarqscqesze <475517977@qq.com>
Co-authored-by: default avatarqingjun <qingjun@minimaxi.com>
Co-authored-by: default avatarqscqesze <475517977@qq.com>
parent 93491aef
......@@ -503,6 +503,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
* ✅︎
* ✅︎
- * `MiniMaxText01ForCausalLM`
* MiniMax-Text
* `MiniMaxAI/MiniMax-Text-01`, etc.
*
* ✅︎
- * `Zamba2ForCausalLM`
* Zamba2
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
......
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from vllm.model_executor.layers.lightning_attn import (
linear_decode_forward_triton)
from vllm.platforms import current_platform
NUM_HEADS = [4, 8]
HEAD_SIZES = [64]
BATCH_SIZES = [1, 2]
SEQ_LENGTHS = [16]
DTYPES = [torch.float32]
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
"""Reference implementation of lightning attention core algorithm
The difference from the main implementation is that this processes
each step sequentially, instead of using parallelized triton kernels
"""
B, H, S, D = q.shape
E = v.shape[-1]
dtype = q.dtype
output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device)
# Use clone() to ensure an independent copy
if kv_history is None:
kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device)
else:
kv_cache = kv_history.clone()
# More efficient implementation
# Convert decay factors to matrix form
if ed.dim() == 1:
decay = torch.exp(-ed).view(1, -1, 1, 1)
else:
decay = torch.exp(-ed)
for b in range(B):
for step in range(S):
# Process all heads at once for this position
q_bs = q[b, :, step] # [H, D]
k_bs = k[b, :, step] # [H, D]
v_bs = v[b, :, step] # [H, E]
# Calculate KV outer products for all heads
for h in range(H):
# Calculate KV outer product
kv_outer = torch.outer(k_bs[h], v_bs[h])
# Update KV cache with decay
# Note: Using the same order as in the Triton kernel
kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer
# Calculate attention output
output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h])
# Match the shape returned by the actual implementation
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
# where dimension 2 contains both KV and KV history
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
dim=2) # [B, H, 2, D, E]
return output, final_kv_cache
def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
"""Reference implementation: linear attention decode function"""
B, H, _, D = q.shape
output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device)
# Calculate decay factors once (more efficient)
decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1]
# Process each batch
for b in range(B):
slot_id = slot_idx[b].item()
# Skip padding positions
if slot_id == -1:
continue
# Process all heads at once for this batch
q_b = q[b, :, 0] # [H, D]
k_b = k[b, :, 0] # [H, D]
v_b = v[b, :, 0] # [H, D]
# Process each attention head
for h in range(H):
# Get current query, key and value
q_bh = q_b[h]
k_bh = k_b[h]
v_bh = v_b[h]
# Get cache
kv_cache_old = kv_caches[b, h]
# Calculate new key-value outer product
kv_outer = torch.outer(k_bh, v_bh)
# Apply decay and update cache
kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old
# Calculate output
out_h = torch.matmul(q_bh, kv_new)
# Update output and cache
output[b, h * D:(h + 1) * D] = out_h
kv_caches[b, h] = kv_new
return output
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton(
batch_size: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)
base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_caches_copy = kv_caches.clone()
slope_rate = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
slot_idx = torch.arange(batch_size, device="cuda")
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)
torch.testing.assert_close(triton_output,
reference_output,
rtol=1e-1,
atol=1e-1)
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
assert triton_output.shape == (batch_size, num_heads * head_size)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton_with_padding(
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)
batch_size = 4
base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_caches_copy = kv_caches.clone()
slope_rate = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)
padding_mask = (slot_idx
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)
triton_masked = triton_output[padding_mask]
reference_masked = reference_output[padding_mask]
atol, rtol = 1.5e-1, 1.5e-1
valid_indices = slot_idx != -1
for i in range(batch_size):
if valid_indices[i] > 0:
torch.testing.assert_close(kv_caches[i],
kv_caches_copy[i],
rtol=rtol,
atol=atol)
torch.testing.assert_close(triton_masked,
reference_masked,
rtol=rtol,
atol=atol)
assert triton_output.shape == (batch_size, num_heads * head_size)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_lightning_attention_reference(
batch_size: int,
num_heads: int,
head_size: int,
seq_len: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)
base = 0.01
q = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
k = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
v = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
ed = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
ed[h] = 0.1 * (h + 1)
kv_history = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")
kv_history_clone = kv_history.clone()
ref_output, ref_kv_cache = reference_lightning_attention(
q, k, v, ed, 256, kv_history)
from vllm.model_executor.layers.lightning_attn import lightning_attention
actual_output, actual_kv_cache = lightning_attention(
q, k, v, ed, 256, kv_history_clone)
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
torch.testing.assert_close(ref_kv_cache,
actual_kv_cache,
rtol=rtol,
atol=atol)
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
assert ref_kv_cache.shape == actual_kv_cache.shape
......@@ -176,6 +176,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
trust_remote_code=True),
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
trust_remote_code=True),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
......
......@@ -971,26 +971,34 @@ class ModelConfig:
return sum(not bc.attention.no_op
for bc in block_configs[start:end])
else:
# Hybrid model
# Hybrid model Jamba
layers_block_type_value = getattr(self.hf_config,
"layers_block_type", None)
if layers_block_type_value is None:
raise ValueError("The model is an hybrid without a "
"layers_block_type in the hf_config, "
"cannot determine the num of "
f"{block_type.value} layers")
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
== "zamba2"):
if attn_block_type:
return sum(t == "hybrid"
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)
if layers_block_type_value is not None:
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
== "zamba2"):
if attn_block_type:
return sum(t == "hybrid"
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)
return sum(t == block_type.value
for t in layers_block_type_value[start:end])
# Hybrid model Minimax
attn_type_list = getattr(self.hf_config, "attn_type_list", None)
if attn_type_list:
return sum(t == 1 for t in attn_type_list[start:end])
if layers_block_type_value is None and attn_type_list is None:
raise ValueError(
"The model is an hybrid without a"
"layers_block_type or an attn_type_list in the hf_config,"
"cannot determine the num of "
f"{block_type.value} layers")
return sum(t == block_type.value
for t in layers_block_type_value[start:end])
return sum(t == 1 for t in attn_type_list[start:end])
def get_multimodal_config(self) -> "MultiModalConfig":
"""
......
......@@ -303,8 +303,11 @@ class _AsyncLLMEngine(LLMEngine):
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
if not scheduler_outputs.is_empty():
# this will cause mamba_cache/minimax_cache failed
# to release finished_requests_ids of the last steps
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
......
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from einops import rearrange
@triton.jit
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
NUM_BLOCK, CBLOCK: tl.constexpr):
# This kernel computes the diagonal blocks of the attention matrix
# Each diagonal block represents attention
# where queries attend to keys in the same block
off = tl.program_id(0)
off_bh = off // NUM_BLOCK # batch-head index
off_block = off % NUM_BLOCK # block index within the sequence
off_cblock = tl.program_id(1) # sub-block index within a block
off_h = off_bh % h # head index
# Calculate base offsets for the current batch and head
qk_offset = off_bh * n * d
v_offset = off_bh * n * e
o_offset = off_bh * n * e
# Calculate offsets for the current block
block_offset = off_block * BLOCK
qk_block_offset = block_offset * d
v_block_offset = block_offset * e
o_block_offset = block_offset * e
# Calculate offsets for the current sub-block
cblock_offset = off_cblock * CBLOCK
q_cblock_offset = cblock_offset * d
o_cblock_offset = cblock_offset * e
# Calculate pointers to the query, key, value, and output tensors
Q_block_ptr = (Q + qk_offset + qk_block_offset + q_cblock_offset +
tl.arange(0, CBLOCK)[:, None] * d +
tl.arange(0, d)[None, :])
K_trans_block_ptr = (K + qk_offset + qk_block_offset +
tl.arange(0, CBLOCK)[None, :] * d +
tl.arange(0, d)[:, None])
V_block_ptr = (V + v_offset + v_block_offset +
tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, e)[None, :])
O_block_ptr = (Out + o_offset + o_block_offset + o_cblock_offset +
tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, e)[None, :])
# Load the decay rate for the current head
S_block_ptr = S + off_h
s = tl.load(S_block_ptr)
i = off_cblock
q_index = tl.arange(0, CBLOCK) + i * CBLOCK
# Load query values
q = tl.load(Q_block_ptr,
mask=block_offset + q_index[:, None] < n,
other=0.0).to(tl.float32)
# Initialize output accumulator
qkv = tl.zeros([CBLOCK, e], dtype=tl.float32)
# Process all sub-blocks up to and
# including the current one (causal attention)
for j in range(i + 1):
kv_index = tl.arange(0, CBLOCK) + j * CBLOCK
diff = q_index[:, None] - kv_index[None, :]
s_index = s * diff
# Apply causal mask: only attend to positions before the current one
s_index = tl.where(diff >= 0, -s_index, float("-inf"))
decay = tl.exp(s_index)
# Load key and value
k_trans = tl.load(
K_trans_block_ptr,
mask=block_offset + kv_index[None, :] < n,
other=0.0,
).to(tl.float32)
v = tl.load(
V_block_ptr,
mask=block_offset + kv_index[:, None] < n,
other=0.0,
).to(tl.float32)
# Compute attention scores and apply decay
qk = tl.dot(q, k_trans) * decay
# Compute weighted values and accumulate
qkv += tl.dot(qk, v)
# Move to the next sub-block
K_trans_block_ptr += CBLOCK * d
V_block_ptr += CBLOCK * e
# Store the result
tl.store(
O_block_ptr,
qkv.to(O_block_ptr.dtype.element_ty),
mask=block_offset + q_index[:, None] < n,
)
@triton.jit
def _fwd_kv_parallel(
K,
V,
K_decay,
KV,
b: tl.constexpr,
h: tl.constexpr,
n,
d: tl.constexpr,
e: tl.constexpr,
BLOCK: tl.constexpr,
NUM_BLOCK,
D_FBLOCK: tl.constexpr,
E_FBLOCK: tl.constexpr,
NUM_FBLOCK: tl.constexpr,
CBLOCK: tl.constexpr,
NUM_CBLOCK: tl.constexpr,
):
# This kernel computes the key-value outer
# products for each block in parallel
off_bh = tl.program_id(0) # batch-head index
off_block = tl.program_id(1) # block index
off_h = off_bh % h # head index
block_offset = off_block * BLOCK
# Calculate offsets for the current block
k_block_offset = block_offset * d
v_block_offset = block_offset * e
kv_block_offset = off_block * d * e
# Calculate base offsets for the current batch and head
k_offset = off_bh * n * d
v_offset = off_bh * n * e
kv_offset = off_bh * NUM_BLOCK * d * e
# Calculate pointers to the key, value, and key-value tensors
K_trans_block_ptr = (K + k_offset + k_block_offset +
tl.arange(0, CBLOCK)[None, :] * d +
tl.arange(0, D_FBLOCK)[:, None])
V_block_ptr = (V + v_offset + v_block_offset +
tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
KV_block_ptr = (KV + kv_offset + kv_block_offset +
tl.arange(0, D_FBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the decay factors for the current head and block
k_decay_ptr = (K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :])
kv_index = tl.arange(0, CBLOCK)
# Initialize the key-value outer product accumulator
kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)
# Handle the last block which might be smaller than BLOCK
if off_block == NUM_BLOCK - 1:
split_n = n - (NUM_BLOCK - 1) * BLOCK
else:
split_n = BLOCK
left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n
num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK)
k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK
# Process all sub-blocks in the current block
for j in range(num_blocks):
left_bound = (1 - j) * left_shift
# Load key and value, handling boundary conditions
k_trans = tl.load(K_trans_block_ptr - left_shift * d,
mask=kv_index[None, :] >= left_bound,
other=0.0)
v = tl.load(V_block_ptr - left_shift * e,
mask=kv_index[:, None] >= left_bound,
other=0.0)
# Load decay factor and compute weighted key-value outer product
k_decay = tl.load(k_decay_ptr)
kv += tl.dot(k_trans * k_decay, v)
# Move to the next sub-block
K_trans_block_ptr += CBLOCK * d
V_block_ptr += CBLOCK * e
k_decay_ptr += CBLOCK
# Store the result
tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty))
@triton.jit
def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n,
d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr,
NUM_BLOCK, D_FBLOCK: tl.constexpr, E_FBLOCK: tl.constexpr):
# This kernel reduces the key-value outer products
# across blocks and updates the KV history
off_bh = tl.program_id(0) # batch-head index
off_h = off_bh % h # head index
kv_offset = off_bh * NUM_BLOCK * d * e
# Calculate pointer to the key-value tensor
KV_block_ptr = (KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the decay rate for the current head
s_ptrs = S + off_h
s = tl.load(s_ptrs)
# Calculate pointer to the key-value history tensor
kv_history_offset = off_bh * d * e
KV_HISTORY_block_ptr = (KV_HISTORY + kv_history_offset +
tl.arange(0, D_FBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the previous key-value history
kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32)
# Process all blocks in reverse order to compute the prefix sum
for i in range(NUM_BLOCK):
block_size = min(n - i * BLOCK, BLOCK)
# Compute decay factor for the current block
block_decay = tl.exp(-s.to(tl.float32) * block_size)
# Load the current key-value outer product
kv_cur = tl.load(KV_block_ptr).to(tl.float32)
# Store the previous key-value history to the current block
tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty))
# Update the key-value history with the current block
kv_pre = block_decay * kv_pre + kv_cur
KV_block_ptr += d * e
# Store the updated key-value history
tl.store(KV_HISTORY_block_ptr, kv_pre)
@triton.jit
def _fwd_none_diag_kernel(
Q,
Out,
S,
KV,
b: tl.constexpr,
h: tl.constexpr,
n,
d: tl.constexpr,
e: tl.constexpr,
BLOCK: tl.constexpr,
NUM_BLOCK,
E_FBLOCK: tl.constexpr,
CBLOCK: tl.constexpr,
NUM_CBLOCK: tl.constexpr,
):
# This kernel computes the non-diagonal blocks of the attention matrix
# Each non-diagonal block represents attention
# where queries attend to keys in different blocks
off_bh = tl.program_id(0) # batch-head index
off_h = off_bh % h # head index
off_nc = tl.program_id(1)
off_n = off_nc // NUM_CBLOCK # block index
off_c = off_nc % NUM_CBLOCK # sub-block index
off_e = tl.program_id(2) # output feature block index
n_offset = off_n * BLOCK
c_offset = off_c * CBLOCK
e_offset = off_e * E_FBLOCK
block_offset = n_offset + c_offset
# Calculate offsets for the current batch, head, and block
q_offset = off_bh * n * d + (n_offset + c_offset) * d
o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset
kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset
# Calculate pointers to the query, output, and key-value tensors
Q_block_ptr = (Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d +
tl.arange(0, d)[None, :])
O_block_ptr = (Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
KV_block_ptr = (KV + kv_offset + tl.arange(0, d)[:, None] * e +
tl.arange(0, E_FBLOCK)[None, :])
# Load the decay rate for the current head
S_block_ptr = S + off_h
s = tl.load(S_block_ptr)
c_array = tl.arange(0, CBLOCK)
# Load the key-value outer product for the current block
kv = tl.load(KV_block_ptr).to(tl.float32)
q_index = block_offset + tl.arange(0, CBLOCK)
# Load query values
q = tl.load(Q_block_ptr, mask=q_index[:, None] < n,
other=0.).to(tl.float32)
# Compute decay factors for the current sub-block
q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None]))
# Compute non-diagonal attention output
qkv_none_diag = tl.dot(q, kv) * q_decay
# Load diagonal attention output (computed by _fwd_diag_kernel)
qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n,
other=0.).to(tl.float32)
# Combine diagonal and non-diagonal attention outputs
qkv = qkv_diag + qkv_none_diag
# Store the result
tl.store(O_block_ptr,
qkv.to(O_block_ptr.dtype.element_ty),
mask=q_index[:, None] < n)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, s, kv_history):
# Forward pass of the lightning attention algorithm
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
s = s.contiguous()
# Check CUDA compute capability
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
raise RuntimeError("Flash attention currently only supported",
"for compute capability >= 80")
# Get input dimensions
b, h, n, d = q.shape
e = v.shape[-1]
# Initialize output tensor
o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
# Set block sizes
BLOCK = 256
NUM_BLOCK = triton.cdiv(n, BLOCK)
CBLOCK = 32
NUM_CBLOCK = BLOCK // CBLOCK
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
# Compute decay factors for keys
array = torch.arange(0, BLOCK, device=q.device) + 1
k_decay = torch.exp(-s * (BLOCK - array.reshape(1, -1)))
# Step 1: Compute diagonal blocks of attention
grid = (b * h * NUM_BLOCK, NUM_CBLOCK)
_fwd_diag_kernel[grid](q,
k,
v,
o,
s,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
CBLOCK=CBLOCK)
# Set feature block sizes
NUM_FBLOCK = 1
D_FBLOCK = d // NUM_FBLOCK
assert d % NUM_FBLOCK == 0
E_FBLOCK = e // NUM_FBLOCK
assert e % NUM_FBLOCK == 0
CBLOCK = 64
NUM_CBLOCK = BLOCK // CBLOCK
assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK"
# Step 2: Compute key-value outer products for each block in parallel
kv = torch.empty((b, h, NUM_BLOCK, d, e),
dtype=torch.float32,
device=q.device)
grid = (b * h, NUM_BLOCK)
_fwd_kv_parallel[grid](
k,
v,
k_decay,
kv,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
D_FBLOCK=D_FBLOCK,
E_FBLOCK=E_FBLOCK,
NUM_FBLOCK=NUM_FBLOCK,
CBLOCK=CBLOCK,
NUM_CBLOCK=NUM_CBLOCK,
)
# Step 3: Reduce key-value outer products
# across blocks and update KV history
grid = (b * h, NUM_FBLOCK)
_fwd_kv_reduce[grid](s,
kv,
kv_history,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
D_FBLOCK=D_FBLOCK,
E_FBLOCK=E_FBLOCK)
# Step 4: Compute non-diagonal blocks of attention
grid = (b * h, NUM_BLOCK * NUM_CBLOCK)
_fwd_none_diag_kernel[grid](
q,
o,
s,
kv,
b,
h,
n,
d,
e,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
E_FBLOCK=E_FBLOCK,
CBLOCK=CBLOCK,
NUM_CBLOCK=NUM_CBLOCK,
)
# Save tensors for backward pass
ctx.save_for_backward(q, k, v, s, kv)
ctx.BLOCK = BLOCK
return o, torch.cat([kv, kv_history.unsqueeze(2)], dim=2)
# Apply the lightning attention function
lightning_attention_ = _attention.apply
def lightning_attention(q, k, v, ed, block_size=256, kv_history=None):
"""
Apply lightning attention algorithm
to compute attention efficiently.
Args:
q: Query tensor of shape [batch, heads, seq_len, dim]
k: Key tensor of shape [batch, heads, seq_len, dim]
v: Value tensor of shape [batch, heads, seq_len, dim_v]
ed: Decay rate tensor of shape [heads]
block_size: Size of blocks for block-sparse attention
kv_history: Optional key-value history from previous computations
Returns:
output: Attention output
kv: Updated key-value history
"""
d = q.shape[-1]
e = v.shape[-1]
if ed.dim() == 1:
ed = ed.view(1, -1, 1, 1)
# Split the computation into chunks for better parallelism
m = 128 if d >= 128 else 64
assert d % m == 0, f"Dimension d ({d}) must be divisible by m ({m})"
arr = [m * i for i in range(d // m + 1)]
if arr[-1] != d:
arr.append(d)
n = len(arr)
output = 0
# Initialize or clone key-value history
if kv_history is None:
kv_history = torch.zeros((q.shape[0], q.shape[1], d, e),
dtype=torch.float32,
device=q.device)
else:
kv_history = kv_history.clone().contiguous()
# Process each chunk and accumulate results
for i in range(n - 1):
s = arr[i]
e = arr[i + 1]
q1 = q[..., s:e]
k1 = k[..., s:e]
o, kv = lightning_attention_(q1, k1, v, ed, kv_history)
output = output + o
return output, kv
@triton.jit
def _linear_attn_decode_kernel(
q_ptr,
k_ptr,
v_ptr,
kv_cache_ptr,
slope_rate,
slot_idx,
output_ptr,
D: tl.constexpr,
qkv_b_stride,
qkv_h_stride,
cache_b_stride,
cache_h_stride,
cache_d0_stride,
cache_d1_stride,
BLOCK_SIZE: tl.constexpr,
):
"""
Kernel for linear attention decoding with KV cache.
This kernel computes attention for a single token using the KV cache.
"""
pid_b = tl.program_id(0) # batch index
pid_h = tl.program_id(1) # head index
pid_d = tl.program_id(2) # dimension block index
# Load slot index for the current batch
slot_id = tl.load(slot_idx + pid_b)
# Skip if slot_id is -1 (padding)
if slot_id == -1:
return
batch_id = pid_b
head_id = pid_h
# Load decay rate for the current head
ratio = tl.load(slope_rate + pid_h)
# Calculate offsets for dimensions
qk_d_offsets = tl.arange(0, D)
v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE
cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[
None, :] * cache_d1_stride
# Calculate offsets for the current batch and head
q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride
cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride
# Create masks for loading tensors
qk_mask = qk_d_offsets < D
v_mask = v_d_offsets < D
# Load query, key, and value tensors
q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0)
k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0)
v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0)
# Compute key-value outer product
kv_outer = k[:, None] * v[None, :]
kv_mask = qk_mask[:, None] & v_mask[None, :]
# Apply decay to previous KV cache
ratio = tl.exp(-ratio)
kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets
kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0)
kv_outer = kv_outer + ratio * kv_cache_old
# Compute attention output
output = q[:, None].to(tl.float32) * kv_outer
output = tl.sum(output, axis=0)
# Update KV cache and store output
tl.store(kv_ptr, kv_outer, mask=kv_mask)
tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask)
def linear_decode_forward_triton(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
slot_idx: torch.Tensor,
BLOCK_SIZE: int = 32,
) -> torch.Tensor:
"""
Perform linear attention decoding using Triton kernels.
Args:
q: Query tensor of shape [B, H, 1, D]
k: Key tensor of shape [B, H, 1, D]
v: Value tensor of shape [B, H, 1, D]
kv_caches: Key-value cache tensor
slope_rate: Decay rate tensor
slot_idx: Slot indices for batches
BLOCK_SIZE: Size of blocks for processing
Returns:
output: Attention output tensor
"""
B, H, _, D = q.shape
assert k.shape == (B, H, 1, D)
assert v.shape == (B, H, 1, D)
# Initialize output tensor
output = torch.empty_like(q)
# Set grid dimensions for the kernel
grid = (B, H, D // BLOCK_SIZE)
# Calculate strides for tensors
qkv_b_stride = q.stride(0)
qkv_h_stride = q.stride(1)
cache_b_stride = kv_caches.stride(0)
cache_h_stride = kv_caches.stride(1)
cache_d0_stride = kv_caches.stride(2)
cache_d1_stride = kv_caches.stride(3)
# Launch the kernel
_linear_attn_decode_kernel[grid](
q,
k,
v,
kv_caches,
slope_rate,
slot_idx,
output,
D,
qkv_b_stride,
qkv_h_stride,
cache_b_stride,
cache_h_stride,
cache_d0_stride,
cache_d1_stride,
BLOCK_SIZE=BLOCK_SIZE,
)
# Reshape output and return
output = rearrange(output, "b h n d -> b n (h d)")
return output.squeeze(1).contiguous()
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
class ConstantSizeCache(ABC):
"""
Abstract base class for managing constant size caches
like Mamba and Minimax.
"""
def __init__(self, max_batch_size: int):
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the cache
self.cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
@property
@abstractmethod
def cache(self) -> Any:
"""Return the underlying cache tensor(s)"""
pass
@abstractmethod
def _copy_cache(self, from_index: int, to_index: int):
"""Copy cache data from one index to another"""
pass
def current_run_tensors(self, **kwargs) -> Tuple:
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
cache_tensors = self.cache
else:
# CUDA graph capturing runs
cache_tensors, state_indices_tensor = kwargs[
"seqlen_agnostic_capture_inputs"]
return (cache_tensors, state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Cache during the CUDA graph replay
runs.
"""
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.cache, state_indices_tensor)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_cache(from_index=index_exists,
to_index=destination_index)
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
return destination_index
else:
return self.cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.cache_indices_mapping:
for seq_id in self.cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.cache_indices_mapping[req_id][seq_id])
self.cache_indices_mapping.pop(req_id)
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Dict, List, Tuple
from typing import Tuple
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
......@@ -21,7 +22,7 @@ class MambaCacheParams:
self.state_indices_tensor)
class MambaCacheManager:
class MambaCacheManager(ConstantSizeCache):
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
num_mamba_layers: int, conv_state_shape: Tuple[int, int],
......@@ -32,6 +33,9 @@ class MambaCacheManager:
if not vllm_config.model_config.enforce_eager:
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
# Initialize parent class
super().__init__(max_batch_size)
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
......@@ -41,126 +45,32 @@ class MambaCacheManager:
dtype=dtype,
device="cuda")
self.mamba_cache = (conv_state, temporal_state)
self._mamba_cache = (conv_state, temporal_state)
@property
def cache(self):
return self._mamba_cache
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
def _copy_cache(self, from_index: int, to_index: int):
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def current_run_tensors(self, **kwargs) -> MambaCacheParams:
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
state_indices_tensor = torch.as_tensor(state_indices,
dtype=torch.int32,
device="cuda")
mamba_cache_tensors = self.mamba_cache
else:
# CUDA graph capturing runs
(mamba_cache_tensors,
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1],
cache_tensors, state_indices_tensor = super().current_run_tensors(
**kwargs)
return MambaCacheParams(cache_tensors[0], cache_tensors[1],
state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant state_indices into the CUDA graph input buffer
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
assert "seqlen_agnostic_capture_inputs" in input_buffers
_, input_state_indices_buffer = input_buffers[
"seqlen_agnostic_capture_inputs"]
self._release_finished_requests(finished_requests_ids)
state_indices = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
state_indices)
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
input_state_indices_buffer.copy_(
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
return (self.mamba_cache, state_indices_tensor)
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
finished_requests_ids) -> int:
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
if cur_rid in finished_requests_ids:
# set as pad, do not allocate destination index
return PAD_SLOT_ID
elif cur_rid not in self.mamba_cache_indices_mapping:
destination_index = self.free_cache_indices.pop()
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
return destination_index
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened, so we copy the
# existing cache into the siblings seq_ids caches
index_exists = next(iter(seq_ids2indices.values()))
# case of decoding n>1, copy prefill cache to decoding indices
destination_index = self.free_cache_indices.pop()
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
return destination_index
else:
# already exists
return self.mamba_cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
for seq_id in self.mamba_cache_indices_mapping[req_id]:
self.free_cache_indices.append(
self.mamba_cache_indices_mapping[req_id][seq_id])
self.mamba_cache_indices_mapping.pop(req_id)
return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
class MinimaxCacheParams:
minimax_cache: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MinimaxCacheParams(self.minimax_cache[layer_idx, ...],
self.state_indices_tensor)
class MinimaxCacheManager(ConstantSizeCache):
def __init__(self, dtype, cache_shape):
super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1]
self._minimax_cache = torch.empty(size=cache_shape,
dtype=dtype,
device="cuda")
@property
def cache(self):
return self._minimax_cache
def _copy_cache(self, from_index: int, to_index: int):
assert len(self.cache) > 0
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
# SPDX-License-Identifier: Apache-2.0
"""Inference-only MiniMaxText01 model."""
import copy
import math
import re
from typing import Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.distributed
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.lightning_attn import (
lightning_attention, linear_decode_forward_triton)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
def replace_weight_name(name: str,
key: str = None,
to: str = None,
count: int = None,
prefix: str = None) -> str:
name = name.replace(key, to) if count is None else \
name.replace(key, to, count)
return name
def weight_loader_with_alias(alias: str):
def wrapper(func: callable):
def inner_func(param: torch.Tensor,
loaded_weight: torch.Tensor,
*args,
prefix: str = None,
**kwargs):
value = func(param, loaded_weight, *args, **kwargs)
return value
return inner_func
return wrapper
class MiniMaxText01RMSNormTP(CustomOp):
name = "MiniMaxText01RMSNormTP"
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.tp_world = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.weight = nn.Parameter(torch.ones(int(hidden_size /
self.tp_world)))
self.weight.weight_loader = self.weight_loader
self.variance_epsilon = eps
return
@staticmethod
def weight_loader(
param: nn.Parameter,
loaded_weight: torch.Tensor,
) -> None:
tp_world = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
shard_size = loaded_weight.shape[0] // tp_world
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
param.data.copy_(loaded_weight[shard])
return
def _forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
if self.tp_world > 1:
variance = tensor_model_parallel_all_reduce(
variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
return x
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x)
class MiniMaxText01RotaryEmbedding(CustomOp):
name = "MiniMaxText01RotaryEmbedding"
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool,
cache_dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position
self.base = base
self.is_neox_style = is_neox_style
self.cache_dtype = cache_dtype
cache = self._compute_cos_sin_cache().to(cache_dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(
self,
base: Union[int, float],
) -> torch.Tensor:
"""Compute the inverse frequency."""
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
query_cast = query.to(self.cache_dtype)
key_cast = key.to(self.cache_dtype)
ops.rotary_embedding(positions, query_cast, key_cast, self.head_size,
self.cos_sin_cache, self.is_neox_style)
query = query_cast.to(query.dtype)
key = key_cast.to(key.dtype)
return query, key
class MiniMaxText01MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
layer_idx: int = None,
prefix: str = "mlp",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = SiluAndMul()
return
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class MiniMaxText01MoE(nn.Module):
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
layer_idx: int = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "moe",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
self.quant_config = quant_config
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.gate = ReplicatedLinear(
self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=torch.float32,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.gate.weight.weight_loader = MiniMaxText01MoE.gate_weight_loader
self.experts = FusedMoE(
num_experts=self.num_total_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size * self.tp_size,
params_dtype=self.params_dtype,
reduce_results=True,
renormalize=True,
quant_config=self.quant_config,
tp_size=self.tp_size,
prefix=f"{prefix}.experts",
)
return
@staticmethod
def gate_weight_loader(param: nn.Parameter,
loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight.to(torch.float32))
return
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
router_logits_fp32, _ = self.gate(hidden_states.to(torch.float32))
final_hidden_states = self.experts(
hidden_states, router_logits_fp32.to(hidden_states.dtype))
final_hidden = final_hidden_states.view(num_tokens, hidden_size)
return final_hidden
class MiniMaxText01LinearKernel:
@staticmethod
def jit_linear_forward_prefix(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
block_size: int,
layer_idx: int = None,
**kwargs) -> torch.Tensor:
slope_rate = slope_rate.to(torch.float32)
should_pad_dim = q.dim() == 3
if should_pad_dim:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
b, h, n, d = q.shape
e = d
kv_history = kv_caches.reshape(1, h, d, e).contiguous()
output, kv_history = lightning_attention(q,
k,
v,
slope_rate,
block_size=block_size,
kv_history=kv_history)
kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
assert output.shape[0] == 1, "batch size must be 1"
return rearrange(output.squeeze(0), "h n d -> n (h d)")
class MiniMaxText01LinearAttention(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_inner_size: int,
num_heads: int,
head_dim: int,
max_position: int,
block_size: int,
num_hidden_layer: int,
quant_config: Optional[QuantizationConfig] = None,
layer_idx: int = 0,
linear_layer_idx: int = 0,
prefix: str = "linear_attn",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.BLOCK = block_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.total_num_heads = num_heads
self.hidden_inner_size = hidden_inner_size
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
assert self.total_num_heads % self.tp_size == 0
self.tp_heads = self.total_num_heads // self.tp_size
self.qkv_size = self.num_heads * self.head_dim
self.tp_hidden = self.head_dim * self.tp_heads
self.qkv_proj = ColumnParallelLinear(
hidden_size,
self.hidden_inner_size * 3,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.output_gate = ColumnParallelLinear(
hidden_size,
self.hidden_inner_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.output_gate",
)
self.out_proj = RowParallelLinear(
self.hidden_inner_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.norm = MiniMaxText01RMSNormTP(
self.hidden_inner_size,
eps=1e-5,
)
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
self.num_heads)
if num_hidden_layer <= 1:
self.slope_rate = slope_rate * (1 + 1e-5)
else:
self.slope_rate = slope_rate * (1 - layer_idx /
(num_hidden_layer - 1) + 1e-5)
self.tp_slope = self.slope_rate[self.tp_rank *
self.tp_heads:(self.tp_rank + 1) *
self.tp_heads].contiguous()
@staticmethod
def weight_direct_load(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
return
@staticmethod
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
slopes = torch.tensor(get_slopes(n_attention_heads),
dtype=torch.float32).reshape(
n_attention_heads, 1, 1)
return slopes
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
_start = attn_metadata.query_start_loc[_prefill_idx]
_end = attn_metadata.query_start_loc[_prefill_idx + 1]
slot_id = state_indices_tensor[_prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slot_id = state_indices_tensor[_prefill_idx]
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
qs,
ks,
vs,
slice_layer_cache,
self.tp_slope,
self.BLOCK,
layer_idx=self.layer_idx)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden.append(
self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
attn_metadata))
hidden = torch.concat(hidden, dim=0).contiguous()
return hidden
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0
):]
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
slot_id, 32)
return hidden
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
qkv32 = qkv.to(torch.float32)
qkvact = torch.nn.functional.silu(qkv32)
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if not decode_only:
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
else:
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor, attn_metadata)
hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states)
hidden = F.sigmoid(gate) * hidden
hidden = hidden.to(hidden_states.dtype)
hidden, _ = self.out_proj(hidden)
return hidden
class MiniMaxText01Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
head_dim: int,
num_kv_heads: int,
rotary_dim: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
sliding_window: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
layer_idx: int = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = "mha",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
return
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
**kwargs) -> torch.Tensor:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = attn_metadata.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class MiniMaxText01DecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
expert_num: int = 1,
layer_id: int = None,
linear_layer_id: Optional[int] = None,
prefix: str = "decoder",
) -> None:
self._ilayer = layer_id
self._irank = get_tensor_model_parallel_rank()
super().__init__()
self.hidden_size = config.hidden_size
self.expert_num = expert_num
rope_theta = getattr(config, "rope_theta", 10000)
head_dim = getattr(config, "head_dim",
config.hidden_size // config.num_attention_heads)
if hasattr(config, "max_model_len") and isinstance(
config.max_model_len, int):
max_position_embeddings = min(config.max_position_embeddings,
config.max_model_len)
if config.attention_type == 0:
use_headxdim = True
hidden_inner = (head_dim * config.num_attention_heads
if use_headxdim else config.hidden_size)
self.self_attn = MiniMaxText01LinearAttention(
hidden_size=self.hidden_size,
hidden_inner_size=hidden_inner,
num_heads=config.num_attention_heads,
head_dim=head_dim,
max_position=max_position_embeddings,
block_size=config.block if hasattr(config, "block") else 256,
num_hidden_layer=config.num_hidden_layers,
quant_config=quant_config,
layer_idx=self._ilayer,
linear_layer_idx=linear_layer_id,
prefix=prefix)
elif config.attention_type == 1:
self.self_attn = MiniMaxText01Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
head_dim=head_dim,
rotary_dim=config.rotary_dim
if hasattr(config, "rotary_dim") else head_dim,
num_kv_heads=config.num_key_value_heads,
max_position=max_position_embeddings,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
quant_config=quant_config,
layer_idx=self._ilayer,
cache_config=cache_config,
prefix=prefix)
else:
raise ValueError(
f"Unsupported attention type: {self.config.attention_type}")
if expert_num == 1:
self.mlp = MiniMaxText01MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
layer_idx=self._ilayer,
prefix=prefix)
else:
self.block_sparse_moe = MiniMaxText01MoE(
num_experts=expert_num,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
layer_idx=self._ilayer,
quant_config=quant_config,
prefix=prefix)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
if config.attention_type == 0:
self.layernorm_attention_alpha = getattr(
config, 'layernorm_linear_attention_alpha', 1)
self.layernorm_attention_beta = getattr(
config, 'layernorm_linear_attention_beta', 1)
else:
self.layernorm_attention_alpha = getattr(
config, 'layernorm_full_attention_alpha', 1)
self.layernorm_attention_beta = getattr(
config, 'layernorm_full_attention_beta', 1)
self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1)
self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1)
self.postnorm = getattr(config, 'postnorm', False)
self.shared_moe = False
shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
if shared_intermediate > 0:
self.shared_moe = True
self.shared_mlp = MiniMaxText01MLP(
hidden_size=self.hidden_size,
intermediate_size=shared_intermediate,
quant_config=quant_config,
layer_idx=self._ilayer,
prefix=prefix)
self.coefficient = ReplicatedLinear(
self.hidden_size,
1,
bias=False,
quant_config=quant_config,
params_dtype=torch.float32,
)
self.coefficient.weight.weight_loader = (
self.shared_moe_coefficient_loader)
self.shared_moe_mode = getattr(config, 'shared_moe_mode',
'softmax')
return
def forward(self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
kv_caches: Union[List[Dict], Optional[torch.Tensor]],
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
is_warmup: bool = False,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
layernorm_input = hidden_states
layernorm_output = self.input_layernorm(layernorm_input)
residual = layernorm_output if self.postnorm else layernorm_input
self_attention_output = self.self_attn(
hidden_states=layernorm_output,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
residual = residual * self.layernorm_attention_alpha
self_attention_output = (self_attention_output *
self.layernorm_attention_beta)
layernorm_input = residual + self_attention_output
layernorm_output = self.post_attention_layernorm(layernorm_input)
residual = layernorm_output if self.postnorm else layernorm_input
if self.expert_num == 1:
hidden_states = self.mlp(layernorm_output)
else:
moe_hidden_states = self.block_sparse_moe(
copy.deepcopy(layernorm_output))
if self.shared_moe:
before_moe_dtype = layernorm_output.dtype
moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
output_mlp = self.shared_mlp(layernorm_output).to(
torch.float32)
coef, _ = self.coefficient(layernorm_output.to(torch.float32))
if self.shared_moe_mode == 'softmax':
coef = torch.nn.functional.softmax(coef, dim=-1)
hidden_states = moe_hidden_fp32 * (
1 - coef) + output_mlp * coef
elif self.shared_moe_mode == 'sigmoid':
coef = torch.nn.functional.sigmoid(coef)
hidden_states = moe_hidden_fp32 * (
1 - coef) + output_mlp * coef
hidden_states = hidden_states.to(before_moe_dtype)
else:
hidden_states = moe_hidden_states
residual = residual * self.layernorm_mlp_alpha
hidden_states = hidden_states * self.layernorm_mlp_beta
hidden_states = residual + hidden_states
return hidden_states, None
@staticmethod
def shared_moe_coefficient_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight.to(torch.float32))
return
class MiniMaxText01Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
scheduler_config=None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.decoder_attention_types = getattr(
config, "attn_type_list", False) or getattr(
config, "decoder_attention_types", False)
if not self.decoder_attention_types:
self.decoder_attention_types = [1] * config.num_hidden_layers
self.num_layers = config.num_hidden_layers
self._layer_barrier = False
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=self.vocab_size,
)
else:
self.embed_tokens = PPMissingLayer()
def layer_fn(prefix):
layer_idx = int(prefix.split('.')[-1])
layer_config = config
layer_config.attention_type = self.decoder_attention_types[
layer_idx]
layer_config.layer_idx = layer_idx
decoder_kwargs = {
"quant_config": quant_config,
"layer_id": layer_idx,
"cache_config": cache_config
}
if layer_config.attention_type == 0:
decoder_kwargs["linear_layer_id"] = sum(
1 for i in range(layer_idx)
if self.decoder_attention_types[i] == 0)
else:
decoder_kwargs["linear_layer_id"] = None
if hasattr(config, "num_local_experts") and isinstance(
config.num_local_experts, list):
decoder_kwargs["expert_num"] = config.num_local_experts[
layer_idx]
elif hasattr(config, "num_local_experts") and isinstance(
config.num_local_experts, int):
decoder_kwargs["expert_num"] = config.num_local_experts
else:
decoder_kwargs["expert_num"] = 1
return MiniMaxText01DecoderLayer(layer_config,
**decoder_kwargs,
prefix=prefix)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers")
linear_layer_nums = sum(1 for i in range(config.num_hidden_layers)
if self.decoder_attention_types[i] == 0)
max_slots_number = scheduler_config.max_num_seqs
self.cache_shape = (linear_layer_nums, max_slots_number,
config.num_attention_heads //
get_tensor_model_parallel_world_size(),
config.head_dim, config.head_dim)
_dummy = torch.zeros(1)
self._dtype = _dummy.dtype
del _dummy
self.minimax_cache = MinimaxCacheManager(dtype=self._dtype,
cache_shape=self.cache_shape)
rope_theta = getattr(config, "rope_theta", 10000)
head_dim = getattr(config, "head_dim",
config.hidden_size // config.num_attention_heads)
if hasattr(config, "max_model_len") and isinstance(
config.max_model_len, int):
max_position_embeddings = min(config.max_position_embeddings,
config.max_model_len)
self.rotary_emb = MiniMaxText01RotaryEmbedding(
head_dim,
rotary_dim=config.rotary_dim
if hasattr(config, "rotary_dim") else head_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
is_neox_style=True,
cache_dtype=torch.float32,
)
norm_kwargs = {}
if hasattr(config, "rms_norm_eps"):
norm_kwargs["eps"] = config.rms_norm_eps
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, **norm_kwargs)
else:
self.norm = PPMissingLayer()
self.embed_scale = 1.0
return
def _clear_prefill_cache(self, attn_metadata,
minimax_cache_tensors: torch.Tensor, **kwargs):
seq_to_slot_maps = {}
seq_id_map = sum(list(kwargs["request_ids_to_seq_ids"].values()), [])
for _, seq_to_slot_map in (
self.minimax_cache.cache_indices_mapping.items()):
seq_to_slot_maps.update(seq_to_slot_map)
slots_to_clear = []
for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)):
seq_id = seq_id_map[_prefill_id]
if attn_metadata.context_lens_tensor[
_prefill_id] == 0 and seq_id in seq_to_slot_maps:
slots_to_clear.append(seq_to_slot_maps[seq_id])
if slots_to_clear:
slots_tensor = torch.tensor(slots_to_clear,
device=minimax_cache_tensors.device,
dtype=torch.long)
minimax_cache_tensors[:, slots_tensor, ...] = 0
def forward(self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
intermediate_tensors=None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return None
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []
(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)
if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
if get_pp_group().is_first_rank:
if inputs_embeds is None:
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
else:
hidden_states = inputs_embeds
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
kv_cache_index = 0
minimax_cache_index = 0
attn_metadata.rotary_emb = self.rotary_emb
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
_caches = None
if isinstance(layer.self_attn, MiniMaxText01Attention):
_caches = kv_caches[kv_cache_index]
kv_cache_index += 1
if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
current_state_layer = minimax_cache_index
_caches = minimax_cache_params.at_layer_idx(
current_state_layer)
minimax_cache_index += 1
hidden_states, residual = layer(
hidden_states=hidden_states,
positions=positions,
kv_caches=_caches,
attn_metadata=attn_metadata,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states)
return hidden_states
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
if not hasattr(config, "sliding_window"):
config.sliding_window = None
self.CONCAT_FFN = True
self.unpadded_vocab_size = self.config.vocab_size
if hasattr(vllm_config.model_config, "max_model_len"):
self.config.max_model_len = vllm_config.model_config.max_model_len
self.model = MiniMaxText01Model(
self.config,
quant_config,
cache_config=vllm_config.cache_config,
scheduler_config=vllm_config.scheduler_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
self.config.hidden_size,
org_num_embeddings=self.config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.config.vocab_size)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()
flash_layer_count = sum(1 for attn_type in self.config.attn_type_list
if attn_type == 1)
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
return
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.model.minimax_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
batch_size)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, self.kv_cache,
intermediate_tensors, inputs_embeds,
**kwargs)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> None:
params_dict = dict(self.named_parameters())
def which_layer(name: str) -> int:
if "layers" in name:
after_layer = name.split("layers")[-1]
return int(after_layer.split(".")[1])
return None
def is_linear_attn_layer(layer_idx: int) -> bool:
if layer_idx is None or not hasattr(self.config, "attn_type_list"):
return False
return self.config.attn_type_list[layer_idx] == 0
def is_moe_weight(name: str) -> bool:
return "block_sparse_moe" in name and not name.endswith(".bias")
def get_expert_id(param_name):
pattern = r'model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\.'
match = re.search(pattern, param_name)
if match:
return match.group(1)
return None
def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor,
self) -> None:
if isinstance(self.config.num_local_experts, list):
expert_params_mapping = [
("w13_weight"
if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(max(self.config.num_local_experts))
for weight_name in ["w1", "w2", "w3"]
]
else:
expert_params_mapping = [
("w13_scale" if weight_name in ["w1", "w3"] else
"w2_scale", f"{expert_id}.{weight_name}.weight_scale",
expert_id, weight_name)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [("w13_weight" if weight_name in ["w1", "w3"] else
"w2_weight", f"{expert_id}.{weight_name}.weight",
expert_id, weight_name)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]]
for (param_name, weight_name, expert_id,
shard_id) in expert_params_mapping:
name_expert_id = get_expert_id(name)
if name_expert_id is not None and int(name_expert_id) != int(
expert_id):
continue
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id,
shard_id=shard_id)
break
else:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
return
def is_shared_mlp_weight(name: str) -> bool:
return "shared_mlp" in name and not name.endswith(".bias")
def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor,
self) -> None:
if not self.CONCAT_FFN:
if "gate_proj" in name:
name = name.replace("gate_proj", "w1", 1)
elif "up_proj" in name:
name = name.replace("up_proj", "w3", 1)
elif "down_proj" in name:
name = name.replace("down_proj", "w2", 1)
else:
if "gate_proj" in name:
name = name.replace("gate_proj", "gate_up_proj", 1)
loaded_shard_id = 0
elif "up_proj" in name:
name = name.replace("up_proj", "gate_up_proj", 1)
loaded_shard_id = 1
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
if not self.CONCAT_FFN:
weight_loader(param, loaded_weight)
else:
if "gate_up_proj" in name:
weight_loader(param, loaded_weight, loaded_shard_id)
elif "down_proj" in name:
weight_loader(param, loaded_weight)
else:
raise AssertionError(
"MLP weight not in [gate_up_proj, down_proj]")
return
def is_mha_weight(name: str) -> bool:
return "self_attn" in name and not name.endswith(".bias")
def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor,
self) -> None:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader",
MiniMaxText01LinearAttention.weight_direct_load)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
return
def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
self) -> None:
flash_mha_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
for (param_name, weight_name,
shard_id) in flash_mha_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight, shard_id)
break
else:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
return
def is_layer_norm_weight(name: str) -> bool:
return "norm" in name and not name.endswith(
".bias") and name in params_dict
def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor,
self) -> None:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
return
def load_basic_weight(name: str, loaded_weight: torch.Tensor,
self) -> None:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
return
for name, loaded_weight in weights:
weight_at_layer = which_layer(name)
if weight_at_layer and weight_at_layer >= len(
self.config.attn_type_list):
continue
if is_layer_norm_weight(name):
load_layer_norm_weight(name, loaded_weight, self)
continue
if is_mha_weight(name):
if is_linear_attn_layer(weight_at_layer):
load_linear_attn_weight(name, loaded_weight, self)
else:
load_flash_attn_weight(name, loaded_weight, self)
continue
if is_moe_weight(name):
load_sparse_moe_weight(name, loaded_weight, self)
continue
if is_shared_mlp_weight(name):
load_shared_mlp_weight(name, loaded_weight, self)
continue
if "rotary_emb.inv_freq" in name:
continue
load_basic_weight(name, loaded_weight, self)
return
......@@ -35,6 +35,7 @@ _TEXT_GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
# baichuan-7b, upper case 'C' in the class name
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
# baichuan-13b, lower case 'c' in the class name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment