Unverified Commit bc8ad684 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Misc][Refactor] Introduce ExecuteModelData (#4540)

parent 344bf7cd
...@@ -13,7 +13,7 @@ from vllm.distributed import (broadcast_tensor_dict, ...@@ -13,7 +13,7 @@ from vllm.distributed import (broadcast_tensor_dict,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase
...@@ -256,22 +256,24 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -256,22 +256,24 @@ class CPUWorker(LoraNotSupportedWorkerBase):
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, execute_model_req: Optional[ExecuteModelRequest] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
num_seq_groups: int = len(seq_group_metadata_list) num_seq_groups: int = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None assert execute_model_req is not None
assert blocks_to_swap_out is not None blocks_to_copy = execute_model_req.blocks_to_copy
assert blocks_to_copy is not None assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(blocks_to_swap_in) == 0 assert len(execute_model_req.blocks_to_swap_out) == 0
assert len(blocks_to_swap_out) == 0
data: Dict[str, Any] = { data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups, "num_seq_groups": num_seq_groups,
"blocks_to_copy": blocks_to_copy, "blocks_to_copy": execute_model_req.blocks_to_copy,
} }
broadcast_tensor_dict(data, src=0) broadcast_tensor_dict(data, src=0)
else: else:
...@@ -279,7 +281,6 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -279,7 +281,6 @@ class CPUWorker(LoraNotSupportedWorkerBase):
num_seq_groups = data["num_seq_groups"] num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"] blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_copy is not None
self.cache_copy(blocks_to_copy) self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
......
...@@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.custom_all_reduce import ( ...@@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.custom_all_reduce import (
init_custom_ar) init_custom_ar)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
...@@ -211,19 +211,21 @@ class Worker(WorkerBase): ...@@ -211,19 +211,21 @@ class Worker(WorkerBase):
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, execute_model_req: Optional[ExecuteModelRequest] = None
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
num_lookahead_slots: int = 0,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
assert execute_model_req is not None
num_seq_groups = len(seq_group_metadata_list) num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None blocks_to_swap_in = execute_model_req.blocks_to_swap_in
assert blocks_to_swap_out is not None blocks_to_swap_out = execute_model_req.blocks_to_swap_out
assert blocks_to_copy is not None blocks_to_copy = execute_model_req.blocks_to_copy
data: Dict[str, Any] = { data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups, "num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_in": blocks_to_swap_in,
...@@ -238,9 +240,6 @@ class Worker(WorkerBase): ...@@ -238,9 +240,6 @@ class Worker(WorkerBase):
blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"] blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
......
...@@ -5,7 +5,7 @@ from typing import Dict, List, Set, Tuple ...@@ -5,7 +5,7 @@ from typing import Dict, List, Set, Tuple
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (enable_trace_function_call_for_thread, from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables) update_environment_variables)
...@@ -48,10 +48,8 @@ class WorkerBase(ABC): ...@@ -48,10 +48,8 @@ class WorkerBase(ABC):
@abstractmethod @abstractmethod
def execute_model( def execute_model(
self, seq_group_metadata_list: List[SequenceGroupMetadata], self,
blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
int],
blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences, unless no """Executes at least one model step on the given sequences, unless no
sequences are provided.""" sequences are provided."""
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment