Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f710fb52
Unverified
Commit
f710fb52
authored
Aug 19, 2024
by
Peng Guanwen
Committed by
GitHub
Aug 19, 2024
Browse files
[Core] Use flashinfer sampling kernel when available (#7137)
Co-authored-by:
Michael Goin
<
michael@neuralmagic.com
>
parent
ff7ec82c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
130 additions
and
28 deletions
+130
-28
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+3
-1
Dockerfile
Dockerfile
+1
-1
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+36
-1
vllm/envs.py
vllm/envs.py
+5
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+85
-25
No files found.
.buildkite/test-pipeline.yaml
View file @
f710fb52
...
...
@@ -192,7 +192,9 @@ steps:
-
vllm/model_executor/layers
-
vllm/sampling_metadata.py
-
tests/samplers
command
:
pytest -v -s samplers
commands
:
-
pytest -v -s samplers
-
VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
-
label
:
LogitsProcessor Test
# 5min
mirror_hardwares
:
[
amd
]
...
...
Dockerfile
View file @
f710fb52
...
...
@@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
python3
-m
pip
install
/usr/src/mamba/
*
.whl
--no-cache-dir
RUN
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
python3
-m
pip
install
https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.
3
/flashinfer-0.1.
3
+cu121torch2.4-cp310-cp310-linux_x86_64.whl
python3
-m
pip
install
https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.
4
/flashinfer-0.1.
4
+cu121torch2.4-cp310-cp310-linux_x86_64.whl
#################### vLLM installation IMAGE ####################
...
...
tests/samplers/test_sampler.py
View file @
f710fb52
...
...
@@ -8,6 +8,7 @@ import pytest
import
torch
from
transformers
import
GenerationConfig
,
GenerationMixin
import
vllm.envs
as
envs
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
...
...
@@ -634,7 +635,10 @@ def test_sampler_top_k_top_p(seed: int, device: str):
return
([[
prob
.
topk
(
1
,
dim
=-
1
).
indices
.
tolist
(),
[
0
]]
for
prob
in
probs
],
None
)
with
patch
(
"vllm.model_executor.layers.sampler._sample"
,
mock_sample
):
# top-k and top-p is only calculated when flashinfer kernel is not available
with
patch
(
"vllm.model_executor.layers.sampler._sample"
,
mock_sample
),
\
patch
(
"vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling"
,
None
):
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
assert
sample_probs
is
not
None
...
...
@@ -645,6 +649,37 @@ def test_sampler_top_k_top_p(seed: int, device: str):
assert
torch
.
equal
(
hf_probs
.
eq
(
0
),
sample_probs
.
eq
(
0
))
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_flashinfer_fallback
(
seed
:
int
,
device
:
str
):
if
not
envs
.
VLLM_USE_FLASHINFER_SAMPLER
:
pytest
.
skip
(
"Flashinfer sampler is disabled"
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
def
failing_flashinfer_sampling
(
*
_args
,
**
_kwargs
):
return
None
,
torch
.
zeros
(
batch_size
,
device
=
device
,
dtype
=
torch
.
int32
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
seed
=
random
.
randint
(
0
,
10000
),
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
with
patch
(
"vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling"
,
failing_flashinfer_sampling
):
fallback_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
assert
sampler_output
==
fallback_sampler_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_repetition_penalty_mixed
(
device
:
str
):
...
...
vllm/envs.py
View file @
f710fb52
...
...
@@ -30,6 +30,7 @@ if TYPE_CHECKING:
VLLM_LOGGING_CONFIG_PATH
:
Optional
[
str
]
=
None
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_USE_FLASHINFER_SAMPLER
:
bool
=
False
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
...
...
@@ -256,6 +257,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ATTENTION_BACKEND"
:
lambda
:
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
None
),
# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASHINFER_SAMPLER"
,
"0"
))),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION"
:
lambda
:
os
.
getenv
(
"VLLM_PP_LAYER_PARTITION"
,
None
),
...
...
vllm/model_executor/layers/sampler.py
View file @
f710fb52
"""A layer that samples the next tokens from the model's outputs."""
import
itertools
import
warnings
from
importlib.util
import
find_spec
from
math
import
inf
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
...
@@ -11,6 +13,7 @@ from vllm.triton_utils import HAS_TRITON
if
HAS_TRITON
:
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
import
vllm.envs
as
envs
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
SamplingTensors
,
SequenceGroupToSample
)
...
...
@@ -19,6 +22,16 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceOutput
)
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
# yapf: disable
from
flashinfer.sampling
import
(
top_k_top_p_sampling_from_probs
as
flashinfer_top_k_top_p_sampling
)
# yapf: enable
else
:
flashinfer_top_k_top_p_sampling
=
None
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
...
...
@@ -123,7 +136,7 @@ class Sampler(nn.Module):
logits
=
logits
.
to
(
torch
.
float
)
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze
(
dim
=
1
))
if
do_top_p_top_k
:
if
do_top_p_top_k
and
flashinfer_top_k_top_p_sampling
is
None
:
logits
=
_apply_top_k_top_p
(
logits
,
sampling_tensors
.
top_ps
,
sampling_tensors
.
top_ks
)
...
...
@@ -476,14 +489,7 @@ def _multinomial(
seq_groups
:
Optional
[
List
[
SequenceGroupToSample
]]
=
None
,
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
# This allows us to do sampling with replacement by creating
# num_samples copies of each row in the tensor, and then
# batch sampling the resulting tensor.
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
probs
=
probs
.
repeat_interleave
(
num_samples
,
dim
=
0
)
q
=
torch
.
empty_like
(
probs
)
if
seq_groups
is
None
:
q
.
exponential_
()
...
...
@@ -491,17 +497,57 @@ def _multinomial(
sample_idx
=
0
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
next_sample_idx
=
sample_idx
+
len
(
seq_ids
)
*
num_samples
q
[
sample_idx
:
next_sample_idx
].
exponential_
(
generator
=
seq_group
.
generator
)
sample_idx
=
next_sample_idx
stride
=
len
(
seq_ids
)
*
num_samples
assert
seq_group
.
generator
is
not
None
q
[
sample_idx
:
sample_idx
+
stride
].
exponential_
(
generator
=
seq_group
.
generator
)
sample_idx
+=
stride
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
def
_top_k_top_p_multinomial_with_flashinfer
(
probs
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
num_samples
:
int
,
seq_groups
:
Optional
[
List
[
SequenceGroupToSample
]]):
max_top_k_round
=
32
if
num_samples
>
1
:
probs
=
probs
.
repeat_interleave
(
num_samples
,
dim
=
0
)
top_ks
=
top_ks
.
repeat_interleave
(
num_samples
)
top_ps
=
top_ps
.
repeat_interleave
(
num_samples
)
batch_size
=
probs
.
shape
[
0
]
uniform_samples
=
torch
.
empty
((
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
if
seq_groups
is
None
:
uniform_samples
.
uniform_
()
else
:
sample_idx
=
0
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
stride
=
len
(
seq_ids
)
*
num_samples
assert
seq_group
.
generator
is
not
None
uniform_samples
[:,
sample_idx
:
sample_idx
+
stride
].
uniform_
(
generator
=
seq_group
.
generator
)
sample_idx
+=
stride
batch_next_token_ids
,
success
=
flashinfer_top_k_top_p_sampling
(
probs
,
uniform_samples
,
top_ks
,
top_ps
,
)
if
not
success
.
all
():
warnings
.
warn
(
"FlashInfer rejection sampling failed, fallback."
,
stacklevel
=
1
)
probs
=
flashinfer
.
sampling
.
top_k_renorm_prob
(
probs
,
top_ks
)
probs
=
flashinfer
.
sampling
.
top_p_renorm_prob
(
probs
,
top_ps
)
batch_next_token_ids
=
flashinfer
.
sampling
.
sampling_from_probs
(
probs
,
uniform_samples
[
0
])
return
batch_next_token_ids
.
view
(
-
1
,
num_samples
)
def
_sample_with_torch
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
...
...
@@ -564,18 +610,28 @@ def _sample_with_torch(
sampling_params
=
seq_group
.
sampling_params
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
sampling_params
.
best_of
)
seeded_args
=
{}
if
sampling_type
==
SamplingType
.
RANDOM
else
{
"seq_groups"
:
seq_groups
,
}
seq_groups_arg
=
(
None
if
sampling_type
==
SamplingType
.
RANDOM
else
seq_groups
)
if
flashinfer_top_k_top_p_sampling
is
not
None
:
multinomial_samples
[
sampling_type
]
=
_top_k_top_p_multinomial_with_flashinfer
(
probs
[
long_sample_indices
],
sampling_tensors
.
top_ks
[
long_sample_indices
],
sampling_tensors
.
top_ps
[
long_sample_indices
],
max_best_of_in_batch
,
seq_groups_arg
,
)
else
:
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
long_sample_indices
],
max_best_of_in_batch
,
**
seeded_args
)
probs
[
long_sample_indices
],
max_best_of_in_batch
,
seq_groups
=
seq_groups_arg
)
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
long_sample_indices
]
=
multinomial_samples
[
sampling_type
]
sampled_token_ids_tensor
[
long_sample_indices
]
=
\
multinomial_samples
[
sampling_type
]
.
to
(
torch
.
long
)
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
...
...
@@ -693,9 +749,12 @@ def _sample_with_triton_kernel(
def
_sample
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
"""
Args:
...
...
@@ -713,6 +772,7 @@ def _sample(
probs
,
logprobs
,
sampling_metadata
,
sampling_tensors
,
include_gpu_probs_tensor
=
include_gpu_probs_tensor
,
modify_greedy_probs
=
modify_greedy_probs
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment