Unverified Commit 025a32f9 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Remove async barrier (#32083)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 19504ac0
...@@ -6,9 +6,8 @@ import torch ...@@ -6,9 +6,8 @@ import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
class BlockTables: class BlockTables:
...@@ -26,19 +25,16 @@ class BlockTables: ...@@ -26,19 +25,16 @@ class BlockTables:
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.device = device self.device = device
if not is_uva_available():
raise RuntimeError("UVA is not available")
self.num_kv_cache_groups = len(self.block_sizes) self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks] # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[UvaBuffer] = [] self.block_tables: list[StagedWriteTensor] = []
for i in range(self.num_kv_cache_groups): for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i] block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size) max_num_blocks = cdiv(self.max_model_len, block_size)
block_table = UvaBuffer( block_table = StagedWriteTensor(
self.max_num_reqs, (self.max_num_reqs, max_num_blocks),
max_num_blocks,
dtype=torch.int32, dtype=torch.int32,
device=device,
) )
self.block_tables.append(block_table) self.block_tables.append(block_table)
self.block_table_ptrs = self._make_ptr_tensor( self.block_table_ptrs = self._make_ptr_tensor(
...@@ -53,9 +49,8 @@ class BlockTables: ...@@ -53,9 +49,8 @@ class BlockTables:
self.block_sizes_tensor = torch.tensor( self.block_sizes_tensor = torch.tensor(
self.block_sizes, dtype=torch.int32, device=self.device self.block_sizes, dtype=torch.int32, device=self.device
) )
self.num_blocks = UvaBuffer( self.num_blocks = UvaBackedTensor(
self.num_kv_cache_groups, (self.num_kv_cache_groups, self.max_num_reqs),
self.max_num_reqs,
dtype=torch.int32, dtype=torch.int32,
) )
...@@ -75,13 +70,11 @@ class BlockTables: ...@@ -75,13 +70,11 @@ class BlockTables:
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor: def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses. # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
ptrs_tensor_cpu = torch.tensor( return torch.tensor(
[t.data_ptr() for t in x], [t.data_ptr() for t in x],
dtype=torch.uint64, dtype=torch.uint64,
device="cpu", device=self.device,
pin_memory=True,
) )
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
def append_block_ids( def append_block_ids(
self, self,
...@@ -90,19 +83,17 @@ class BlockTables: ...@@ -90,19 +83,17 @@ class BlockTables:
overwrite: bool, overwrite: bool,
) -> None: ) -> None:
for i in range(self.num_kv_cache_groups): for i in range(self.num_kv_cache_groups):
start = self.num_blocks.np[i, req_index] if not overwrite else 0
block_ids = new_block_ids[i] block_ids = new_block_ids[i]
num_new_blocks = len(block_ids) self.block_tables[i].stage_write(req_index, start, block_ids)
if num_new_blocks == 0: self.num_blocks.np[i, req_index] = start + len(block_ids)
continue
# TODO(woosuk): Too many Numpy invocations. Optimize this. def apply_staged_writes(self) -> None:
start = self.num_blocks.np[i, req_index] if not overwrite else 0 # TODO(woosuk): This can be inefficient since it launches one kernel per
end = start + num_new_blocks # block table. Implement a kernel to handle all block tables at once.
if num_new_blocks == 1: for block_table in self.block_tables:
self.block_tables[i].np[req_index, start] = block_ids[0] block_table.apply_write()
else: self.num_blocks.copy_to_uva()
self.block_tables[i].np[req_index, start:end] = block_ids
self.num_blocks.np[i, req_index] = end
def gather_block_tables( def gather_block_tables(
self, self,
...@@ -229,10 +220,3 @@ def _load_ptr(ptr_to_ptr, elem_dtype): ...@@ -229,10 +220,3 @@ def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr) ptr = tl.load(ptr_to_ptr)
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype)) ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
return tl.multiple_of(ptr, 16) return tl.multiple_of(ptr, 16)
class UvaBuffer:
def __init__(self, *size, dtype: torch.dtype):
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True)
self.np = self.cpu.numpy()
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
class UvaBuffer:
def __init__(self, size: int | Sequence[int], dtype: torch.dtype):
if not is_uva_available():
raise RuntimeError("UVA is not available")
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=True)
self.np = self.cpu.numpy()
self.uva = get_cuda_view_from_cpu_tensor(self.cpu)
class UvaBufferPool:
def __init__(
self,
size: int | Sequence[int],
dtype: torch.dtype,
max_concurrency: int = 2,
):
self.size = size
self.dtype = dtype
self.max_concurrency = max_concurrency
# UVA buffers for concurrency
self._uva_bufs = [UvaBuffer(size, dtype) for _ in range(max_concurrency)]
# Current buffer index
self._curr = 0
def copy_to_uva(self, x: torch.Tensor | np.ndarray | list) -> torch.Tensor:
# Round robin to the next buffer.
self._curr = (self._curr + 1) % self.max_concurrency
buf = self._uva_bufs[self._curr]
# CPU-to-CPU copy
dst = buf.cpu if isinstance(x, torch.Tensor) else buf.np
n = len(x)
dst[:n] = x
return buf.uva[:n]
def copy_to_gpu(
self,
x: torch.Tensor | np.ndarray,
out: torch.Tensor | None = None,
) -> torch.Tensor:
uva = self.copy_to_uva(x)
if out is None:
# CPU-to-GPU copy
return uva.clone()
# CPU-to-GPU copy
return out.copy_(uva, non_blocking=True)
class UvaBackedTensor:
def __init__(
self,
size: int | Sequence[int],
dtype: torch.dtype,
max_concurrency: int = 2,
):
self.dtype = dtype
self.max_concurrency = max_concurrency
# Source of truth
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=False)
self.np = self.cpu.numpy()
# Buffers for concurrency
self.pool = UvaBufferPool(size, dtype, max_concurrency)
self.gpu = self.pool.copy_to_uva(self.np)
def copy_to_uva(self, n: int | None = None) -> torch.Tensor:
# CPU-to-CPU copy
self.gpu = self.pool.copy_to_uva(self.np[:n] if n is not None else self.np)
return self.gpu
class StagedWriteTensor:
def __init__(
self,
size: int | Sequence[int],
dtype: torch.dtype,
device: torch.device,
max_concurrency: int = 2,
uva_instead_of_gpu: bool = False,
):
if dtype not in [torch.int32, torch.int64]:
raise ValueError(
f"Unsupported dtype {dtype}: should be either int32 or int64"
)
self.num_rows = size if isinstance(size, int) else size[0]
self.dtype = dtype
self.max_concurrency = max_concurrency
if not uva_instead_of_gpu:
# Create a GPU tensor (default)
self.gpu = torch.zeros(size, dtype=dtype, device=device)
else:
# For a large but not-frequently-accessed tensor, we can use UVA instead of
# GPU to save GPU memory
self._uva_buf = UvaBuffer(size, dtype)
self.gpu = self._uva_buf.uva
self._staged_write_indices: list[int] = []
self._staged_write_starts: list[int] = []
self._staged_write_contents: list[int] = []
self._staged_write_cu_lens: list[int] = []
self.write_indices = UvaBufferPool(
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
self.write_starts = UvaBufferPool(
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
init_size = next_power_of_2(self.num_rows)
self.write_contents = UvaBufferPool(
init_size, dtype=dtype, max_concurrency=max_concurrency
)
self.write_cu_lens = UvaBufferPool(
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
def stage_write(self, index: int, start: int, x: list[int]) -> None:
assert index >= 0
assert start >= 0
if not x:
return
self._staged_write_indices.append(index)
self._staged_write_starts.append(start)
self._staged_write_contents.extend(x)
self._staged_write_cu_lens.append(len(self._staged_write_contents))
def stage_write_elem(self, index: int, x: int) -> None:
assert index >= 0
self._staged_write_indices.append(index)
self._staged_write_starts.append(0)
self._staged_write_contents.append(x)
self._staged_write_cu_lens.append(len(self._staged_write_contents))
def apply_write(self) -> None:
n = len(self._staged_write_indices)
if n == 0:
return
indices_uva = self.write_indices.copy_to_uva(self._staged_write_indices)
starts_uva = self.write_starts.copy_to_uva(self._staged_write_starts)
cu_lens_uva = self.write_cu_lens.copy_to_uva(self._staged_write_cu_lens)
# Special handling for write_contents
diff_len = len(self._staged_write_contents)
assert isinstance(self.write_contents.size, int)
if diff_len > self.write_contents.size:
# Re-allocate a larger buffer for the write_contents
new_size = next_power_of_2(diff_len)
self.write_contents = UvaBufferPool(
new_size, dtype=self.dtype, max_concurrency=self.max_concurrency
)
# NOTE(woosuk): Since the previous write_contents buffer is released,
# we perform a synchronization here to ensure that all data transfers
# involving the old buffer have finished before allocating a new one.
# This prevents potential race conditions. The slight overhead is
# negligible because the reallocations are infrequent in practice.
torch.cuda.synchronize()
contents_uva = self.write_contents.copy_to_uva(self._staged_write_contents)
# Write diffs to the GPU buffer
_apply_write_kernel[(n,)](
self.gpu,
self.gpu.stride(0),
indices_uva,
starts_uva,
contents_uva,
cu_lens_uva,
BLOCK_SIZE=1024,
)
# Clear the staged writes
self.clear_staged_writes()
def clear_staged_writes(self) -> None:
self._staged_write_indices.clear()
self._staged_write_starts.clear()
self._staged_write_contents.clear()
self._staged_write_cu_lens.clear()
@triton.jit
def _apply_write_kernel(
output_ptr,
output_stride,
write_indices_ptr,
write_starts_ptr,
write_contents_ptr,
write_cu_lens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
row_idx = tl.load(write_indices_ptr + pid)
start_idx = tl.load(write_starts_ptr + pid)
cu_start = tl.load(write_cu_lens_ptr + pid - 1) if pid > 0 else 0
cu_end = tl.load(write_cu_lens_ptr + pid)
content_len = cu_end - cu_start
for i in range(0, content_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < content_len
content = tl.load(write_contents_ptr + cu_start + block, mask=mask)
tl.store(
output_ptr + row_idx * output_stride + start_idx + block, content, mask=mask
)
...@@ -228,10 +228,13 @@ def prepare_inputs_to_capture( ...@@ -228,10 +228,13 @@ def prepare_inputs_to_capture(
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> dict[str, Any]: ) -> dict[str, Any]:
num_tokens_per_req = num_tokens // num_reqs num_tokens_per_req = num_tokens // num_reqs
query_start_loc = input_buffers.query_start_loc
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) * num_tokens_per_req query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
query_start_loc.np[num_reqs:] = num_tokens query_start_loc_np[-1] = num_tokens
query_start_loc.copy_to_gpu() query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens # HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len. # rather than max_model_len.
...@@ -245,8 +248,8 @@ def prepare_inputs_to_capture( ...@@ -245,8 +248,8 @@ def prepare_inputs_to_capture(
attn_metadata_builders=attn_metadata_builders, attn_metadata_builders=attn_metadata_builders,
num_reqs=num_reqs, num_reqs=num_reqs,
num_tokens=num_tokens, num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc.gpu[: num_reqs + 1], query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu[: num_reqs + 1], query_start_loc_cpu=query_start_loc_cpu,
seq_lens=input_buffers.seq_lens, seq_lens=input_buffers.seq_lens,
max_seq_len=max_model_len, max_seq_len=max_model_len,
block_tables=input_block_tables, block_tables=input_block_tables,
......
...@@ -8,8 +8,6 @@ import torch ...@@ -8,8 +8,6 @@ import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
class InputBuffers: class InputBuffers:
...@@ -21,30 +19,17 @@ class InputBuffers: ...@@ -21,30 +19,17 @@ class InputBuffers:
vocab_size: int, vocab_size: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
pin_memory: bool,
): ):
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.device = device self.device = device
self.pin_memory = pin_memory
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device) self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device) self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) self.query_start_loc = torch.zeros(
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) max_num_reqs + 1, dtype=torch.int32, device=device
self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
# Structured outputs.
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.grammar_bitmask = self._make_buffer(
max_num_reqs, cdiv(vocab_size, 32), dtype=torch.int32
)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
) )
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
@dataclass @dataclass
...@@ -56,6 +41,8 @@ class InputBatch: ...@@ -56,6 +41,8 @@ class InputBatch:
# batch_idx -> req_state_idx # batch_idx -> req_state_idx
idx_mapping: torch.Tensor idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray idx_mapping_np: np.ndarray
# Identical to idx_mapping except for spec decoding.
expanded_idx_mapping: torch.Tensor
# [num_reqs] # [num_reqs]
# batch_idx -> num_scheduled_tokens # batch_idx -> num_scheduled_tokens
...@@ -83,6 +70,7 @@ class InputBatch: ...@@ -83,6 +70,7 @@ class InputBatch:
logits_indices: torch.Tensor logits_indices: torch.Tensor
# [num_reqs + 1] # [num_reqs + 1]
cu_num_logits: torch.Tensor cu_num_logits: torch.Tensor
cu_num_logits_np: np.ndarray
@classmethod @classmethod
def make_dummy( def make_dummy(
...@@ -96,33 +84,41 @@ class InputBatch: ...@@ -96,33 +84,41 @@ class InputBatch:
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)] req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32) idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
expanded_idx_mapping = idx_mapping
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32) num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs num_scheduled_tokens[-1] += num_tokens % num_reqs
assert int(num_scheduled_tokens.sum()) == num_tokens assert int(num_scheduled_tokens.sum()) == num_tokens
input_buffers.query_start_loc.np[0] = 0
input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
num_scheduled_tokens
)
input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
# seq_len equals to query_len # seq_len equals to query_len
input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
# Pad for full CUDA graph mode.
input_buffers.seq_lens[num_reqs:] = 0 input_buffers.seq_lens[num_reqs:] = 0
seq_lens = input_buffers.seq_lens[:num_reqs] seq_lens = input_buffers.seq_lens[:num_reqs]
query_start_loc_np = np.empty(num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
input_buffers.query_start_loc[0] = 0
torch.cumsum(
seq_lens, dim=0, out=input_buffers.query_start_loc[1 : num_reqs + 1]
)
# Pad for full CUDA graph mode.
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
input_ids = input_buffers.input_ids[:num_tokens] input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens] positions = input_buffers.positions[:num_tokens]
# attn_metadata = defaultdict(lambda: None) # attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1 logits_indices = query_start_loc[1:] - 1
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32) cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
return cls( return cls(
req_ids=req_ids, req_ids=req_ids,
num_reqs=num_reqs, num_reqs=num_reqs,
idx_mapping=idx_mapping, idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np, idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping,
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_after_padding=num_tokens, num_tokens_after_padding=num_tokens,
...@@ -135,6 +131,7 @@ class InputBatch: ...@@ -135,6 +131,7 @@ class InputBatch:
attn_metadata=None, # type: ignore attn_metadata=None, # type: ignore
logits_indices=logits_indices, logits_indices=logits_indices,
cu_num_logits=cu_num_logits, cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
) )
...@@ -473,3 +470,38 @@ def post_update( ...@@ -473,3 +470,38 @@ def post_update(
query_start_loc, query_start_loc,
num_warps=1, num_warps=1,
) )
@triton.jit
def _expand_idx_mapping_kernel(
idx_mapping_ptr,
expanded_idx_mapping_ptr,
cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
block = tl.arange(0, BLOCK_SIZE)
mask = block < num_tokens
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
tl.store(expanded_idx_mapping_ptr + start_idx + block, req_state_idx, mask=mask)
def expand_idx_mapping(
idx_mapping: torch.Tensor,
total_num_logits: int,
cu_num_logits: torch.Tensor,
max_expand_len: int,
) -> torch.Tensor:
num_reqs = idx_mapping.shape[0]
expanded_idx_mapping = idx_mapping.new_empty(total_num_logits)
_expand_idx_mapping_kernel[(num_reqs,)](
idx_mapping,
expanded_idx_mapping,
cu_num_logits,
BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
)
return expanded_idx_mapping
This diff is collapsed.
...@@ -13,6 +13,7 @@ def _gumbel_sample_kernel( ...@@ -13,6 +13,7 @@ def _gumbel_sample_kernel(
local_max_stride, local_max_stride,
logits_ptr, logits_ptr,
logits_stride, logits_stride,
idx_mapping_ptr,
seeds_ptr, seeds_ptr,
pos_ptr, pos_ptr,
temp_ptr, temp_ptr,
...@@ -20,22 +21,24 @@ def _gumbel_sample_kernel( ...@@ -20,22 +21,24 @@ def _gumbel_sample_kernel(
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr, APPLY_TEMPERATURE: tl.constexpr,
): ):
req_idx = tl.program_id(0) batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
block_idx = tl.program_id(1) block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size mask = block < vocab_size
logits = tl.load( logits = tl.load(
logits_ptr + req_idx * logits_stride + block, logits_ptr + batch_idx * logits_stride + block,
mask=mask, mask=mask,
other=float("-inf"), other=float("-inf"),
) )
logits = logits.to(tl.float32) logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_idx).to(tl.float32) temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if temp != 0.0: if temp != 0.0:
# Calculate the seed for gumbel noise. # Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_idx) seed = tl.load(seeds_ptr + req_state_idx)
pos = tl.load(pos_ptr + req_idx) pos = tl.load(pos_ptr + batch_idx)
gumbel_seed = tl.randint(seed, pos) gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise. # Generate gumbel noise.
...@@ -55,12 +58,13 @@ def _gumbel_sample_kernel( ...@@ -55,12 +58,13 @@ def _gumbel_sample_kernel(
idx = tl.argmax(logits, axis=0) idx = tl.argmax(logits, axis=0)
token_id = block_idx * BLOCK_SIZE + idx token_id = block_idx * BLOCK_SIZE + idx
value = tl.max(logits, axis=0) value = tl.max(logits, axis=0)
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id) tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value) tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
def gumbel_sample( def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size] logits: torch.Tensor, # [num_reqs, vocab_size]
idx_mapping: torch.Tensor, # [num_reqs]
temperature: torch.Tensor, # [num_reqs] temperature: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs] seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs] pos: torch.Tensor, # [num_reqs]
...@@ -88,6 +92,7 @@ def gumbel_sample( ...@@ -88,6 +92,7 @@ def gumbel_sample(
local_max.stride(0), local_max.stride(0),
logits, logits,
logits.stride(0), logits.stride(0),
idx_mapping,
seed, seed,
pos, pos,
temperature, temperature,
......
...@@ -4,20 +4,23 @@ from dataclasses import dataclass ...@@ -4,20 +4,23 @@ from dataclasses import dataclass
import torch import torch
from vllm.triton_utils import tl, triton
@dataclass @dataclass
class SamplingMetadata: class SamplingMetadata:
idx_mapping: torch.Tensor
temperature: torch.Tensor temperature: torch.Tensor
top_p: torch.Tensor | None top_p: torch.Tensor | None
top_k: torch.Tensor | None top_k: torch.Tensor | None
min_p: torch.Tensor | None min_p: torch.Tensor | None
# For penalties
repetition_penalty: torch.Tensor repetition_penalty: torch.Tensor
frequency_penalty: torch.Tensor frequency_penalty: torch.Tensor
presence_penalty: torch.Tensor presence_penalty: torch.Tensor
prompt_bin_mask: torch.Tensor
output_bin_counts: torch.Tensor
seeds: torch.Tensor seeds: torch.Tensor
pos: torch.Tensor pos: torch.Tensor
...@@ -25,11 +28,6 @@ class SamplingMetadata: ...@@ -25,11 +28,6 @@ class SamplingMetadata:
# None means no logprobs, 0 means sampled token logprobs only # None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: int | None max_num_logprobs: int | None
# For penalties
idx_mapping: torch.Tensor
prompt_bin_mask: torch.Tensor
output_bin_counts: torch.Tensor
@classmethod @classmethod
def make_dummy( def make_dummy(
cls, cls,
...@@ -37,6 +35,8 @@ class SamplingMetadata: ...@@ -37,6 +35,8 @@ class SamplingMetadata:
device: torch.device, device: torch.device,
) -> "SamplingMetadata": ) -> "SamplingMetadata":
assert num_reqs > 0 assert num_reqs > 0
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device) temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
temperature[0] = 0.5 temperature[0] = 0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler. # TODO(woosuk): Use top-p and top-k for dummy sampler.
...@@ -51,18 +51,19 @@ class SamplingMetadata: ...@@ -51,18 +51,19 @@ class SamplingMetadata:
repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device) repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device) frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device) presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the # NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton # penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime. # specialization and re-compilation at runtime.
prompt_bin_mask = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) prompt_bin_mask = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
return cls( return cls(
idx_mapping=idx_mapping,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
...@@ -70,123 +71,9 @@ class SamplingMetadata: ...@@ -70,123 +71,9 @@ class SamplingMetadata:
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
prompt_bin_mask=prompt_bin_mask,
output_bin_counts=output_bin_counts,
seeds=seeds, seeds=seeds,
pos=pos, pos=pos,
max_num_logprobs=max_num_logprobs, max_num_logprobs=max_num_logprobs,
idx_mapping=idx_mapping,
prompt_bin_mask=prompt_bin_mask,
output_bin_counts=output_bin_counts,
) )
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
@triton.jit
def _expand_sampling_metadata_kernel(
temp_ptr,
expanded_temp_ptr,
top_p_ptr,
expanded_top_p_ptr,
top_k_ptr,
expanded_top_k_ptr,
min_p_ptr,
expanded_min_p_ptr,
rep_penalty_ptr,
expanded_rep_penalty_ptr,
freq_penalty_ptr,
expanded_freq_penalty_ptr,
pres_penalty_ptr,
expanded_pres_penalty_ptr,
seeds_ptr,
expanded_seeds_ptr,
cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
block = tl.arange(0, BLOCK_SIZE)
mask = block < num_tokens
temp = tl.load(temp_ptr + req_idx)
tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask)
if top_p_ptr is not None:
top_p = tl.load(top_p_ptr + req_idx)
tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask)
if top_k_ptr is not None:
top_k = tl.load(top_k_ptr + req_idx)
tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
if min_p_ptr is not None:
min_p = tl.load(min_p_ptr + req_idx)
tl.store(expanded_min_p_ptr + start_idx + block, min_p, mask=mask)
rep_penalty = tl.load(rep_penalty_ptr + req_idx)
tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask)
freq_penalty = tl.load(freq_penalty_ptr + req_idx)
tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask)
pres_penalty = tl.load(pres_penalty_ptr + req_idx)
tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask)
seed = tl.load(seeds_ptr + req_idx)
tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
def expand_sampling_metadata(
sampling_metadata: SamplingMetadata,
cu_num_logits: torch.Tensor,
max_expand_len: int,
) -> SamplingMetadata:
total_num_logits = sampling_metadata.pos.shape[0]
create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None
expanded_temp = create_empty(sampling_metadata.temperature)
expanded_top_p = create_empty(sampling_metadata.top_p)
expanded_top_k = create_empty(sampling_metadata.top_k)
expanded_min_p = create_empty(sampling_metadata.min_p)
expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty)
expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty)
expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty)
expanded_seeds = create_empty(sampling_metadata.seeds)
num_reqs = cu_num_logits.shape[0] - 1
_expand_sampling_metadata_kernel[(num_reqs,)](
sampling_metadata.temperature,
expanded_temp,
sampling_metadata.top_p,
expanded_top_p,
sampling_metadata.top_k,
expanded_top_k,
sampling_metadata.min_p,
expanded_min_p,
sampling_metadata.repetition_penalty,
expanded_repetition_penalty,
sampling_metadata.frequency_penalty,
expanded_frequency_penalty,
sampling_metadata.presence_penalty,
expanded_presence_penalty,
sampling_metadata.seeds,
expanded_seeds,
cu_num_logits,
BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
)
return SamplingMetadata(
temperature=expanded_temp,
top_p=expanded_top_p,
top_k=expanded_top_k,
min_p=expanded_min_p,
seeds=expanded_seeds,
repetition_penalty=expanded_repetition_penalty,
frequency_penalty=expanded_frequency_penalty,
presence_penalty=expanded_presence_penalty,
pos=sampling_metadata.pos,
max_num_logprobs=sampling_metadata.max_num_logprobs,
# TODO(woosuk): Support penalties with spec decoding.
idx_mapping=sampling_metadata.idx_mapping,
prompt_bin_mask=sampling_metadata.prompt_bin_mask,
output_bin_counts=sampling_metadata.output_bin_counts,
)
...@@ -9,12 +9,14 @@ from vllm.triton_utils import tl, triton ...@@ -9,12 +9,14 @@ from vllm.triton_utils import tl, triton
def _min_p_kernel( def _min_p_kernel(
logits_ptr, logits_ptr,
logits_stride, logits_stride,
idx_mapping_ptr,
min_p_ptr, min_p_ptr,
vocab_size, vocab_size,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
req_idx = tl.program_id(0) req_idx = tl.program_id(0)
min_p = tl.load(min_p_ptr + req_idx).to(tl.float32) req_state_idx = tl.load(idx_mapping_ptr + req_idx)
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
if min_p == 0.0: if min_p == 0.0:
return return
...@@ -39,12 +41,17 @@ def _min_p_kernel( ...@@ -39,12 +41,17 @@ def _min_p_kernel(
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask) tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor) -> None: def apply_min_p(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
min_p: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
_min_p_kernel[(num_reqs,)]( _min_p_kernel[(num_reqs,)](
logits, logits,
logits.stride(0), logits.stride(0),
idx_mapping,
min_p, min_p,
vocab_size, vocab_size,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
......
...@@ -10,11 +10,11 @@ from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata ...@@ -10,11 +10,11 @@ from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
def _penalties_and_temperature_kernel( def _penalties_and_temperature_kernel(
logits_ptr, logits_ptr,
logits_stride, logits_stride,
idx_mapping_ptr,
repetition_penalty_ptr, repetition_penalty_ptr,
frequency_penalty_ptr, frequency_penalty_ptr,
presence_penalty_ptr, presence_penalty_ptr,
temperature_ptr, temperature_ptr,
idx_mapping_ptr,
prompt_bin_mask_ptr, prompt_bin_mask_ptr,
prompt_bin_mask_stride, prompt_bin_mask_stride,
output_bin_counts_ptr, output_bin_counts_ptr,
...@@ -23,10 +23,11 @@ def _penalties_and_temperature_kernel( ...@@ -23,10 +23,11 @@ def _penalties_and_temperature_kernel(
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
batch_idx = tl.program_id(0) batch_idx = tl.program_id(0)
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx) req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx) rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
pres_penalty = tl.load(presence_penalty_ptr + batch_idx) freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
temperature = tl.load(temperature_ptr + batch_idx) pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
temperature = tl.load(temperature_ptr + req_state_idx)
temperature = tl.where(temperature == 0.0, 1.0, temperature) temperature = tl.where(temperature == 0.0, 1.0, temperature)
use_rep_penalty = rep_penalty != 1.0 use_rep_penalty = rep_penalty != 1.0
...@@ -45,7 +46,6 @@ def _penalties_and_temperature_kernel( ...@@ -45,7 +46,6 @@ def _penalties_and_temperature_kernel(
logits = logits.to(tl.float32) logits = logits.to(tl.float32)
if use_penalty: if use_penalty:
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
output_bin_counts = tl.load( output_bin_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block, output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
mask=mask, mask=mask,
...@@ -92,11 +92,11 @@ def apply_penalties_and_temperature( ...@@ -92,11 +92,11 @@ def apply_penalties_and_temperature(
_penalties_and_temperature_kernel[(num_reqs, num_blocks)]( _penalties_and_temperature_kernel[(num_reqs, num_blocks)](
logits, logits,
logits.stride(0), logits.stride(0),
sampling_metadata.idx_mapping,
sampling_metadata.repetition_penalty, sampling_metadata.repetition_penalty,
sampling_metadata.frequency_penalty, sampling_metadata.frequency_penalty,
sampling_metadata.presence_penalty, sampling_metadata.presence_penalty,
sampling_metadata.temperature, sampling_metadata.temperature,
sampling_metadata.idx_mapping,
sampling_metadata.prompt_bin_mask, sampling_metadata.prompt_bin_mask,
sampling_metadata.prompt_bin_mask.stride(0), sampling_metadata.prompt_bin_mask.stride(0),
sampling_metadata.output_bin_counts, sampling_metadata.output_bin_counts,
......
...@@ -71,7 +71,7 @@ class Sampler: ...@@ -71,7 +71,7 @@ class Sampler:
apply_penalties_and_temperature(logits, sampling_metadata) apply_penalties_and_temperature(logits, sampling_metadata)
# Apply min_p in place. # Apply min_p in place.
if sampling_metadata.min_p is not None: if sampling_metadata.min_p is not None:
apply_min_p(logits, sampling_metadata.min_p) apply_min_p(logits, sampling_metadata.idx_mapping, sampling_metadata.min_p)
# Apply top_k and/or top_p. This might return a new tensor. # Apply top_k and/or top_p. This might return a new tensor.
logits = apply_top_k_top_p( logits = apply_top_k_top_p(
logits, sampling_metadata.top_k, sampling_metadata.top_p logits, sampling_metadata.top_k, sampling_metadata.top_p
...@@ -79,6 +79,7 @@ class Sampler: ...@@ -79,6 +79,7 @@ class Sampler:
sampled = gumbel_sample( sampled = gumbel_sample(
logits, logits,
sampling_metadata.idx_mapping,
sampling_metadata.temperature, sampling_metadata.temperature,
sampling_metadata.seeds, sampling_metadata.seeds,
sampling_metadata.pos, sampling_metadata.pos,
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any from typing import Any
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -12,7 +11,6 @@ from vllm.forward_context import set_forward_context ...@@ -12,7 +11,6 @@ from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
...@@ -46,7 +44,6 @@ class EagleSpeculator: ...@@ -46,7 +44,6 @@ class EagleSpeculator:
self.hidden_size = self.draft_model_config.get_hidden_size() self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size() self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
self.vocab_size = self.draft_model_config.get_vocab_size() self.vocab_size = self.draft_model_config.get_vocab_size()
self.pin_memory = is_pin_memory_available()
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
self.input_buffers = InputBuffers( self.input_buffers = InputBuffers(
...@@ -56,7 +53,6 @@ class EagleSpeculator: ...@@ -56,7 +53,6 @@ class EagleSpeculator:
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
pin_memory=self.pin_memory,
) )
self.hidden_states = torch.zeros( self.hidden_states = torch.zeros(
self.max_num_tokens, self.max_num_tokens,
...@@ -64,6 +60,11 @@ class EagleSpeculator: ...@@ -64,6 +60,11 @@ class EagleSpeculator:
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
) )
self.idx_mapping = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=device,
)
self.temperature = torch.zeros( self.temperature = torch.zeros(
self.max_num_reqs, self.max_num_reqs,
dtype=torch.float32, dtype=torch.float32,
...@@ -140,7 +141,7 @@ class EagleSpeculator: ...@@ -140,7 +141,7 @@ class EagleSpeculator:
num_tokens_across_dp: torch.Tensor | None, num_tokens_across_dp: torch.Tensor | None,
) -> None: ) -> None:
pos = self.input_buffers.positions[:num_reqs] pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc.gpu[: num_reqs + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
for step in range(1, self.num_speculative_steps): for step in range(1, self.num_speculative_steps):
# Run the eagle model. # Run the eagle model.
last_hidden_states, hidden_states = self.run_model( last_hidden_states, hidden_states = self.run_model(
...@@ -152,8 +153,9 @@ class EagleSpeculator: ...@@ -152,8 +153,9 @@ class EagleSpeculator:
# used for draft and target sampling. # used for draft and target sampling.
draft_tokens = gumbel_sample( draft_tokens = gumbel_sample(
logits, logits,
self.temperature[:num_reqs], self.idx_mapping[:num_reqs],
self.seeds[:num_reqs], self.temperature,
self.seeds,
pos + 1, pos + 1,
apply_temperature=True, apply_temperature=True,
) )
...@@ -237,23 +239,27 @@ class EagleSpeculator: ...@@ -237,23 +239,27 @@ class EagleSpeculator:
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs num_reqs = input_batch.num_reqs
cu_num_logits = input_batch.cu_num_logits[:num_reqs]
# NOTE(woosuk): For draft sampling, we only consider the temperature # NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p, # and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance. # for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not # While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling. # affect the output distribution after rejection sampling.
temperature = self.temperature[:num_reqs] idx_mapping = self.idx_mapping[:num_reqs]
seeds = self.seeds[:num_reqs] idx_mapping.copy_(input_batch.idx_mapping)
pos = self.input_buffers.positions[:num_reqs] self.temperature.copy_(sampling_metadata.temperature)
self.seeds.copy_(sampling_metadata.seeds)
# Gather the values and copy them to the pre-allocated buffers. # Gather the values and copy them to the pre-allocated buffers.
torch.gather(sampling_metadata.temperature, 0, cu_num_logits, out=temperature) pos = self.input_buffers.positions[:num_reqs]
torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds)
torch.gather(input_batch.positions, 0, last_token_indices, out=pos) torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling. # used for draft and target sampling.
draft_tokens = gumbel_sample( draft_tokens = gumbel_sample(
logits, temperature, seeds, pos + 1, apply_temperature=True logits,
idx_mapping,
self.temperature,
self.seeds,
pos + 1,
apply_temperature=True,
) )
if self.num_speculative_steps == 1: if self.num_speculative_steps == 1:
# Early exit. # Early exit.
...@@ -273,11 +279,8 @@ class EagleSpeculator: ...@@ -273,11 +279,8 @@ class EagleSpeculator:
self.max_model_len, self.max_model_len,
self.max_num_reqs, self.max_num_reqs,
) )
query_start_loc = self.input_buffers.query_start_loc query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] slot_mappings = self.block_tables.compute_slot_mappings(query_start_loc, pos)
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, pos
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs) cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
if cudagraph_size is not None: if cudagraph_size is not None:
...@@ -286,8 +289,9 @@ class EagleSpeculator: ...@@ -286,8 +289,9 @@ class EagleSpeculator:
return self.draft_tokens[:num_reqs] return self.draft_tokens[:num_reqs]
# Run eager mode. # Run eager mode.
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) query_start_loc_cpu = torch.arange(
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1] num_reqs + 1, dtype=torch.int32, device="cpu"
)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables] block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
# FIXME(woosuk): This is UNSAFE!! # FIXME(woosuk): This is UNSAFE!!
...@@ -295,7 +299,7 @@ class EagleSpeculator: ...@@ -295,7 +299,7 @@ class EagleSpeculator:
attn_metadata_builders=self.attn_metadata_builders, attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs, num_reqs=num_reqs,
num_tokens=num_reqs, num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc_gpu, query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu, query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens[:num_reqs], seq_lens=self.input_buffers.seq_lens[:num_reqs],
max_seq_len=self.max_model_len, max_seq_len=self.max_model_len,
...@@ -484,7 +488,7 @@ def prepare_eagle_decode( ...@@ -484,7 +488,7 @@ def prepare_eagle_decode(
input_buffers.positions, input_buffers.positions,
input_hidden_states, input_hidden_states,
input_hidden_states.stride(0), input_hidden_states.stride(0),
input_buffers.query_start_loc.gpu, input_buffers.query_start_loc,
input_buffers.seq_lens, input_buffers.seq_lens,
hidden_size, hidden_size,
max_model_len, max_model_len,
......
...@@ -8,10 +8,8 @@ import torch ...@@ -8,10 +8,8 @@ import torch
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.penalties import bincount from vllm.v1.worker.gpu.sample.penalties import bincount
...@@ -29,7 +27,6 @@ class RequestState: ...@@ -29,7 +27,6 @@ class RequestState:
num_speculative_steps: int, num_speculative_steps: int,
vocab_size: int, vocab_size: int,
device: torch.device, device: torch.device,
pin_memory: bool,
): ):
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
...@@ -37,7 +34,6 @@ class RequestState: ...@@ -37,7 +34,6 @@ class RequestState:
self.num_speculative_steps = num_speculative_steps self.num_speculative_steps = num_speculative_steps
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.device = device self.device = device
self.pin_memory = pin_memory
self.req_id_to_index: dict[str, int] = {} self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {} self.index_to_req_id: dict[int, str] = {}
...@@ -47,16 +43,18 @@ class RequestState: ...@@ -47,16 +43,18 @@ class RequestState:
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32) self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# depending on the configured max_num_reqs and max_model_len. # depending on the configured max_num_reqs and max_model_len.
self.prefill_token_ids = UvaBuffer( # To save GPU memory, we use UVA instead of GPU for this tensor.
self.max_num_reqs, self.max_model_len, dtype=torch.int32 self.prefill_token_ids = StagedWriteTensor(
(self.max_num_reqs, self.max_model_len),
dtype=torch.int32,
device=device,
uva_instead_of_gpu=True,
) )
# NOTE(woosuk): We don't use UVA for prefill_len because its GPU view self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
# can be used outside of update_states and prepare_inputs.
# Without async barrier, using UVA can cause race conditions.
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
# Number of computed tokens. # Number of computed tokens.
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = torch.zeros( self.num_computed_tokens = StagedWriteTensor(
self.max_num_reqs, dtype=torch.int32, device=device self.max_num_reqs, dtype=torch.int32, device=device
) )
...@@ -84,14 +82,16 @@ class RequestState: ...@@ -84,14 +82,16 @@ class RequestState:
self.lora_ids.fill(NO_LORA_ID) self.lora_ids.fill(NO_LORA_ID)
# Sampling parameters. # Sampling parameters.
self.temperature = self._make_param(self.max_num_reqs, torch.float32) self.temperature = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.top_p = self._make_param(self.max_num_reqs, torch.float32) self.top_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.top_k = self._make_param(self.max_num_reqs, torch.int32) self.top_k = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
self.min_p = self._make_param(self.max_num_reqs, torch.float32) self.min_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.repetition_penalty = self._make_param(self.max_num_reqs, torch.float32) self.repetition_penalty = UvaBackedTensor(
self.frequency_penalty = self._make_param(self.max_num_reqs, torch.float32) self.max_num_reqs, dtype=torch.float32
self.presence_penalty = self._make_param(self.max_num_reqs, torch.float32) )
self.seeds = self._make_param(self.max_num_reqs, torch.int64) self.frequency_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.presence_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.seeds = UvaBackedTensor(self.max_num_reqs, dtype=torch.int64)
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32) self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested. # -1 means no logprobs are requested.
...@@ -111,13 +111,7 @@ class RequestState: ...@@ -111,13 +111,7 @@ class RequestState:
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
) )
def _make_param(self, size: int, dtype: torch.dtype) -> "Param": self._penalties_reqs: list[int] = []
return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
)
@property @property
def num_reqs(self) -> int: def num_reqs(self) -> int:
...@@ -144,12 +138,9 @@ class RequestState: ...@@ -144,12 +138,9 @@ class RequestState:
f"prefill_len {prefill_len} < prompt_len {prompt_len}" f"prefill_len {prefill_len} < prompt_len {prompt_len}"
) )
self.prefill_len.np[req_idx] = prefill_len self.prefill_len.np[req_idx] = prefill_len
self.prefill_token_ids.np[req_idx, :prefill_len] = prefill_token_ids self.prefill_token_ids.stage_write(req_idx, 0, prefill_token_ids)
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
# FIXME(woosuk): This triggers a GPU operation whenever adding a new request. self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
# Optimize this.
self.num_computed_tokens[req_idx] = num_computed_tokens
if lora_request is not None: if lora_request is not None:
self.lora_ids[req_idx] = lora_request.lora_int_id self.lora_ids[req_idx] = lora_request.lora_int_id
...@@ -169,13 +160,7 @@ class RequestState: ...@@ -169,13 +160,7 @@ class RequestState:
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
if use_penalty(sampling_params): if use_penalty(sampling_params):
bincount( self._penalties_reqs.append(req_idx)
self.prefill_token_ids.gpu[req_idx],
prefill_len,
prompt_len,
self.prompt_bin_mask[req_idx],
self.output_bin_counts[req_idx],
)
if sampling_params.seed is not None: if sampling_params.seed is not None:
seed = sampling_params.seed seed = sampling_params.seed
...@@ -193,6 +178,22 @@ class RequestState: ...@@ -193,6 +178,22 @@ class RequestState:
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
def apply_staged_writes(self) -> None:
self.prefill_len.copy_to_uva()
self.prefill_token_ids.apply_write()
self.num_computed_tokens.apply_write()
# TODO(woosuk): Optimize this.
for req_idx in self._penalties_reqs:
bincount(
self.prefill_token_ids.gpu[req_idx],
int(self.prefill_len.np[req_idx]),
int(self.prompt_len[req_idx]),
self.prompt_bin_mask[req_idx],
self.output_bin_counts[req_idx],
)
self._penalties_reqs.clear()
def remove_request(self, req_id: str) -> None: def remove_request(self, req_id: str) -> None:
self.extra_data.pop(req_id, None) self.extra_data.pop(req_id, None)
req_idx = self.req_id_to_index.pop(req_id, None) req_idx = self.req_id_to_index.pop(req_id, None)
...@@ -208,30 +209,25 @@ class RequestState: ...@@ -208,30 +209,25 @@ class RequestState:
idx_mapping_np: np.ndarray, idx_mapping_np: np.ndarray,
pos: torch.Tensor, pos: torch.Tensor,
) -> SamplingMetadata: ) -> SamplingMetadata:
temperature = self.temperature.np[idx_mapping_np] temperature = self.temperature.copy_to_uva()
temperature = self.temperature.copy_np_to_gpu(temperature)
top_p = self.top_p.np[idx_mapping_np] top_p = self.top_p.np[idx_mapping_np]
no_top_p = np.all(top_p == 1.0) no_top_p = np.all(top_p == 1.0)
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None top_p = self.top_p.copy_to_uva()[idx_mapping] if not no_top_p else None
top_k = self.top_k.np[idx_mapping_np] top_k = self.top_k.np[idx_mapping_np]
no_top_k = np.all(top_k == self.vocab_size) no_top_k = np.all(top_k == self.vocab_size)
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None top_k = self.top_k.copy_to_uva()[idx_mapping] if not no_top_k else None
min_p = self.min_p.np[idx_mapping_np] min_p = self.min_p.np[idx_mapping_np]
no_min_p = np.all(min_p == 0.0) no_min_p = np.all(min_p == 0.0)
min_p = self.min_p.copy_np_to_gpu(min_p) if not no_min_p else None min_p = self.min_p.copy_to_uva() if not no_min_p else None
rep_penalty = self.repetition_penalty.np[idx_mapping_np] rep_penalty = self.repetition_penalty.copy_to_uva()
rep_penalty = self.repetition_penalty.copy_np_to_gpu(rep_penalty) freq_penalty = self.frequency_penalty.copy_to_uva()
freq_penalty = self.frequency_penalty.np[idx_mapping_np] pres_penalty = self.presence_penalty.copy_to_uva()
freq_penalty = self.frequency_penalty.copy_np_to_gpu(freq_penalty)
pres_penalty = self.presence_penalty.np[idx_mapping_np]
pres_penalty = self.presence_penalty.copy_np_to_gpu(pres_penalty)
seeds = self.seeds.np[idx_mapping_np] seeds = self.seeds.copy_to_uva()
seeds = self.seeds.copy_np_to_gpu(seeds)
num_logprobs = self.num_logprobs[idx_mapping_np] num_logprobs = self.num_logprobs[idx_mapping_np]
max_num_logprobs: int | None = int(np.max(num_logprobs)) max_num_logprobs: int | None = int(np.max(num_logprobs))
...@@ -239,6 +235,7 @@ class RequestState: ...@@ -239,6 +235,7 @@ class RequestState:
max_num_logprobs = None max_num_logprobs = None
return SamplingMetadata( return SamplingMetadata(
idx_mapping=idx_mapping,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
...@@ -246,12 +243,11 @@ class RequestState: ...@@ -246,12 +243,11 @@ class RequestState:
repetition_penalty=rep_penalty, repetition_penalty=rep_penalty,
frequency_penalty=freq_penalty, frequency_penalty=freq_penalty,
presence_penalty=pres_penalty, presence_penalty=pres_penalty,
prompt_bin_mask=self.prompt_bin_mask,
output_bin_counts=self.output_bin_counts,
seeds=seeds, seeds=seeds,
pos=pos, pos=pos,
max_num_logprobs=max_num_logprobs, max_num_logprobs=max_num_logprobs,
idx_mapping=idx_mapping,
prompt_bin_mask=self.prompt_bin_mask,
output_bin_counts=self.output_bin_counts,
) )
def make_lora_inputs( def make_lora_inputs(
...@@ -272,42 +268,12 @@ class RequestState: ...@@ -272,42 +268,12 @@ class RequestState:
return prompt_lora_mapping, token_lora_mapping, active_lora_requests return prompt_lora_mapping, token_lora_mapping, active_lora_requests
class Param:
def __init__(
self,
size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.buffer = CpuGpuBuffer(
size,
dtype=dtype,
device=device,
pin_memory=pin_memory,
)
self.np = np.zeros_like(self.buffer.np)
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
n = x.shape[0]
self.buffer.np[:n] = x
return self.buffer.copy_to_gpu(n)
@dataclass @dataclass
class ExtraData: class ExtraData:
lora_request: LoRARequest | None lora_request: LoRARequest | None
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list) in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
class UvaBuffer:
def __init__(self, *size: int | torch.SymInt, dtype: torch.dtype):
assert is_uva_available()
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True)
self.np = self.cpu.numpy()
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)
def use_penalty(sampling_params: SamplingParams) -> bool: def use_penalty(sampling_params: SamplingParams) -> bool:
return ( return (
sampling_params.repetition_penalty != 1.0 sampling_params.repetition_penalty != 1.0
......
...@@ -4,38 +4,65 @@ import numpy as np ...@@ -4,38 +4,65 @@ import numpy as np
import torch import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.utils.math_utils import cdiv
from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool
from vllm.v1.worker.gpu.input_batch import InputBatch
def apply_grammar_bitmask( class StructuredOutputsWorker:
logits: torch.Tensor, def __init__(
req_ids: list[str], self,
grammar_req_ids: list[str], max_num_logits: int,
grammar_bitmask: np.ndarray, vocab_size: int,
input_buffers: InputBuffers, ):
) -> None: # NOTE(woosuk): Here, we use UvaBufferPool instead of UvaBackedTensor
input_buffers.grammar_bitmask.np[: grammar_bitmask.shape[0]] = grammar_bitmask # to save a unnecessary CPU-to-CPU copy.
input_buffers.grammar_bitmask.copy_to_gpu(grammar_bitmask.shape[0]) self.logits_indices = UvaBufferPool(max_num_logits, torch.int32)
self.grammar_bitmask = UvaBufferPool(
(max_num_logits, cdiv(vocab_size, 32)), torch.int32
)
batch_size = logits.shape[0] def apply_grammar_bitmask(
grammar_req_id_to_idx = {req_id: i for i, req_id in enumerate(grammar_req_ids)} self,
# logits -> bitmask mapping logits: torch.Tensor,
mapping = [grammar_req_id_to_idx.get(req_id, -1) for req_id in req_ids] input_batch: InputBatch,
input_buffers.bitmask_indices.np[:batch_size] = mapping grammar_req_ids: list[str],
input_buffers.bitmask_indices.copy_to_gpu(batch_size) grammar_bitmask: np.ndarray,
) -> None:
if not grammar_req_ids:
return
vocab_size = logits.shape[-1] # Construct bitmask -> logits mapping
BLOCK_SIZE = 8192 mapping: list[int] = []
grid = (batch_size, triton.cdiv(vocab_size, BLOCK_SIZE)) req_ids = input_batch.req_ids
_apply_grammar_bitmask_kernel[grid]( cu_num_logits = input_batch.cu_num_logits_np.tolist()
logits, req_id_to_idx = {req_id: i for i, req_id in enumerate(req_ids)}
logits.stride(0), for grammar_req_id in grammar_req_ids:
input_buffers.grammar_bitmask.gpu, req_idx = req_id_to_idx[grammar_req_id]
input_buffers.grammar_bitmask.gpu.stride(0), logits_start_idx = cu_num_logits[req_idx]
input_buffers.bitmask_indices.gpu, logits_end_idx = cu_num_logits[req_idx + 1]
vocab_size, mapping.extend(range(logits_start_idx, logits_end_idx))
BLOCK_SIZE=BLOCK_SIZE, # Copy the mapping.
) mapping_np = np.array(mapping, dtype=np.int32)
logits_indices = self.logits_indices.copy_to_uva(mapping_np)
# Copy the bitmask.
bitmask = self.grammar_bitmask.copy_to_uva(grammar_bitmask)
num_masks = bitmask.shape[0]
assert num_masks == len(mapping)
vocab_size = logits.shape[-1]
BLOCK_SIZE = 8192
grid = (num_masks, triton.cdiv(vocab_size, BLOCK_SIZE))
_apply_grammar_bitmask_kernel[grid](
logits,
logits.stride(0),
logits_indices,
bitmask,
bitmask.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
# Adapted from # Adapted from
...@@ -44,17 +71,14 @@ def apply_grammar_bitmask( ...@@ -44,17 +71,14 @@ def apply_grammar_bitmask(
def _apply_grammar_bitmask_kernel( def _apply_grammar_bitmask_kernel(
logits_ptr, logits_ptr,
logits_stride, logits_stride,
logits_indices_ptr,
bitmask_ptr, bitmask_ptr,
bitmask_stride, bitmask_stride,
bitmask_indices_ptr,
vocab_size, vocab_size,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
logits_idx = tl.program_id(0) bitmask_idx = tl.program_id(0)
bitmask_idx = tl.load(bitmask_indices_ptr + logits_idx) logits_idx = tl.load(logits_indices_ptr + bitmask_idx)
if bitmask_idx == -1:
# No bitmask to apply.
return
# Load the bitmask. # Load the bitmask.
block_id = tl.program_id(1) block_id = tl.program_id(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment