Unverified Commit b12e87f9 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[platforms] enable platform plugins (#11602)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 5dbf8545
......@@ -12,7 +12,6 @@ from torch import is_tensor
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
if TYPE_CHECKING:
......@@ -265,13 +264,13 @@ class ModelRunnerBase(ABC, Generic[T]):
"""
raise NotImplementedError
@current_platform.inference_mode()
def execute_model(
self,
model_input: T,
kv_caches: Optional[List[torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
**kwargs,
) -> Optional[List[SamplerOutput]]:
"""
Execute the model on the given input.
......
......@@ -544,6 +544,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input.record_step_event(current_stream)
if get_pp_group().is_last_rank and self.is_driver_worker:
assert isinstance(output, list)
assert len(
output
) == 1, "MultiStepModelRunner requires single-step base_models"
......
......@@ -11,7 +11,6 @@ from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, update_environment_variables)
......@@ -44,6 +43,8 @@ class WorkerBase(ABC):
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config
from vllm.platforms import current_platform
self.current_platform = current_platform
@abstractmethod
def init_device(self) -> None:
......@@ -74,17 +75,17 @@ class WorkerBase(ABC):
"""
raise NotImplementedError
@current_platform.inference_mode()
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
while True:
output = self.execute_model(execute_model_req=None)
if output is None:
return None
with self.current_platform.inference_mode():
while True:
output = self.execute_model(execute_model_req=None)
if output is None:
return None
@abstractmethod
def execute_model(
......@@ -352,6 +353,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank:
# output is IntermediateTensors
assert isinstance(output, IntermediateTensors)
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
output.tensors["model_execute_time"] = torch.tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment