Unverified Commit e3eb146f authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Add ModelStateInterface [4/N] (#35621)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 95a395db
......@@ -22,7 +22,7 @@ from vllm.v1.worker.gpu.attn_utils import (
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.model_states import ModelState
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup
......
......@@ -78,7 +78,7 @@ from vllm.v1.worker.gpu.kv_connector import (
)
from vllm.v1.worker.gpu.lora_utils import LoraState
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.model_states import ModelState
from vllm.v1.worker.gpu.model_states import init_model_state
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
from vllm.v1.worker.gpu.sample.output import SamplerOutput
......@@ -267,7 +267,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prepare_communication_buffer_for_model(self.speculator)
# Initialize the components that require the model.
self.model_state = ModelState(
self.model_state = init_model_state(
self.vllm_config, self.model, self.encoder_cache, self.device
)
if self.is_pooling_model:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
def init_model_state(
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
):
from vllm.v1.worker.gpu.model_states.default import DefaultModelState
return DefaultModelState(vllm_config, model, encoder_cache, device)
......@@ -13,11 +13,12 @@ from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import AttentionGroup
class ModelState:
class DefaultModelState(ModelState):
def __init__(
self,
vllm_config: VllmConfig,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import AttentionGroup
class ModelState(ABC):
@abstractmethod
def __init__(
self,
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
) -> None:
raise NotImplementedError
@abstractmethod
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
raise NotImplementedError
@abstractmethod
def apply_staged_writes(self) -> None:
raise NotImplementedError
@abstractmethod
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
req_states: RequestState,
) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, torch.Tensor | None]:
raise NotImplementedError
@abstractmethod
def prepare_dummy_inputs(
self, num_reqs: int, num_tokens: int
) -> dict[str, torch.Tensor | None]:
raise NotImplementedError
@abstractmethod
def prepare_attn(
self,
input_batch: InputBatch,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment