Unverified Commit 6d86fde0 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Remove UvaBufferPool for cpu->gpu copy (#33055)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 510ed1e8
...@@ -11,6 +11,26 @@ from vllm.utils.platform_utils import is_uva_available ...@@ -11,6 +11,26 @@ from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
def async_copy_to_gpu(
x: torch.Tensor | np.ndarray,
out: torch.Tensor | None = None,
device: torch.device | None = None,
) -> torch.Tensor:
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
assert x.is_cpu
assert not x.is_pinned()
if out is None:
assert device is not None
out = torch.empty_like(x, device=device)
# CPU-to-CPU copy
tmp = x.pin_memory()
# CPU-to-GPU copy
return out.copy_(tmp, non_blocking=True)
class UvaBuffer: class UvaBuffer:
def __init__(self, size: int | Sequence[int], dtype: torch.dtype): def __init__(self, size: int | Sequence[int], dtype: torch.dtype):
if not is_uva_available(): if not is_uva_available():
......
...@@ -6,7 +6,6 @@ import torch ...@@ -6,7 +6,6 @@ import torch
from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
...@@ -32,8 +31,6 @@ class EncoderRunner: ...@@ -32,8 +31,6 @@ class EncoderRunner:
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {} self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
self.encoder_cache: dict[str, torch.Tensor] = {} self.encoder_cache: dict[str, torch.Tensor] = {}
self.tmp_is_mm_embed = UvaBufferPool(max_num_tokens, torch.bool)
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]): def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
self.req_id_to_mm_features[req_id] = mm_features self.req_id_to_mm_features[req_id] = mm_features
...@@ -114,7 +111,7 @@ class EncoderRunner: ...@@ -114,7 +111,7 @@ class EncoderRunner:
total_num_scheduled_tokens, total_num_scheduled_tokens,
dtype=torch.bool, dtype=torch.bool,
device="cpu", device="cpu",
pin_memory=False, pin_memory=True,
) )
for i, req_id in enumerate(req_ids): for i, req_id in enumerate(req_ids):
if not is_prefilling[i]: if not is_prefilling[i]:
...@@ -163,7 +160,7 @@ class EncoderRunner: ...@@ -163,7 +160,7 @@ class EncoderRunner:
mm_embeds.append(mm_embeds_item) mm_embeds.append(mm_embeds_item)
# Copy the is_mm_embed tensor to the GPU. # Copy the is_mm_embed tensor to the GPU.
is_mm_embed = self.tmp_is_mm_embed.copy_to_gpu(is_mm_embed) is_mm_embed = is_mm_embed.to(device=self.device, non_blocking=True)
return mm_embeds, is_mm_embed return mm_embeds, is_mm_embed
@torch.inference_mode() @torch.inference_mode()
......
...@@ -30,7 +30,7 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -30,7 +30,7 @@ from vllm.v1.worker.gpu.attn_utils import (
init_kv_cache, init_kv_cache,
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import ( from vllm.v1.worker.gpu.dp_utils import (
get_cudagraph_and_dp_padding, get_cudagraph_and_dp_padding,
...@@ -172,11 +172,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -172,11 +172,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# LoRA-related workers. # LoRA-related workers.
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs) self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
# Buffers for CPU-to-GPU copies.
self.tmp_idx_mapping = UvaBufferPool(self.max_num_reqs, torch.int32)
self.tmp_cu_num_logits = UvaBufferPool(self.max_num_reqs + 1, torch.int32)
self.tmp_query_start_loc = UvaBufferPool(self.max_num_reqs + 1, torch.int32)
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
def update_max_model_len(self, max_model_len: int) -> None: def update_max_model_len(self, max_model_len: int) -> None:
...@@ -518,7 +513,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -518,7 +513,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.req_id_to_index[req_id] for req_id in req_ids self.req_states.req_id_to_index[req_id] for req_id in req_ids
] ]
idx_mapping_np = np.array(idx_mapping_list, dtype=np.int32) idx_mapping_np = np.array(idx_mapping_list, dtype=np.int32)
idx_mapping = self.tmp_idx_mapping.copy_to_gpu(idx_mapping_np) idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
# Get the number of draft tokens for each request. # Get the number of draft tokens for each request.
if not scheduler_output.scheduled_spec_decode_tokens: if not scheduler_output.scheduled_spec_decode_tokens:
...@@ -546,7 +541,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -546,7 +541,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cu_num_logits_np = np.empty(num_reqs + 1, dtype=np.int32) cu_num_logits_np = np.empty(num_reqs + 1, dtype=np.int32)
cu_num_logits_np[0] = 0 cu_num_logits_np[0] = 0
np.cumsum(num_logits, out=cu_num_logits_np[1:]) np.cumsum(num_logits, out=cu_num_logits_np[1:])
cu_num_logits = self.tmp_cu_num_logits.copy_to_gpu(cu_num_logits_np) cu_num_logits = async_copy_to_gpu(cu_num_logits_np, device=self.device)
expanded_idx_mapping = expand_idx_mapping( expanded_idx_mapping = expand_idx_mapping(
idx_mapping, idx_mapping,
...@@ -565,10 +560,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -565,10 +560,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Pad for full CUDA graph mode. # Pad for full CUDA graph mode.
# Some attention backends like FA3 require query_start_loc to be non-decreasing. # Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np[num_reqs + 1 :] = num_tokens query_start_loc_np[num_reqs + 1 :] = num_tokens
self.tmp_query_start_loc.copy_to_gpu( async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
query_start_loc_np,
out=self.input_buffers.query_start_loc,
)
query_start_loc_np = query_start_loc_np[: num_reqs + 1] query_start_loc_np = query_start_loc_np[: num_reqs + 1]
query_start_loc_cpu = torch.from_numpy(query_start_loc_np) query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment