Unverified Commit d8839ef7 authored by Xinyu Chen's avatar Xinyu Chen Committed by GitHub
Browse files

[XPU] Enable ModelRunnerV2 on XPU (#36078)


Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
parent e998fa76
...@@ -8,6 +8,9 @@ import torch ...@@ -8,6 +8,9 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.torch_utils import supports_xpu_graph from vllm.utils.torch_utils import supports_xpu_graph
from vllm.v1.worker.gpu.model_runner import (
GPUModelRunner as GPUModelRunnerV2,
)
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -30,6 +33,18 @@ class XPUModelRunner(GPUModelRunner): ...@@ -30,6 +33,18 @@ class XPUModelRunner(GPUModelRunner):
self.cascade_attn_enabled = False self.cascade_attn_enabled = False
class XPUModelRunnerV2(GPUModelRunnerV2):
"""A model runner for XPU devices."""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
with _torch_cuda_wrapper():
super().__init__(vllm_config, device)
@contextmanager @contextmanager
def _torch_cuda_wrapper(): def _torch_cuda_wrapper():
try: try:
...@@ -39,9 +54,12 @@ def _torch_cuda_wrapper(): ...@@ -39,9 +54,12 @@ def _torch_cuda_wrapper():
torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.stream = torch.xpu.stream torch.cuda.stream = torch.xpu.stream
torch.cuda.mem_get_info = torch.xpu.mem_get_info torch.cuda.mem_get_info = torch.xpu.mem_get_info
torch.cuda.Event = torch.Event
torch.cuda.set_stream = torch.xpu.set_stream
if supports_xpu_graph(): if supports_xpu_graph():
torch.cuda.graph = torch.xpu.graph torch.cuda.graph = torch.xpu.graph
torch.cuda.CUDAGraph = torch.xpu.XPUGraph torch.cuda.CUDAGraph = torch.xpu.XPUGraph
torch.cuda.graph_pool_handle = torch.xpu.graph_pool_handle
yield yield
finally: finally:
pass pass
...@@ -15,7 +15,7 @@ from vllm.utils.torch_utils import set_random_seed ...@@ -15,7 +15,7 @@ from vllm.utils.torch_utils import set_random_seed
from vllm.v1.utils import report_usage_stats from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
from vllm.v1.worker.xpu_model_runner import XPUModelRunner from vllm.v1.worker.xpu_model_runner import XPUModelRunner, XPUModelRunnerV2
from .utils import request_memory from .utils import request_memory
...@@ -105,7 +105,8 @@ class XPUWorker(Worker): ...@@ -105,7 +105,8 @@ class XPUWorker(Worker):
init_workspace_manager(self.device, num_ubatches) init_workspace_manager(self.device, num_ubatches)
# Construct the model runner # Construct the model runner
self.model_runner = XPUModelRunner( # type: ignore model_runner = XPUModelRunnerV2 if self.use_v2_model_runner else XPUModelRunner
self.model_runner = model_runner( # type: ignore
self.vllm_config, self.device self.vllm_config, self.device
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment