Unverified Commit bc8ad684 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Misc][Refactor] Introduce ExecuteModelData (#4540)

parent 344bf7cd
......@@ -13,7 +13,7 @@ from vllm.distributed import (broadcast_tensor_dict,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
......@@ -256,22 +256,24 @@ class CPUWorker(LoraNotSupportedWorkerBase):
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups: int = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
assert len(blocks_to_swap_in) == 0
assert len(blocks_to_swap_out) == 0
assert execute_model_req is not None
blocks_to_copy = execute_model_req.blocks_to_copy
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_copy": blocks_to_copy,
"blocks_to_copy": execute_model_req.blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
......@@ -279,7 +281,6 @@ class CPUWorker(LoraNotSupportedWorkerBase):
num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_copy is not None
self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model.
......
......@@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.custom_all_reduce import (
init_custom_ar)
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker_base import WorkerBase
......@@ -211,19 +211,21 @@ class Worker(WorkerBase):
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
num_lookahead_slots: int = 0,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker:
assert seq_group_metadata_list is not None
assert execute_model_req is not None
num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
blocks_to_copy = execute_model_req.blocks_to_copy
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,
......@@ -238,9 +240,6 @@ class Worker(WorkerBase):
blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.
......
......@@ -5,7 +5,7 @@ from typing import Dict, List, Set, Tuple
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables)
......@@ -48,10 +48,8 @@ class WorkerBase(ABC):
@abstractmethod
def execute_model(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
int],
blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment