Unverified Commit 546034b4 authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

[refactor] remove triton based sampler (#8524)

parent cca61642
import random
import pytest
import torch
from vllm.model_executor.layers.ops.rand import seeded_uniform
from vllm.model_executor.utils import set_random_seed
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_3d", [True, False])
def test_seeded_uniform(dtype: torch.dtype, use_3d: bool):
device = "cuda"
for seed in range(512):
set_random_seed(seed)
rows = random.randint(1, 512)
cols = random.randint(1, 64000)
if use_3d:
third_dim = random.randint(2, 10)
dims = [rows, third_dim, cols]
else:
dims = [rows, cols]
seeds = torch.randint(torch.iinfo(torch.long).min,
torch.iinfo(torch.long).max, (rows, ),
device=device)
# Test that the same seed produces the same output
out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
torch.testing.assert_close(out, out2)
# del to save memory
del out2
out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
torch.testing.assert_close(out, out3)
# del to save memory
del out3
# Initialize out tensor with garbage to ensure that it is overwritten
out_with_tensor = seeded_uniform(
*dims,
out=torch.full(
(*dims, ),
-1,
dtype=dtype,
device=device,
),
seeds=seeds,
dtype=dtype,
)
torch.testing.assert_close(out, out_with_tensor)
import gc
from unittest.mock import patch
import pytest
import torch
import triton
import triton.language as tl
from vllm.model_executor.layers.ops.sample import (_sample_triton,
_uniform_to_exponential,
sample)
from vllm.model_executor.sampling_metadata import SamplingTensors
from vllm.model_executor.utils import set_random_seed
from vllm.triton_utils.libentry import LibEntry
from vllm.triton_utils.sample import (MAX_TRITON_N_COLS,
get_num_triton_sampler_splits)
SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100
@pytest.fixture(autouse=True)
def _cleanup():
yield
gc.collect()
torch.cuda.empty_cache()
@triton.jit
def _uniform_to_exponential_kernel(input, output, n: tl.constexpr):
idx = tl.arange(0, n)
x = tl.load(input + idx)
y = _uniform_to_exponential(x)
tl.store(output + idx, y)
def test_uniform_to_exponential():
"""Test that we can convert uniform to exponential without div by 0."""
input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],
dtype=torch.float32,
device="cuda")
output = torch.zeros(input.shape, dtype=torch.float32, device="cuda")
_uniform_to_exponential_kernel[(1, )](input, output, 2)
assert torch.all(torch.isfinite(output))
assert torch.all(output > 0)
assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))
@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
@pytest.mark.parametrize("seed", [1337])
@pytest.mark.parametrize("vocab_size",
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
@pytest.mark.parametrize("save_logprobs", [True, False])
def test_sample_decoding_only(random_sampling, max_best_of,
modify_greedy_probs, seed, vocab_size,
save_logprobs):
set_random_seed(seed)
bs = 8
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
for i in range(bs):
probs[i, i * (vocab_size // bs)] = 1.0
logprobs = torch.rand_like(probs)
sample_indices = torch.arange(bs, dtype=torch.long, device="cuda")
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if random_sampling == "mixed":
random_sampling_mask = (torch.rand(
(1, bs), device="cuda") < 0.5).expand(n_splits, bs)
elif random_sampling:
random_sampling_mask = torch.ones((n_splits, bs),
dtype=torch.bool,
device="cuda")
else:
random_sampling_mask = torch.zeros((n_splits, bs),
dtype=torch.bool,
device="cuda")
seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, bs),
device="cuda").mul_(random_sampling_mask)
#The current _sample_triton does not utilize the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with patch("vllm.model_executor.layers.ops.sample._sample_triton",
LibEntry(_sample_triton)):
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
probs=probs,
logprobs=logprobs,
sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
_save_modified_probs=True)
assert sampled_tokens.shape == (bs, max_best_of)
for i in range(bs):
assert torch.all(sampled_tokens[i] == i * (vocab_size // bs))
request_uses_random_sampling = random_sampling_mask[0, i]
if modify_greedy_probs and not request_uses_random_sampling:
# If we are modifying greedy probs and the request is greedy,
# we want to make sure the probs tensor is modified in place
torch.testing.assert_close(
probs[i][sampled_tokens[i]],
torch.full_like(probs[i][sampled_tokens[i]], 1.0))
assert torch.sum(probs[i]) == 1.0
torch.testing.assert_close(
sampled_modified_probs[i][0],
torch.full_like(sampled_modified_probs[i][0], 1.0))
elif request_uses_random_sampling:
# If the request is random, we want to make sure
# sampled_modified_probs tensor has noise added
# (and thus is different from probs tensor)
assert not torch.allclose(sampled_modified_probs[i][0],
probs[i][sampled_tokens[i]])
elif not request_uses_random_sampling:
# If the request is greedy and we are not modifying greedy probs,
# we want to make sure sampled_modified_probs tensor is the same as
# the probs tensor.
torch.testing.assert_close(sampled_modified_probs[i],
probs[i][sampled_tokens[i]])
if save_logprobs:
assert sampled_logprobs.shape == (bs, max_best_of)
for i in range(bs):
for best_of in range(max_best_of):
assert torch.all(sampled_logprobs[i] == logprobs[i][
sampled_tokens[i, best_of]])
else:
assert sampled_logprobs is None
@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
@pytest.mark.parametrize("seed", [1337])
@pytest.mark.parametrize("vocab_size",
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
def test_sample_prompt_logprobs(random_sampling, max_best_of,
modify_greedy_probs, seed, vocab_size):
set_random_seed(seed)
prompt_sizes = [16, 32, 64, 128] * 2
samples = 8
bs = samples + sum(prompt_sizes)
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
for i in range(bs):
probs[i, i * (vocab_size // bs)] = 1.0
logprobs = torch.rand_like(probs)
sample_indices = torch.tensor(prompt_sizes,
dtype=torch.long,
device="cuda").cumsum_(0)
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if random_sampling == "mixed":
random_sampling_mask = torch.rand(
(n_splits, samples), device="cuda") < 0.5
elif random_sampling:
random_sampling_mask = torch.ones((n_splits, samples),
dtype=torch.bool,
device="cuda")
else:
random_sampling_mask = torch.zeros((n_splits, samples),
dtype=torch.bool,
device="cuda")
seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, samples),
device="cuda").mul_(random_sampling_mask)
#ditto
with patch("vllm.model_executor.layers.ops.sample._sample_triton",
LibEntry(_sample_triton)):
sampled_tokens, sampled_logprobs, _ = sample(
probs=probs,
logprobs=logprobs,
sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=True)
assert sampled_tokens.shape == (samples, max_best_of)
assert sampled_logprobs.shape == (samples, max_best_of)
for i, t in enumerate(sample_indices):
assert torch.all(sampled_tokens[i] == t * (vocab_size // bs))
for best_of in range(max_best_of):
assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]]
[sampled_tokens[i, best_of]])
@pytest.mark.parametrize("seed", list(range(16)))
def test_get_sequence_seeds(seed):
"""Ensure that we get a different child seed from base
seed + extra entropy"""
starting_seed = seed
seq_seed = None
extra_entropy = 1
for i in range(512):
new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed,
i,
seeds_to_generate=1,
is_greedy=False)[0]
new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds(
starting_seed,
i,
extra_entropy,
seeds_to_generate=1,
is_greedy=False)[0]
assert new_seq_seed_extra_entropy != new_seq_seed
assert seq_seed != new_seq_seed
seq_seed = new_seq_seed
from typing import Optional, Union
import torch
import triton
import triton.language as tl
def seeded_uniform(
*size,
seeds: torch.Tensor,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
pin_memory: Optional[bool] = False,
) -> torch.Tensor:
"""Similar to torch.rand, but allows for seeds to be set per row.
seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
If it is 3d, the additional seeds needed will be derived automatically
in a deterministic fashion:
[
row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
]
"""
n_dims = len(size)
if n_dims > 3:
raise ValueError("seeded_uniform only supports up to 3D tensors")
if out is None:
out = torch.empty(*size,
dtype=dtype,
device=device,
pin_memory=pin_memory)
elif out.shape != size:
raise ValueError("shape of out and size must be the same")
if n_dims == 3:
n_rows, n_3d, n_cols = out.shape
stride_row = out.stride(0)
stride_3d = out.stride(1)
elif n_dims == 2:
n_rows, n_cols = out.shape
n_3d = 1
stride_row = out.stride(0)
stride_3d = 1
else:
n_cols = out.shape[0]
n_rows = 1
n_3d = 1
stride_row = 1
stride_3d = 1
if seeds.ndim != 1:
raise ValueError("seeds must be a 1D tensor")
if seeds.numel() != n_rows:
raise ValueError(
"seeds must have the same number of elements as out has rows")
# The philox PRNG Triton uses generates 4 random numbers at once.
# Therefore, the most efficient use of it is to divide the
# block size by 4, and then save the generated random numbers to
# each of the 4 slices of the tensor.
full_block_size = triton.next_power_of_2(n_cols)
philox_block_size = max(full_block_size // 4, 1)
n_slices = full_block_size // philox_block_size
num_warps = 4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if philox_block_size >= 8192:
num_warps = 32
elif philox_block_size >= 4096:
num_warps = 16
elif philox_block_size >= 2048:
num_warps = 8
_seeded_uniform_triton[(n_rows, n_3d)](
out,
seeds,
stride_row,
stride_3d,
seeds.stride(0),
n_rows,
n_3d,
n_cols,
n_slices=n_slices,
num_warps=num_warps,
block_size=philox_block_size,
)
return out
@triton.jit
def _seeded_uniform_triton(
out_ptr: torch.Tensor,
seed_ptr: torch.Tensor,
out_row_stride: int,
out_3d_stride: int,
seed_row_stride: int,
n_rows: int,
n_3d: int,
n_cols: int,
n_slices: tl.constexpr,
block_size: tl.constexpr,
):
"""
Generate a random float32 number in [0, 1) for each element in the output
tensor. The random numbers in a row generated using the seed for that row.
Args:
out_ptr: The output tensor.
seed_ptr: The per-row seeds to use for random number generation.
out_row_stride: The stride between rows of the output tensor.
out_3d_stride: The stride between 3D slices of the output tensor.
seed_row_stride: The stride between rows of the seed tensor.
n_rows: The number of rows in the output tensor.
n_3d: The size of second dimension of the output tensor,
if output tensor is 3D.
n_cols: The number of columns in the output tensor.
n_slices: The number of philox outputs to use.
"""
tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
# Get the row index.
row_idx = tl.program_id(axis=0)
three_d_idx = tl.program_id(axis=1)
philox_offsets = tl.arange(0, block_size)
# Get the seed for the current element.
seed = tl.load(seed_ptr + row_idx * seed_row_stride)
if three_d_idx > 0:
seed ^= three_d_idx
# Generate random numbers in [0, 1).
out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
three_d_idx * out_3d_stride)
out1_offsets = philox_offsets
tl.store(output_row_start_ptr + out1_offsets,
out1,
mask=out1_offsets < n_cols)
if n_slices > 1:
out2_offsets = tl.arange(block_size, block_size * 2)
tl.store(output_row_start_ptr + out2_offsets,
out2,
mask=out2_offsets < n_cols)
if n_slices > 2:
out3_offsets = tl.arange(block_size * 2, block_size * 3)
tl.store(output_row_start_ptr + out3_offsets,
out3,
mask=out3_offsets < n_cols)
if n_slices > 3:
out4_offsets = tl.arange(block_size * 3, block_size * 4)
tl.store(output_row_start_ptr + out4_offsets,
out4,
mask=out4_offsets < n_cols)
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm.model_executor.layers.ops.rand import seeded_uniform
from vllm.triton_utils.sample import get_num_triton_sampler_splits
_EPS: tl.constexpr = 1e-6
def _multi_split_sample(
probs: torch.Tensor,
seeds: torch.Tensor,
n_splits: int,
sampled_tokens_size: Tuple[int, int],
sampled_logprobs_size: Tuple[int, int],
sample_indices: torch.Tensor,
logprobs: torch.Tensor,
*,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
):
"""Sample tokens where vocab size is split into multiple parts
(too large for Triton otherwise)."""
assert seeds.ndim == 2 and seeds.shape[0] == n_splits
split_probs = probs.tensor_split(n_splits, 1)
split_logprobs = logprobs.tensor_split(n_splits, 1)
sampled_tokens_tmp = [
torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
for _ in range(n_splits)
]
sampled_logprobs_tmp = [
torch.empty(sampled_logprobs_size,
dtype=probs.dtype,
device=probs.device) for _ in range(n_splits)
]
# We are purposefuly using sampled_tokens_size as we need to always
# save modified probs in this case.
sampled_modified_probs_tmp = [
torch.empty(sampled_tokens_size,
dtype=probs.dtype,
device=probs.device) for _ in range(n_splits)
]
for i in range(n_splits):
n_samples = sample_indices.shape[0]
n_cols = split_probs[i].shape[1]
n_best = sampled_tokens_tmp[i].shape[1]
uniform_noise = seeded_uniform(n_samples,
n_best,
n_cols,
seeds=seeds[i].flatten(),
device=split_probs[i].device,
dtype=split_probs[i].dtype)
# TODO(yard1): See if we can remove the contiguous() calls.
# Will need kernel support.
_sample(
split_probs[i].contiguous(),
split_logprobs[i].contiguous(),
sample_indices,
sampled_tokens_tmp[i],
sampled_logprobs_tmp[i],
sampled_modified_probs_tmp[i],
seeds[i],
uniform_noise,
modify_greedy_probs=False,
save_logprobs=save_logprobs,
save_modified_probs=True,
)
if i > 0:
# Add offset to sampled tokens
sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
sampled_tokens = torch.stack(sampled_tokens_tmp)
sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
# Reduce the results from the splits.
sampled_modified_probs, indices = torch.max(sampled_modified_probs,
dim=0,
keepdim=True)
sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
if save_logprobs:
sampled_logprobs = torch.stack(sampled_logprobs_tmp)
sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
else:
sampled_logprobs = None
sampled_modified_probs = sampled_modified_probs.squeeze(0)
if modify_greedy_probs:
# We need to modify the greedy probs for the sampled tokens.
# We can't do this in the kernel as we need to know the
# sampled tokens.
probs.fill_(0.0)
probs.scatter_(1, sampled_tokens, 1.0)
return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
def sample(
probs: torch.Tensor,
seeds: torch.Tensor,
*,
max_best_of: int = 1,
sample_indices: Optional[torch.Tensor] = None,
logprobs: Optional[torch.Tensor] = None,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
_save_modified_probs: bool = False, # pylint: disable=invalid-name
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Sample tokens from probs. with per-sequence seeds.
Can sample from a subset of sequences through sample_indices.
Args:
probs: Probabilities to sample from.
shape = [batch_size, vocab_size]
seeds: Per-sequence seed values.
shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
max_best_of: Number of samples to generate per sequence.
Sequence seed will be incremented by 1 each time.
sample_indices: Indices of sequences to sample from.
If not provided, will sample from all sequences.
shape = [n]
logprobs: Log-probabilities of the sampled tokens.
Only used for saving the logprobs if save_logprobs is True.
shape = [batch_size, vocab_size]
modify_greedy_probs: Whether to modify the greedy probabilities
for speculative sampling (sampled token = 1.0,
everything else = 0.0).
save_logprobs: Whether to save the log-probabilities of the
sampled tokens to a tensor.
_save_modified_probs: Whether to save the modified probabilities
(including gumbel noise) of the sampled tokens to a tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
This is exposed only for testing.
Returns:
sampled_tokens: shape = [n, max_best_of]
sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
sampled_modified_probs: shape = [n, max_best_of]
if save_modified_probs else None
"""
if sample_indices is None:
sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
sampled_tokens_size = (sample_indices.size(0), max_best_of)
if save_logprobs:
if logprobs is None:
raise ValueError(
"logprobs tensor must be provided if save_logprobs is True")
sampled_logprobs_size = sampled_tokens_size
else:
# Empty tensors to invoke the kernel
sampled_logprobs_size = (0, 0)
logprobs = probs
assert logprobs is not None
if _save_modified_probs:
sampled_modified_probs_size = sampled_tokens_size
else:
# Empty tensors to invoke the kernel
sampled_modified_probs_size = (0, 0)
# If the number of columns in probs is too large for Triton to handle,
# we split the tensor and sample from each split separately, and then
# do an argmax+gather to combine the results.
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if n_splits > 1:
(sampled_tokens, sampled_logprobs,
sampled_modified_probs) = _multi_split_sample(
probs,
seeds,
n_splits,
sampled_tokens_size,
sampled_logprobs_size,
sample_indices,
logprobs=logprobs,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs)
else:
sampled_tokens = torch.empty(sampled_tokens_size,
dtype=torch.long,
device=probs.device)
sampled_logprobs = torch.empty(sampled_logprobs_size,
dtype=probs.dtype,
device=probs.device)
sampled_modified_probs = torch.empty(sampled_modified_probs_size,
dtype=probs.dtype,
device=probs.device)
n_samples = sample_indices.shape[0]
n_cols = probs.shape[1]
uniform_noise = seeded_uniform(n_samples,
max_best_of,
n_cols,
seeds=seeds.flatten(),
device=probs.device,
dtype=probs.dtype)
_sample(
probs,
logprobs,
sample_indices,
sampled_tokens,
sampled_logprobs,
sampled_modified_probs,
seeds,
uniform_noise,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
save_modified_probs=_save_modified_probs,
)
return (sampled_tokens, sampled_logprobs if save_logprobs else None,
sampled_modified_probs if _save_modified_probs else None)
def _sample(probs: torch.Tensor,
logprobs: torch.Tensor,
sample_indices: torch.Tensor,
output_samples: torch.Tensor,
output_logprobs: torch.Tensor,
output_modified_probs: torch.Tensor,
seeds: torch.Tensor,
uniform_noise: torch.Tensor,
*,
modify_greedy_probs: bool = False,
save_logprobs: bool = True,
save_modified_probs: bool = False) -> torch.Tensor:
"""Sample tokens from probs.
Args:
probs [batch_size, vocab_size]: probs to sample from.
logprobs [batch_size, vocab_size]: logprobs (used when
save_logprobsis True).
sample_indices [n]: Indices of the samples to use for each row of probs.
output_samples [n, n_best]: Output tensor to store samples in.
output_logprobs [n, n_best]: Output tensor to store logprobs in.
output_modified_probs [n, n_best]: Output tensor to store
probs of chosen tokens in (modified with noise).
seeds [n]: Seeds to use for sampling. If the seed is 0, we use
greedy sampling. Note this is ONLY used for determining
whether to use random sampling or not. The actual random
noise should be passed as uniform_noise.
uniform_noise [batch_size, n_best, vocab_size]: Uniform
noise to use for random sampling (will be converted
to exponential gumbel noise by the kernel).
modify_greedy_probs: If True, we modify the probs tensor in-place
to encode the sampling method used for each row. This is used
in speculative decoding. Only applies in greedy decoding.
save_logprobs: If True, we save the logprobs of the sampled tokens
in the output_logprobs tensor.
save_modified_probs: If True, we save the modified probs (with noise)
of the sampled tokens in the output_modified_probs tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
"""
n_samples = sample_indices.shape[0]
n_cols = probs.shape[1]
n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
# The block size is the smallest power of two greater than the number of
# columns in probs
block_size = triton.next_power_of_2(n_cols)
num_warps = 4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if block_size >= 8192:
num_warps = 32
elif block_size >= 4096:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
# Enqueue kernel. The 1D launch grid is simple: we have one kernel
# instance per row of the probs matrix
_sample_triton[(n_samples, n_best)](
sample_indices,
output_samples,
output_logprobs,
output_modified_probs,
probs,
logprobs,
seeds,
uniform_noise,
output_samples.stride(0),
probs.stride(0),
uniform_noise.stride(0),
uniform_noise.stride(1) if n_best > 1 else 1,
n_samples,
n_cols,
n_best,
num_warps=num_warps,
block_size=block_size,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
save_modified_probs=save_modified_probs,
)
return output_samples, output_logprobs, output_modified_probs
@triton.jit
def _uniform_to_exponential(uniform_noise):
"""Convert uniform samples to exponential samples."""
# tl.rand returns values in [0, 1), so we clamp lower bound
# to _EPS to avoid log(0) and thus division by 0 later
lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
uniform_noise = tl.maximum(uniform_noise, lb)
# Use the inversion method to turn uniform samples
# into exponential samples
exponential_noise = -tl.log(uniform_noise)
return exponential_noise
@triton.jit
def _sample_triton(
sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
output_logprobs_ptr: torch.Tensor,
output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
uniform_noise_ptr: torch.Tensor, output_row_stride: int,
probs_row_stride: int, uniform_noise_row_stride: int,
uniform_noise_best_stride: int, n_samples: int, n_cols: int,
n_best: int, block_size: tl.constexpr,
modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
save_modified_probs: tl.constexpr):
# The rows are independent, so we parallelize across those
sample_idx = tl.program_id(0)
best_idx = tl.program_id(1)
# Load the row index from DRAM
row_idx = tl.load(sample_indices_ptr + sample_idx)
seed = tl.load(seeds_ptr + sample_idx)
uses_random_sampling = seed != 0
# The stride represents how much we need to increase the
# pointer to advance 1 row
row_start_ptr = probs_ptr + row_idx * probs_row_stride
# The block size is the next power of two greater than n_cols,
# so we can fit each row in a single block
col_offsets = tl.arange(0, block_size)
# Load the row into SRAM, using a mask since block_size may be > than n_cols
row = tl.load(row_start_ptr + col_offsets,
mask=col_offsets < n_cols,
other=float("-inf"))
if uses_random_sampling:
uniform_noise_start_ptr = (uniform_noise_ptr +
sample_idx * uniform_noise_row_stride +
best_idx * uniform_noise_best_stride)
uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
mask=col_offsets < n_cols,
other=0.5)
exponential_noise = _uniform_to_exponential(uniform_noise)
row /= exponential_noise
sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
# clamp sampled token to n_cols - 1
# this should not be necessary, but we do it
# just in case
if sampled_token >= n_cols:
sampled_token = n_cols - 1
# Write back output to DRAM
output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
best_idx)
tl.store(output_row_start_ptr, sampled_token)
if modify_greedy_probs: # noqa
if not uses_random_sampling:
# Set the probability of the sampled token to 1, all other
# tokens to zero. This is used in speculative decoding where
# the sampling method must be encoded within the sampled
# probability distributions.
row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
tl.store(row_start_ptr + col_offsets,
row,
mask=col_offsets < n_cols)
if save_modified_probs:
output_row_start_ptr = (output_modified_probs_ptr +
sample_idx * output_row_stride + best_idx)
tl.store(output_row_start_ptr, sampled_value)
if save_logprobs:
# Load the row into SRAM, using a mask since block_size
# may be > than n_cols
sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
sampled_token)
# Write back output to DRAM
output_row_start_ptr = (output_logprobs_ptr +
sample_idx * output_row_stride + best_idx)
tl.store(output_row_start_ptr, sampled_logprob)
...@@ -10,12 +10,6 @@ import msgspec ...@@ -10,12 +10,6 @@ import msgspec
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.model_executor.layers.ops.sample import sample as sample_triton
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.sampling_metadata import (SamplingMetadata, from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors, SamplingTensors,
...@@ -23,6 +17,7 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata, ...@@ -23,6 +17,7 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput) PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling import flashinfer.sampling
...@@ -777,7 +772,7 @@ def _sample_with_torch( ...@@ -777,7 +772,7 @@ def _sample_with_torch(
# Counterintiutively, having two loops here is actually faster. # Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync. # The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType: for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type][:, 0] sample_indices = categorized_sample_indices[sampling_type]
num_tokens = len(sample_indices) num_tokens = len(sample_indices)
if num_tokens == 0: if num_tokens == 0:
continue continue
...@@ -863,88 +858,6 @@ def _sample_with_torch( ...@@ -863,88 +858,6 @@ def _sample_with_torch(
) )
def _sample_with_triton_kernel(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
) -> SampleResultType:
categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata: Dict[SamplingType,
Tuple[List[int], List[SequenceGroupToSample],
torch.Tensor, torch.Tensor]] = {}
max_best_of_in_batch = 1
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type][:, 0]
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups,
sample_indices,
sampled_token_indices)
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
for seq_group in seq_groups:
if seq_group.is_prompt:
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
sampled_tokens, _, _ = sample_triton(
probs=probs,
seeds=sampling_tensors.sampling_seeds,
max_best_of=max_best_of_in_batch,
sample_indices=sampling_tensors.sample_indices,
logprobs=logprobs,
# don't save logprobs because we have logic for that below
# TODO: use this instead of the CPU-based logic below
save_logprobs=False,
)
# GPU<->CPU sync happens in the loop below.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups, sample_indices,
sampled_token_indices) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(
seq_groups, sampled_tokens[sampled_token_indices])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results
def _sample( def _sample(
probs: torch.Tensor, probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
...@@ -974,10 +887,6 @@ def _sample( ...@@ -974,10 +887,6 @@ def _sample(
modify_greedy_probs=modify_greedy_probs, modify_greedy_probs=modify_greedy_probs,
) )
# TODO: Enable once Triton kernel & associated code is faster.
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
# sampling_tensors)
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
""" """
......
import random
from array import array from array import array
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -8,15 +7,10 @@ import torch ...@@ -8,15 +7,10 @@ import torch
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.triton_utils.sample import get_num_triton_sampler_splits
from vllm.utils import (PyObjectCache, async_tensor_h2d, from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad, is_pin_memory_available, make_tensor_with_pad)
maybe_expand_dim)
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558
# Some triton sampler related code is guarded before it is ready.
_USE_TRITON_SAMPLER = False
@dataclass @dataclass
...@@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int): ...@@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int):
generator=None, generator=None,
is_prompt=True, is_prompt=True,
prompt_logprob_indices=[], prompt_logprob_indices=[],
sample_indices=[]) sample_indices=[],
)
class SamplingMetadataCache: class SamplingMetadataCache:
"""Used to cache SamplingMetadata objects between scheduler iterations """Used to cache SamplingMetadata objects between scheduler iterations"""
"""
def __init__(self): def __init__(self):
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {} self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
...@@ -165,16 +159,19 @@ class SamplingMetadata: ...@@ -165,16 +159,19 @@ class SamplingMetadata:
num_prompts, num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device, generators, cache) device, generators, cache)
selected_token_indices = async_tensor_h2d(selected_token_indices, selected_token_indices = async_tensor_h2d(
selected_token_indices,
dtype=torch.long, dtype=torch.long,
target_device=device, target_device=device,
pin_memory=pin_memory) pin_memory=pin_memory,
)
categorized_sample_indices = { categorized_sample_indices = {
t: maybe_expand_dim( t: async_tensor_h2d(
async_tensor_h2d(seq_ids, seq_ids,
dtype=torch.int, dtype=torch.int,
target_device=device, target_device=device,
pin_memory=pin_memory), 2, 2) pin_memory=pin_memory,
)
for t, seq_ids in categorized_sample_indices.items() for t, seq_ids in categorized_sample_indices.items()
} }
...@@ -201,8 +198,8 @@ def _prepare_seq_groups( ...@@ -201,8 +198,8 @@ def _prepare_seq_groups(
device: str, device: str,
generators: Optional[Dict[str, torch.Generator]] = None, generators: Optional[Dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None, cache: Optional[SamplingMetadataCache] = None,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[SamplingType,
SamplingType, List[Tuple[int, int]]], int]: List[int]], int, ]:
"""Prepare sequence groups and indices for sampling. """Prepare sequence groups and indices for sampling.
Args: Args:
...@@ -233,16 +230,13 @@ def _prepare_seq_groups( ...@@ -233,16 +230,13 @@ def _prepare_seq_groups(
# Sampling type -> ( # Sampling type -> (
# indices to sample/prompt logprob within pruned output logits, # indices to sample/prompt logprob within pruned output logits,
# indices to sample within pruned logits) # indices to sample within pruned logits)
categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = { categorized_sample_indices: Dict[SamplingType, List[int]] = {
t: [] t: []
for t in SamplingType for t in SamplingType
} }
# Index of logits to compute logprob. Logits include both prompt logprob # Index of logits to compute logprob. Logits include both prompt logprob
# and sample logprob indices. # and sample logprob indices.
logit_idx = 0 logit_idx = 0
# Index to sample from a sample tensor. It is used by triton sample kernel.
# See `_sample_with_triton_kernel` for more details.
sample_idx = 0
# Total number of prompts from given sequence groups. # Total number of prompts from given sequence groups.
num_prompts = 0 num_prompts = 0
...@@ -264,10 +258,10 @@ def _prepare_seq_groups( ...@@ -264,10 +258,10 @@ def _prepare_seq_groups(
# If the current seq group is in decode stage, it is None. # If the current seq group is in decode stage, it is None.
seq_len: Optional[int] = None seq_len: Optional[int] = None
query_len: Optional[int] = None query_len: Optional[int] = None
prompt_logprob_indices: List[int] = \ prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
sample_obj.prompt_logprob_indices if cache is not None else [] if cache is not None else [])
sample_indices: List[int] = \ sample_indices: List[int] = (sample_obj.sample_indices
sample_obj.sample_indices if cache is not None else [] if cache is not None else [])
do_sample = seq_group_metadata.do_sample do_sample = seq_group_metadata.do_sample
if seq_group_metadata.is_prompt: if seq_group_metadata.is_prompt:
...@@ -333,11 +327,8 @@ def _prepare_seq_groups( ...@@ -333,11 +327,8 @@ def _prepare_seq_groups(
if do_sample: if do_sample:
sample_indices.extend(range(logit_idx, logit_idx + sample_len)) sample_indices.extend(range(logit_idx, logit_idx + sample_len))
categorized_sample_indices[sampling_params.sampling_type].extend( categorized_sample_indices[sampling_params.sampling_type].extend(
list( list(range(logit_idx, logit_idx + sample_len)))
zip(range(logit_idx, logit_idx + sample_len),
range(sample_idx, sample_idx + sample_len))))
logit_idx += sample_len logit_idx += sample_len
sample_idx += sample_len
if cache is not None: if cache is not None:
sample_obj.sampling_params = sampling_params sample_obj.sampling_params = sampling_params
...@@ -356,7 +347,8 @@ def _prepare_seq_groups( ...@@ -356,7 +347,8 @@ def _prepare_seq_groups(
generator=generator, generator=generator,
is_prompt=is_prompt, is_prompt=is_prompt,
prompt_logprob_indices=list(prompt_logprob_indices), prompt_logprob_indices=list(prompt_logprob_indices),
sample_indices=list(sample_indices)) sample_indices=list(sample_indices),
)
seq_groups.append(sample_obj) seq_groups.append(sample_obj)
...@@ -378,9 +370,6 @@ class SamplingTensors: ...@@ -378,9 +370,6 @@ class SamplingTensors:
presence_penalties: torch.Tensor presence_penalties: torch.Tensor
frequency_penalties: torch.Tensor frequency_penalties: torch.Tensor
repetition_penalties: torch.Tensor repetition_penalties: torch.Tensor
sampling_seeds: torch.Tensor
sample_indices: torch.Tensor
extra_seeds: Optional[torch.Tensor]
prompt_tokens: torch.Tensor prompt_tokens: torch.Tensor
output_tokens: torch.Tensor output_tokens: torch.Tensor
...@@ -391,15 +380,7 @@ class SamplingTensors: ...@@ -391,15 +380,7 @@ class SamplingTensors:
vocab_size: int, vocab_size: int,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
*,
extra_seeds_to_generate: int = 0,
extra_entropy: Optional[Tuple[int, ...]] = None
) -> Tuple["SamplingTensors", bool, bool, bool]: ) -> Tuple["SamplingTensors", bool, bool, bool]:
"""
extra_seeds_to_generate: extra seeds to generate using the
user-defined seed for each sequence.
extra_entropy: extra entropy to use when generating seeds.
"""
prompt_tokens: List[array] = [] prompt_tokens: List[array] = []
output_tokens: List[array] = [] output_tokens: List[array] = []
top_ks: List[int] = [] top_ks: List[int] = []
...@@ -409,19 +390,10 @@ class SamplingTensors: ...@@ -409,19 +390,10 @@ class SamplingTensors:
presence_penalties: List[float] = [] presence_penalties: List[float] = []
frequency_penalties: List[float] = [] frequency_penalties: List[float] = []
repetition_penalties: List[float] = [] repetition_penalties: List[float] = []
sampling_seeds: List[int] = []
sample_indices: List[int] = []
do_penalties = False do_penalties = False
do_top_p_top_k = False do_top_p_top_k = False
do_min_p = False do_min_p = False
if _USE_TRITON_SAMPLER:
prompt_best_of: List[int] = []
# We need one base seed per Triton slice.
seeds_to_generate = (extra_seeds_to_generate +
get_num_triton_sampler_splits(vocab_size))
assert sampling_metadata.seq_groups is not None assert sampling_metadata.seq_groups is not None
for seq_group in sampling_metadata.seq_groups: for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
...@@ -452,7 +424,7 @@ class SamplingTensors: ...@@ -452,7 +424,7 @@ class SamplingTensors:
do_penalties = True do_penalties = True
is_prompt = seq_group.is_prompt is_prompt = seq_group.is_prompt
if (is_prompt and sampling_params.prompt_logprobs is not None): if is_prompt and sampling_params.prompt_logprobs is not None:
# For tokens in the prompt that we only need to get # For tokens in the prompt that we only need to get
# their logprobs # their logprobs
query_len = seq_group.query_len query_len = seq_group.query_len
...@@ -477,28 +449,6 @@ class SamplingTensors: ...@@ -477,28 +449,6 @@ class SamplingTensors:
frequency_penalties += [f] * len(seq_ids) frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids) repetition_penalties += [r] * len(seq_ids)
if _USE_TRITON_SAMPLER:
if is_prompt:
prompt_best_of.append(sampling_params.best_of)
query_len = seq_group.query_len
assert query_len is not None
seed = sampling_params.seed
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
extra_entropy = extra_entropy or ()
seq_seeds = cls._get_sequence_seeds(
seed,
seq_data.get_len(),
*extra_entropy,
seq_id,
seeds_to_generate=seeds_to_generate,
is_greedy=is_greedy)
sampling_seeds.append(seq_seeds)
sample_indices.extend(seq_group.sample_indices)
if do_penalties: if do_penalties:
for seq_group in sampling_metadata.seq_groups: for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
...@@ -518,23 +468,37 @@ class SamplingTensors: ...@@ -518,23 +468,37 @@ class SamplingTensors:
output_tokens.append(seq_data.output_token_ids_array) output_tokens.append(seq_data.output_token_ids_array)
sampling_tensors = SamplingTensors.from_lists( sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties, temperatures,
frequency_penalties, repetition_penalties, sampling_seeds, top_ps,
sample_indices, prompt_tokens, output_tokens, vocab_size, top_ks,
extra_seeds_to_generate, device, dtype) min_ps,
presence_penalties,
frequency_penalties,
repetition_penalties,
prompt_tokens,
output_tokens,
vocab_size,
device,
dtype,
)
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
@classmethod @classmethod
def from_lists(cls, temperatures: List[float], top_ps: List[float], def from_lists(
top_ks: List[int], min_ps: List[float], cls,
temperatures: List[float],
top_ps: List[float],
top_ks: List[int],
min_ps: List[float],
presence_penalties: List[float], presence_penalties: List[float],
frequency_penalties: List[float], frequency_penalties: List[float],
repetition_penalties: List[float], repetition_penalties: List[float],
sampling_seeds: List[int], sample_indices: List[int], prompt_tokens: List[array],
prompt_tokens: List[array], output_tokens: List[array], output_tokens: List[array],
vocab_size: int, extra_seeds_to_generate: int, vocab_size: int,
device: torch.device, device: torch.device,
dtype: torch.dtype) -> "SamplingTensors": dtype: torch.dtype,
) -> "SamplingTensors":
# Note that the performance will be very bad without # Note that the performance will be very bad without
# pinned memory. # pinned memory.
pin_memory = is_pin_memory_available() pin_memory = is_pin_memory_available()
...@@ -603,34 +567,9 @@ class SamplingTensors: ...@@ -603,34 +567,9 @@ class SamplingTensors:
dtype=torch.int, dtype=torch.int,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
sample_indices_t = torch.tensor(
sample_indices,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
# need to transpose and make contiguous to
# copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size]
sampling_seeds_t = torch.tensor(
sampling_seeds,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
).t().contiguous()
# Because the memory is pinned, we can do non-blocking # Because the memory is pinned, we can do non-blocking
# transfer to device. # transfer to device.
# How many seeds the sample operation itself will need.
num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
sampling_seeds_gpu = sampling_seeds_t.to(device=device,
non_blocking=True)
extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
if not extra_seeds_gpu.numel():
extra_seeds_gpu = None
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
return cls( return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True), temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True),
...@@ -644,38 +583,4 @@ class SamplingTensors: ...@@ -644,38 +583,4 @@ class SamplingTensors:
non_blocking=True), non_blocking=True),
prompt_tokens=prompt_t.to(device=device, non_blocking=True), prompt_tokens=prompt_t.to(device=device, non_blocking=True),
output_tokens=output_t.to(device=device, non_blocking=True), output_tokens=output_t.to(device=device, non_blocking=True),
sampling_seeds=sampling_seeds_gpu,
sample_indices=sample_indices_t.to(device=device,
non_blocking=True),
extra_seeds=extra_seeds_gpu,
) )
@staticmethod
def _get_sequence_seeds(
seed: int,
*extra_entropy: int,
seeds_to_generate: int,
is_greedy: bool,
):
"""Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
if not is_greedy:
if seed is None:
randint_fn = random.randint
else:
generator = random.Random(str((seed, ) + extra_entropy))
randint_fn = generator.randint
lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
# If the user/random sets seed = 0 but request should
# have sampling, we need to change it to something
# else. We use a constant in that case.
# This way we don't need to create and load a bool
# matrix in the sampling kernel, which reduces CPU
# overhead and latency.
seq_seeds = [
randint_fn(lo, hi) or _SEED_0_REPLACEMENT
for _ in range(seeds_to_generate)
]
else:
# For the kernel, seed == 0 means greedy decoding.
seq_seeds = [0] * seeds_to_generate
return seq_seeds
import math
# This is a hardcoded limit in Triton (max block size).
MAX_TRITON_N_COLS = 131072
def get_num_triton_sampler_splits(n_cols: int) -> int:
"""Get the number of splits to use for Triton sampling.
Triton has a limit on the number of columns it can handle, so we need to
split the tensor and call the kernel multiple times if it's too large.
"""
return math.ceil(n_cols / MAX_TRITON_N_COLS)
...@@ -837,15 +837,6 @@ def async_tensor_h2d( ...@@ -837,15 +837,6 @@ def async_tensor_h2d(
return t.to(device=target_device, non_blocking=True) return t.to(device=target_device, non_blocking=True)
def maybe_expand_dim(tensor: torch.Tensor,
target_dims: int,
size: int = 1) -> torch.Tensor:
"""Expand the tensor to the target_dims."""
if tensor.ndim < target_dims:
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
return tensor
def get_dtype_size(dtype: torch.dtype) -> int: def get_dtype_size(dtype: torch.dtype) -> int:
"""Get the size of the data type in bytes.""" """Get the size of the data type in bytes."""
return torch.tensor([], dtype=dtype).element_size() return torch.tensor([], dtype=dtype).element_size()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment