Unverified Commit e969a169 authored by Xinyu Chen's avatar Xinyu Chen Committed by GitHub
Browse files

support view_from_cpu_tensor on XPU (#33868)


Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
parent 6d8d34be
......@@ -4,7 +4,7 @@ import pytest
import torch
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
......@@ -14,7 +14,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
def test_cpu_write(device):
torch.set_default_device(device)
cpu_tensor = torch.zeros(10, 10, device="cpu", pin_memory=True, dtype=torch.int32)
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
cuda_view = get_accelerator_view_from_cpu_tensor(cpu_tensor)
assert cuda_view.device.type == "cuda"
assert cuda_view[0, 0] == 0
......@@ -36,7 +36,7 @@ def test_cpu_write(device):
def test_gpu_write(device):
torch.set_default_device(device)
cpu_tensor = torch.zeros(10, 10, device="cpu", pin_memory=True, dtype=torch.int32)
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
cuda_view = get_accelerator_view_from_cpu_tensor(cpu_tensor)
assert cuda_view.device.type == "cuda"
assert cuda_view[0, 0] == 0
......
......@@ -36,7 +36,7 @@ from vllm.utils.platform_utils import (
)
from vllm.utils.torch_utils import (
direct_register_custom_op,
get_cuda_view_from_cpu_tensor,
get_accelerator_view_from_cpu_tensor,
)
logger = init_logger(__name__)
......@@ -663,7 +663,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
else:
# keep the cpu data alive
p._vllm_offloaded_cpu_data = cpu_data
p.data = get_cuda_view_from_cpu_tensor(cpu_data)
p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True
......
......@@ -674,11 +674,15 @@ def weak_ref_tensors(
raise ValueError("Invalid type for tensors")
def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
def get_accelerator_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
"""
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
Get an accelerator view of a CPU tensor using Unified Virtual Addressing (UVA).
"""
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
from vllm.platforms import current_platform
if current_platform.is_xpu():
return torch.ops._C.get_xpu_view_from_cpu_tensor(cpu_tensor)
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
......
......@@ -9,7 +9,7 @@ import torch
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor
def async_copy_to_gpu(
......@@ -38,7 +38,7 @@ class UvaBuffer:
raise RuntimeError("UVA is not available")
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=True)
self.np = self.cpu.numpy()
self.uva = get_cuda_view_from_cpu_tensor(self.cpu)
self.uva = get_accelerator_view_from_cpu_tensor(self.cpu)
class UvaBufferPool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment