Unverified Commit 7342a7d7 authored by Tyler Michael Smith's avatar Tyler Michael Smith Committed by GitHub
Browse files

[Model] Support Mamba (#6484)

parent df3dcdf4
...@@ -418,13 +418,12 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): ...@@ -418,13 +418,12 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
self.sliding_window = model_config.get_sliding_window() self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(), self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
self.kv_cache_dtype, self.kv_cache_dtype,
self.block_size, self.block_size,
self.model_config.is_attention_free,
) )
# Multi-modal data support # Multi-modal data support
......
...@@ -56,13 +56,12 @@ class CPUCacheEngine: ...@@ -56,13 +56,12 @@ class CPUCacheEngine:
# Get attention backend. # Get attention backend.
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(), self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
cache_config.cache_dtype, cache_config.cache_dtype,
self.block_size, self.block_size,
self.model_config.is_attention_free,
) )
# Initialize the cache. # Initialize the cache.
......
...@@ -196,7 +196,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -196,7 +196,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
seqlen_agnostic_kwargs = { seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids, "finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {} } if self.has_inner_state else {}
multi_modal_kwargs = model_input.multi_modal_kwargs or {} multi_modal_kwargs = model_input.multi_modal_kwargs or {}
with set_forward_context(model_input.attn_metadata): with set_forward_context(model_input.attn_metadata):
......
...@@ -17,7 +17,6 @@ import torch.nn as nn ...@@ -17,7 +17,6 @@ import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.levels import CompilationLevel from vllm.compilation.levels import CompilationLevel
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...@@ -991,8 +990,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -991,8 +990,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.graph_memory_pool: Optional[Tuple[ self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture. int, int]] = None # Set during graph capture.
self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers( self.has_inner_state = model_config.has_inner_state
parallel_config)
# When using CUDA graph, the input block tables must be padded to # When using CUDA graph, the input block tables must be padded to
# max_seq_len_to_capture. However, creating the block table in # max_seq_len_to_capture. However, creating the block table in
...@@ -1003,22 +1001,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1003,22 +1001,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.graph_block_tables = np.zeros( self.graph_block_tables = np.zeros(
(self.max_batchsize_to_capture, self.get_max_block_per_batch()), (self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32) dtype=np.int32)
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
num_attn_heads,
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(), self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
self.kv_cache_dtype, self.kv_cache_dtype,
self.block_size, self.block_size,
) if num_attn_heads else None self.model_config.is_attention_free,
if self.attn_backend: )
self.attn_state = self.attn_backend.get_state_cls()( self.attn_state = self.attn_backend.get_state_cls()(
weakref.proxy(self)) weakref.proxy(self))
else:
self.attn_state = CommonAttentionState(weakref.proxy(self))
# Multi-modal data support # Multi-modal data support
self.input_registry = input_registry self.input_registry = input_registry
...@@ -1498,7 +1490,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1498,7 +1490,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"previous_hidden_states"] = previous_hidden_states[: "previous_hidden_states"] = previous_hidden_states[:
batch_size] batch_size]
if self.has_seqlen_agnostic: if self.has_inner_state:
# Only used by Mamba-based models CUDA graph atm (Jamba) # Only used by Mamba-based models CUDA graph atm (Jamba)
capture_inputs.update({ capture_inputs.update({
"seqlen_agnostic_capture_inputs": "seqlen_agnostic_capture_inputs":
...@@ -1647,7 +1639,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1647,7 +1639,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
seqlen_agnostic_kwargs = { seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids, "finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {} } if self.has_inner_state else {}
if (self.observability_config is not None if (self.observability_config is not None
and self.observability_config.collect_model_forward_time): and self.observability_config.collect_model_forward_time):
model_forward_start = torch.cuda.Event(enable_timing=True) model_forward_start = torch.cuda.Event(enable_timing=True)
...@@ -1852,10 +1844,14 @@ class CUDAGraphRunner: ...@@ -1852,10 +1844,14 @@ class CUDAGraphRunner:
# Copy the input tensors to the input buffers. # Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True) if self.backend_name != "placeholder-attn":
self.input_buffers["slot_mapping"].copy_(
attn_metadata.slot_mapping, non_blocking=True)
self.attn_state.prepare_graph_input_buffers( self.attn_state.prepare_graph_input_buffers(
self.input_buffers, attn_metadata, self._is_encoder_decoder_model) self.input_buffers, attn_metadata, self._is_encoder_decoder_model)
if "seqlen_agnostic_capture_inputs" in self.input_buffers: if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_inputs_before_cuda_graphs(self.input_buffers, self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
**kwargs) **kwargs)
......
...@@ -74,13 +74,12 @@ class OpenVINOModelRunner: ...@@ -74,13 +74,12 @@ class OpenVINOModelRunner:
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(), self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
self.kv_cache_dtype, self.kv_cache_dtype,
self.block_size, self.block_size,
self.model_config.is_attention_free,
) )
# Multi-modal data support # Multi-modal data support
......
...@@ -70,13 +70,12 @@ class OpenVINOCacheEngine: ...@@ -70,13 +70,12 @@ class OpenVINOCacheEngine:
# Get attention backend. # Get attention backend.
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.head_size, self.head_size,
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(), self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
self.cache_config.cache_dtype, self.cache_config.cache_dtype,
self.block_size, self.block_size,
self.model_config.is_attention_free,
) )
# Initialize the cache. # Initialize the cache.
......
...@@ -113,13 +113,12 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -113,13 +113,12 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
(self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq), (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
dtype=np.int32) dtype=np.int32)
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(), self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
self.cache_config.cache_dtype, self.cache_config.cache_dtype,
self.block_size, self.block_size,
self.model_config.is_attention_free,
False, False,
) )
self.cached_step_outputs: List[torch.Tensor] = [] self.cached_step_outputs: List[torch.Tensor] = []
......
...@@ -236,11 +236,15 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -236,11 +236,15 @@ class Worker(LocalOrDistributedWorkerBase):
"not properly cleaned up before initializing the vLLM instance.") "not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes() cache_block_size = self.get_cache_block_size_bytes()
num_gpu_blocks = int( if cache_block_size == 0:
(total_gpu_memory * self.cache_config.gpu_memory_utilization - num_gpu_blocks = 0
peak_memory) // cache_block_size) num_cpu_blocks = 0
num_cpu_blocks = int(self.cache_config.swap_space_bytes // else:
cache_block_size) num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0) num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager: if self.model_runner.lora_manager:
...@@ -257,6 +261,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -257,6 +261,7 @@ class Worker(LocalOrDistributedWorkerBase):
""" """
raise_if_cache_size_invalid(num_gpu_blocks, raise_if_cache_size_invalid(num_gpu_blocks,
self.cache_config.block_size, self.cache_config.block_size,
self.cache_config.is_attention_free,
self.model_config.max_model_len) self.model_config.max_model_len)
self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_gpu_blocks = num_gpu_blocks
...@@ -472,14 +477,18 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): ...@@ -472,14 +477,18 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
"`dtype` flag in CLI, for example: --dtype=half.") "`dtype` flag in CLI, for example: --dtype=half.")
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
max_model_len) -> None: max_model_len) -> None:
if num_gpu_blocks <= 0: if is_attention_free and num_gpu_blocks != 0:
raise ValueError("No memory should be allocated for the cache blocks "
f"for an attention-free model, but {num_gpu_blocks}"
"blocks are allocated.")
if not is_attention_free and num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. " raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when " "Try increasing `gpu_memory_utilization` when "
"initializing the engine.") "initializing the engine.")
max_seq_len = block_size * num_gpu_blocks max_seq_len = block_size * num_gpu_blocks
if max_model_len > max_seq_len: if not is_attention_free and max_model_len > max_seq_len:
raise ValueError( raise ValueError(
f"The model's max seq len ({max_model_len}) " f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be " "is larger than the maximum number of tokens that can be "
......
...@@ -372,13 +372,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -372,13 +372,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(), self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
self.kv_cache_dtype, self.kv_cache_dtype,
self.block_size, self.block_size,
self.model_config.is_attention_free,
) )
# Multi-modal data support # Multi-modal data support
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment