Unverified Commit aff40457 authored by Yu Chin Fabian Lim's avatar Yu Chin Fabian Lim Committed by GitHub
Browse files

Add Bamba Model (#10909)


Signed-off-by: default avatarYu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent 467a96a5
# SPDX-License-Identifier: Apache-2.0
import unittest
from typing import Tuple
import pytest
import torch
from tests.utils import multi_gpu_test
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize(
"hidden_size_n_groups",
[
(64, 1),
(64, 2),
(64, 4), # hidden_size be divisible by num_gpus
(100, 5), # and n_groups must divide hidden_size
])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_mixer2_gated_norm_multi_gpu(
batch_size: int,
seq_len: int,
hidden_size_n_groups: Tuple[int, int],
dtype: torch.dtype,
device: str = 'cuda',
):
hidden_size, n_groups = hidden_size_n_groups
num_processes = 2
def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch.multiprocessing.spawn(fn,
args=(
num_processes,
batch_size,
seq_len,
hidden_size,
n_groups,
dtype,
device,
),
nprocs=nprocs)
run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2)
def mixer2_gated_norm_tensor_parallel(
local_rank: int,
world_size: int,
batch_size: int,
seq_len: int,
hidden_size: int,
n_groups: int,
dtype: torch.dtype,
device: str,
):
current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
# create random weights an inputs
weight = torch.rand((hidden_size, ), dtype=dtype, device=device)
hidden_states = torch.randn(batch_size, seq_len, hidden_size)
gate_states = torch.randn(batch_size, seq_len, hidden_size)
# create gated-norm with TP
mixer = Mixer2RMSNormGated(
full_hidden_size=hidden_size,
full_n_groups=n_groups,
)
mixer.weight.weight_loader(mixer.weight, weight) # load
# create gated-norm without TP to compute reference
# - utilize mock patching to disable TP when
with (unittest.mock.patch(
"vllm.model_executor.layers.mamba.mamba_mixer2."
"get_tensor_model_parallel_world_size",
return_value=1),
unittest.mock.patch(
"vllm.model_executor.layers.mamba.mamba_mixer2."
"get_tensor_model_parallel_rank",
return_value=0)):
mixer_single_gpu = Mixer2RMSNormGated(
full_hidden_size=hidden_size,
full_n_groups=n_groups,
)
# assign weight to single-gpu mixer
mixer_single_gpu.weight.data = weight
# generate and compare
N = hidden_size // world_size
output = mixer(
hidden_states[..., local_rank * N:(local_rank + 1) * N],
gate_states[..., local_rank * N:(local_rank + 1) * N],
)
ref_output = mixer_single_gpu(hidden_states, gate_states)
torch.allclose(output,
ref_output[..., local_rank * N:(local_rank + 1) * N],
atol=1e-3,
rtol=1e-3)
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, Tuple
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
from vllm.platforms import current_platform
# Added by the IBM Team, 2024
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
# this is the segsum implementation taken from above
def segsum(x):
"""Calculates segment sum."""
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
"""
Arguments:
X: (batch, length, n_heads, d_head)
A: (batch, length, n_heads)
B: (batch, length, n_heads, d_state)
C: (batch, length, n_heads, d_state)
Return:
Y: (batch, length, n_heads, d_head)
"""
assert X.dtype == A.dtype == B.dtype == C.dtype
assert X.shape[1] % block_len == 0
# Rearrange into blocks/chunks
X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len)
for x in (X, A, B, C))
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A))
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at
# chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms
# (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
def generate_random_inputs(batch_size,
seqlen,
n_heads,
d_head,
itype,
device='cuda'):
current_platform.seed_everything(0)
A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device)))
dt = F.softplus(
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) -
4)
X = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
B = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
C = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
return A, dt, X, B, C
def generate_continous_batched_examples(example_lens_by_batch,
num_examples,
full_length,
last_taken,
exhausted,
n_heads,
d_head,
itype,
device='cuda'):
# this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed
# them in continuous batches to the kernels
# generate the full-length example
A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads,
d_head, itype)
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
A * dt,
B,
C,
block_len=full_length // 4)
# internal function that outputs a cont batch of examples
# given a tuple of lengths for each example in the batch
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
# 4 examples from second eg, etc
def get_continuous_batch(example_lens: Tuple[int, ...]):
indices = []
for i, x in enumerate(example_lens):
c = last_taken.get(i, 0)
indices.append((c, c + x))
last_taken[i] = (c + x) % full_length
exhausted[i] = last_taken[i] == 0
return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)
]).unsqueeze(0) for x in (dt, X, B, C))
# internal function that maps "n" to the appropriate right boundary
# value when forming continuous batches from examples of length given
# by "full_length".
# - e.g., when n > full_length, returns n % full_length
# when n == full_length, returns full_length
def end_boundary(n: int):
return n - ((n - 1) // full_length) * full_length
IND_E = None
for spec in example_lens_by_batch:
# get the (maybe partial) example seen in this cont batch
dt2, X2, B2, C2 = get_continuous_batch(spec)
# get the metadata
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
sed_idx = torch.zeros(cu_seqlens[-1],
dtype=torch.int32,
device=cu_seqlens.device)
for i, (srt, end) in enumerate(zip(
cu_seqlens,
cu_seqlens[1:],
)):
sed_idx[srt:end] = i
# for cont batch
if IND_E is None:
IND_S = [0 for _ in range(len(spec))]
else:
IND_S = [x % full_length for x in IND_E]
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)])
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
itype):
# this tests the kernels on a single example (no batching)
# set seed
batch_size = 1 # batch_size
# ssd_minimal_discrete requires chunk_size divide seqlen
# - this is only required for generating the reference seqs,
# it is not an operational limitation.
seqlen, chunk_size = seq_len_chunk_size
A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads,
d_head, itype)
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
B, C, chunk_size)
Y, final_state = mamba_chunk_scan_combined(X,
dt,
A,
B,
C,
chunk_size,
D=None,
return_final_states=True)
# just test the last in sequence
torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3)
# just test the last head
# NOTE, in the kernel we always cast states to fp32
torch.allclose(final_state[:, -1],
final_state_min[:, -1].to(torch.float32),
atol=1e-3,
rtol=1e-3)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@pytest.mark.parametrize("n_heads", [4, 8, 13])
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
@pytest.mark.parametrize(
"seq_len_chunk_size_cases",
[
# small-ish chunk_size (8)
(64, 8, 2, [(64, 32), (64, 32)]),
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
(64, 8, 2, [(4, 4), (4, 4), (4, 4),
(4, 4)]), # chunk_size larger than cont batches
(64, 8, 5, [
(64, 32, 16, 8, 8),
(8, 16, 32, 16, 8),
(8, 8, 16, 32, 16),
]), # mode examples with varied lengths
# odd chunk_size
(64, 29, 2, [(11, 4), (13, 23), (19, 22),
(21, 15)]), # irregular sizes
# large-ish chunk_size (256)
(64, 256, 1, [(5, ), (1, ), (1, ),
(1, )]), # irregular sizes with small sequences
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
(1, 2)]), # irregular sizes with small sequences
])
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
itype):
# this test with multiple examples in a continuous batch
# (i.e. chunked prefill)
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: Dict = {} # map: eg -> pointer to last taken sample
exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted
states = None
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B,
C) in generate_continous_batched_examples(
cases, num_examples, seqlen,
last_taken, exhausted, n_heads,
d_head, itype):
Y, new_states = mamba_chunk_scan_combined(
X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=sed_idx,
return_varlen_states=True,
initial_states=states,
)
# just test the last in sequence
for i in range(num_examples):
# just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0]
torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3)
# update states
states = new_states
for i, clear in exhausted.items():
if clear:
states[i].fill_(0.)
exhausted[i] = False
...@@ -8,7 +8,8 @@ from vllm.sampling_params import SamplingParams ...@@ -8,7 +8,8 @@ from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal from ...utils import check_outputs_equal
MODELS = ["ai21labs/Jamba-tiny-dev"] # This test is for the hybrid models
MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
...@@ -23,6 +24,10 @@ def test_models( ...@@ -23,6 +24,10 @@ def test_models(
max_tokens: int, max_tokens: int,
) -> None: ) -> None:
# numeric error produces different generation
if 'Bamba' in model:
example_prompts.pop(3)
with hf_runner( with hf_runner(
model, model,
dtype=dtype, dtype=dtype,
...@@ -108,15 +113,21 @@ def test_mamba_prefill_chunking_with_parallel_sampling( ...@@ -108,15 +113,21 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10]) @pytest.mark.parametrize("max_tokens", [7])
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
model: str, dtype: str, model: str, dtype: str,
max_tokens: int) -> None: max_tokens: int) -> None:
# numeric error during prefill chucking produces different generation # numeric error during prefill chucking produces different generation
# compared to w/o prefill chunking for those examples, removed them for now # compared to w/o prefill chunking for those examples, removed them for now
if 'Jamba' in model:
example_prompts.pop(7) example_prompts.pop(7)
example_prompts.pop(2) example_prompts.pop(2)
example_prompts.pop(1) example_prompts.pop(1)
elif 'Bamba' in model:
example_prompts.pop(6)
example_prompts.pop(3)
example_prompts.pop(2)
dtype = "half" # use a different dtype for Bamba
with hf_runner( with hf_runner(
model, model,
...@@ -145,7 +156,7 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, ...@@ -145,7 +156,7 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [15]) @pytest.mark.parametrize("max_tokens", [15])
def test_parallel_sampling( def test_parallel_sampling(
vllm_runner, vllm_runner,
...@@ -249,17 +260,17 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( ...@@ -249,17 +260,17 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
dtype: str, dtype: str,
example_prompts, example_prompts,
) -> None: ) -> None:
# This test is for verifying that the Jamba inner state management doesn't # This test is for verifying that the hybrid inner state management doesn't
# collapse in case where the number of incoming requests and # collapse in case where the number of incoming requests and
# finished_requests_ids is larger than the maximum mamba block capacity. # finished_requests_ids is larger than the maximum mamba block capacity.
# This could generally happen due to the fact that Jamba does support # This could generally happen due to the fact that hybrid does support
# statelessness mechanism where it can cleanup new incoming requests in # statelessness mechanism where it can cleanup new incoming requests in
# a single step. # a single step.
try: try:
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 100, 10) vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
except ValueError: except ValueError:
pytest.fail("Jamba inner state wasn't cleaned up properly between" pytest.fail("Hybrid inner state wasn't cleaned up properly between"
"steps finished requests registered unnecessarily ") "steps finished requests registered unnecessarily ")
...@@ -271,14 +282,14 @@ def test_state_cleanup( ...@@ -271,14 +282,14 @@ def test_state_cleanup(
dtype: str, dtype: str,
example_prompts, example_prompts,
) -> None: ) -> None:
# This test is for verifying that the Jamba state is cleaned up between # This test is for verifying that the Hybrid state is cleaned up between
# steps, If its not cleaned, an error would be expected. # steps, If its not cleaned, an error would be expected.
try: try:
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype) as vllm_model:
for _ in range(10): for _ in range(10):
vllm_model.generate_greedy([example_prompts[0]] * 100, 1) vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
except ValueError: except ValueError:
pytest.fail("Jamba inner state wasn't cleaned up between states, " pytest.fail("Hybrid inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids") "could be related to finished_requests_ids")
...@@ -324,7 +335,7 @@ def test_multistep_correctness(vllm_runner, model: str, dtype: str, ...@@ -324,7 +335,7 @@ def test_multistep_correctness(vllm_runner, model: str, dtype: str,
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
def test_jamba_distributed_produces_identical_generation( def test_hybrid_distributed_produces_identical_generation(
vllm_runner, model: str, dtype: str, max_tokens: int, vllm_runner, model: str, dtype: str, max_tokens: int,
example_prompts) -> None: example_prompts) -> None:
......
...@@ -102,6 +102,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -102,6 +102,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True), trust_remote_code=True),
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
trust_remote_code=True), trust_remote_code=True),
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B"),
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"), "BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"),
# ChatGLMModel supports multimodal # ChatGLMModel supports multimodal
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from itertools import accumulate
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
import torch import torch
...@@ -15,6 +16,7 @@ from vllm.multimodal import MultiModalPlaceholderMap ...@@ -15,6 +16,7 @@ from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
from vllm.utils import async_tensor_h2d
# Placeholder attention backend for models like Mamba and pooling models that # Placeholder attention backend for models like Mamba and pooling models that
# lack attention. # lack attention.
...@@ -77,43 +79,39 @@ class PlaceholderAttentionMetadata(AttentionMetadata): ...@@ -77,43 +79,39 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
# seq_lens stored as a tensor. # seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor] seq_lens_tensor: Optional[torch.Tensor]
# Maximum query length in the batch.
max_query_len: Optional[int]
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding # Maximum sequence length among prefill batch. 0 if there are decoding
# requests only. # requests only.
max_prefill_seq_len: int max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill # Maximum sequence length among decode batch. 0 if there are prefill
# requests only. # requests only.
max_decode_seq_len: int max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed # (batch_size,) A tensor of context lengths (tokens that are computed
# so far). # so far).
context_lens_tensor: Optional[torch.Tensor] context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled. # Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only. # Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool use_cuda_graph: bool
# Maximum query length in the batch.
max_query_len: Optional[int]
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# Placeholder.
block_tables: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
...@@ -125,11 +123,17 @@ class PlaceholderAttentionMetadata(AttentionMetadata): ...@@ -125,11 +123,17 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
if self._cached_prefill_metadata is not None: if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata return self._cached_prefill_metadata
assert self.seq_lens is not None # Compute some attn_metadata fields which default to None
assert self.seq_lens_tensor is not None query_start_loc = (None if self.query_start_loc is None else
assert self.query_start_loc is not None self.query_start_loc[:self.num_prefills + 1])
assert self.context_lens_tensor is not None seq_lens = (None if self.seq_lens is None else
assert self.seq_start_loc is not None self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
# Placeholders # Placeholders
slot_mapping = torch.empty(0) slot_mapping = torch.empty(0)
...@@ -143,15 +147,15 @@ class PlaceholderAttentionMetadata(AttentionMetadata): ...@@ -143,15 +147,15 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps, multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation, enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=seq_lens,
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=0, max_decode_query_len=0,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len, max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0, max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1], query_start_loc=query_start_loc,
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], seq_start_loc=seq_start_loc,
context_lens_tensor=self.context_lens_tensor[:self.num_prefills], context_lens_tensor=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=False, use_cuda_graph=False,
) )
...@@ -169,6 +173,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata): ...@@ -169,6 +173,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
# Placeholders # Placeholders
slot_mapping = torch.empty(0) slot_mapping = torch.empty(0)
block_tables = torch.empty(0) block_tables = torch.empty(0)
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
self._cached_decode_metadata = PlaceholderAttentionMetadata( self._cached_decode_metadata = PlaceholderAttentionMetadata(
num_prefills=0, num_prefills=0,
...@@ -178,13 +184,16 @@ class PlaceholderAttentionMetadata(AttentionMetadata): ...@@ -178,13 +184,16 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True, enable_kv_scales_calculation=True,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len, max_decode_query_len=self.max_decode_query_len,
max_query_len=None, max_query_len=None,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len, max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None, query_start_loc=(self.query_start_loc[self.num_prefills:] -
seq_start_loc=None, self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None, context_lens_tensor=None,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph, use_cuda_graph=self.use_cuda_graph,
...@@ -235,8 +244,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata): ...@@ -235,8 +244,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
assert self.context_lens_tensor is not None assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, ) assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
# Update query lengths. Note that we update only queries and not seqs, # Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size # since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries): for i in range(num_queries):
...@@ -299,9 +306,6 @@ class PlaceholderAttentionMetadataBuilder( ...@@ -299,9 +306,6 @@ class PlaceholderAttentionMetadataBuilder(
self.num_prefill_tokens += token_len self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len) self.prefill_seq_lens.append(seq_len)
else: else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len) self.curr_seq_lens.append(curr_seq_len)
...@@ -323,15 +327,6 @@ class PlaceholderAttentionMetadataBuilder( ...@@ -323,15 +327,6 @@ class PlaceholderAttentionMetadataBuilder(
device = self.runner.device device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1 use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
max_query_len = max(query_lens) max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:] decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0: if len(decode_query_lens) > 0:
...@@ -341,48 +336,37 @@ class PlaceholderAttentionMetadataBuilder( ...@@ -341,48 +336,37 @@ class PlaceholderAttentionMetadataBuilder(
max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
if use_captured_graph: if use_captured_graph:
num_decode_tokens = batch_size num_decode_tokens = batch_size - self.num_prefill_tokens
assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(self.context_lens, assert device is not None
dtype=torch.int, context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device=device) device, self.runner.pin_memory)
seq_lens_tensor = torch.tensor(seq_lens, seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
dtype=torch.int, self.runner.pin_memory)
device=device) query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
query_lens_tensor = torch.tensor(query_lens, device,
dtype=torch.long, self.runner.pin_memory)
device=device) seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, device, self.runner.pin_memory)
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
placeholder_index_maps = { placeholder_index_maps = {
modality: placeholder_map.index_map() modality: placeholder_map.index_map()
for modality, placeholder_map in for modality, placeholder_map in
self.multimodal_placeholder_maps.items() self.multimodal_placeholder_maps.items()
} }
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
# Placeholders # Placeholders
slot_mapping = torch.empty(0) slot_mapping_tensor = torch.empty(0)
block_tables = torch.empty(0) block_tables = torch.empty(0)
return PlaceholderAttentionMetadata( return PlaceholderAttentionMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping, slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True, enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
...@@ -393,8 +377,8 @@ class PlaceholderAttentionMetadataBuilder( ...@@ -393,8 +377,8 @@ class PlaceholderAttentionMetadataBuilder(
max_decode_query_len=max_decode_query_len, max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len, max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc, query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc, seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor, context_lens_tensor=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu. # Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
import torch import torch
import triton import triton
......
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py
# ruff: noqa: E501,SIM102
import math
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64
},
num_stages=3,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=2),
],
key=['chunk_size', 'K', 'IS_CAUSAL'],
)
@triton.jit
def _bmm_chunk_fwd_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
out_ptr,
seq_idx_ptr,
# Matrix dimensions
seqlen,
chunk_size,
K,
ngroups,
stride_a_batch,
stride_a_seqlen,
stride_a_head,
stride_ak,
stride_b_batch,
stride_b_seqlen,
stride_b_head,
stride_bk,
stride_out_batch,
stride_out_chunk,
stride_out_head,
stride_outm,
stride_outn,
stride_seq_idx_batch,
stride_seq_idx_seqlen,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
dot_dtype: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_ch = tl.program_id(axis=2).to(tl.int64)
pid_c = pid_ch // ngroups
pid_h = pid_ch - pid_c * ngroups
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
if IS_CAUSAL:
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
return
a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen +
offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
offs_n[None, :] * stride_b_seqlen)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=(offs_m[:, None] < chunk_size_limit) &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0).to(dot_dtype)
b = tl.load(b_ptrs,
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) &
(offs_n[None, :] < chunk_size_limit),
other=0.0).to(dot_dtype)
acc += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_SEQ_IDX:
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
mask=offs_m < chunk_size_limit,
other=-1)
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
mask=offs_n < chunk_size_limit,
other=-2)
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
out = acc.to(out_ptr.dtype.element_ty)
out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
offs_n[None, :] * stride_outn)
tl.store(out_ptrs,
out,
mask=(offs_m[:, None] < chunk_size) &
(offs_n[None, :] < chunk_size))
def _bmm_chunk_fwd(a,
b,
chunk_size,
seq_idx=None,
causal=False,
output_dtype=None):
"""
Argument:
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
guaranteed to be correct.
Return:
out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
"""
# Check constraints.
has_groups = a.dim() == 4
if not has_groups:
batch, seqlen, k = a.shape
else:
batch, seqlen, ngroups, k = a.shape
assert b.shape == a.shape
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if a.stride(-1) != 1 and a.stride(1) != 1:
a = a.contiguous()
if b.stride(-1) != 1 and b.stride(1) != 1:
b = b.contiguous()
nchunks = math.ceil(seqlen / chunk_size)
# Allocates output.
out_dtype = a.dtype if output_dtype is None else output_dtype
out = torch.empty(
(batch, nchunks, chunk_size, chunk_size) if not has_groups else
(batch, nchunks, ngroups, chunk_size, chunk_size),
device=a.device,
dtype=out_dtype)
dot_dtype = (tl.bfloat16
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
(tl.float16 if a.dtype == torch.float16
or b.dtype == torch.float16 else tl.float32))
grid = lambda META: (triton.cdiv(
chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
chunk_size, META['BLOCK_SIZE_N']), batch, nchunks
if not has_groups else nchunks * ngroups)
with torch.cuda.device(a.device.index):
_bmm_chunk_fwd_kernel[grid](
a,
b,
out,
seq_idx,
seqlen,
chunk_size,
k,
ngroups if has_groups else 1,
a.stride(0),
a.stride(1),
0 if not has_groups else a.stride(2),
a.stride(-1),
b.stride(0),
b.stride(1),
0 if not has_groups else b.stride(2),
b.stride(-1),
out.stride(0),
out.stride(1),
0 if not has_groups else out.stride(2),
out.stride(-2),
out.stride(-1),
*((seq_idx.stride(0),
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
causal,
dot_dtype,
HAS_SEQ_IDX=seq_idx is not None,
)
return out
This diff is collapsed.
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
# ruff: noqa: E501
import torch
import triton
from einops import rearrange
from packaging import version
from .ssd_bmm import _bmm_chunk_fwd
from .ssd_chunk_scan import _chunk_scan_fwd
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
chunk_state_varlen)
from .ssd_state_passing import _state_passing_fwd
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
def _mamba_chunk_scan_combined_fwd(x,
dt,
A,
B,
C,
chunk_size,
D=None,
z=None,
dt_bias=None,
initial_states=None,
seq_idx=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf"))):
batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (batch, seqlen, ngroups, dstate)
assert x.shape == (batch, seqlen, nheads, headdim)
assert dt.shape == (batch, seqlen, nheads)
assert A.shape == (nheads, )
assert C.shape == B.shape
if z is not None:
assert z.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if x.stride(-1) != 1 and x.stride(
1) != 1: # Either M or K dimension should be contiguous
x = x.contiguous()
if z is not None and z.stride(-1) != 1 and z.stride(
1) != 1: # Either M or K dimension should be contiguous
z = z.contiguous()
if D is not None and D.stride(-1) != 1:
D = D.contiguous()
if initial_states is not None:
if cu_seqlens is None:
assert initial_states.shape == (batch, nheads, headdim, dstate)
else:
assert initial_states.shape == (len(cu_seqlens) - 1, nheads,
headdim, dstate)
# This function executes 5 sub-functions for computing mamba
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
# which has a minimal implementation to understand the below operations
# - as explained by the blog, mamba is a special case of causal attention
# - the idea is to chunk the attention matrix and compute each
# submatrix separately using different optimizations.
# - see the blog and paper for a visualization of the submatrices
# which we refer to in the comments below
# 1. Compute chunked cumsum of A * dt
# - here dt may go through a softplus activation
dA_cumsum, dt = _chunk_cumsum_fwd(dt,
A,
chunk_size,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
dt_limit=dt_limit)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
states = _chunk_state_fwd(B,
x,
dt,
dA_cumsum,
seq_idx=seq_idx,
states_in_fp32=True)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# ii) seq_idx and iii) has_cu_seqlens to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - this will ensure that states will be updated with the rightmost flushed seq_idx
# of the previous chunk. This implies that the first chunk of states is either 0
# or equal to init_states of the first example.
states, final_states = _state_passing_fwd(
rearrange(states, "... p n -> ... (p n)"),
dA_cumsum[:, :, :, -1],
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
if initial_states is not None else None,
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=C.dtype,
is_cont_batched=cu_seqlens is not None)
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states])
# 4. Compute batched matrix multiply for C_j^T B_i terms
CB = _bmm_chunk_fwd(C,
B,
chunk_size,
seq_idx=seq_idx,
output_dtype=torch.float32)
# 5. Scan and compute the diagonal blocks, taking into
# account past causal states.
# - if initial states are provided, then states information will be
# augmented with initial_states.
# - to do this properly, we need to account for example changes in
# the continuous batch, therefore we introduce pseudo chunks, which is
# a chunk that is split up each time an example changes.
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
# a seq_idx change, in which case we take states information from
# init_states.
out, out_x = _chunk_scan_fwd(
CB,
x,
dt,
dA_cumsum,
C,
states,
D=D,
z=z,
seq_idx=seq_idx,
initial_states=initial_states,
)
if cu_seqlens is None:
return out, out_x, dt, dA_cumsum, states, final_states
else:
assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
varlen_states = chunk_state_varlen(
B.squeeze(0),
x.squeeze(0),
dt.squeeze(0),
dA_cumsum.squeeze(0),
cu_seqlens,
states.squeeze(0),
initial_states=initial_states,
)
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
def mamba_chunk_scan_combined(x,
dt,
A,
B,
C,
chunk_size,
D=None,
z=None,
dt_bias=None,
initial_states=None,
seq_idx=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
return_final_states=False,
return_varlen_states=False):
"""
Argument:
x: (batch, seqlen, nheads, headdim)
dt: (batch, seqlen, nheads)
A: (nheads)
B: (batch, seqlen, ngroups, dstate)
C: (batch, seqlen, ngroups, dstate)
chunk_size: int
D: (nheads, headdim) or (nheads,)
z: (batch, seqlen, nheads, headdim)
dt_bias: (nheads,)
initial_states: (batch, nheads, headdim, dstate)
seq_idx: (batch, seqlen)
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
dt_softplus: Whether to apply softplus to dt
Return:
out: (batch, seqlen, nheads, headdim)
"""
if not return_varlen_states:
cu_seqlens = None
else:
assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
x,
dt,
A,
B,
C,
chunk_size,
D=D,
z=z,
dt_bias=dt_bias,
initial_states=initial_states,
seq_idx=seq_idx,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit)
if not return_varlen_states:
return out if not return_final_states else (out, final_states)
else:
varlen_states = rest[0]
return (out,
varlen_states) if not return_final_states else (out,
final_states,
varlen_states)
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
# ruff: noqa: E501
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64}),
triton.Config({'BLOCK_SIZE': 128}),
triton.Config({'BLOCK_SIZE': 256}),
triton.Config({'BLOCK_SIZE': 512}),
triton.Config({'BLOCK_SIZE': 1024}),
triton.Config({'BLOCK_SIZE': 2048}),
],
key=['dim'],
)
@triton.jit
def _state_passing_fwd_kernel(
# Pointers to matrices
states_ptr,
out_ptr,
final_states_ptr,
dA_cs_ptr,
initstates_ptr,
seq_idx_ptr,
# Matrix dimensions
dim,
nchunks,
seqlen,
chunk_size,
# Strides
stride_states_batch,
stride_states_chunk,
stride_states_head,
stride_states_dim,
stride_out_batch,
stride_out_chunk,
stride_out_head,
stride_out_dim,
stride_final_states_batch,
stride_final_states_head,
stride_final_states_dim,
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_initstates_batch,
stride_initstates_head,
stride_initstates_dim,
stride_seq_idx_batch,
stride_seq_idx_seqlen,
# Meta-parameters
HAS_INITSTATES: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
IS_CONT_BATCHED: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
if HAS_INITSTATES:
initstates_ptr += pid_h * stride_initstates_head
if not IS_CONT_BATCHED:
initstates_ptr += pid_b * stride_initstates_batch
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
states_ptrs = states_ptr + offs_m * stride_states_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
# - states will be the past state of the sequence that continues on the current check
if not HAS_INITSTATES:
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
else:
initstates_ptr += offs_m * stride_initstates_dim
initstates_ptrs = initstates_ptr
# - for cont batches, for the first chunk mean it will be the first batch's
# init state
states = tl.load(initstates_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
tl.store(out_ptrs, states, mask=offs_m < dim)
out_ptrs += stride_out_chunk
seq_idx = 0
for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
scale = tl.exp(dA_cs)
if HAS_SEQ_IDX:
# - the seq to pass forward is the one that is flushed to the right
# boundary.
# - that is given by seq_idx_new below.
seq_idx_new = tl.load(seq_idx_ptr +
(min((c + 1) * chunk_size, seqlen) - 1) *
stride_seq_idx_seqlen)
if HAS_INITSTATES:
if IS_CONT_BATCHED and seq_idx != seq_idx_new:
# this means in the current chunk the rightmost flushed seq
# has changed.
# - so we do not propagate the state from previous chunk
# - but rather we load that sequence's init state
initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch
# - update state with seq_idx_new's init state
states = tl.load(initstates_ptrs,
mask=offs_m < dim,
other=0.0).to(tl.float32)
else:
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
seq_idx = seq_idx_new
states = scale * states + new_states
if c < nchunks - 1:
tl.store(out_ptrs, states, mask=offs_m < dim)
else:
tl.store(final_states_ptrs, states, mask=offs_m < dim)
states_ptrs += stride_states_chunk
dA_cs_ptr += stride_dA_cs_chunk
out_ptrs += stride_out_chunk
def _state_passing_fwd(
states,
dA_chunk_cumsum,
initial_states=None,
seq_idx=None,
chunk_size=None,
out_dtype=None,
is_cont_batched=False,
):
batch, nchunks, nheads, dim = states.shape
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
if initial_states is not None:
if is_cont_batched:
# - if cu_seqlens is provided, then the initial states
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert seq_idx is not None, ""
assert initial_states.shape == (seq_idx.max().item() + 1, nheads,
dim)
else:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.
assert initial_states.shape == (batch, nheads, dim)
if seq_idx is not None:
assert chunk_size is not None
seqlen = seq_idx.shape[-1]
assert seq_idx.shape == (batch, seqlen)
out_dtype = states.dtype if out_dtype is None else out_dtype
out = torch.empty((batch, nchunks, nheads, dim),
device=states.device,
dtype=out_dtype)
final_states = torch.empty((batch, nheads, dim),
device=states.device,
dtype=torch.float32)
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
with torch.cuda.device(states.device.index):
_state_passing_fwd_kernel[grid](
states,
out,
final_states,
dA_chunk_cumsum,
initial_states,
seq_idx,
dim,
nchunks,
seqlen if seq_idx is not None else 0,
chunk_size if seq_idx is not None else 0,
states.stride(0),
states.stride(1),
states.stride(2),
states.stride(3),
out.stride(0),
out.stride(1),
out.stride(2),
out.stride(3),
final_states.stride(0),
final_states.stride(1),
final_states.stride(2),
dA_chunk_cumsum.stride(0),
dA_chunk_cumsum.stride(2),
dA_chunk_cumsum.stride(1),
*((initial_states.stride(0), initial_states.stride(1),
initial_states.stride(2)) if initial_states is not None else
(0, 0, 0)),
*((seq_idx.stride(0),
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
HAS_INITSTATES=initial_states is not None,
HAS_SEQ_IDX=seq_idx is not None,
IS_CONT_BATCHED=is_cont_batched,
)
return out, final_states
This diff is collapsed.
...@@ -455,14 +455,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -455,14 +455,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.mamba_cache = MambaCacheManager( self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape()) self.max_batch_size, *self._get_mamba_cache_shape())
(
mamba_cache_tensors, mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_params, attn_metadata, mamba_cache_params,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)
......
...@@ -232,15 +232,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): ...@@ -232,15 +232,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
self.lm_head.weight.dtype, num_mamba_layers, self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape()) self.max_batch_size, *self._get_mamba_cache_shape())
( mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
mamba_cache_tensors,
state_indices_tensor,
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
**kwargs)
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
mamba_cache_tensors[1],
state_indices_tensor)
hidden_states = self.backbone(input_ids, positions, attn_metadata, hidden_states = self.backbone(input_ids, positions, attn_metadata,
mamba_cache_params, intermediate_tensors, mamba_cache_params, intermediate_tensors,
......
...@@ -5,7 +5,6 @@ from typing import Dict, List ...@@ -5,7 +5,6 @@ from typing import Dict, List
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
...@@ -42,8 +41,7 @@ class MambaCacheManager: ...@@ -42,8 +41,7 @@ class MambaCacheManager:
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size)) self.free_cache_indices = list(range(max_batch_size))
def current_run_tensors(self, input_ids: torch.Tensor, def current_run_tensors(self, **kwargs) -> MambaCacheParams:
attn_metadata: AttentionMetadata, **kwargs):
""" """
Return the tensors for the current run's conv and ssm state. Return the tensors for the current run's conv and ssm state.
""" """
...@@ -66,7 +64,8 @@ class MambaCacheManager: ...@@ -66,7 +64,8 @@ class MambaCacheManager:
(mamba_cache_tensors, (mamba_cache_tensors,
state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"] state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
return (mamba_cache_tensors, state_indices_tensor) return MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1],
state_indices_tensor)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
""" """
......
...@@ -37,6 +37,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -37,6 +37,7 @@ _TEXT_GENERATION_MODELS = {
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
# baichuan-13b, lower case 'c' in the class name # baichuan-13b, lower case 'c' in the class name
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
"BambaForCausalLM": ("bamba", "BambaForCausalLM"),
"BloomForCausalLM": ("bloom", "BloomForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
# ChatGLMModel supports multimodal # ChatGLMModel supports multimodal
"CohereForCausalLM": ("commandr", "CohereForCausalLM"), "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment