Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
546034b4
Unverified
Commit
546034b4
authored
Sep 16, 2024
by
Simon Mo
Committed by
GitHub
Sep 16, 2024
Browse files
[refactor] remove triton based sampler (#8524)
parent
cca61642
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
75 additions
and
1095 deletions
+75
-1095
tests/kernels/test_rand.py
tests/kernels/test_rand.py
+0
-52
tests/kernels/test_sampler.py
tests/kernels/test_sampler.py
+0
-209
vllm/model_executor/layers/ops/__init__.py
vllm/model_executor/layers/ops/__init__.py
+0
-0
vllm/model_executor/layers/ops/rand.py
vllm/model_executor/layers/ops/rand.py
+0
-157
vllm/model_executor/layers/ops/sample.py
vllm/model_executor/layers/ops/sample.py
+0
-394
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+3
-94
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+58
-153
vllm/triton_utils/sample.py
vllm/triton_utils/sample.py
+0
-13
vllm/utils.py
vllm/utils.py
+14
-23
No files found.
tests/kernels/test_rand.py
deleted
100644 → 0
View file @
cca61642
import
random
import
pytest
import
torch
from
vllm.model_executor.layers.ops.rand
import
seeded_uniform
from
vllm.model_executor.utils
import
set_random_seed
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"use_3d"
,
[
True
,
False
])
def
test_seeded_uniform
(
dtype
:
torch
.
dtype
,
use_3d
:
bool
):
device
=
"cuda"
for
seed
in
range
(
512
):
set_random_seed
(
seed
)
rows
=
random
.
randint
(
1
,
512
)
cols
=
random
.
randint
(
1
,
64000
)
if
use_3d
:
third_dim
=
random
.
randint
(
2
,
10
)
dims
=
[
rows
,
third_dim
,
cols
]
else
:
dims
=
[
rows
,
cols
]
seeds
=
torch
.
randint
(
torch
.
iinfo
(
torch
.
long
).
min
,
torch
.
iinfo
(
torch
.
long
).
max
,
(
rows
,
),
device
=
device
)
# Test that the same seed produces the same output
out
=
seeded_uniform
(
*
dims
,
seeds
=
seeds
,
dtype
=
dtype
,
device
=
device
)
out2
=
seeded_uniform
(
*
dims
,
seeds
=
seeds
,
dtype
=
dtype
,
device
=
device
)
torch
.
testing
.
assert_close
(
out
,
out2
)
# del to save memory
del
out2
out3
=
seeded_uniform
(
*
dims
,
seeds
=
seeds
,
dtype
=
dtype
,
device
=
device
)
torch
.
testing
.
assert_close
(
out
,
out3
)
# del to save memory
del
out3
# Initialize out tensor with garbage to ensure that it is overwritten
out_with_tensor
=
seeded_uniform
(
*
dims
,
out
=
torch
.
full
(
(
*
dims
,
),
-
1
,
dtype
=
dtype
,
device
=
device
,
),
seeds
=
seeds
,
dtype
=
dtype
,
)
torch
.
testing
.
assert_close
(
out
,
out_with_tensor
)
tests/kernels/test_sampler.py
deleted
100644 → 0
View file @
cca61642
import
gc
from
unittest.mock
import
patch
import
pytest
import
torch
import
triton
import
triton.language
as
tl
from
vllm.model_executor.layers.ops.sample
import
(
_sample_triton
,
_uniform_to_exponential
,
sample
)
from
vllm.model_executor.sampling_metadata
import
SamplingTensors
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.triton_utils.libentry
import
LibEntry
from
vllm.triton_utils.sample
import
(
MAX_TRITON_N_COLS
,
get_num_triton_sampler_splits
)
SINGLE_SPLIT_VOCAB_SIZE
=
32000
# llama/mistral/mixtral vocab size
MULTI_SPLIT_VOCAB_SIZE
=
MAX_TRITON_N_COLS
+
100
@
pytest
.
fixture
(
autouse
=
True
)
def
_cleanup
():
yield
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
@
triton
.
jit
def
_uniform_to_exponential_kernel
(
input
,
output
,
n
:
tl
.
constexpr
):
idx
=
tl
.
arange
(
0
,
n
)
x
=
tl
.
load
(
input
+
idx
)
y
=
_uniform_to_exponential
(
x
)
tl
.
store
(
output
+
idx
,
y
)
def
test_uniform_to_exponential
():
"""Test that we can convert uniform to exponential without div by 0."""
input
=
torch
.
tensor
([
0.0
,
1.0
-
torch
.
finfo
(
torch
.
float32
).
eps
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
output
=
torch
.
zeros
(
input
.
shape
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
_uniform_to_exponential_kernel
[(
1
,
)](
input
,
output
,
2
)
assert
torch
.
all
(
torch
.
isfinite
(
output
))
assert
torch
.
all
(
output
>
0
)
assert
torch
.
all
(
torch
.
isfinite
(
torch
.
full_like
(
output
,
1.0
)
/
output
))
@
pytest
.
mark
.
parametrize
(
"random_sampling"
,
[
True
,
False
,
"mixed"
])
@
pytest
.
mark
.
parametrize
(
"max_best_of"
,
[
1
,
2
,
3
,
4
,
5
])
@
pytest
.
mark
.
parametrize
(
"modify_greedy_probs"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1337
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
SINGLE_SPLIT_VOCAB_SIZE
,
MULTI_SPLIT_VOCAB_SIZE
])
@
pytest
.
mark
.
parametrize
(
"save_logprobs"
,
[
True
,
False
])
def
test_sample_decoding_only
(
random_sampling
,
max_best_of
,
modify_greedy_probs
,
seed
,
vocab_size
,
save_logprobs
):
set_random_seed
(
seed
)
bs
=
8
probs
=
torch
.
zeros
((
bs
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
for
i
in
range
(
bs
):
probs
[
i
,
i
*
(
vocab_size
//
bs
)]
=
1.0
logprobs
=
torch
.
rand_like
(
probs
)
sample_indices
=
torch
.
arange
(
bs
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
n_splits
=
get_num_triton_sampler_splits
(
probs
.
shape
[
1
])
if
random_sampling
==
"mixed"
:
random_sampling_mask
=
(
torch
.
rand
(
(
1
,
bs
),
device
=
"cuda"
)
<
0.5
).
expand
(
n_splits
,
bs
)
elif
random_sampling
:
random_sampling_mask
=
torch
.
ones
((
n_splits
,
bs
),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
else
:
random_sampling_mask
=
torch
.
zeros
((
n_splits
,
bs
),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
seeds
=
torch
.
randint
(
1
,
torch
.
iinfo
(
torch
.
long
).
max
,
(
n_splits
,
bs
),
device
=
"cuda"
).
mul_
(
random_sampling_mask
)
#The current _sample_triton does not utilize the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with
patch
(
"vllm.model_executor.layers.ops.sample._sample_triton"
,
LibEntry
(
_sample_triton
)):
sampled_tokens
,
sampled_logprobs
,
sampled_modified_probs
=
sample
(
probs
=
probs
,
logprobs
=
logprobs
,
sample_indices
=
sample_indices
,
seeds
=
seeds
,
max_best_of
=
max_best_of
,
modify_greedy_probs
=
modify_greedy_probs
,
save_logprobs
=
save_logprobs
,
_save_modified_probs
=
True
)
assert
sampled_tokens
.
shape
==
(
bs
,
max_best_of
)
for
i
in
range
(
bs
):
assert
torch
.
all
(
sampled_tokens
[
i
]
==
i
*
(
vocab_size
//
bs
))
request_uses_random_sampling
=
random_sampling_mask
[
0
,
i
]
if
modify_greedy_probs
and
not
request_uses_random_sampling
:
# If we are modifying greedy probs and the request is greedy,
# we want to make sure the probs tensor is modified in place
torch
.
testing
.
assert_close
(
probs
[
i
][
sampled_tokens
[
i
]],
torch
.
full_like
(
probs
[
i
][
sampled_tokens
[
i
]],
1.0
))
assert
torch
.
sum
(
probs
[
i
])
==
1.0
torch
.
testing
.
assert_close
(
sampled_modified_probs
[
i
][
0
],
torch
.
full_like
(
sampled_modified_probs
[
i
][
0
],
1.0
))
elif
request_uses_random_sampling
:
# If the request is random, we want to make sure
# sampled_modified_probs tensor has noise added
# (and thus is different from probs tensor)
assert
not
torch
.
allclose
(
sampled_modified_probs
[
i
][
0
],
probs
[
i
][
sampled_tokens
[
i
]])
elif
not
request_uses_random_sampling
:
# If the request is greedy and we are not modifying greedy probs,
# we want to make sure sampled_modified_probs tensor is the same as
# the probs tensor.
torch
.
testing
.
assert_close
(
sampled_modified_probs
[
i
],
probs
[
i
][
sampled_tokens
[
i
]])
if
save_logprobs
:
assert
sampled_logprobs
.
shape
==
(
bs
,
max_best_of
)
for
i
in
range
(
bs
):
for
best_of
in
range
(
max_best_of
):
assert
torch
.
all
(
sampled_logprobs
[
i
]
==
logprobs
[
i
][
sampled_tokens
[
i
,
best_of
]])
else
:
assert
sampled_logprobs
is
None
@
pytest
.
mark
.
parametrize
(
"random_sampling"
,
[
True
,
False
,
"mixed"
])
@
pytest
.
mark
.
parametrize
(
"max_best_of"
,
[
1
,
2
,
3
,
4
,
5
])
@
pytest
.
mark
.
parametrize
(
"modify_greedy_probs"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1337
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
SINGLE_SPLIT_VOCAB_SIZE
,
MULTI_SPLIT_VOCAB_SIZE
])
def
test_sample_prompt_logprobs
(
random_sampling
,
max_best_of
,
modify_greedy_probs
,
seed
,
vocab_size
):
set_random_seed
(
seed
)
prompt_sizes
=
[
16
,
32
,
64
,
128
]
*
2
samples
=
8
bs
=
samples
+
sum
(
prompt_sizes
)
probs
=
torch
.
zeros
((
bs
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
for
i
in
range
(
bs
):
probs
[
i
,
i
*
(
vocab_size
//
bs
)]
=
1.0
logprobs
=
torch
.
rand_like
(
probs
)
sample_indices
=
torch
.
tensor
(
prompt_sizes
,
dtype
=
torch
.
long
,
device
=
"cuda"
).
cumsum_
(
0
)
n_splits
=
get_num_triton_sampler_splits
(
probs
.
shape
[
1
])
if
random_sampling
==
"mixed"
:
random_sampling_mask
=
torch
.
rand
(
(
n_splits
,
samples
),
device
=
"cuda"
)
<
0.5
elif
random_sampling
:
random_sampling_mask
=
torch
.
ones
((
n_splits
,
samples
),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
else
:
random_sampling_mask
=
torch
.
zeros
((
n_splits
,
samples
),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
seeds
=
torch
.
randint
(
1
,
torch
.
iinfo
(
torch
.
long
).
max
,
(
n_splits
,
samples
),
device
=
"cuda"
).
mul_
(
random_sampling_mask
)
#ditto
with
patch
(
"vllm.model_executor.layers.ops.sample._sample_triton"
,
LibEntry
(
_sample_triton
)):
sampled_tokens
,
sampled_logprobs
,
_
=
sample
(
probs
=
probs
,
logprobs
=
logprobs
,
sample_indices
=
sample_indices
,
seeds
=
seeds
,
max_best_of
=
max_best_of
,
modify_greedy_probs
=
modify_greedy_probs
,
save_logprobs
=
True
)
assert
sampled_tokens
.
shape
==
(
samples
,
max_best_of
)
assert
sampled_logprobs
.
shape
==
(
samples
,
max_best_of
)
for
i
,
t
in
enumerate
(
sample_indices
):
assert
torch
.
all
(
sampled_tokens
[
i
]
==
t
*
(
vocab_size
//
bs
))
for
best_of
in
range
(
max_best_of
):
assert
torch
.
all
(
sampled_logprobs
[
i
]
==
logprobs
[
sample_indices
[
i
]]
[
sampled_tokens
[
i
,
best_of
]])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
16
)))
def
test_get_sequence_seeds
(
seed
):
"""Ensure that we get a different child seed from base
seed + extra entropy"""
starting_seed
=
seed
seq_seed
=
None
extra_entropy
=
1
for
i
in
range
(
512
):
new_seq_seed
=
SamplingTensors
.
_get_sequence_seeds
(
starting_seed
,
i
,
seeds_to_generate
=
1
,
is_greedy
=
False
)[
0
]
new_seq_seed_extra_entropy
=
SamplingTensors
.
_get_sequence_seeds
(
starting_seed
,
i
,
extra_entropy
,
seeds_to_generate
=
1
,
is_greedy
=
False
)[
0
]
assert
new_seq_seed_extra_entropy
!=
new_seq_seed
assert
seq_seed
!=
new_seq_seed
seq_seed
=
new_seq_seed
vllm/model_executor/layers/ops/__init__.py
deleted
100644 → 0
View file @
cca61642
vllm/model_executor/layers/ops/rand.py
deleted
100644 → 0
View file @
cca61642
from
typing
import
Optional
,
Union
import
torch
import
triton
import
triton.language
as
tl
def
seeded_uniform
(
*
size
,
seeds
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
]]
=
None
,
pin_memory
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
"""Similar to torch.rand, but allows for seeds to be set per row.
seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
If it is 3d, the additional seeds needed will be derived automatically
in a deterministic fashion:
[
row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
]
"""
n_dims
=
len
(
size
)
if
n_dims
>
3
:
raise
ValueError
(
"seeded_uniform only supports up to 3D tensors"
)
if
out
is
None
:
out
=
torch
.
empty
(
*
size
,
dtype
=
dtype
,
device
=
device
,
pin_memory
=
pin_memory
)
elif
out
.
shape
!=
size
:
raise
ValueError
(
"shape of out and size must be the same"
)
if
n_dims
==
3
:
n_rows
,
n_3d
,
n_cols
=
out
.
shape
stride_row
=
out
.
stride
(
0
)
stride_3d
=
out
.
stride
(
1
)
elif
n_dims
==
2
:
n_rows
,
n_cols
=
out
.
shape
n_3d
=
1
stride_row
=
out
.
stride
(
0
)
stride_3d
=
1
else
:
n_cols
=
out
.
shape
[
0
]
n_rows
=
1
n_3d
=
1
stride_row
=
1
stride_3d
=
1
if
seeds
.
ndim
!=
1
:
raise
ValueError
(
"seeds must be a 1D tensor"
)
if
seeds
.
numel
()
!=
n_rows
:
raise
ValueError
(
"seeds must have the same number of elements as out has rows"
)
# The philox PRNG Triton uses generates 4 random numbers at once.
# Therefore, the most efficient use of it is to divide the
# block size by 4, and then save the generated random numbers to
# each of the 4 slices of the tensor.
full_block_size
=
triton
.
next_power_of_2
(
n_cols
)
philox_block_size
=
max
(
full_block_size
//
4
,
1
)
n_slices
=
full_block_size
//
philox_block_size
num_warps
=
4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if
philox_block_size
>=
8192
:
num_warps
=
32
elif
philox_block_size
>=
4096
:
num_warps
=
16
elif
philox_block_size
>=
2048
:
num_warps
=
8
_seeded_uniform_triton
[(
n_rows
,
n_3d
)](
out
,
seeds
,
stride_row
,
stride_3d
,
seeds
.
stride
(
0
),
n_rows
,
n_3d
,
n_cols
,
n_slices
=
n_slices
,
num_warps
=
num_warps
,
block_size
=
philox_block_size
,
)
return
out
@
triton
.
jit
def
_seeded_uniform_triton
(
out_ptr
:
torch
.
Tensor
,
seed_ptr
:
torch
.
Tensor
,
out_row_stride
:
int
,
out_3d_stride
:
int
,
seed_row_stride
:
int
,
n_rows
:
int
,
n_3d
:
int
,
n_cols
:
int
,
n_slices
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
):
"""
Generate a random float32 number in [0, 1) for each element in the output
tensor. The random numbers in a row generated using the seed for that row.
Args:
out_ptr: The output tensor.
seed_ptr: The per-row seeds to use for random number generation.
out_row_stride: The stride between rows of the output tensor.
out_3d_stride: The stride between 3D slices of the output tensor.
seed_row_stride: The stride between rows of the seed tensor.
n_rows: The number of rows in the output tensor.
n_3d: The size of second dimension of the output tensor,
if output tensor is 3D.
n_cols: The number of columns in the output tensor.
n_slices: The number of philox outputs to use.
"""
tl
.
static_assert
(
n_slices
>
0
and
n_slices
<=
4
,
"0 < n_slices <= 4"
)
# Get the row index.
row_idx
=
tl
.
program_id
(
axis
=
0
)
three_d_idx
=
tl
.
program_id
(
axis
=
1
)
philox_offsets
=
tl
.
arange
(
0
,
block_size
)
# Get the seed for the current element.
seed
=
tl
.
load
(
seed_ptr
+
row_idx
*
seed_row_stride
)
if
three_d_idx
>
0
:
seed
^=
three_d_idx
# Generate random numbers in [0, 1).
out1
,
out2
,
out3
,
out4
=
tl
.
rand4x
(
seed
,
philox_offsets
)
output_row_start_ptr
=
(
out_ptr
+
row_idx
*
out_row_stride
+
three_d_idx
*
out_3d_stride
)
out1_offsets
=
philox_offsets
tl
.
store
(
output_row_start_ptr
+
out1_offsets
,
out1
,
mask
=
out1_offsets
<
n_cols
)
if
n_slices
>
1
:
out2_offsets
=
tl
.
arange
(
block_size
,
block_size
*
2
)
tl
.
store
(
output_row_start_ptr
+
out2_offsets
,
out2
,
mask
=
out2_offsets
<
n_cols
)
if
n_slices
>
2
:
out3_offsets
=
tl
.
arange
(
block_size
*
2
,
block_size
*
3
)
tl
.
store
(
output_row_start_ptr
+
out3_offsets
,
out3
,
mask
=
out3_offsets
<
n_cols
)
if
n_slices
>
3
:
out4_offsets
=
tl
.
arange
(
block_size
*
3
,
block_size
*
4
)
tl
.
store
(
output_row_start_ptr
+
out4_offsets
,
out4
,
mask
=
out4_offsets
<
n_cols
)
vllm/model_executor/layers/ops/sample.py
deleted
100644 → 0
View file @
cca61642
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
vllm.model_executor.layers.ops.rand
import
seeded_uniform
from
vllm.triton_utils.sample
import
get_num_triton_sampler_splits
_EPS
:
tl
.
constexpr
=
1e-6
def
_multi_split_sample
(
probs
:
torch
.
Tensor
,
seeds
:
torch
.
Tensor
,
n_splits
:
int
,
sampled_tokens_size
:
Tuple
[
int
,
int
],
sampled_logprobs_size
:
Tuple
[
int
,
int
],
sample_indices
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
*
,
modify_greedy_probs
:
bool
=
False
,
save_logprobs
:
bool
=
False
,
):
"""Sample tokens where vocab size is split into multiple parts
(too large for Triton otherwise)."""
assert
seeds
.
ndim
==
2
and
seeds
.
shape
[
0
]
==
n_splits
split_probs
=
probs
.
tensor_split
(
n_splits
,
1
)
split_logprobs
=
logprobs
.
tensor_split
(
n_splits
,
1
)
sampled_tokens_tmp
=
[
torch
.
empty
(
sampled_tokens_size
,
dtype
=
torch
.
long
,
device
=
probs
.
device
)
for
_
in
range
(
n_splits
)
]
sampled_logprobs_tmp
=
[
torch
.
empty
(
sampled_logprobs_size
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
for
_
in
range
(
n_splits
)
]
# We are purposefuly using sampled_tokens_size as we need to always
# save modified probs in this case.
sampled_modified_probs_tmp
=
[
torch
.
empty
(
sampled_tokens_size
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
for
_
in
range
(
n_splits
)
]
for
i
in
range
(
n_splits
):
n_samples
=
sample_indices
.
shape
[
0
]
n_cols
=
split_probs
[
i
].
shape
[
1
]
n_best
=
sampled_tokens_tmp
[
i
].
shape
[
1
]
uniform_noise
=
seeded_uniform
(
n_samples
,
n_best
,
n_cols
,
seeds
=
seeds
[
i
].
flatten
(),
device
=
split_probs
[
i
].
device
,
dtype
=
split_probs
[
i
].
dtype
)
# TODO(yard1): See if we can remove the contiguous() calls.
# Will need kernel support.
_sample
(
split_probs
[
i
].
contiguous
(),
split_logprobs
[
i
].
contiguous
(),
sample_indices
,
sampled_tokens_tmp
[
i
],
sampled_logprobs_tmp
[
i
],
sampled_modified_probs_tmp
[
i
],
seeds
[
i
],
uniform_noise
,
modify_greedy_probs
=
False
,
save_logprobs
=
save_logprobs
,
save_modified_probs
=
True
,
)
if
i
>
0
:
# Add offset to sampled tokens
sampled_tokens_tmp
[
i
].
add_
(
i
*
split_probs
[
i
-
1
].
shape
[
1
])
sampled_tokens
=
torch
.
stack
(
sampled_tokens_tmp
)
sampled_modified_probs
=
torch
.
stack
(
sampled_modified_probs_tmp
)
# Reduce the results from the splits.
sampled_modified_probs
,
indices
=
torch
.
max
(
sampled_modified_probs
,
dim
=
0
,
keepdim
=
True
)
sampled_tokens
=
sampled_tokens
.
gather
(
0
,
indices
).
squeeze
(
0
)
if
save_logprobs
:
sampled_logprobs
=
torch
.
stack
(
sampled_logprobs_tmp
)
sampled_logprobs
=
sampled_logprobs
.
gather
(
0
,
indices
).
squeeze
(
0
)
else
:
sampled_logprobs
=
None
sampled_modified_probs
=
sampled_modified_probs
.
squeeze
(
0
)
if
modify_greedy_probs
:
# We need to modify the greedy probs for the sampled tokens.
# We can't do this in the kernel as we need to know the
# sampled tokens.
probs
.
fill_
(
0.0
)
probs
.
scatter_
(
1
,
sampled_tokens
,
1.0
)
return
(
sampled_tokens
,
sampled_logprobs
,
sampled_modified_probs
)
def
sample
(
probs
:
torch
.
Tensor
,
seeds
:
torch
.
Tensor
,
*
,
max_best_of
:
int
=
1
,
sample_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
logprobs
:
Optional
[
torch
.
Tensor
]
=
None
,
modify_greedy_probs
:
bool
=
False
,
save_logprobs
:
bool
=
False
,
_save_modified_probs
:
bool
=
False
,
# pylint: disable=invalid-name
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
"""Sample tokens from probs. with per-sequence seeds.
Can sample from a subset of sequences through sample_indices.
Args:
probs: Probabilities to sample from.
shape = [batch_size, vocab_size]
seeds: Per-sequence seed values.
shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
max_best_of: Number of samples to generate per sequence.
Sequence seed will be incremented by 1 each time.
sample_indices: Indices of sequences to sample from.
If not provided, will sample from all sequences.
shape = [n]
logprobs: Log-probabilities of the sampled tokens.
Only used for saving the logprobs if save_logprobs is True.
shape = [batch_size, vocab_size]
modify_greedy_probs: Whether to modify the greedy probabilities
for speculative sampling (sampled token = 1.0,
everything else = 0.0).
save_logprobs: Whether to save the log-probabilities of the
sampled tokens to a tensor.
_save_modified_probs: Whether to save the modified probabilities
(including gumbel noise) of the sampled tokens to a tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
This is exposed only for testing.
Returns:
sampled_tokens: shape = [n, max_best_of]
sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
sampled_modified_probs: shape = [n, max_best_of]
if save_modified_probs else None
"""
if
sample_indices
is
None
:
sample_indices
=
torch
.
arange
(
0
,
probs
.
shape
[
0
],
device
=
probs
.
device
)
sampled_tokens_size
=
(
sample_indices
.
size
(
0
),
max_best_of
)
if
save_logprobs
:
if
logprobs
is
None
:
raise
ValueError
(
"logprobs tensor must be provided if save_logprobs is True"
)
sampled_logprobs_size
=
sampled_tokens_size
else
:
# Empty tensors to invoke the kernel
sampled_logprobs_size
=
(
0
,
0
)
logprobs
=
probs
assert
logprobs
is
not
None
if
_save_modified_probs
:
sampled_modified_probs_size
=
sampled_tokens_size
else
:
# Empty tensors to invoke the kernel
sampled_modified_probs_size
=
(
0
,
0
)
# If the number of columns in probs is too large for Triton to handle,
# we split the tensor and sample from each split separately, and then
# do an argmax+gather to combine the results.
n_splits
=
get_num_triton_sampler_splits
(
probs
.
shape
[
1
])
if
n_splits
>
1
:
(
sampled_tokens
,
sampled_logprobs
,
sampled_modified_probs
)
=
_multi_split_sample
(
probs
,
seeds
,
n_splits
,
sampled_tokens_size
,
sampled_logprobs_size
,
sample_indices
,
logprobs
=
logprobs
,
modify_greedy_probs
=
modify_greedy_probs
,
save_logprobs
=
save_logprobs
)
else
:
sampled_tokens
=
torch
.
empty
(
sampled_tokens_size
,
dtype
=
torch
.
long
,
device
=
probs
.
device
)
sampled_logprobs
=
torch
.
empty
(
sampled_logprobs_size
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
sampled_modified_probs
=
torch
.
empty
(
sampled_modified_probs_size
,
dtype
=
probs
.
dtype
,
device
=
probs
.
device
)
n_samples
=
sample_indices
.
shape
[
0
]
n_cols
=
probs
.
shape
[
1
]
uniform_noise
=
seeded_uniform
(
n_samples
,
max_best_of
,
n_cols
,
seeds
=
seeds
.
flatten
(),
device
=
probs
.
device
,
dtype
=
probs
.
dtype
)
_sample
(
probs
,
logprobs
,
sample_indices
,
sampled_tokens
,
sampled_logprobs
,
sampled_modified_probs
,
seeds
,
uniform_noise
,
modify_greedy_probs
=
modify_greedy_probs
,
save_logprobs
=
save_logprobs
,
save_modified_probs
=
_save_modified_probs
,
)
return
(
sampled_tokens
,
sampled_logprobs
if
save_logprobs
else
None
,
sampled_modified_probs
if
_save_modified_probs
else
None
)
def
_sample
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sample_indices
:
torch
.
Tensor
,
output_samples
:
torch
.
Tensor
,
output_logprobs
:
torch
.
Tensor
,
output_modified_probs
:
torch
.
Tensor
,
seeds
:
torch
.
Tensor
,
uniform_noise
:
torch
.
Tensor
,
*
,
modify_greedy_probs
:
bool
=
False
,
save_logprobs
:
bool
=
True
,
save_modified_probs
:
bool
=
False
)
->
torch
.
Tensor
:
"""Sample tokens from probs.
Args:
probs [batch_size, vocab_size]: probs to sample from.
logprobs [batch_size, vocab_size]: logprobs (used when
save_logprobsis True).
sample_indices [n]: Indices of the samples to use for each row of probs.
output_samples [n, n_best]: Output tensor to store samples in.
output_logprobs [n, n_best]: Output tensor to store logprobs in.
output_modified_probs [n, n_best]: Output tensor to store
probs of chosen tokens in (modified with noise).
seeds [n]: Seeds to use for sampling. If the seed is 0, we use
greedy sampling. Note this is ONLY used for determining
whether to use random sampling or not. The actual random
noise should be passed as uniform_noise.
uniform_noise [batch_size, n_best, vocab_size]: Uniform
noise to use for random sampling (will be converted
to exponential gumbel noise by the kernel).
modify_greedy_probs: If True, we modify the probs tensor in-place
to encode the sampling method used for each row. This is used
in speculative decoding. Only applies in greedy decoding.
save_logprobs: If True, we save the logprobs of the sampled tokens
in the output_logprobs tensor.
save_modified_probs: If True, we save the modified probs (with noise)
of the sampled tokens in the output_modified_probs tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
"""
n_samples
=
sample_indices
.
shape
[
0
]
n_cols
=
probs
.
shape
[
1
]
n_best
=
output_samples
.
shape
[
1
]
if
len
(
output_samples
.
shape
)
>
1
else
1
# The block size is the smallest power of two greater than the number of
# columns in probs
block_size
=
triton
.
next_power_of_2
(
n_cols
)
num_warps
=
4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if
block_size
>=
8192
:
num_warps
=
32
elif
block_size
>=
4096
:
num_warps
=
16
elif
block_size
>=
2048
:
num_warps
=
8
# Enqueue kernel. The 1D launch grid is simple: we have one kernel
# instance per row of the probs matrix
_sample_triton
[(
n_samples
,
n_best
)](
sample_indices
,
output_samples
,
output_logprobs
,
output_modified_probs
,
probs
,
logprobs
,
seeds
,
uniform_noise
,
output_samples
.
stride
(
0
),
probs
.
stride
(
0
),
uniform_noise
.
stride
(
0
),
uniform_noise
.
stride
(
1
)
if
n_best
>
1
else
1
,
n_samples
,
n_cols
,
n_best
,
num_warps
=
num_warps
,
block_size
=
block_size
,
modify_greedy_probs
=
modify_greedy_probs
,
save_logprobs
=
save_logprobs
,
save_modified_probs
=
save_modified_probs
,
)
return
output_samples
,
output_logprobs
,
output_modified_probs
@
triton
.
jit
def
_uniform_to_exponential
(
uniform_noise
):
"""Convert uniform samples to exponential samples."""
# tl.rand returns values in [0, 1), so we clamp lower bound
# to _EPS to avoid log(0) and thus division by 0 later
lb
=
tl
.
full
(
uniform_noise
.
shape
,
_EPS
,
uniform_noise
.
dtype
)
uniform_noise
=
tl
.
maximum
(
uniform_noise
,
lb
)
# Use the inversion method to turn uniform samples
# into exponential samples
exponential_noise
=
-
tl
.
log
(
uniform_noise
)
return
exponential_noise
@
triton
.
jit
def
_sample_triton
(
sample_indices_ptr
:
torch
.
Tensor
,
output_ptr
:
torch
.
Tensor
,
output_logprobs_ptr
:
torch
.
Tensor
,
output_modified_probs_ptr
:
torch
.
Tensor
,
probs_ptr
:
torch
.
Tensor
,
logprobs_ptr
:
torch
.
Tensor
,
seeds_ptr
:
torch
.
Tensor
,
uniform_noise_ptr
:
torch
.
Tensor
,
output_row_stride
:
int
,
probs_row_stride
:
int
,
uniform_noise_row_stride
:
int
,
uniform_noise_best_stride
:
int
,
n_samples
:
int
,
n_cols
:
int
,
n_best
:
int
,
block_size
:
tl
.
constexpr
,
modify_greedy_probs
:
tl
.
constexpr
,
save_logprobs
:
tl
.
constexpr
,
save_modified_probs
:
tl
.
constexpr
):
# The rows are independent, so we parallelize across those
sample_idx
=
tl
.
program_id
(
0
)
best_idx
=
tl
.
program_id
(
1
)
# Load the row index from DRAM
row_idx
=
tl
.
load
(
sample_indices_ptr
+
sample_idx
)
seed
=
tl
.
load
(
seeds_ptr
+
sample_idx
)
uses_random_sampling
=
seed
!=
0
# The stride represents how much we need to increase the
# pointer to advance 1 row
row_start_ptr
=
probs_ptr
+
row_idx
*
probs_row_stride
# The block size is the next power of two greater than n_cols,
# so we can fit each row in a single block
col_offsets
=
tl
.
arange
(
0
,
block_size
)
# Load the row into SRAM, using a mask since block_size may be > than n_cols
row
=
tl
.
load
(
row_start_ptr
+
col_offsets
,
mask
=
col_offsets
<
n_cols
,
other
=
float
(
"-inf"
))
if
uses_random_sampling
:
uniform_noise_start_ptr
=
(
uniform_noise_ptr
+
sample_idx
*
uniform_noise_row_stride
+
best_idx
*
uniform_noise_best_stride
)
uniform_noise
=
tl
.
load
(
uniform_noise_start_ptr
+
col_offsets
,
mask
=
col_offsets
<
n_cols
,
other
=
0.5
)
exponential_noise
=
_uniform_to_exponential
(
uniform_noise
)
row
/=
exponential_noise
sampled_value
,
sampled_token
=
tl
.
max
(
row
,
axis
=
0
,
return_indices
=
True
)
# clamp sampled token to n_cols - 1
# this should not be necessary, but we do it
# just in case
if
sampled_token
>=
n_cols
:
sampled_token
=
n_cols
-
1
# Write back output to DRAM
output_row_start_ptr
=
(
output_ptr
+
sample_idx
*
output_row_stride
+
best_idx
)
tl
.
store
(
output_row_start_ptr
,
sampled_token
)
if
modify_greedy_probs
:
# noqa
if
not
uses_random_sampling
:
# Set the probability of the sampled token to 1, all other
# tokens to zero. This is used in speculative decoding where
# the sampling method must be encoded within the sampled
# probability distributions.
row
=
tl
.
where
(
col_offsets
==
sampled_token
,
1.0
,
0.0
)
tl
.
store
(
row_start_ptr
+
col_offsets
,
row
,
mask
=
col_offsets
<
n_cols
)
if
save_modified_probs
:
output_row_start_ptr
=
(
output_modified_probs_ptr
+
sample_idx
*
output_row_stride
+
best_idx
)
tl
.
store
(
output_row_start_ptr
,
sampled_value
)
if
save_logprobs
:
# Load the row into SRAM, using a mask since block_size
# may be > than n_cols
sampled_logprob
=
tl
.
load
(
logprobs_ptr
+
row_idx
*
probs_row_stride
+
sampled_token
)
# Write back output to DRAM
output_row_start_ptr
=
(
output_logprobs_ptr
+
sample_idx
*
output_row_stride
+
best_idx
)
tl
.
store
(
output_row_start_ptr
,
sampled_logprob
)
vllm/model_executor/layers/sampler.py
View file @
546034b4
...
@@ -10,12 +10,6 @@ import msgspec
...
@@ -10,12 +10,6 @@ import msgspec
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
SamplingTensors
,
SamplingTensors
,
...
@@ -23,6 +17,7 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
...
@@ -23,6 +17,7 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
import
flashinfer.sampling
...
@@ -777,7 +772,7 @@ def _sample_with_torch(
...
@@ -777,7 +772,7 @@ def _sample_with_torch(
# Counterintiutively, having two loops here is actually faster.
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
# The first loop can run without waiting on GPU<->CPU sync.
for
sampling_type
in
SamplingType
:
for
sampling_type
in
SamplingType
:
sample_indices
=
categorized_sample_indices
[
sampling_type
]
[:,
0
]
sample_indices
=
categorized_sample_indices
[
sampling_type
]
num_tokens
=
len
(
sample_indices
)
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
if
num_tokens
==
0
:
continue
continue
...
@@ -863,88 +858,6 @@ def _sample_with_torch(
...
@@ -863,88 +858,6 @@ def _sample_with_torch(
)
)
def
_sample_with_triton_kernel
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
)
->
SampleResultType
:
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_metadata
:
Dict
[
SamplingType
,
Tuple
[
List
[
int
],
List
[
SequenceGroupToSample
],
torch
.
Tensor
,
torch
.
Tensor
]]
=
{}
max_best_of_in_batch
=
1
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for
sampling_type
in
SamplingType
:
sample_indices
=
categorized_sample_indices
[
sampling_type
][:,
0
]
sampled_token_indices
=
categorized_sample_indices
[
sampling_type
][:,
1
]
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
continue
seq_group_id
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_id
]
sample_metadata
[
sampling_type
]
=
(
seq_group_id
,
seq_groups
,
sample_indices
,
sampled_token_indices
)
if
sampling_type
in
(
SamplingType
.
GREEDY
,
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
for
seq_group
in
seq_groups
:
if
seq_group
.
is_prompt
:
sampling_params
=
seq_group
.
sampling_params
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
sampling_params
.
best_of
)
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
sampled_tokens
,
_
,
_
=
sample_triton
(
probs
=
probs
,
seeds
=
sampling_tensors
.
sampling_seeds
,
max_best_of
=
max_best_of_in_batch
,
sample_indices
=
sampling_tensors
.
sample_indices
,
logprobs
=
logprobs
,
# don't save logprobs because we have logic for that below
# TODO: use this instead of the CPU-based logic below
save_logprobs
=
False
,
)
# GPU<->CPU sync happens in the loop below.
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
continue
(
seq_group_id
,
seq_groups
,
sample_indices
,
sampled_token_indices
)
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
sampled_tokens
[
sampled_token_indices
][:,
0
])
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
seq_groups
,
sampled_tokens
[
sampled_token_indices
])
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
sample_results
=
[
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
return
sample_results
def
_sample
(
def
_sample
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
...
@@ -974,10 +887,6 @@ def _sample(
...
@@ -974,10 +887,6 @@ def _sample(
modify_greedy_probs
=
modify_greedy_probs
,
modify_greedy_probs
=
modify_greedy_probs
,
)
)
# TODO: Enable once Triton kernel & associated code is faster.
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
# sampling_tensors)
def
_get_ranks
(
x
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_get_ranks
(
x
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
...
...
vllm/model_executor/sampling_metadata.py
View file @
546034b4
import
random
from
array
import
array
from
array
import
array
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
@@ -8,15 +7,10 @@ import torch
...
@@ -8,15 +7,10 @@ import torch
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.triton_utils.sample
import
get_num_triton_sampler_splits
from
vllm.utils
import
(
PyObjectCache
,
async_tensor_h2d
,
from
vllm.utils
import
(
PyObjectCache
,
async_tensor_h2d
,
is_pin_memory_available
,
make_tensor_with_pad
,
is_pin_memory_available
,
make_tensor_with_pad
)
maybe_expand_dim
)
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
_SEED_0_REPLACEMENT
=
3403598558
# Some triton sampler related code is guarded before it is ready.
_USE_TRITON_SAMPLER
=
False
@
dataclass
@
dataclass
...
@@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int):
...
@@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int):
generator
=
None
,
generator
=
None
,
is_prompt
=
True
,
is_prompt
=
True
,
prompt_logprob_indices
=
[],
prompt_logprob_indices
=
[],
sample_indices
=
[])
sample_indices
=
[],
)
class
SamplingMetadataCache
:
class
SamplingMetadataCache
:
"""Used to cache SamplingMetadata objects between scheduler iterations
"""Used to cache SamplingMetadata objects between scheduler iterations"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_seq_group_to_sample_cache
:
Dict
[
int
,
PyObjectCache
]
=
{}
self
.
_seq_group_to_sample_cache
:
Dict
[
int
,
PyObjectCache
]
=
{}
...
@@ -165,16 +159,19 @@ class SamplingMetadata:
...
@@ -165,16 +159,19 @@ class SamplingMetadata:
num_prompts
,
num_prompts
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
device
,
generators
,
cache
)
device
,
generators
,
cache
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
target_device
=
device
,
target_device
=
device
,
pin_memory
=
pin_memory
)
pin_memory
=
pin_memory
,
)
categorized_sample_indices
=
{
categorized_sample_indices
=
{
t
:
maybe_expand_dim
(
t
:
async_tensor_h2d
(
async_tensor_h2d
(
seq_ids
,
seq_ids
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
target_device
=
device
,
target_device
=
device
,
pin_memory
=
pin_memory
),
2
,
2
)
pin_memory
=
pin_memory
,
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
}
...
@@ -201,8 +198,8 @@ def _prepare_seq_groups(
...
@@ -201,8 +198,8 @@ def _prepare_seq_groups(
device
:
str
,
device
:
str
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
cache
:
Optional
[
SamplingMetadataCache
]
=
None
,
cache
:
Optional
[
SamplingMetadataCache
]
=
None
,
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
SamplingType
,
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
List
[
int
]]
,
int
,
]:
"""Prepare sequence groups and indices for sampling.
"""Prepare sequence groups and indices for sampling.
Args:
Args:
...
@@ -233,16 +230,13 @@ def _prepare_seq_groups(
...
@@ -233,16 +230,13 @@ def _prepare_seq_groups(
# Sampling type -> (
# Sampling type -> (
# indices to sample/prompt logprob within pruned output logits,
# indices to sample/prompt logprob within pruned output logits,
# indices to sample within pruned logits)
# indices to sample within pruned logits)
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]
]]
=
{
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
t
:
[]
for
t
in
SamplingType
for
t
in
SamplingType
}
}
# Index of logits to compute logprob. Logits include both prompt logprob
# Index of logits to compute logprob. Logits include both prompt logprob
# and sample logprob indices.
# and sample logprob indices.
logit_idx
=
0
logit_idx
=
0
# Index to sample from a sample tensor. It is used by triton sample kernel.
# See `_sample_with_triton_kernel` for more details.
sample_idx
=
0
# Total number of prompts from given sequence groups.
# Total number of prompts from given sequence groups.
num_prompts
=
0
num_prompts
=
0
...
@@ -264,10 +258,10 @@ def _prepare_seq_groups(
...
@@ -264,10 +258,10 @@ def _prepare_seq_groups(
# If the current seq group is in decode stage, it is None.
# If the current seq group is in decode stage, it is None.
seq_len
:
Optional
[
int
]
=
None
seq_len
:
Optional
[
int
]
=
None
query_len
:
Optional
[
int
]
=
None
query_len
:
Optional
[
int
]
=
None
prompt_logprob_indices
:
List
[
int
]
=
\
prompt_logprob_indices
:
List
[
int
]
=
(
sample_obj
.
prompt_logprob_indices
sample_obj
.
prompt_logprob_indices
if
cache
is
not
None
else
[]
if
cache
is
not
None
else
[]
)
sample_indices
:
List
[
int
]
=
\
sample_indices
:
List
[
int
]
=
(
sample_obj
.
sample_indices
sample_obj
.
sample_indices
if
cache
is
not
None
else
[]
if
cache
is
not
None
else
[]
)
do_sample
=
seq_group_metadata
.
do_sample
do_sample
=
seq_group_metadata
.
do_sample
if
seq_group_metadata
.
is_prompt
:
if
seq_group_metadata
.
is_prompt
:
...
@@ -333,11 +327,8 @@ def _prepare_seq_groups(
...
@@ -333,11 +327,8 @@ def _prepare_seq_groups(
if
do_sample
:
if
do_sample
:
sample_indices
.
extend
(
range
(
logit_idx
,
logit_idx
+
sample_len
))
sample_indices
.
extend
(
range
(
logit_idx
,
logit_idx
+
sample_len
))
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
list
(
list
(
range
(
logit_idx
,
logit_idx
+
sample_len
)))
zip
(
range
(
logit_idx
,
logit_idx
+
sample_len
),
range
(
sample_idx
,
sample_idx
+
sample_len
))))
logit_idx
+=
sample_len
logit_idx
+=
sample_len
sample_idx
+=
sample_len
if
cache
is
not
None
:
if
cache
is
not
None
:
sample_obj
.
sampling_params
=
sampling_params
sample_obj
.
sampling_params
=
sampling_params
...
@@ -356,7 +347,8 @@ def _prepare_seq_groups(
...
@@ -356,7 +347,8 @@ def _prepare_seq_groups(
generator
=
generator
,
generator
=
generator
,
is_prompt
=
is_prompt
,
is_prompt
=
is_prompt
,
prompt_logprob_indices
=
list
(
prompt_logprob_indices
),
prompt_logprob_indices
=
list
(
prompt_logprob_indices
),
sample_indices
=
list
(
sample_indices
))
sample_indices
=
list
(
sample_indices
),
)
seq_groups
.
append
(
sample_obj
)
seq_groups
.
append
(
sample_obj
)
...
@@ -378,9 +370,6 @@ class SamplingTensors:
...
@@ -378,9 +370,6 @@ class SamplingTensors:
presence_penalties
:
torch
.
Tensor
presence_penalties
:
torch
.
Tensor
frequency_penalties
:
torch
.
Tensor
frequency_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
sampling_seeds
:
torch
.
Tensor
sample_indices
:
torch
.
Tensor
extra_seeds
:
Optional
[
torch
.
Tensor
]
prompt_tokens
:
torch
.
Tensor
prompt_tokens
:
torch
.
Tensor
output_tokens
:
torch
.
Tensor
output_tokens
:
torch
.
Tensor
...
@@ -391,15 +380,7 @@ class SamplingTensors:
...
@@ -391,15 +380,7 @@ class SamplingTensors:
vocab_size
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
*
,
extra_seeds_to_generate
:
int
=
0
,
extra_entropy
:
Optional
[
Tuple
[
int
,
...]]
=
None
)
->
Tuple
[
"SamplingTensors"
,
bool
,
bool
,
bool
]:
)
->
Tuple
[
"SamplingTensors"
,
bool
,
bool
,
bool
]:
"""
extra_seeds_to_generate: extra seeds to generate using the
user-defined seed for each sequence.
extra_entropy: extra entropy to use when generating seeds.
"""
prompt_tokens
:
List
[
array
]
=
[]
prompt_tokens
:
List
[
array
]
=
[]
output_tokens
:
List
[
array
]
=
[]
output_tokens
:
List
[
array
]
=
[]
top_ks
:
List
[
int
]
=
[]
top_ks
:
List
[
int
]
=
[]
...
@@ -409,19 +390,10 @@ class SamplingTensors:
...
@@ -409,19 +390,10 @@ class SamplingTensors:
presence_penalties
:
List
[
float
]
=
[]
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
sampling_seeds
:
List
[
int
]
=
[]
sample_indices
:
List
[
int
]
=
[]
do_penalties
=
False
do_penalties
=
False
do_top_p_top_k
=
False
do_top_p_top_k
=
False
do_min_p
=
False
do_min_p
=
False
if
_USE_TRITON_SAMPLER
:
prompt_best_of
:
List
[
int
]
=
[]
# We need one base seed per Triton slice.
seeds_to_generate
=
(
extra_seeds_to_generate
+
get_num_triton_sampler_splits
(
vocab_size
))
assert
sampling_metadata
.
seq_groups
is
not
None
assert
sampling_metadata
.
seq_groups
is
not
None
for
seq_group
in
sampling_metadata
.
seq_groups
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
...
@@ -452,7 +424,7 @@ class SamplingTensors:
...
@@ -452,7 +424,7 @@ class SamplingTensors:
do_penalties
=
True
do_penalties
=
True
is_prompt
=
seq_group
.
is_prompt
is_prompt
=
seq_group
.
is_prompt
if
(
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
)
:
if
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
:
# For tokens in the prompt that we only need to get
# For tokens in the prompt that we only need to get
# their logprobs
# their logprobs
query_len
=
seq_group
.
query_len
query_len
=
seq_group
.
query_len
...
@@ -477,28 +449,6 @@ class SamplingTensors:
...
@@ -477,28 +449,6 @@ class SamplingTensors:
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
if
_USE_TRITON_SAMPLER
:
if
is_prompt
:
prompt_best_of
.
append
(
sampling_params
.
best_of
)
query_len
=
seq_group
.
query_len
assert
query_len
is
not
None
seed
=
sampling_params
.
seed
is_greedy
=
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
extra_entropy
=
extra_entropy
or
()
seq_seeds
=
cls
.
_get_sequence_seeds
(
seed
,
seq_data
.
get_len
(),
*
extra_entropy
,
seq_id
,
seeds_to_generate
=
seeds_to_generate
,
is_greedy
=
is_greedy
)
sampling_seeds
.
append
(
seq_seeds
)
sample_indices
.
extend
(
seq_group
.
sample_indices
)
if
do_penalties
:
if
do_penalties
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
...
@@ -518,23 +468,37 @@ class SamplingTensors:
...
@@ -518,23 +468,37 @@ class SamplingTensors:
output_tokens
.
append
(
seq_data
.
output_token_ids_array
)
output_tokens
.
append
(
seq_data
.
output_token_ids_array
)
sampling_tensors
=
SamplingTensors
.
from_lists
(
sampling_tensors
=
SamplingTensors
.
from_lists
(
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
temperatures
,
frequency_penalties
,
repetition_penalties
,
sampling_seeds
,
top_ps
,
sample_indices
,
prompt_tokens
,
output_tokens
,
vocab_size
,
top_ks
,
extra_seeds_to_generate
,
device
,
dtype
)
min_ps
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
,
prompt_tokens
,
output_tokens
,
vocab_size
,
device
,
dtype
,
)
return
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
return
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
@
classmethod
@
classmethod
def
from_lists
(
cls
,
temperatures
:
List
[
float
],
top_ps
:
List
[
float
],
def
from_lists
(
top_ks
:
List
[
int
],
min_ps
:
List
[
float
],
cls
,
temperatures
:
List
[
float
],
top_ps
:
List
[
float
],
top_ks
:
List
[
int
],
min_ps
:
List
[
float
],
presence_penalties
:
List
[
float
],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
sampling_seeds
:
List
[
int
],
sample_indice
s
:
List
[
int
],
prompt_token
s
:
List
[
array
],
prompt_tokens
:
List
[
array
],
output_tokens
:
List
[
array
],
output_tokens
:
List
[
array
],
vocab_size
:
int
,
extra_seeds_to_generate
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
dtype
:
torch
.
dtype
,
)
->
"SamplingTensors"
:
# Note that the performance will be very bad without
# Note that the performance will be very bad without
# pinned memory.
# pinned memory.
pin_memory
=
is_pin_memory_available
()
pin_memory
=
is_pin_memory_available
()
...
@@ -603,34 +567,9 @@ class SamplingTensors:
...
@@ -603,34 +567,9 @@ class SamplingTensors:
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
)
)
sample_indices_t
=
torch
.
tensor
(
sample_indices
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
# need to transpose and make contiguous to
# copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size]
sampling_seeds_t
=
torch
.
tensor
(
sampling_seeds
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
).
t
().
contiguous
()
# Because the memory is pinned, we can do non-blocking
# Because the memory is pinned, we can do non-blocking
# transfer to device.
# transfer to device.
# How many seeds the sample operation itself will need.
num_base_seeds
=
sampling_seeds_t
.
shape
[
0
]
-
extra_seeds_to_generate
sampling_seeds_gpu
=
sampling_seeds_t
.
to
(
device
=
device
,
non_blocking
=
True
)
extra_seeds_gpu
=
sampling_seeds_gpu
[
num_base_seeds
:]
if
not
extra_seeds_gpu
.
numel
():
extra_seeds_gpu
=
None
sampling_seeds_gpu
=
sampling_seeds_gpu
[:
num_base_seeds
]
return
cls
(
return
cls
(
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
...
@@ -644,38 +583,4 @@ class SamplingTensors:
...
@@ -644,38 +583,4 @@ class SamplingTensors:
non_blocking
=
True
),
non_blocking
=
True
),
prompt_tokens
=
prompt_t
.
to
(
device
=
device
,
non_blocking
=
True
),
prompt_tokens
=
prompt_t
.
to
(
device
=
device
,
non_blocking
=
True
),
output_tokens
=
output_t
.
to
(
device
=
device
,
non_blocking
=
True
),
output_tokens
=
output_t
.
to
(
device
=
device
,
non_blocking
=
True
),
sampling_seeds
=
sampling_seeds_gpu
,
sample_indices
=
sample_indices_t
.
to
(
device
=
device
,
non_blocking
=
True
),
extra_seeds
=
extra_seeds_gpu
,
)
)
@
staticmethod
def
_get_sequence_seeds
(
seed
:
int
,
*
extra_entropy
:
int
,
seeds_to_generate
:
int
,
is_greedy
:
bool
,
):
"""Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
if
not
is_greedy
:
if
seed
is
None
:
randint_fn
=
random
.
randint
else
:
generator
=
random
.
Random
(
str
((
seed
,
)
+
extra_entropy
))
randint_fn
=
generator
.
randint
lo
,
hi
=
torch
.
iinfo
(
torch
.
long
).
min
,
torch
.
iinfo
(
torch
.
long
).
max
# If the user/random sets seed = 0 but request should
# have sampling, we need to change it to something
# else. We use a constant in that case.
# This way we don't need to create and load a bool
# matrix in the sampling kernel, which reduces CPU
# overhead and latency.
seq_seeds
=
[
randint_fn
(
lo
,
hi
)
or
_SEED_0_REPLACEMENT
for
_
in
range
(
seeds_to_generate
)
]
else
:
# For the kernel, seed == 0 means greedy decoding.
seq_seeds
=
[
0
]
*
seeds_to_generate
return
seq_seeds
vllm/triton_utils/sample.py
deleted
100644 → 0
View file @
cca61642
import
math
# This is a hardcoded limit in Triton (max block size).
MAX_TRITON_N_COLS
=
131072
def
get_num_triton_sampler_splits
(
n_cols
:
int
)
->
int
:
"""Get the number of splits to use for Triton sampling.
Triton has a limit on the number of columns it can handle, so we need to
split the tensor and call the kernel multiple times if it's too large.
"""
return
math
.
ceil
(
n_cols
/
MAX_TRITON_N_COLS
)
vllm/utils.py
View file @
546034b4
...
@@ -837,15 +837,6 @@ def async_tensor_h2d(
...
@@ -837,15 +837,6 @@ def async_tensor_h2d(
return
t
.
to
(
device
=
target_device
,
non_blocking
=
True
)
return
t
.
to
(
device
=
target_device
,
non_blocking
=
True
)
def
maybe_expand_dim
(
tensor
:
torch
.
Tensor
,
target_dims
:
int
,
size
:
int
=
1
)
->
torch
.
Tensor
:
"""Expand the tensor to the target_dims."""
if
tensor
.
ndim
<
target_dims
:
tensor
=
tensor
.
view
(
-
1
,
*
([
size
]
*
(
target_dims
-
tensor
.
ndim
)))
return
tensor
def
get_dtype_size
(
dtype
:
torch
.
dtype
)
->
int
:
def
get_dtype_size
(
dtype
:
torch
.
dtype
)
->
int
:
"""Get the size of the data type in bytes."""
"""Get the size of the data type in bytes."""
return
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
return
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment