Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f710fb52
Unverified
Commit
f710fb52
authored
Aug 19, 2024
by
Peng Guanwen
Committed by
GitHub
Aug 19, 2024
Browse files
[Core] Use flashinfer sampling kernel when available (#7137)
Co-authored-by:
Michael Goin
<
michael@neuralmagic.com
>
parent
ff7ec82c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
130 additions
and
28 deletions
+130
-28
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+3
-1
Dockerfile
Dockerfile
+1
-1
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+36
-1
vllm/envs.py
vllm/envs.py
+5
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+85
-25
No files found.
.buildkite/test-pipeline.yaml
View file @
f710fb52
...
@@ -192,7 +192,9 @@ steps:
...
@@ -192,7 +192,9 @@ steps:
-
vllm/model_executor/layers
-
vllm/model_executor/layers
-
vllm/sampling_metadata.py
-
vllm/sampling_metadata.py
-
tests/samplers
-
tests/samplers
command
:
pytest -v -s samplers
commands
:
-
pytest -v -s samplers
-
VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
-
label
:
LogitsProcessor Test
# 5min
-
label
:
LogitsProcessor Test
# 5min
mirror_hardwares
:
[
amd
]
mirror_hardwares
:
[
amd
]
...
...
Dockerfile
View file @
f710fb52
...
@@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
...
@@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
python3
-m
pip
install
/usr/src/mamba/
*
.whl
--no-cache-dir
python3
-m
pip
install
/usr/src/mamba/
*
.whl
--no-cache-dir
RUN
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
RUN
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
python3
-m
pip
install
https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.
3
/flashinfer-0.1.
3
+cu121torch2.4-cp310-cp310-linux_x86_64.whl
python3
-m
pip
install
https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.
4
/flashinfer-0.1.
4
+cu121torch2.4-cp310-cp310-linux_x86_64.whl
#################### vLLM installation IMAGE ####################
#################### vLLM installation IMAGE ####################
...
...
tests/samplers/test_sampler.py
View file @
f710fb52
...
@@ -8,6 +8,7 @@ import pytest
...
@@ -8,6 +8,7 @@ import pytest
import
torch
import
torch
from
transformers
import
GenerationConfig
,
GenerationMixin
from
transformers
import
GenerationConfig
,
GenerationMixin
import
vllm.envs
as
envs
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
...
@@ -634,7 +635,10 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -634,7 +635,10 @@ def test_sampler_top_k_top_p(seed: int, device: str):
return
([[
prob
.
topk
(
1
,
dim
=-
1
).
indices
.
tolist
(),
[
0
]]
return
([[
prob
.
topk
(
1
,
dim
=-
1
).
indices
.
tolist
(),
[
0
]]
for
prob
in
probs
],
None
)
for
prob
in
probs
],
None
)
with
patch
(
"vllm.model_executor.layers.sampler._sample"
,
mock_sample
):
# top-k and top-p is only calculated when flashinfer kernel is not available
with
patch
(
"vllm.model_executor.layers.sampler._sample"
,
mock_sample
),
\
patch
(
"vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling"
,
None
):
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
assert
sample_probs
is
not
None
assert
sample_probs
is
not
None
...
@@ -645,6 +649,37 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -645,6 +649,37 @@ def test_sampler_top_k_top_p(seed: int, device: str):
assert
torch
.
equal
(
hf_probs
.
eq
(
0
),
sample_probs
.
eq
(
0
))
assert
torch
.
equal
(
hf_probs
.
eq
(
0
),
sample_probs
.
eq
(
0
))
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_flashinfer_fallback
(
seed
:
int
,
device
:
str
):
if
not
envs
.
VLLM_USE_FLASHINFER_SAMPLER
:
pytest
.
skip
(
"Flashinfer sampler is disabled"
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
_
,
fake_logits
,
sampler
=
_prepare_test
(
batch_size
)
def
failing_flashinfer_sampling
(
*
_args
,
**
_kwargs
):
return
None
,
torch
.
zeros
(
batch_size
,
device
=
device
,
dtype
=
torch
.
int32
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
seed
=
random
.
randint
(
0
,
10000
),
)
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
with
patch
(
"vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling"
,
failing_flashinfer_sampling
):
fallback_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
sampling_params
,
device
)
assert
sampler_output
==
fallback_sampler_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_repetition_penalty_mixed
(
device
:
str
):
def
test_sampler_repetition_penalty_mixed
(
device
:
str
):
...
...
vllm/envs.py
View file @
f710fb52
...
@@ -30,6 +30,7 @@ if TYPE_CHECKING:
...
@@ -30,6 +30,7 @@ if TYPE_CHECKING:
VLLM_LOGGING_CONFIG_PATH
:
Optional
[
str
]
=
None
VLLM_LOGGING_CONFIG_PATH
:
Optional
[
str
]
=
None
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_USE_FLASHINFER_SAMPLER
:
bool
=
False
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
...
@@ -256,6 +257,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -256,6 +257,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ATTENTION_BACKEND"
:
"VLLM_ATTENTION_BACKEND"
:
lambda
:
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
None
),
lambda
:
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
,
None
),
# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASHINFER_SAMPLER"
,
"0"
))),
# Pipeline stage partition strategy
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION"
:
"VLLM_PP_LAYER_PARTITION"
:
lambda
:
os
.
getenv
(
"VLLM_PP_LAYER_PARTITION"
,
None
),
lambda
:
os
.
getenv
(
"VLLM_PP_LAYER_PARTITION"
,
None
),
...
...
vllm/model_executor/layers/sampler.py
View file @
f710fb52
"""A layer that samples the next tokens from the model's outputs."""
"""A layer that samples the next tokens from the model's outputs."""
import
itertools
import
itertools
import
warnings
from
importlib.util
import
find_spec
from
math
import
inf
from
math
import
inf
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
@@ -11,6 +13,7 @@ from vllm.triton_utils import HAS_TRITON
...
@@ -11,6 +13,7 @@ from vllm.triton_utils import HAS_TRITON
if
HAS_TRITON
:
if
HAS_TRITON
:
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
import
vllm.envs
as
envs
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
SamplingTensors
,
SamplingTensors
,
SequenceGroupToSample
)
SequenceGroupToSample
)
...
@@ -19,6 +22,16 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
...
@@ -19,6 +22,16 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceOutput
)
SequenceOutput
)
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
# yapf: disable
from
flashinfer.sampling
import
(
top_k_top_p_sampling_from_probs
as
flashinfer_top_k_top_p_sampling
)
# yapf: enable
else
:
flashinfer_top_k_top_p_sampling
=
None
# (num_token_ids, num_parent_ids) per sequence group.
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
...
@@ -123,7 +136,7 @@ class Sampler(nn.Module):
...
@@ -123,7 +136,7 @@ class Sampler(nn.Module):
logits
=
logits
.
to
(
torch
.
float
)
logits
=
logits
.
to
(
torch
.
float
)
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze
(
dim
=
1
))
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze
(
dim
=
1
))
if
do_top_p_top_k
:
if
do_top_p_top_k
and
flashinfer_top_k_top_p_sampling
is
None
:
logits
=
_apply_top_k_top_p
(
logits
,
sampling_tensors
.
top_ps
,
logits
=
_apply_top_k_top_p
(
logits
,
sampling_tensors
.
top_ps
,
sampling_tensors
.
top_ks
)
sampling_tensors
.
top_ks
)
...
@@ -476,14 +489,7 @@ def _multinomial(
...
@@ -476,14 +489,7 @@ def _multinomial(
seq_groups
:
Optional
[
List
[
SequenceGroupToSample
]]
=
None
,
seq_groups
:
Optional
[
List
[
SequenceGroupToSample
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
probs
=
probs
.
repeat_interleave
(
num_samples
,
dim
=
0
)
# forces a GPU<->CPU sync).
# This allows us to do sampling with replacement by creating
# num_samples copies of each row in the tensor, and then
# batch sampling the resulting tensor.
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
)
q
=
torch
.
empty_like
(
probs
)
if
seq_groups
is
None
:
if
seq_groups
is
None
:
q
.
exponential_
()
q
.
exponential_
()
...
@@ -491,17 +497,57 @@ def _multinomial(
...
@@ -491,17 +497,57 @@ def _multinomial(
sample_idx
=
0
sample_idx
=
0
for
seq_group
in
seq_groups
:
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
next_sample_idx
=
sample_idx
+
len
(
seq_ids
)
*
num_samples
stride
=
len
(
seq_ids
)
*
num_samples
q
[
sample_idx
:
next_sample_idx
].
exponential_
(
assert
seq_group
.
generator
is
not
None
generator
=
seq_group
.
generator
)
q
[
sample_idx
:
sample_idx
+
sample_idx
=
next_sample_idx
stride
].
exponential_
(
generator
=
seq_group
.
generator
)
sample_idx
+=
stride
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
def
_top_k_top_p_multinomial_with_flashinfer
(
probs
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
num_samples
:
int
,
seq_groups
:
Optional
[
List
[
SequenceGroupToSample
]]):
max_top_k_round
=
32
if
num_samples
>
1
:
probs
=
probs
.
repeat_interleave
(
num_samples
,
dim
=
0
)
top_ks
=
top_ks
.
repeat_interleave
(
num_samples
)
top_ps
=
top_ps
.
repeat_interleave
(
num_samples
)
batch_size
=
probs
.
shape
[
0
]
uniform_samples
=
torch
.
empty
((
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
if
seq_groups
is
None
:
uniform_samples
.
uniform_
()
else
:
sample_idx
=
0
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
stride
=
len
(
seq_ids
)
*
num_samples
assert
seq_group
.
generator
is
not
None
uniform_samples
[:,
sample_idx
:
sample_idx
+
stride
].
uniform_
(
generator
=
seq_group
.
generator
)
sample_idx
+=
stride
batch_next_token_ids
,
success
=
flashinfer_top_k_top_p_sampling
(
probs
,
uniform_samples
,
top_ks
,
top_ps
,
)
if
not
success
.
all
():
warnings
.
warn
(
"FlashInfer rejection sampling failed, fallback."
,
stacklevel
=
1
)
probs
=
flashinfer
.
sampling
.
top_k_renorm_prob
(
probs
,
top_ks
)
probs
=
flashinfer
.
sampling
.
top_p_renorm_prob
(
probs
,
top_ps
)
batch_next_token_ids
=
flashinfer
.
sampling
.
sampling_from_probs
(
probs
,
uniform_samples
[
0
])
return
batch_next_token_ids
.
view
(
-
1
,
num_samples
)
def
_sample_with_torch
(
def
_sample_with_torch
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
modify_greedy_probs
:
bool
,
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
...
@@ -564,18 +610,28 @@ def _sample_with_torch(
...
@@ -564,18 +610,28 @@ def _sample_with_torch(
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
sampling_params
.
best_of
)
sampling_params
.
best_of
)
seeded_args
=
{}
if
sampling_type
==
SamplingType
.
RANDOM
else
{
seq_groups_arg
=
(
None
if
sampling_type
==
SamplingType
.
RANDOM
else
"seq_groups"
:
seq_groups
,
seq_groups
)
}
if
flashinfer_top_k_top_p_sampling
is
not
None
:
multinomial_samples
[
sampling_type
]
=
_multinomial
(
multinomial_samples
[
probs
[
long_sample_indices
],
max_best_of_in_batch
,
sampling_type
]
=
_top_k_top_p_multinomial_with_flashinfer
(
**
seeded_args
)
probs
[
long_sample_indices
],
sampling_tensors
.
top_ks
[
long_sample_indices
],
sampling_tensors
.
top_ps
[
long_sample_indices
],
max_best_of_in_batch
,
seq_groups_arg
,
)
else
:
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
long_sample_indices
],
max_best_of_in_batch
,
seq_groups
=
seq_groups_arg
)
if
sampled_token_ids_tensor
is
not
None
:
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
sampled_token_ids_tensor
[
long_sample_indices
]
=
\
long_sample_indices
]
=
multinomial_samples
[
sampling_type
]
multinomial_samples
[
sampling_type
]
.
to
(
torch
.
long
)
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
beam_search_logprobs
=
logprobs
[
sample_indices
]
...
@@ -693,9 +749,12 @@ def _sample_with_triton_kernel(
...
@@ -693,9 +749,12 @@ def _sample_with_triton_kernel(
def
_sample
(
def
_sample
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
logprobs
:
torch
.
Tensor
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
SampleResultType
,
Optional
[
torch
.
Tensor
]]:
"""
"""
Args:
Args:
...
@@ -713,6 +772,7 @@ def _sample(
...
@@ -713,6 +772,7 @@ def _sample(
probs
,
probs
,
logprobs
,
logprobs
,
sampling_metadata
,
sampling_metadata
,
sampling_tensors
,
include_gpu_probs_tensor
=
include_gpu_probs_tensor
,
include_gpu_probs_tensor
=
include_gpu_probs_tensor
,
modify_greedy_probs
=
modify_greedy_probs
,
modify_greedy_probs
=
modify_greedy_probs
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment