Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e6a26ed0
Unverified
Commit
e6a26ed0
authored
Sep 01, 2024
by
Lily Liu
Committed by
GitHub
Sep 01, 2024
Browse files
[SpecDecode][Kernel] Flashinfer Rejection Sampling (#7244)
parent
f8d60145
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
306 additions
and
109 deletions
+306
-109
Dockerfile
Dockerfile
+1
-1
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+97
-19
tests/samplers/test_typical_acceptance_sampler.py
tests/samplers/test_typical_acceptance_sampler.py
+32
-18
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+2
-3
vllm/envs.py
vllm/envs.py
+1
-0
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+141
-43
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+25
-18
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+4
-3
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+3
-4
No files found.
Dockerfile
View file @
e6a26ed0
...
@@ -162,7 +162,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
...
@@ -162,7 +162,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
RUN
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
RUN
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
.
/etc/environment
&&
\
.
/etc/environment
&&
\
python3
-m
pip
install
https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.
4
/flashinfer-0.1.
4
+cu121torch2.4-cp
${
PYTHON_VERSION_STR
}
-cp
${
PYTHON_VERSION_STR
}
-linux_x86_64
.whl
python3
-m
pip
install
https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.
6
/flashinfer-0.1.
6
+cu121torch2.4-cp
${
PYTHON_VERSION_STR
}
-cp
${
PYTHON_VERSION_STR
}
-linux_x86_64
.whl
#################### vLLM installation IMAGE ####################
#################### vLLM installation IMAGE ####################
...
...
tests/samplers/test_rejection_sampler.py
View file @
e6a26ed0
...
@@ -44,12 +44,16 @@ def mock_causal_accepted_tensor(
...
@@ -44,12 +44,16 @@ def mock_causal_accepted_tensor(
[
"all_tokens_accepted"
,
"no_tokens_accepted"
,
"some_tokens_accepted"
])
[
"all_tokens_accepted"
,
"no_tokens_accepted"
,
"some_tokens_accepted"
])
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"disable_bonus_tokens"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
def
test_correct_output_format
(
which_tokens_accepted
:
str
,
seed
:
int
,
disable_bonus_tokens
:
bool
,
seed
:
int
,
disable_bonus_tokens
:
bool
,
device
:
str
,
device
:
str
):
use_flashinfer
:
bool
):
"""Verify the output has correct format given predetermined accepted matrix.
"""Verify the output has correct format given predetermined accepted matrix.
"""
"""
if
use_flashinfer
and
disable_bonus_tokens
:
pytest
.
skip
(
"Flashinfer rejection sampler must enable bonus token."
)
set_random_seed
(
seed
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
...
@@ -85,7 +89,8 @@ def test_correct_output_format(which_tokens_accepted: str,
...
@@ -85,7 +89,8 @@ def test_correct_output_format(which_tokens_accepted: str,
dtype
=
torch
.
int64
)
dtype
=
torch
.
int64
)
rejection_sampler
=
RejectionSampler
(
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
disable_bonus_tokens
)
disable_bonus_tokens
=
disable_bonus_tokens
,
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
output_token_ids
=
rejection_sampler
.
_create_output
(
# pylint: disable=protected-access
output_token_ids
=
rejection_sampler
.
_create_output
(
# pylint: disable=protected-access
accepted
,
accepted
,
...
@@ -133,15 +138,20 @@ def test_correct_output_format(which_tokens_accepted: str,
...
@@ -133,15 +138,20 @@ def test_correct_output_format(which_tokens_accepted: str,
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
32
)))
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
32
)))
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_no_crash_with_varying_dims
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
def
test_no_crash_with_varying_dims
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
device
:
str
):
device
:
str
,
use_flashinfer
:
bool
):
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
rejection_sampler
=
RejectionSampler
()
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_probs
=
torch
.
rand
(
batch_size
,
k
+
1
,
vocab_size
,
dtype
=
torch
.
float32
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
size
=
(
batch_size
,
1
),
...
@@ -161,16 +171,21 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
...
@@ -161,16 +171,21 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
8
,
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
8
,
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"n_rep"
,
[
100
])
@
pytest
.
mark
.
parametrize
(
"n_rep"
,
[
100
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_deterministic_when_seeded
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
def
test_deterministic_when_seeded
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
frac_seeded
:
float
,
n_rep
:
int
,
frac_seeded
:
float
,
n_rep
:
int
,
device
:
str
,
device
:
str
):
use_flashinfer
:
bool
):
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
rejection_sampler
=
RejectionSampler
()
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_probs
=
torch
.
rand
(
batch_size
,
k
+
1
,
vocab_size
,
dtype
=
torch
.
float32
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
size
=
(
batch_size
,
1
),
...
@@ -198,23 +213,85 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
...
@@ -198,23 +213,85 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
assert
torch
.
equal
(
results
[
j
][
i
],
results
[
0
][
i
])
assert
torch
.
equal
(
results
[
j
][
i
],
results
[
0
][
i
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
1
,
3
,
6
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
30_000
,
50_000
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
8
,
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_compare_nonflashinfer_backend
(
k
:
int
,
vocab_size
:
int
,
batch_size
:
int
,
device
:
str
):
"""
Test the flashinfer and nonflashinfer backend generate
the same output metrics.
"""
torch
.
set_default_device
(
device
)
torch
.
manual_seed
(
0
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_probs
=
torch
.
rand
(
batch_size
,
k
+
1
,
vocab_size
,
dtype
=
torch
.
float32
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
)
num_accepted_tokens
=
[]
num_emitted_tokens
=
[]
num_draft_tokens
=
[]
def
get_seeded_seqs
():
return
{
i
:
torch
.
Generator
(
device
=
device
).
manual_seed
(
i
)
for
i
in
range
(
batch_size
)
}
for
use_flashinfer
in
[
True
,
False
]:
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
use_flashinfer
=
use_flashinfer
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
# We use seeded sequences to ensure the same tokens are accepted
# for both flashinfer and nonflashinfer backends.
seeded_seqs
=
get_seeded_seqs
()
rejection_sampler
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
,
seeded_seqs
)
num_accepted_tokens
.
append
(
rejection_sampler
.
num_accepted_tokens
)
num_emitted_tokens
.
append
(
rejection_sampler
.
num_emitted_tokens
)
num_draft_tokens
.
append
(
rejection_sampler
.
num_draft_tokens
)
assert
num_accepted_tokens
[
0
]
==
num_accepted_tokens
[
1
]
assert
num_emitted_tokens
[
0
]
==
num_emitted_tokens
[
1
]
assert
num_draft_tokens
[
0
]
==
num_draft_tokens
[
1
]
@
pytest
.
mark
.
parametrize
(
"above_or_below_vocab_range"
,
[
"above"
,
"below"
])
@
pytest
.
mark
.
parametrize
(
"above_or_below_vocab_range"
,
[
"above"
,
"below"
])
@
pytest
.
mark
.
parametrize
(
"which_token_ids"
,
@
pytest
.
mark
.
parametrize
(
"which_token_ids"
,
[
"bonus_token_ids"
,
"draft_token_ids"
])
[
"bonus_token_ids"
,
"draft_token_ids"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_raises_when_vocab_oob
(
above_or_below_vocab_range
:
str
,
def
test_raises_when_vocab_oob
(
above_or_below_vocab_range
:
str
,
which_token_ids
:
str
,
device
:
str
):
which_token_ids
:
str
,
device
:
str
,
use_flashinfer
:
bool
):
k
=
3
k
=
3
batch_size
=
5
batch_size
=
5
vocab_size
=
30_000
vocab_size
=
30_000
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
rejection_sampler
=
RejectionSampler
(
strict_mode
=
True
)
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
use_flashinfer
=
use_flashinfer
,
strict_mode
=
True
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
rejection_sampler
.
init_gpu_tensors
(
device
=
device
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
draft_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_probs
=
torch
.
rand
(
batch_size
,
k
+
1
,
vocab_size
,
dtype
=
torch
.
float32
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
size
=
(
batch_size
,
1
),
...
@@ -248,9 +325,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
...
@@ -248,9 +325,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
@
pytest
.
mark
.
parametrize
(
"draft_and_target_probs_equal"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"draft_and_target_probs_equal"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
5
)))
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
5
)))
@
pytest
.
mark
.
parametrize
(
"use_flashinfer"
,
[
True
,
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_rejection_sampling_approximates_target_distribution
(
def
test_rejection_sampling_approximates_target_distribution
(
seed
:
int
,
draft_and_target_probs_equal
:
bool
):
seed
:
int
,
draft_and_target_probs_equal
:
bool
,
use_flashinfer
:
bool
):
"""Verify rejection sampling approximates target distribution,
"""Verify rejection sampling approximates target distribution,
despite sampling from a potentially distinct draft distribution.
despite sampling from a potentially distinct draft distribution.
...
@@ -279,10 +357,10 @@ def test_rejection_sampling_approximates_target_distribution(
...
@@ -279,10 +357,10 @@ def test_rejection_sampling_approximates_target_distribution(
"""
"""
torch
.
set_default_device
(
"cpu"
)
torch
.
set_default_device
(
"cpu"
)
set_random_seed
(
seed
)
set_random_seed
(
seed
)
helper
=
_CorrectnessTestHelper
(
helper
=
_CorrectnessTestHelper
(
vocab_size
=
10
,
vocab_size
=
10
,
rejection_sampler
=
RejectionSampler
(),
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
False
,
use_flashinfer
=
use_flashinfer
),
)
)
draft_probs
,
target_probs
,
reference_probs
=
helper
.
generate_probs_for_test
(
draft_probs
,
target_probs
,
reference_probs
=
helper
.
generate_probs_for_test
(
...
@@ -398,10 +476,10 @@ class _CorrectnessTestHelper:
...
@@ -398,10 +476,10 @@ class _CorrectnessTestHelper:
draft_probs
=
draft_probs
.
reshape
(
1
,
self
.
k
,
self
.
vocab_size
).
repeat
(
draft_probs
=
draft_probs
.
reshape
(
1
,
self
.
k
,
self
.
vocab_size
).
repeat
(
num_samples
,
1
,
1
)
num_samples
,
1
,
1
)
# Repeat target probs num_samples *
k
times.
# Repeat target probs num_samples *
(k + 1)
times.
# Rejection sampler requires bonus token probs, but they aren't used.
# Rejection sampler requires bonus token probs, but they aren't used.
target_probs
=
target_probs
.
reshape
(
1
,
1
,
self
.
vocab_size
).
repeat
(
target_probs
=
target_probs
.
reshape
(
1
,
1
,
self
.
vocab_size
).
repeat
(
num_samples
,
self
.
k
,
1
)
num_samples
,
self
.
k
+
1
,
1
)
# Randomly sample draft token ids from draft probs.
# Randomly sample draft token ids from draft probs.
draft_token_ids
=
torch
.
multinomial
(
draft_probs
[:,
0
,
:],
draft_token_ids
=
torch
.
multinomial
(
draft_probs
[:,
0
,
:],
...
...
tests/samplers/test_typical_acceptance_sampler.py
View file @
e6a26ed0
...
@@ -79,7 +79,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
...
@@ -79,7 +79,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
()
typical_acceptance_sampler
=
get_acceptance_sampler
()
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_with_bonus_probs
=
torch
.
rand
(
batch_size
,
k
+
1
,
vocab_size
,
dtype
=
torch
.
float32
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
size
=
(
batch_size
,
1
),
...
@@ -89,7 +92,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
...
@@ -89,7 +92,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
size
=
(
batch_size
,
k
),
size
=
(
batch_size
,
k
),
dtype
=
torch
.
int64
)
dtype
=
torch
.
int64
)
# Verify that sampling succeeds for all cases.
# Verify that sampling succeeds for all cases.
typical_acceptance_sampler
(
target_probs
,
typical_acceptance_sampler
(
target_
with_bonus_
probs
,
bonus_token_ids
,
bonus_token_ids
,
draft_probs
=
None
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
draft_token_ids
=
draft_token_ids
)
...
@@ -112,7 +115,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
...
@@ -112,7 +115,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_with_bonus_probs
=
torch
.
rand
(
batch_size
,
k
+
1
,
vocab_size
,
dtype
=
torch
.
float32
)
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
size
=
(
batch_size
,
1
),
...
@@ -141,7 +147,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
...
@@ -141,7 +147,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
oob_token_ids
[
0
][
0
]
=
rogue_token_id
oob_token_ids
[
0
][
0
]
=
rogue_token_id
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
typical_acceptance_sampler
(
target_probs
,
typical_acceptance_sampler
(
target_
with_bonus_
probs
,
bonus_token_ids
,
bonus_token_ids
,
draft_probs
=
None
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
draft_token_ids
=
draft_token_ids
)
...
@@ -172,7 +178,10 @@ def test_uniform_target_distribution_accepts_all_tokens(
...
@@ -172,7 +178,10 @@ def test_uniform_target_distribution_accepts_all_tokens(
typical_acceptance_sampler
=
get_acceptance_sampler
(
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
strict_mode
=
True
,
disable_bonus_tokens
=
disable_bonus_tokens
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_with_bonus_probs
=
torch
.
rand
(
batch_size
,
k
+
1
,
vocab_size
,
dtype
=
torch
.
float32
)
draft_token_ids
=
torch
.
randint
(
low
=
0
,
draft_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
high
=
vocab_size
,
size
=
(
batch_size
,
k
),
size
=
(
batch_size
,
k
),
...
@@ -182,7 +191,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
...
@@ -182,7 +191,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
size
=
(
batch_size
,
1
),
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
dtype
=
torch
.
int64
)
output_token_ids
=
typical_acceptance_sampler
(
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
target_
with_bonus_
probs
,
bonus_token_ids
,
bonus_token_ids
,
draft_probs
=
None
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
draft_token_ids
=
draft_token_ids
)
...
@@ -229,8 +238,9 @@ def test_temperature_zero_target_distribution(seed: int,
...
@@ -229,8 +238,9 @@ def test_temperature_zero_target_distribution(seed: int,
# Simulate temperature 0 probability distribution for target probabilities
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# and create target probabilities such that only 1 token id has
# probability 1.0
# probability 1.0
target_probs
,
zero_temperature_token_ids
=
get_zero_temperature_prob_dist
(
target_with_bonus_probs
,
zero_temperature_token_ids
=
\
batch_size
,
k
,
vocab_size
)
get_zero_temperature_prob_dist
(
batch_size
,
k
+
1
,
vocab_size
)
zero_temperature_token_ids
=
zero_temperature_token_ids
[:,
:
-
1
]
# Populate draft_token_ids such that they exclude the token_ids
# Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0
# with probability = 1.0
draft_token_ids
=
get_draft_token_ids
(
batch_size
,
k
,
vocab_size
,
draft_token_ids
=
get_draft_token_ids
(
batch_size
,
k
,
vocab_size
,
...
@@ -245,7 +255,7 @@ def test_temperature_zero_target_distribution(seed: int,
...
@@ -245,7 +255,7 @@ def test_temperature_zero_target_distribution(seed: int,
# fallback to the greedy sampling for selecting 1 token for each sequence.
# fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same.
# Verify the same.
output_token_ids
=
typical_acceptance_sampler
(
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
target_
with_bonus_
probs
,
bonus_token_ids
,
bonus_token_ids
,
draft_probs
=
None
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
draft_token_ids
=
draft_token_ids
)
...
@@ -289,8 +299,10 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
...
@@ -289,8 +299,10 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
# For sequences 0 and 2 set the distribution to a temperature
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
# zero distribution. For sequences 1 and 3 set it to a uniform
# distribution.
# distribution.
target_probs
,
zero_temperature_token_ids
=
(
get_zero_temperature_prob_dist
(
target_with_bonus_probs
,
zero_temperature_token_ids
=
\
batch_size
,
k
,
vocab_size
))
get_zero_temperature_prob_dist
(
batch_size
,
k
+
1
,
vocab_size
)
zero_temperature_token_ids
=
zero_temperature_token_ids
[:,
:
-
1
]
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
draft_token_ids
=
get_draft_token_ids
(
batch_size
,
k
,
vocab_size
,
draft_token_ids
=
get_draft_token_ids
(
batch_size
,
k
,
vocab_size
,
zero_temperature_token_ids
)
zero_temperature_token_ids
)
uniform_probs
=
torch
.
rand
(
2
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
uniform_probs
=
torch
.
rand
(
2
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
...
@@ -300,7 +312,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
...
@@ -300,7 +312,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
size
=
(
batch_size
,
1
),
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
dtype
=
torch
.
int64
)
output_token_ids
=
typical_acceptance_sampler
(
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
target_
with_bonus_
probs
,
bonus_token_ids
,
bonus_token_ids
,
draft_probs
=
None
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
draft_token_ids
=
draft_token_ids
)
...
@@ -356,15 +368,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
...
@@ -356,15 +368,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
# Create a temperature zero target probability distribution and ensure
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted.
# Verify that all of them are accepted.
target_probs
,
zero_temperature_token_ids
=
(
get_zero_temperature_prob_dist
(
target_with_bonus_probs
,
zero_temperature_token_ids
=
\
batch_size
,
k
,
vocab_size
))
get_zero_temperature_prob_dist
(
batch_size
,
k
+
1
,
vocab_size
)
zero_temperature_token_ids
=
zero_temperature_token_ids
[:,
:
-
1
]
draft_token_ids
=
zero_temperature_token_ids
draft_token_ids
=
zero_temperature_token_ids
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
bonus_token_ids
=
torch
.
randint
(
low
=
0
,
high
=
vocab_size
,
high
=
vocab_size
,
size
=
(
batch_size
,
1
),
size
=
(
batch_size
,
1
),
dtype
=
torch
.
int64
)
dtype
=
torch
.
int64
)
output_token_ids
=
typical_acceptance_sampler
(
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
target_
with_bonus_
probs
,
bonus_token_ids
,
bonus_token_ids
,
draft_probs
=
None
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
draft_token_ids
=
draft_token_ids
)
...
@@ -384,7 +397,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
...
@@ -384,7 +397,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
draft_token_ids
=
torch
.
cat
(
draft_token_ids
=
torch
.
cat
(
(
draft_token_ids
[:,
:
2
],
draft_token_ids_to_replace
[:,
-
3
:]),
dim
=
1
)
(
draft_token_ids
[:,
:
2
],
draft_token_ids_to_replace
[:,
-
3
:]),
dim
=
1
)
output_token_ids
=
typical_acceptance_sampler
(
output_token_ids
=
typical_acceptance_sampler
(
target_probs
,
target_
with_bonus_
probs
,
bonus_token_ids
,
bonus_token_ids
,
draft_probs
=
None
,
draft_probs
=
None
,
draft_token_ids
=
draft_token_ids
)
draft_token_ids
=
draft_token_ids
)
...
@@ -421,8 +434,9 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
...
@@ -421,8 +434,9 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0. Without any changes to the posterior thresholds
# with probability = 1.0. Without any changes to the posterior thresholds
# none of the draft tokens are accepted.
# none of the draft tokens are accepted.
target_probs
,
zero_temperature_token_ids
=
(
get_zero_temperature_prob_dist
(
target_probs
,
zero_temperature_token_ids
=
get_zero_temperature_prob_dist
(
batch_size
,
k
,
vocab_size
))
batch_size
,
k
+
1
,
vocab_size
)
zero_temperature_token_ids
=
zero_temperature_token_ids
[:,
:
-
1
]
target_probs
[
target_probs
==
0
]
=
0.00001
target_probs
[
target_probs
==
0
]
=
0.00001
draft_token_ids
=
get_draft_token_ids
(
batch_size
,
k
,
vocab_size
,
draft_token_ids
=
get_draft_token_ids
(
batch_size
,
k
,
vocab_size
,
zero_temperature_token_ids
)
zero_temperature_token_ids
)
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
e6a26ed0
...
@@ -230,9 +230,8 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
...
@@ -230,9 +230,8 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
assert
torch
.
equal
(
actual
.
bonus_token_ids
,
assert
torch
.
equal
(
actual
.
bonus_token_ids
,
target_token_ids
.
reshape
(
batch_size
,
k
+
1
)[:,
-
1
:])
target_token_ids
.
reshape
(
batch_size
,
k
+
1
)[:,
-
1
:])
assert
torch
.
equal
(
assert
torch
.
equal
(
actual
.
target_with_bonus_probs
,
actual
.
target_probs
,
target_token_probs
.
reshape
(
batch_size
,
k
+
1
,
-
1
))
target_token_probs
.
reshape
(
batch_size
,
k
+
1
,
-
1
)[:,
:
-
1
])
assert
torch
.
equal
(
actual
.
draft_token_ids
,
proposal_token_ids
)
assert
torch
.
equal
(
actual
.
draft_token_ids
,
proposal_token_ids
)
assert
torch
.
equal
(
actual
.
draft_probs
,
proposal_probs
)
assert
torch
.
equal
(
actual
.
draft_probs
,
proposal_probs
)
...
...
vllm/envs.py
View file @
e6a26ed0
...
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
...
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_USE_FLASHINFER_SAMPLER
:
bool
=
False
VLLM_USE_FLASHINFER_SAMPLER
:
bool
=
False
VLLM_USE_FLASHINFER_REJECTION_SAMPLER
:
bool
=
False
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
e6a26ed0
from
functools
import
cached_property
from
functools
import
cached_property
from
importlib.util
import
find_spec
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.jit
import
torch.jit
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeStochasticBaseSampler
)
SpecDecodeStochasticBaseSampler
)
logger
=
init_logger
(
__name__
)
if
find_spec
(
"flashinfer"
):
"""
Consider utilizing the FlashInfer rejection sampling kernel initially,
as it employs a dedicated kernel rather than relying on
Torch tensor operations. This design choice helps to fuse operations,
reduce memory I/O, and consequently enhances performance.
"""
from
flashinfer.sampling
import
chain_speculative_sampling
else
:
chain_speculative_sampling
=
None
class
RejectionSampler
(
SpecDecodeStochasticBaseSampler
):
class
RejectionSampler
(
SpecDecodeStochasticBaseSampler
):
"""Apply modified rejection sampling as described in "Accelerating Large
"""Apply modified rejection sampling as described in "Accelerating Large
...
@@ -16,7 +32,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -16,7 +32,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
def
__init__
(
self
,
def
__init__
(
self
,
disable_bonus_tokens
:
bool
=
True
,
disable_bonus_tokens
:
bool
=
True
,
strict_mode
:
bool
=
False
):
strict_mode
:
bool
=
False
,
use_flashinfer
:
Optional
[
bool
]
=
None
):
"""Create a rejection sampler.
"""Create a rejection sampler.
Args:
Args:
...
@@ -26,13 +43,29 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -26,13 +43,29 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
strict_mode: Whether or not to perform shape/device/dtype checks
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
during sampling. This catches correctness issues but adds
nontrivial latency.
nontrivial latency.
use_falshinfer: We will use this parameter to determine whether
to use the FlashInfer rejection sampling kernel or not. If it's
None, we will use the default value from the environment variable.
This parameter is only used for testing purposes.
"""
"""
super
().
__init__
(
disable_bonus_tokens
=
disable_bonus_tokens
,
super
().
__init__
(
disable_bonus_tokens
=
disable_bonus_tokens
,
strict_mode
=
strict_mode
)
strict_mode
=
strict_mode
)
if
use_flashinfer
is
None
:
self
.
use_flashinfer
=
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
(
chain_speculative_sampling
is
not
None
)
else
:
self
.
use_flashinfer
=
use_flashinfer
if
self
.
use_flashinfer
:
assert
not
disable_bonus_tokens
,
\
"flashinfer will enable bonus token by default"
logger
.
info
(
"Use flashinfer for rejection sampling."
)
else
:
logger
.
info
(
"Use pytorch for rejection sampling."
)
def
forward
(
def
forward
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
...
@@ -50,9 +83,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -50,9 +83,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
sequence.
sequence.
Args:
Args:
target_probs: The probability distribution
over token ids given
target_
with_bonus_
probs: The probability distribution
context according to the target model.
over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
shape = [batch_size, num_speculative_tokens
+ 1
, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
speculative tokens in a sequence are accepted.
...
@@ -78,23 +111,52 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -78,23 +111,52 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# Only perform shape/dtype/device checking in strict mode, as it adds
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
# overhead.
if
self
.
_strict_mode
:
if
self
.
_strict_mode
:
self
.
_raise_if_incorrect_input
(
target_probs
,
draft_token_ids
,
self
.
_raise_if_incorrect_input
(
target_with_bonus_probs
,
bonus_token_ids
,
draft_probs
)
draft_token_ids
,
bonus_token_ids
,
draft_probs
)
accepted
,
recovered_token_ids
=
(
batch_size
,
k
,
_
=
draft_probs
.
shape
self
.
_batch_modified_rejection_sampling
(
target_probs
,
draft_probs
,
draft_token_ids
,
seeded_seqs
,
))
output_token_ids
=
self
.
_create_output
(
# batch_size = 0 when all requests in the batch are
accepted
,
# non_spec requests. In this case, output_token_ids is
recovered_token_ids
,
# just an empty tensor.
draft_token_ids
,
if
batch_size
==
0
:
bonus_token_ids
,
return
torch
.
empty
(
0
,
k
+
1
,
device
=
draft_probs
.
device
,
dtype
=
int
)
)
# If use Flashinfer chain_speculative_sampling kernel
# for rejection sampling
if
self
.
use_flashinfer
:
batch_size
,
k
,
_
=
draft_probs
.
shape
uniform_samples
=
self
.
_create_uniform_samples
(
seeded_seqs
,
batch_size
,
k
,
draft_probs
.
device
)
output_token_ids
,
accepted_token_num
,
emitted_token_num
\
=
chain_speculative_sampling
(
draft_probs
,
draft_token_ids
,
uniform_samples
,
target_with_bonus_probs
)
# num_emitted_tokens returned by flashinfer
# does not include the bonus token
# Flashinfer stops at the first token that violates
# the condition p >= q and does not include recovery/bonus token.
# Therefore, we need to add batch_size here.
self
.
num_accepted_tokens
+=
accepted_token_num
.
sum
()
self
.
num_emitted_tokens
+=
emitted_token_num
.
sum
()
+
batch_size
self
.
num_draft_tokens
+=
batch_size
*
k
else
:
accepted
,
recovered_token_ids
=
(
self
.
_batch_modified_rejection_sampling
(
target_with_bonus_probs
[:,
:
-
1
],
draft_probs
,
draft_token_ids
,
seeded_seqs
,
))
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
draft_token_ids
,
bonus_token_ids
,
)
return
output_token_ids
return
output_token_ids
...
@@ -135,6 +197,63 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -135,6 +197,63 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
return
accepted
,
recovered_token_ids
return
accepted
,
recovered_token_ids
def
_create_uniform_samples
(
self
,
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]],
batch_size
:
int
,
k
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""
Generates a batch of uniform random samples, with optional seeding
for specific sequences.
This method creates a tensor of shape `(batch_size, k + 1)` filled
with uniform random values in the range [0, 1). If `seeded_seqs`
is provided, the sequences corresponding to specific indices
will be generated using the provided `torch.Generator` for
reproducibility. The other sequences will be generated without
a seed.
Args:
seeded_seqs : Optional[Dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects. If `None`, all samples are
generated without a seed.
batch_size : int
The number of sequences to generate.
k : int
The number of random samples per sequence.
device : torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand : torch.Tensor
A tensor of shape `(batch_size, k + 1)` containing uniform
random values in the range [0, 1).
"""
if
not
seeded_seqs
:
return
torch
.
rand
(
batch_size
,
k
+
1
,
device
=
device
)
uniform_rand
=
torch
.
empty
(
batch_size
,
k
+
1
,
device
=
device
)
non_seeded_indices
=
[]
for
idx
in
range
(
batch_size
):
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
append
(
idx
)
else
:
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
k
+
1
,
dtype
=
self
.
probs_dtype
,
device
=
device
,
generator
=
generator
)
if
non_seeded_indices
:
uniform_rand
[
non_seeded_indices
,
:]
=
torch
.
rand
(
len
(
non_seeded_indices
),
k
+
1
,
dtype
=
self
.
probs_dtype
,
device
=
device
)
return
uniform_rand
def
_get_accepted
(
def
_get_accepted
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
...
@@ -175,29 +294,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -175,29 +294,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
draft_token_ids
]
draft_token_ids
]
if
not
seeded_seqs
:
uniform_rand
=
self
.
_create_uniform_samples
(
seeded_seqs
,
batch_size
,
uniform_rand
=
torch
.
rand_like
(
selected_target_probs
)
k
-
1
,
target_probs
.
device
)
else
:
uniform_rand
=
torch
.
empty_like
(
selected_target_probs
)
non_seeded_indices
=
[]
for
idx
in
range
(
batch_size
):
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
append
(
idx
)
else
:
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
,
generator
=
generator
)
if
non_seeded_indices
:
uniform_rand
[
non_seeded_indices
,
:]
=
torch
.
rand
(
len
(
non_seeded_indices
),
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
)
capped_ratio
=
torch
.
minimum
(
capped_ratio
=
torch
.
minimum
(
selected_target_probs
/
selected_draft_probs
,
selected_target_probs
/
selected_draft_probs
,
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
e6a26ed0
...
@@ -130,29 +130,35 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -130,29 +130,35 @@ class SpecDecodeBaseSampler(nn.Module):
def
_raise_if_incorrect_input
(
def
_raise_if_incorrect_input
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
self
.
_raise_if_incorrect_shape
(
target_probs
,
draft_token_ids
,
self
.
_raise_if_incorrect_shape
(
target_with_bonus_probs
,
bonus_token_ids
,
draft_probs
)
draft_token_ids
,
bonus_token_ids
,
self
.
_raise_if_incorrect_dtype
(
target_probs
,
draft_token_ids
,
draft_probs
)
bonus_token_ids
,
draft_probs
)
self
.
_raise_if_incorrect_dtype
(
target_with_bonus_probs
,
self
.
_raise_if_inconsistent_device
(
target_probs
,
draft_token_ids
,
draft_token_ids
,
bonus_token_ids
,
bonus_token_ids
,
draft_probs
)
draft_probs
)
self
.
_raise_if_out_of_bounds_vocab
(
target_probs
.
shape
[
-
1
],
self
.
_raise_if_inconsistent_device
(
target_with_bonus_probs
,
draft_token_ids
,
bonus_token_ids
,
draft_probs
)
self
.
_raise_if_out_of_bounds_vocab
(
target_with_bonus_probs
.
shape
[
-
1
],
draft_token_ids
,
bonus_token_ids
)
draft_token_ids
,
bonus_token_ids
)
def
_raise_if_incorrect_shape
(
def
_raise_if_incorrect_shape
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
(
target_batch_size
,
num_target_probs
,
(
target_batch_size
,
num_target_probs
,
target_vocab_size
)
=
target_probs
.
shape
target_vocab_size
)
=
target_with_bonus_probs
.
shape
# Does not count the extra token
num_target_probs
-=
1
# validate the shape of draft token ids.
# validate the shape of draft token ids.
draft_token_ids_batch_size
,
num_draft_token_ids
=
draft_token_ids
.
shape
draft_token_ids_batch_size
,
num_draft_token_ids
=
draft_token_ids
.
shape
...
@@ -175,12 +181,12 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -175,12 +181,12 @@ class SpecDecodeBaseSampler(nn.Module):
def
_raise_if_incorrect_dtype
(
def
_raise_if_incorrect_dtype
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
assert
target_probs
.
dtype
==
self
.
probs_dtype
assert
target_
with_bonus_
probs
.
dtype
==
self
.
probs_dtype
assert
draft_token_ids
.
dtype
==
self
.
token_id_dtype
assert
draft_token_ids
.
dtype
==
self
.
token_id_dtype
assert
bonus_token_ids
.
dtype
==
self
.
token_id_dtype
assert
bonus_token_ids
.
dtype
==
self
.
token_id_dtype
if
draft_probs
is
not
None
:
if
draft_probs
is
not
None
:
...
@@ -188,15 +194,16 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -188,15 +194,16 @@ class SpecDecodeBaseSampler(nn.Module):
def
_raise_if_inconsistent_device
(
def
_raise_if_inconsistent_device
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
)
->
None
:
devices
=
[
devices
=
[
t
.
device
for
t
in
t
.
device
for
t
in
[
[
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
]
target_with_bonus_probs
,
bonus_token_ids
,
draft_probs
,
if
t
is
not
None
draft_token_ids
]
if
t
is
not
None
]
]
assert
all
([
devices
[
0
]
==
device
for
device
in
devices
])
assert
all
([
devices
[
0
]
==
device
for
device
in
devices
])
...
@@ -220,7 +227,7 @@ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
...
@@ -220,7 +227,7 @@ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
@
abstractmethod
@
abstractmethod
def
forward
(
def
forward
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
...
@@ -236,7 +243,7 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
...
@@ -236,7 +243,7 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
@
abstractmethod
@
abstractmethod
def
forward
(
def
forward
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
e6a26ed0
...
@@ -41,7 +41,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -41,7 +41,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
def
forward
(
def
forward
(
self
,
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
...
@@ -80,8 +80,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -80,8 +80,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
# Only perform shape/dtype/device checking in strict mode, as it adds
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
# overhead.
if
self
.
_strict_mode
:
if
self
.
_strict_mode
:
self
.
_raise_if_incorrect_input
(
target_probs
,
draft_token_ids
,
self
.
_raise_if_incorrect_input
(
target_with_bonus_probs
,
bonus_token_ids
)
draft_token_ids
,
bonus_token_ids
)
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
draft_token_ids
)
draft_token_ids
)
recovered_token_ids
=
self
.
_replacement_token_ids
(
target_probs
)
recovered_token_ids
=
self
.
_replacement_token_ids
(
target_probs
)
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
e6a26ed0
...
@@ -625,8 +625,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -625,8 +625,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
seq_group_metadata_list
,
proposal_lens_list
)
seq_group_metadata_list
,
proposal_lens_list
)
original_indices
=
spec_indices
+
non_spec_indices
original_indices
=
spec_indices
+
non_spec_indices
# Get probabilities of target model,
ex
cluding bonus token.
# Get probabilities of target model,
in
cluding bonus token
s
.
proposal_verifier_probs
=
proposal_scores
.
probs
[
spec_indices
,
:
-
1
]
proposal_verifier_probs
=
proposal_scores
.
probs
[
spec_indices
]
# Get non-speculative sampled tokens from target model.
# Get non-speculative sampled tokens from target model.
non_spec_token_ids
=
proposal_scores
.
token_ids
[
non_spec_indices
]
non_spec_token_ids
=
proposal_scores
.
token_ids
[
non_spec_indices
]
...
@@ -651,13 +651,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -651,13 +651,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
}
}
accepted_token_ids
=
self
.
spec_decode_sampler
(
accepted_token_ids
=
self
.
spec_decode_sampler
(
target_probs
=
proposal_verifier_probs
,
target_
with_bonus_
probs
=
proposal_verifier_probs
,
bonus_token_ids
=
bonus_token_ids
,
bonus_token_ids
=
bonus_token_ids
,
draft_probs
=
proposal_probs
,
draft_probs
=
proposal_probs
,
draft_token_ids
=
proposal_token_ids
,
draft_token_ids
=
proposal_token_ids
,
**
sampler_extra_kwargs
,
**
sampler_extra_kwargs
,
)
)
# Append output tokens from non-speculative sequences to
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
# the accepted token ids tensor.
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
+
non_spec_token_ids
=
non_spec_token_ids
.
expand
(
-
1
,
max_proposal_len
+
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment