"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "82e2339b0632a4c787915210b5b57da13de26bf6"
Unverified Commit 18dd5e01 authored by Chih-Chieh Yang's avatar Chih-Chieh Yang Committed by GitHub
Browse files

[Model] Mamba2 causal conv1d Refactor to Split Prefill and Decode Requests for...


[Model] Mamba2 causal conv1d Refactor to Split Prefill and Decode Requests for Corresponding Kernels (#17146)
Signed-off-by: default avatarChih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
parent 6de3e134
...@@ -6,7 +6,7 @@ import torch.nn.functional as F ...@@ -6,7 +6,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.mamba2_metadata import ( from vllm.model_executor.layers.mamba.mamba2_metadata import (
_seq_idx_to_chunk_indices_offsets) _query_start_loc_to_chunk_indices_offsets)
from vllm.model_executor.layers.mamba.ops.ssd_combined import ( from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined) mamba_chunk_scan_combined)
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, ...@@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
last_taken, exhausted, n_heads, last_taken, exhausted, n_heads,
d_head, itype): d_head, itype):
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( chunk_indices, chunk_offsets = \
seq_idx, chunk_size) _query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])
Y, new_states = mamba_chunk_scan_combined( Y, new_states = mamba_chunk_scan_combined(
X, X,
......
...@@ -13,7 +13,6 @@ from vllm.attention.backends.xformers import XFormersMetadata ...@@ -13,7 +13,6 @@ from vllm.attention.backends.xformers import XFormersMetadata
@dataclass @dataclass
class Mamba2Metadata: class Mamba2Metadata:
has_prefill: bool
has_initial_states: torch.Tensor has_initial_states: torch.Tensor
prep_initial_states: bool prep_initial_states: bool
...@@ -24,21 +23,23 @@ class Mamba2Metadata: ...@@ -24,21 +23,23 @@ class Mamba2Metadata:
chunk_offsets: torch.Tensor chunk_offsets: torch.Tensor
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):
# convert seq_idx to chunk indices and offsets cu_seqlens = query_start_loc[1:] # remove prepended 0
# - derive the cu_seqlens
_, cu_seqlens = torch.where(seq_idx.diff())
cu_seqlens += 1
# outputs will have length expansion of chunks that do not divide # outputs will have length expansion of chunks that do not divide
# chunk_size # chunk_size
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
> 0).sum() > 0).sum()
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device) chunk_indices = torch.arange(N,
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device) dtype=torch.int,
device=query_start_loc.device)
chunk_offsets = torch.zeros((N, ),
dtype=torch.int,
device=query_start_loc.device)
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
p = 0 # num of insertions p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
...@@ -60,48 +61,49 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): ...@@ -60,48 +61,49 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
def prepare_mamba2_metadata( def prepare_mamba2_metadata(
chunk_size: int, chunk_size: int,
input_ids: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> Mamba2Metadata: ) -> Mamba2Metadata:
# compute number of prefill and decode requests
# NOTE: in V0 we assume prefills are before decodes
num_prefills = attn_metadata.num_prefills
num_prefill_tokens = attn_metadata.num_prefill_tokens
seq_idx = None
chunk_indices, chunk_offsets = None, None
# Need flags to indicate if there are initial states # Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend # currently we really only support the FlashAttention backend
has_initial_states = None has_initial_states = None
prep_initial_states = False prep_initial_states = False
if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
and attn_metadata.context_lens_tensor is not None):
has_initial_states = attn_metadata.context_lens_tensor > 0
# precompute flag to avoid device syncs later in mamba2 forwards
prep_initial_states = torch.any(has_initial_states).item()
has_prefill = attn_metadata.num_prefills > 0
seq_idx = None # Compute seq_idx, chunk_indices and chunk_offsets for prefill only
chunk_indices, chunk_offsets = None, None if num_prefills > 0:
if has_prefill: if (isinstance(attn_metadata,
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) (FlashAttentionMetadata, XFormersMetadata,
for i, (srt, end) in enumerate( PlaceholderAttentionMetadata))
zip( and attn_metadata.context_lens_tensor is not None):
attn_metadata.query_start_loc, has_initial_states = \
attn_metadata.query_start_loc[1:], attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
)): # precompute flag to avoid device syncs in mamba2 layer forwards
seq_idx[srt:end] = i # prep is only needed for mamba2 ssd prefill processing
prep_initial_states = torch.any(has_initial_states).item()
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
seq_idx = torch.repeat_interleave(torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc.device),
query_start_loc.diff(),
output_size=num_prefill_tokens)
seq_idx.unsqueeze_(0) seq_idx.unsqueeze_(0)
# compute metadata for chunked prefill. # We compute metadata for chunked prefill once at the top level model
# actually this is only needed if there are initial states, # forward and reuse them in mamba layers. If not needed, they will be
# but this is determinable only from attention metadata yet # ignored inside mamba kernels.
# unavailable from the top-level model forward. Rather than if prep_initial_states:
# complicating things to extract said metadata, we simply just chunk_indices, chunk_offsets = \
# compute them once at the top level model forward and reuse _query_start_loc_to_chunk_indices_offsets(
# them in mamba layers. If not needed, they will be ignored query_start_loc, chunk_size, num_prefill_tokens)
# inside mamba kernels.
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( return Mamba2Metadata(has_initial_states=has_initial_states,
seq_idx, chunk_size)
return Mamba2Metadata(has_prefill=has_prefill,
has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states, prep_initial_states=prep_initial_states,
chunk_size=chunk_size, chunk_size=chunk_size,
seq_idx=seq_idx, seq_idx=seq_idx,
......
...@@ -388,10 +388,15 @@ class MambaMixer2(CustomOp): ...@@ -388,10 +388,15 @@ class MambaMixer2(CustomOp):
# mamba2_metadata contains metadata necessary for the mamba2 triton # mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill # kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they # modes; they are computed at top-level model forward since they
# are the same and reused for all mamba layers in the same iteration # stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
seq_len, _ = hidden_states.shape num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
groups_time_state_size = self.n_groups * self.ssm_state_size groups_time_state_size = self.n_groups * self.ssm_state_size
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
...@@ -406,44 +411,32 @@ class MambaMixer2(CustomOp): ...@@ -406,44 +411,32 @@ class MambaMixer2(CustomOp):
dim=-1, dim=-1,
) )
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2)) self.conv1d.weight.size(2))
if mamba2_metadata.has_prefill: # Separate prefill and decode by splitting varlen input
# |---------- N-1 iteration --------| # Split along token dimension
# |---------------- N iteration ---------------------| hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
# |- tokenA -|......................|-- newTokens ---| hidden_states_B_C,
# |---------- context_len ----------| [num_prefill_tokens, num_decodes],
# |-------------------- seq_len ---------------------| dim=0,
# |-- query_len ---| )
dt_p, dt_d = torch.split(
# - "cache_indices" updates the conv_state cache in positions dt,
# pointed to by "mamba_cache_params.state_indices_tensor" [num_prefill_tokens, num_decodes],
hidden_states_B_C = causal_conv1d_fn( dim=0,
hidden_states_B_C.transpose(0, 1), )
conv_weights, # Split along batch dimension
self.conv1d.bias, state_indices_tensor_p, state_indices_tensor_d = torch.split(
activation=self.activation, mamba_cache_params.state_indices_tensor,
conv_states=mamba_cache_params.conv_state, [num_prefills, num_decodes],
has_initial_state=mamba2_metadata.has_initial_states, dim=0,
cache_indices=mamba_cache_params.state_indices_tensor, )
query_start_loc=attn_metadata.query_start_loc).transpose( query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
0, 1)[:seq_len] if has_prefill else None)
# TODO: Why is this needed?
hidden_states_B_C = hidden_states_B_C.contiguous()
else:
hidden_states_B_C = causal_conv1d_update(
hidden_states_B_C,
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
# - get hidden_states, B and C after depthwise convolution. # - get hidden_states, B and C after depthwise convolution.
hidden_states, B, C = torch.split( split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C, hidden_states_B_C,
[ [
self.intermediate_size // self.tp_size, self.intermediate_size // self.tp_size,
...@@ -453,24 +446,48 @@ class MambaMixer2(CustomOp): ...@@ -453,24 +446,48 @@ class MambaMixer2(CustomOp):
dim=-1, dim=-1,
) )
# 3. State Space Model sequence transformation ssd_output_list = []
if mamba2_metadata.has_prefill:
# Process prefill requests
if has_prefill:
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
hidden_states_B_C_p = causal_conv1d_fn(
hidden_states_B_C_p.transpose(0, 1),
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=mamba2_metadata.has_initial_states,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
# TODO: Why is this needed?
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
hidden_states_B_C_p)
# 3. State Space Model sequence transformation
initial_states = None initial_states = None
if (mamba2_metadata.has_initial_states is not None if (mamba2_metadata.has_initial_states is not None
and mamba2_metadata.prep_initial_states): and mamba2_metadata.prep_initial_states):
# making a copy of the states # making a copy of the states
initial_states = torch.where( initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None], mamba2_metadata.has_initial_states[:, None, None, None],
mamba_cache_params.ssm_state[ mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
mamba_cache_params.state_indices_tensor], 0)
scan_output, varlen_state = mamba_chunk_scan_combined( scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states.view(1, seq_len, self.num_heads // self.tp_size, hidden_states_p.view(1, num_prefill_tokens,
self.head_dim), self.num_heads // self.tp_size,
dt.unsqueeze(0), self.head_dim),
dt_p.unsqueeze(0),
self.A, self.A,
B.view(1, seq_len, self.n_groups // self.tp_size, -1), B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
C.view(1, seq_len, self.n_groups // self.tp_size, -1), -1),
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
-1),
chunk_size=mamba2_metadata.chunk_size, chunk_size=mamba2_metadata.chunk_size,
D=self.D, D=self.D,
z=None, z=None,
...@@ -478,7 +495,7 @@ class MambaMixer2(CustomOp): ...@@ -478,7 +495,7 @@ class MambaMixer2(CustomOp):
seq_idx=mamba2_metadata.seq_idx, seq_idx=mamba2_metadata.seq_idx,
chunk_indices=mamba2_metadata.chunk_indices, chunk_indices=mamba2_metadata.chunk_indices,
chunk_offsets=mamba2_metadata.chunk_offsets, chunk_offsets=mamba2_metadata.chunk_offsets,
cu_seqlens=attn_metadata.query_start_loc, cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
initial_states=initial_states, initial_states=initial_states,
return_varlen_states=True, return_varlen_states=True,
return_final_states=False, return_final_states=False,
...@@ -487,52 +504,65 @@ class MambaMixer2(CustomOp): ...@@ -487,52 +504,65 @@ class MambaMixer2(CustomOp):
) )
# update ssm states # update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[ mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
mamba_cache_params.state_indices_tensor] = varlen_state
# - reshape # - reshape
hidden_states = scan_output.view(seq_len, -1) ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
else:
# Process decode requests
if has_decode:
# 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
hidden_states_B_C_d)
# 3. State Space Model sequence transformation
n_groups = self.n_groups // self.tp_size n_groups = self.n_groups // self.tp_size
A = self.A[:, None, ...][:, :, None].expand( A_d = self.A[:, None, ...][:, :, None].expand(
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
dt = dt[:, :, None].expand(-1, -1, self.head_dim) dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim) D_d = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.view(-1, n_groups, B.shape[1] // n_groups) B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
C = C.view(-1, n_groups, C.shape[1] // n_groups) C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
hidden_states_reshaped = hidden_states.view( hidden_states_d = hidden_states_d.view(
-1, self.num_heads // self.tp_size, self.head_dim) -1, self.num_heads // self.tp_size, self.head_dim)
# - the hidden is reshaped into number of current batches # - the hidden is reshaped into (bs, num_heads, head_dim)
# - in this case there is no more prefill, so the batches gen
# 1 token at a time
# - thus hidden will be (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected # - mamba_cache_params.ssm_state's slots will be selected
# using "mamba_cache_params.state_indices_tensor", just as # using state_indices_tensor_d
# above in the prefill case
hidden_states = selective_state_update( hidden_states_d = selective_state_update(
mamba_cache_params.ssm_state, mamba_cache_params.ssm_state,
hidden_states_reshaped, hidden_states_d,
dt, dt_d,
A, A_d,
B, B_d,
C, C_d,
D, D_d,
z=None, z=None,
dt_bias=dt_bias, dt_bias=dt_bias,
dt_softplus=True, dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor, state_batch_indices=state_indices_tensor_d,
) )
hidden_states = hidden_states.view( ssd_output_list.append(
-1, (self.num_heads // self.tp_size) * self.head_dim) hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(ssd_output_list)
# # 4. gated MLP # 4. gated MLP
hidden_states = self.norm(hidden_states, gate) hidden_states = self.norm(hidden_states, gate)
# # 5. Final linear projection # 5. Final linear projection
out, _ = self.out_proj(hidden_states) out, _ = self.out_proj(hidden_states)
return out return out
...@@ -40,7 +40,6 @@ def _mamba_chunk_scan_combined_fwd(x, ...@@ -40,7 +40,6 @@ def _mamba_chunk_scan_combined_fwd(x,
_, _, ngroups, dstate = B.shape _, _, ngroups, dstate = B.shape
assert nheads % ngroups == 0 assert nheads % ngroups == 0
assert B.shape == (batch, seqlen, ngroups, dstate) assert B.shape == (batch, seqlen, ngroups, dstate)
assert x.shape == (batch, seqlen, nheads, headdim)
assert dt.shape == (batch, seqlen, nheads) assert dt.shape == (batch, seqlen, nheads)
assert A.shape == (nheads, ) assert A.shape == (nheads, )
assert C.shape == B.shape assert C.shape == B.shape
......
...@@ -313,7 +313,6 @@ class BambaModel(nn.Module): ...@@ -313,7 +313,6 @@ class BambaModel(nn.Module):
mamba2_metadata = prepare_mamba2_metadata( mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size, chunk_size=self.config.mamba_chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
) )
......
...@@ -338,7 +338,6 @@ class GraniteMoeHybridModel(nn.Module): ...@@ -338,7 +338,6 @@ class GraniteMoeHybridModel(nn.Module):
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata( mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size, chunk_size=self.config.mamba_chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
) )
......
...@@ -142,7 +142,6 @@ class Mamba2Model(nn.Module): ...@@ -142,7 +142,6 @@ class Mamba2Model(nn.Module):
mamba2_metadata = prepare_mamba2_metadata( mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size, chunk_size=self.config.chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
) )
......
...@@ -751,7 +751,6 @@ class Zamba2Model(nn.Module): ...@@ -751,7 +751,6 @@ class Zamba2Model(nn.Module):
mamba2_metadata = prepare_mamba2_metadata( mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size, chunk_size=self.config.chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment