Unverified Commit 9042d683 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Consolidate and optimize logic for building padded tensors (#6541)

parent 3f8d42c8
...@@ -21,7 +21,8 @@ from vllm.distributed import (destroy_distributed_environment, ...@@ -21,7 +21,8 @@ from vllm.distributed import (destroy_distributed_environment,
from vllm.inputs import TextPrompt from vllm.inputs import TextPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -124,12 +125,6 @@ def image_assets() -> _ImageAssets: ...@@ -124,12 +125,6 @@ def image_assets() -> _ImageAssets:
return IMAGE_ASSETS return IMAGE_ASSETS
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
}
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding) _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
...@@ -151,8 +146,7 @@ class HfRunner: ...@@ -151,8 +146,7 @@ class HfRunner:
is_vision_model: bool = False, is_vision_model: bool = False,
is_sparseml_model: bool = False, is_sparseml_model: bool = False,
) -> None: ) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name self.model_name = model_name
......
...@@ -306,11 +306,8 @@ class FlashAttentionMetadataBuilder( ...@@ -306,11 +306,8 @@ class FlashAttentionMetadataBuilder(
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device) block_tables = torch.tensor(input_block_tables, device=device)
else: else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
self.block_tables, self.block_tables,
max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device=device, device=device,
......
...@@ -344,11 +344,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -344,11 +344,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cuda_graph_pad_size) cuda_graph_pad_size)
self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
else: else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
self.block_tables, self.block_tables,
max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device=device, device=device,
......
...@@ -182,11 +182,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -182,11 +182,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device) block_tables = torch.tensor(input_block_tables, device=device)
else: else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
self.block_tables, self.block_tables,
max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device=device, device=device,
......
...@@ -2,14 +2,13 @@ import random ...@@ -2,14 +2,13 @@ import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np
import torch import torch
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData, SequenceGroupMetadata from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available, from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
maybe_expand_dim) make_tensor_with_pad, maybe_expand_dim)
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558 _SEED_0_REPLACEMENT = 3403598558
...@@ -466,22 +465,24 @@ class SamplingTensors: ...@@ -466,22 +465,24 @@ class SamplingTensors:
do_penalties = prompt_tokens or output_tokens do_penalties = prompt_tokens or output_tokens
if do_penalties: if do_penalties:
prompt_max_len = max([len(tokens) for tokens in prompt_tokens], prompt_t = make_tensor_with_pad(
default=0) prompt_tokens,
prompt_padded_tokens = np.full(
(len(prompt_tokens), prompt_max_len),
vocab_size, vocab_size,
dtype=np.int64) device="cpu",
for i, tokens in enumerate(prompt_tokens): dtype=torch.int64,
prompt_padded_tokens[i, :len(tokens)] = tokens pin_memory=pin_memory,
output_max_len = max([len(tokens) for tokens in output_tokens], )
default=0) output_t = make_tensor_with_pad(
output_padded_tokens = np.full( output_tokens,
(len(output_tokens), output_max_len),
vocab_size, vocab_size,
dtype=np.int64) device="cpu",
for i, tokens in enumerate(output_tokens): dtype=torch.int64,
output_padded_tokens[i, :len(tokens)] = tokens pin_memory=pin_memory,
)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_t = empty_tensor
output_t = empty_tensor
temperatures_t = torch.tensor( temperatures_t = torch.tensor(
temperatures, temperatures,
...@@ -531,15 +532,6 @@ class SamplingTensors: ...@@ -531,15 +532,6 @@ class SamplingTensors:
dtype=torch.long, dtype=torch.long,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
if do_penalties:
prompt_tensor = torch.from_numpy(prompt_padded_tokens)
output_tensor = torch.from_numpy(output_padded_tokens)
if pin_memory:
prompt_tensor = prompt_tensor.pin_memory()
output_tensor = output_tensor.pin_memory()
else:
prompt_tensor = None
output_tensor = None
# need to transpose and make contiguous to # need to transpose and make contiguous to
# copy the tensor correctly. # copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size] # [batch_size, n_seeds] -> [n_seeds, batch_size]
...@@ -562,16 +554,6 @@ class SamplingTensors: ...@@ -562,16 +554,6 @@ class SamplingTensors:
extra_seeds_gpu = None extra_seeds_gpu = None
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
if do_penalties:
prompt_tokens_gpu = prompt_tensor.to(device=device,
non_blocking=True)
output_tokens_gpu = output_tensor.to(device=device,
non_blocking=True)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_tokens_gpu = empty_tensor
output_tokens_gpu = empty_tensor
return cls( return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True), temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True),
...@@ -583,8 +565,8 @@ class SamplingTensors: ...@@ -583,8 +565,8 @@ class SamplingTensors:
non_blocking=True), non_blocking=True),
repetition_penalties=repetition_penalties_t.to(device=device, repetition_penalties=repetition_penalties_t.to(device=device,
non_blocking=True), non_blocking=True),
prompt_tokens=prompt_tokens_gpu, prompt_tokens=prompt_t.to(device=device, non_blocking=True),
output_tokens=output_tokens_gpu, output_tokens=output_t.to(device=device, non_blocking=True),
sampling_seeds=sampling_seeds_gpu, sampling_seeds=sampling_seeds_gpu,
sample_indices=sample_indices_t.to(device=device, sample_indices=sample_indices_t.to(device=device,
non_blocking=True), non_blocking=True),
......
...@@ -20,6 +20,7 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, ...@@ -20,6 +20,7 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Union) Union)
import numpy as np import numpy as np
import numpy.typing as npt
import psutil import psutil
import torch import torch
import torch.types import torch.types
...@@ -40,6 +41,15 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -40,6 +41,15 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8_e5m2": torch.uint8, "fp8_e5m2": torch.uint8,
} }
TORCH_DTYPE_TO_NUMPY_DTYPE = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
torch.uint8: np.uint8,
torch.int32: np.int32,
torch.int64: np.int64,
}
P = ParamSpec('P') P = ParamSpec('P')
K = TypeVar("K") K = TypeVar("K")
T = TypeVar("T") T = TypeVar("T")
...@@ -617,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]: ...@@ -617,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
f"(e.g., 1, 2, 3). Given input: {s}") from e f"(e.g., 1, 2, 3). Given input: {s}") from e
def make_tensor_with_pad( def make_ndarray_with_pad(
x: List[List[int]], x: List[List[T]],
max_len: int, pad: T,
pad: int, dtype: npt.DTypeLike,
dtype: torch.dtype, *,
device: Optional[Union[str, torch.device]], max_len: Optional[int] = None,
) -> torch.Tensor: ) -> npt.NDArray:
"""Make a padded tensor of a 2D inputs. """
Make a padded array from 2D inputs.
The padding is applied to the end of each inner list until it reaches The padding is applied to the end of each inner list until it reaches
`max_len`. `max_len`.
""" """
padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad if max_len is None:
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x), default=0)
padded_x = np.full((len(x), max_len), pad, dtype=dtype)
for ind, blocktb in enumerate(x): for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len assert len(blocktb) <= max_len
padded_x[ind, :len(blocktb)] = blocktb padded_x[ind, :len(blocktb)] = blocktb
return torch.tensor(padded_x, dtype=dtype, device=device)
return padded_x
def make_tensor_with_pad(
x: List[List[T]],
pad: T,
dtype: torch.dtype,
*,
max_len: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory()
return tensor
def async_tensor_h2d( def async_tensor_h2d(
......
...@@ -276,11 +276,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -276,11 +276,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
block_tables, block_tables,
max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device=self.device, device=self.device,
......
...@@ -121,13 +121,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -121,13 +121,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
max_seq_len = max(seq_lens) max_seq_len = max(seq_lens)
assert max_seq_len > 0 assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens, input_tokens = make_tensor_with_pad(input_tokens,
max_seq_len,
pad=0, pad=0,
max_len=max_seq_len,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
input_positions = make_tensor_with_pad(input_positions, input_positions = make_tensor_with_pad(input_positions,
max_seq_len,
pad=0, pad=0,
max_len=max_seq_len,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
input_block_ids = torch.tensor(input_block_ids, input_block_ids = torch.tensor(input_block_ids,
...@@ -171,13 +171,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -171,13 +171,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_block_ids.append(block_table[0]) input_block_ids.append(block_table[0])
input_tokens = make_tensor_with_pad(input_tokens, input_tokens = make_tensor_with_pad(input_tokens,
max_len=1,
pad=0, pad=0,
max_len=1,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
input_positions = make_tensor_with_pad(input_positions, input_positions = make_tensor_with_pad(input_positions,
max_len=1,
pad=0, pad=0,
max_len=1,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
context_lens = torch.tensor(context_lens, context_lens = torch.tensor(context_lens,
......
...@@ -335,11 +335,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -335,11 +335,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
block_tables, block_tables,
max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device=self.device, device=self.device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment