Unverified Commit 9d6a8daa authored by Mor Zusman's avatar Mor Zusman Committed by GitHub
Browse files
parent ee93f4f9
...@@ -23,4 +23,4 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" ...@@ -23,4 +23,4 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
docker exec cpu-test bash -c "cd tests; docker exec cpu-test bash -c "cd tests;
pip install pytest Pillow protobuf pip install pytest Pillow protobuf
cd ../ cd ../
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py" pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported
...@@ -43,6 +43,10 @@ COPY requirements-cuda.txt requirements-cuda.txt ...@@ -43,6 +43,10 @@ COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-cuda.txt python3 -m pip install -r requirements-cuda.txt
COPY requirements-mamba.txt requirements-mamba.txt
RUN python3 -m pip install packaging
RUN python3 -m pip install -r requirements-mamba.txt
# cuda arch list used by torch # cuda arch list used by torch
# can be useful for both `dev` and `test` # can be useful for both `dev` and `test`
# explicitly set the list to avoid issues with torch 2.2 # explicitly set the list to avoid issues with torch 2.2
...@@ -123,6 +127,21 @@ RUN --mount=type=cache,target=/root/.cache/pip \ ...@@ -123,6 +127,21 @@ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-dev.txt python3 -m pip install -r requirements-dev.txt
#################### DEV IMAGE #################### #################### DEV IMAGE ####################
#################### MAMBA Build IMAGE ####################
FROM dev as mamba-builder
# max jobs used for build
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}
WORKDIR /usr/src/mamba
COPY requirements-mamba.txt requirements-mamba.txt
# Download the wheel or build it if a pre-compiled release doesn't exist
RUN pip --verbose wheel -r requirements-mamba.txt \
--no-build-isolation --no-deps --no-cache-dir
#################### MAMBA Build IMAGE ####################
#################### vLLM installation IMAGE #################### #################### vLLM installation IMAGE ####################
# image with vLLM installed # image with vLLM installed
...@@ -143,6 +162,10 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ ...@@ -143,6 +162,10 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install dist/*.whl --verbose python3 -m pip install dist/*.whl --verbose
RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
#################### vLLM installation IMAGE #################### #################### vLLM installation IMAGE ####################
......
...@@ -87,6 +87,10 @@ Alongside each architecture, we include some popular models that use it. ...@@ -87,6 +87,10 @@ Alongside each architecture, we include some popular models that use it.
- Jais - Jais
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. - :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
- -
* - :code:`JambaForCausalLM`
- Jamba
- :code:`ai21labs/Jamba-v0.1`, etc.
- ✅︎
* - :code:`LlamaForCausalLM` * - :code:`LlamaForCausalLM`
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi - LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. - :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
......
# Mamba dependencies
mamba-ssm>=1.2.2
causal-conv1d>=1.2.0
import pytest
MODELS = ["ai21labs/Jamba-tiny-random"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
assert dtype == "float"
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Jamba state is cleaned up between
# steps, If its not cleaned, an error would be expected.
try:
with vllm_runner(model, dtype=dtype) as vllm_model:
for _ in range(10):
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
except ValueError:
pytest.fail("Jamba inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
...@@ -386,9 +386,36 @@ class ModelConfig: ...@@ -386,9 +386,36 @@ class ModelConfig:
return num_heads // parallel_config.tensor_parallel_size return num_heads // parallel_config.tensor_parallel_size
def get_num_layers(self, parallel_config: "ParallelConfig") -> int: def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
return total_num_hidden_layers // parallel_config.pipeline_parallel_size return total_num_hidden_layers // parallel_config.pipeline_parallel_size
def contains_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> bool:
"""True for Mamba/SSM models (Jamba)"""
return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
def get_layers_block_type(self,
parallel_config: "ParallelConfig") -> List[str]:
num_layers = self.get_num_layers(parallel_config)
# Transformers supports layers_block_type @property
return getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)
def get_num_attention_layers(self,
parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t == "attention"
])
def _get_num_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t != "attention"
])
class CacheConfig: class CacheConfig:
"""Configuration for the KV cache. """Configuration for the KV cache.
......
...@@ -299,7 +299,10 @@ class Scheduler: ...@@ -299,7 +299,10 @@ class Scheduler:
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out. # Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque() self.swapped: Deque[SequenceGroup] = deque()
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
self._finished_requests_ids: List[str] = list()
# Time at previous scheduling step # Time at previous scheduling step
self.prev_time = 0.0 self.prev_time = 0.0
# Did we schedule a prompt at previous step? # Did we schedule a prompt at previous step?
...@@ -373,6 +376,12 @@ class Scheduler: ...@@ -373,6 +376,12 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
def get_and_reset_finished_requests_ids(self) -> List[str]:
"""Flushes the list of request ids of previously finished seq_groups."""
finished_requests_ids = self._finished_requests_ids
self._finished_requests_ids = list()
return finished_requests_ids
def _schedule_running( def _schedule_running(
self, self,
running_queue: deque, running_queue: deque,
...@@ -1036,6 +1045,11 @@ class Scheduler: ...@@ -1036,6 +1045,11 @@ class Scheduler:
self.block_manager.free(seq) self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:
for queue in [self.running, self.swapped, self.waiting]:
self._finished_requests_ids += [
seq_group.request_id for seq_group in queue
if seq_group.is_finished()
]
self.running = deque(seq_group for seq_group in self.running self.running = deque(seq_group for seq_group in self.running
if not seq_group.is_finished()) if not seq_group.is_finished())
......
...@@ -224,6 +224,8 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -224,6 +224,8 @@ class _AsyncLLMEngine(LLMEngine):
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler[ seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule() virtual_engine].schedule()
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
# Execute the model. # Execute the model.
...@@ -235,7 +237,7 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -235,7 +237,7 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots, num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size, running_queue_size=scheduler_outputs.running_queue_size,
) finished_requests_ids=finished_requests_ids)
output = await self.model_executor.execute_model_async( output = await self.model_executor.execute_model_async(
execute_model_req) execute_model_req)
else: else:
......
...@@ -846,6 +846,8 @@ class LLMEngine: ...@@ -846,6 +846,8 @@ class LLMEngine:
"as performance will be severely degraded otherwise.") "as performance will be severely degraded otherwise.")
seq_group_metadata_list, scheduler_outputs = self.scheduler[ seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule() 0].schedule()
finished_requests_ids = self.scheduler[
0].get_and_reset_finished_requests_ids()
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
execute_model_req = ExecuteModelRequest( execute_model_req = ExecuteModelRequest(
...@@ -855,7 +857,7 @@ class LLMEngine: ...@@ -855,7 +857,7 @@ class LLMEngine:
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots, num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size, running_queue_size=scheduler_outputs.running_queue_size,
) finished_requests_ids=finished_requests_ids)
output = self.model_executor.execute_model( output = self.model_executor.execute_model(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
else: else:
......
...@@ -63,6 +63,7 @@ _GENERATION_MODELS = { ...@@ -63,6 +63,7 @@ _GENERATION_MODELS = {
"XverseForCausalLM": ("xverse", "XverseForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
} }
_EMBEDDING_MODELS = { _EMBEDDING_MODELS = {
......
This diff is collapsed.
...@@ -934,6 +934,8 @@ class ExecuteModelRequest: ...@@ -934,6 +934,8 @@ class ExecuteModelRequest:
previous_hidden_states: Optional[HiddenStates] = None previous_hidden_states: Optional[HiddenStates] = None
# The number of forward steps to run. # The number of forward steps to run.
num_steps: int = 1 num_steps: int = 1
# Finished request ids since last step.
finished_requests_ids: List[str] = field(default_factory=list)
def clone( def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata] self, seq_group_metadata_list: List[SequenceGroupMetadata]
...@@ -949,4 +951,4 @@ class ExecuteModelRequest: ...@@ -949,4 +951,4 @@ class ExecuteModelRequest:
running_queue_size=self.running_queue_size, running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states, previous_hidden_states=self.previous_hidden_states,
num_steps=self.num_steps, num_steps=self.num_steps,
) finished_requests_ids=self.finished_requests_ids)
...@@ -77,13 +77,17 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -77,13 +77,17 @@ class TP1DraftModelRunner(ModelRunner):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata: virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
"""A temporary solution that caches the seq_group_metadata_list """A temporary solution that caches the seq_group_metadata_list
for multi-step execution. for multi-step execution.
TODO: In-place update model_input and remove this function. TODO: In-place update model_input and remove this function.
""" """
self.cached_seq_group_metadata_list = seq_group_metadata_list self.cached_seq_group_metadata_list = seq_group_metadata_list
return super().prepare_model_input(seq_group_metadata_list) return super().prepare_model_input(
seq_group_metadata_list,
finished_requests_ids=finished_requests_ids)
def update_model_input( def update_model_input(
self, model_input: ModelInputForGPUWithSamplingMetadata, self, model_input: ModelInputForGPUWithSamplingMetadata,
......
...@@ -33,7 +33,9 @@ class CacheEngine: ...@@ -33,7 +33,9 @@ class CacheEngine:
self.device_config = device_config self.device_config = device_config
self.head_size = model_config.get_head_size() self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config) # Models like Jamba, have mixed typed layers, E.g Mamba
self.num_attention_layers = model_config.get_num_attention_layers(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
...@@ -75,7 +77,7 @@ class CacheEngine: ...@@ -75,7 +77,7 @@ class CacheEngine:
num_blocks, self.block_size, self.num_kv_heads, self.head_size) num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = [] kv_cache: List[torch.Tensor] = []
for _ in range(self.num_layers): for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that # null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out. # block to be zeroed-out.
# We zero-out everything for simplicity. # We zero-out everything for simplicity.
...@@ -87,12 +89,12 @@ class CacheEngine: ...@@ -87,12 +89,12 @@ class CacheEngine:
return kv_cache return kv_cache
def swap_in(self, src_to_dst: torch.Tensor) -> None: def swap_in(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_layers): for i in range(self.num_attention_layers):
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
src_to_dst) src_to_dst)
def swap_out(self, src_to_dst: torch.Tensor) -> None: def swap_out(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_layers): for i in range(self.num_attention_layers):
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
src_to_dst) src_to_dst)
...@@ -107,11 +109,12 @@ class CacheEngine: ...@@ -107,11 +109,12 @@ class CacheEngine:
) -> int: ) -> int:
head_size = model_config.get_head_size() head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config) num_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config) num_attention_layers = model_config.get_num_attention_layers(
parallel_config)
key_cache_block = cache_config.block_size * num_heads * head_size key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block) total = num_attention_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto": if cache_config.cache_dtype == "auto":
dtype = model_config.dtype dtype = model_config.dtype
else: else:
......
...@@ -317,6 +317,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -317,6 +317,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> CPUModelInput: ) -> CPUModelInput:
multi_modal_kwargs = None multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
......
...@@ -120,10 +120,11 @@ class EmbeddingModelRunner( ...@@ -120,10 +120,11 @@ class EmbeddingModelRunner(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithPoolingMetadata: ) -> ModelInputForGPUWithPoolingMetadata:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors( model_input = self._prepare_model_input_tensors(
seq_group_metadata_list) seq_group_metadata_list, finished_requests_ids)
# Prepare PoolingMetadata. # Prepare PoolingMetadata.
assert model_input.seq_lens is not None assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list, pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
......
...@@ -84,6 +84,8 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -84,6 +84,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
lora_requests: Optional[Set[LoRARequest]] = None lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0 virtual_engine: int = 0
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
...@@ -94,6 +96,8 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -94,6 +96,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict return tensor_dict
...@@ -128,6 +132,8 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): ...@@ -128,6 +132,8 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict, _add_sampling_metadata_broadcastable_dict(tensor_dict,
...@@ -191,6 +197,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -191,6 +197,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
] ]
self.graph_memory_pool: Optional[Tuple[ self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture. int, int]] = None # Set during graph capture.
self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
parallel_config)
# When using CUDA graph, the input block tables must be padded to # When using CUDA graph, the input block tables must be padded to
# max_seq_len_to_capture. However, creating the block table in # max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table # Python can be expensive. To optimize this, we cache the block table
...@@ -317,6 +327,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -317,6 +327,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def _prepare_model_input_tensors( def _prepare_model_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None
) -> TModelInputForGPU: ) -> TModelInputForGPU:
"""Helper method to prepare the model input based on a given sequence """Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not group. Prepares metadata needed for the base model forward pass but not
...@@ -347,6 +358,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -347,6 +358,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
block_tables: List[List[int]] = [] block_tables: List[List[int]] = []
multi_modal_kwargs_list: Dict[str, multi_modal_kwargs_list: Dict[str,
List[torch.Tensor]] = defaultdict(list) List[torch.Tensor]] = defaultdict(list)
request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
decode_only = True decode_only = True
num_prefills = 0 num_prefills = 0
num_prefill_tokens = 0 num_prefill_tokens = 0
...@@ -738,7 +750,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -738,7 +750,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
k: torch.cat(v, dim=0).to(self.device) k: torch.cat(v, dim=0).to(self.device)
for k, v in multi_modal_kwargs_list.items() for k, v in multi_modal_kwargs_list.items()
} }
request_ids_to_seq_ids = {
seq_group_metadata.request_id:
list(seq_group_metadata.seq_data.keys())
for seq_group_metadata in seq_group_metadata_list
}
return self._model_input_cls( return self._model_input_cls(
input_tokens=input_tokens_tensor, input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor, input_positions=input_positions_tensor,
...@@ -748,7 +764,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -748,7 +764,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_mapping=lora_mapping, lora_mapping=lora_mapping,
lora_requests=lora_requests, lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs, multi_modal_kwargs=multi_modal_kwargs,
) request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=finished_requests_ids)
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
...@@ -821,7 +838,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -821,7 +838,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# Run the model with the dummy inputs. # Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers kv_caches = [None] * num_layers
model_input = self.prepare_model_input(seqs) finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None intermediate_tensors = None
if not get_pp_group().is_first_rank: if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors( intermediate_tensors = self.model.make_empty_intermediate_tensors(
...@@ -1033,21 +1052,37 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1033,21 +1052,37 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
graph_runner.flashinfer_decode_wrapper = \ graph_runner.flashinfer_decode_wrapper = \
decode_wrapper decode_wrapper
graph_runner.capture( capture_inputs = {
"input_ids":
input_tokens[:batch_size], input_tokens[:batch_size],
"positions":
input_positions[:batch_size], input_positions[:batch_size],
"hidden_or_intermediate_states":
hidden_or_intermediate_states[ hidden_or_intermediate_states[
virtual_engine] # type: ignore virtual_engine] # type: ignore
[:batch_size] [:batch_size]
if hidden_or_intermediate_states[virtual_engine] if hidden_or_intermediate_states[virtual_engine]
is not None else None, is not None else None,
"intermediate_inputs":
intermediate_inputs[:batch_size] intermediate_inputs[:batch_size]
if intermediate_inputs is not None else None, if intermediate_inputs is not None else None,
"kv_caches":
kv_caches[virtual_engine], kv_caches[virtual_engine],
"attn_metadata":
attn_metadata, attn_metadata,
memory_pool=self.graph_memory_pool, "memory_pool":
stream=graph_capture_context.stream, self.graph_memory_pool,
) "stream":
graph_capture_context.stream
}
if self.has_seqlen_agnostic:
# Only used by Mamba-based models CUDA graph atm (Jamba)
capture_inputs.update({
"seqlen_agnostic_capture_inputs":
self.model.get_seqlen_agnostic_capture_inputs(
batch_size)
})
graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool() self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = ( self.graph_runners[virtual_engine][batch_size] = (
graph_runner) graph_runner)
...@@ -1084,6 +1119,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1084,6 +1119,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata: ) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including """Prepare the model input based on a given sequence group, including
metadata for the sampling step. metadata for the sampling step.
...@@ -1099,7 +1135,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1099,7 +1135,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
If cuda graph is required, this API automatically pads inputs. If cuda graph is required, this API automatically pads inputs.
""" """
model_input = self._prepare_model_input_tensors( model_input = self._prepare_model_input_tensors(
seq_group_metadata_list) seq_group_metadata_list, finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens, model_input.seq_lens,
model_input.query_lens, model_input.query_lens,
...@@ -1175,6 +1211,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1175,6 +1211,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_executable = self.model model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {} multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {}
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
...@@ -1182,7 +1222,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1182,7 +1222,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs, **multi_modal_kwargs,
) **seqlen_agnostic_kwargs)
# Compute the logits in the last pipeline stage. # Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
...@@ -1305,6 +1345,7 @@ class CUDAGraphRunner: ...@@ -1305,6 +1345,7 @@ class CUDAGraphRunner:
"positions": positions, "positions": positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping, "slot_mapping": attn_metadata.slot_mapping,
**kwargs,
} }
else: else:
self.input_buffers = { self.input_buffers = {
...@@ -1315,6 +1356,7 @@ class CUDAGraphRunner: ...@@ -1315,6 +1356,7 @@ class CUDAGraphRunner:
"seq_lens_tensor": "seq_lens_tensor":
attn_metadata.decode_metadata.seq_lens_tensor, attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables, "block_tables": attn_metadata.decode_metadata.block_tables,
**kwargs,
} }
if intermediate_inputs is not None: if intermediate_inputs is not None:
self.input_buffers.update(intermediate_inputs.tensors) self.input_buffers.update(intermediate_inputs.tensors)
...@@ -1349,13 +1391,18 @@ class CUDAGraphRunner: ...@@ -1349,13 +1391,18 @@ class CUDAGraphRunner:
non_blocking=True) non_blocking=True)
self.input_buffers["block_tables"].copy_( self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True) attn_metadata.decode_metadata.block_tables, non_blocking=True)
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
**kwargs)
if intermediate_tensors is not None: if intermediate_tensors is not None:
for key in intermediate_tensors.tensors: for key in intermediate_tensors.tensors:
self.input_buffers[key].copy_(intermediate_tensors[key], self.input_buffers[key].copy_(intermediate_tensors[key],
non_blocking=True) non_blocking=True)
# Run the graph. # Run the graph.
self.graph.replay() self.graph.replay()
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
**kwargs)
# Return the output tensor. # Return the output tensor.
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
return self.output_buffers["hidden_states"] return self.output_buffers["hidden_states"]
......
...@@ -139,6 +139,7 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -139,6 +139,7 @@ class ModelRunnerBase(ABC, Generic[T]):
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> T: ) -> T:
""" """
Prepare the inputs to ModelRunnerBase.execute_model from an execution Prepare the inputs to ModelRunnerBase.execute_model from an execution
......
...@@ -177,6 +177,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -177,6 +177,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForNeuron: ) -> ModelInputForNeuron:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
......
...@@ -234,7 +234,8 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -234,7 +234,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
model_input: ModelRunnerInputBase = ( model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list, execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine)) execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
num_steps = execute_model_req.num_steps num_steps = execute_model_req.num_steps
if self.do_metadata_broadcast: if self.do_metadata_broadcast:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment