Unverified Commit 35d44b45 authored by Xinyu Chen's avatar Xinyu Chen Committed by GitHub
Browse files

[XPU]Support CUDAGraph on XPU Platform (#34482)


Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: default avatarchzhang <chaojun.zhang@intel.com>
Co-authored-by: default avatarzhenwei-intel <zhenwei.liu@intel.com>
Co-authored-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 8ad54a99
...@@ -13,6 +13,7 @@ import vllm_xpu_kernels._moe_C # noqa ...@@ -13,6 +13,7 @@ import vllm_xpu_kernels._moe_C # noqa
import vllm_xpu_kernels._xpu_C # noqa import vllm_xpu_kernels._xpu_C # noqa
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.torch_utils import supports_xpu_graph
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
...@@ -151,10 +152,15 @@ class XPUPlatform(Platform): ...@@ -151,10 +152,15 @@ class XPUPlatform(Platform):
def inference_mode(cls): def inference_mode(cls):
return torch.no_grad() return torch.no_grad()
@classmethod
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod @classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
# in V1(or with chunked prefill) block_size is 64 # in V1(or with chunked prefill) block_size is 64
if cache_config and cache_config.block_size is None: if cache_config and cache_config.block_size is None:
cache_config.block_size = 64 cache_config.block_size = 64
...@@ -166,8 +172,31 @@ class XPUPlatform(Platform): ...@@ -166,8 +172,31 @@ class XPUPlatform(Platform):
if compilation_config.compile_sizes is None: if compilation_config.compile_sizes is None:
compilation_config.compile_sizes = [] compilation_config.compile_sizes = []
assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, ( attention_config = vllm_config.attention_config
"CUDA graph mode should be NONE on XPU" if attention_config.backend is None:
attention_config.backend = AttentionBackendEnum.FLASH_ATTN
if not supports_xpu_graph():
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
logger.warning(
"XPU Graph is not supported in the current PyTorch version, "
"disabling cudagraph_mode."
)
elif parallel_config.world_size_across_dp > 1:
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
logger.warning(
"XPU Graph doesn't support capture communication ops, "
"disabling cudagraph_mode."
)
else:
if (
attention_config.backend == AttentionBackendEnum.FLASH_ATTN
and compilation_config.cudagraph_mode
not in {CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE}
):
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
logger.warning(
"FMHA sycl-tla kernels cannot be captured with XPU graphs, "
"falling back to PIECEWISE graph mode on XPU platform."
) )
if vllm_config.lora_config is not None: if vllm_config.lora_config is not None:
...@@ -201,7 +230,7 @@ class XPUPlatform(Platform): ...@@ -201,7 +230,7 @@ class XPUPlatform(Platform):
@classmethod @classmethod
def support_static_graph_mode(cls) -> bool: def support_static_graph_mode(cls) -> bool:
return False return True
@classmethod @classmethod
def is_pin_memory_available(cls): def is_pin_memory_available(cls):
......
...@@ -745,6 +745,11 @@ def supports_xccl() -> bool: ...@@ -745,6 +745,11 @@ def supports_xccl() -> bool:
return torch.distributed.is_xccl_available() return torch.distributed.is_xccl_available()
# Supports XPU Graph with PyTorch versions >= 2.11.0.dev for XPU platform
def supports_xpu_graph() -> bool:
return is_torch_equal_or_newer("2.11.0.dev")
# create a library to hold the custom op # create a library to hold the custom op
vllm_lib = Library("vllm", "FRAGMENT") # noqa vllm_lib = Library("vllm", "FRAGMENT") # noqa
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.torch_utils import supports_xpu_graph
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -40,6 +41,12 @@ def _torch_cuda_wrapper(): ...@@ -40,6 +41,12 @@ def _torch_cuda_wrapper():
torch.cuda.default_stream = torch.xpu.current_stream torch.cuda.default_stream = torch.xpu.current_stream
torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.stream = torch.xpu.stream torch.cuda.stream = torch.xpu.stream
torch.cuda.mem_get_info = torch.xpu.mem_get_info
torch.cuda.synchronize = torch.xpu.synchronize
if supports_xpu_graph():
torch.cuda.graph = torch.xpu.graph
torch.cuda.CUDAGraph = torch.xpu.XPUGraph
torch.cuda.empty_cache = torch.xpu.empty_cache
yield yield
finally: finally:
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment