Unverified Commit b12e87f9 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[platforms] enable platform plugins (#11602)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 5dbf8545
...@@ -12,7 +12,6 @@ from torch import is_tensor ...@@ -12,7 +12,6 @@ from torch import is_tensor
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -265,13 +264,13 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -265,13 +264,13 @@ class ModelRunnerBase(ABC, Generic[T]):
""" """
raise NotImplementedError raise NotImplementedError
@current_platform.inference_mode()
def execute_model( def execute_model(
self, self,
model_input: T, model_input: T,
kv_caches: Optional[List[torch.Tensor]], kv_caches: Optional[List[torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
**kwargs,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
""" """
Execute the model on the given input. Execute the model on the given input.
......
...@@ -544,6 +544,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -544,6 +544,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input.record_step_event(current_stream) model_input.record_step_event(current_stream)
if get_pp_group().is_last_rank and self.is_driver_worker: if get_pp_group().is_last_rank and self.is_driver_worker:
assert isinstance(output, list)
assert len( assert len(
output output
) == 1, "MultiStepModelRunner requires single-step base_models" ) == 1, "MultiStepModelRunner requires single-step base_models"
......
...@@ -11,7 +11,6 @@ from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group ...@@ -11,7 +11,6 @@ from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread, from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, update_environment_variables) resolve_obj_by_qualname, update_environment_variables)
...@@ -44,6 +43,8 @@ class WorkerBase(ABC): ...@@ -44,6 +43,8 @@ class WorkerBase(ABC):
self.prompt_adapter_config = vllm_config.prompt_adapter_config self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config self.kv_transfer_config = vllm_config.kv_transfer_config
from vllm.platforms import current_platform
self.current_platform = current_platform
@abstractmethod @abstractmethod
def init_device(self) -> None: def init_device(self) -> None:
...@@ -74,17 +75,17 @@ class WorkerBase(ABC): ...@@ -74,17 +75,17 @@ class WorkerBase(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@current_platform.inference_mode()
def start_worker_execution_loop(self) -> None: def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker. """Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output. You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details. See `stop_remote_worker_execution_loop` for more details.
""" """
while True: with self.current_platform.inference_mode():
output = self.execute_model(execute_model_req=None) while True:
if output is None: output = self.execute_model(execute_model_req=None)
return None if output is None:
return None
@abstractmethod @abstractmethod
def execute_model( def execute_model(
...@@ -352,6 +353,7 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -352,6 +353,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
model_execute_time = time.perf_counter() - start_time model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
# output is IntermediateTensors # output is IntermediateTensors
assert isinstance(output, IntermediateTensors)
if (self.observability_config is not None if (self.observability_config is not None
and self.observability_config.collect_model_execute_time): and self.observability_config.collect_model_execute_time):
output.tensors["model_execute_time"] = torch.tensor( output.tensors["model_execute_time"] = torch.tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment