Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ca796e19
Commit
ca796e19
authored
Mar 21, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.1' into v0.8.1-ori
parents
e983c804
61c7a1b8
Changes
130
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
480 additions
and
247 deletions
+480
-247
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+7
-0
vllm/v1/spec_decode/metadata.py
vllm/v1/spec_decode/metadata.py
+61
-0
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+0
-1
vllm/v1/structured_output/__init__.py
vllm/v1/structured_output/__init__.py
+31
-92
vllm/v1/structured_output/backend_types.py
vllm/v1/structured_output/backend_types.py
+89
-0
vllm/v1/structured_output/backend_xgrammar.py
vllm/v1/structured_output/backend_xgrammar.py
+143
-0
vllm/v1/structured_output/grammar.py
vllm/v1/structured_output/grammar.py
+0
-77
vllm/v1/structured_output/request.py
vllm/v1/structured_output/request.py
+12
-6
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+134
-71
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+3
-0
No files found.
vllm/v1/sample/sampler.py
View file @
ca796e19
...
@@ -47,6 +47,11 @@ class Sampler(nn.Module):
...
@@ -47,6 +47,11 @@ class Sampler(nn.Module):
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
)
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
)
# Sample the next token.
# Sample the next token.
sampled
=
self
.
sample
(
logits
,
sampling_metadata
)
sampled
=
self
.
sample
(
logits
,
sampling_metadata
)
# Convert sampled token ids to int64 (long) type to ensure compatibility
# with subsequent operations that may use these values as indices.
# This conversion is necessary because FlashInfer sampling operations
# return int32 (while PyTorch argmax and topk return int64).
sampled
=
sampled
.
long
()
# Gather the logprobs of the topk and sampled token (if requested).
# Gather the logprobs of the topk and sampled token (if requested).
# Get logprobs and rank tensors (if requested)
# Get logprobs and rank tensors (if requested)
...
@@ -139,12 +144,14 @@ class Sampler(nn.Module):
...
@@ -139,12 +144,14 @@ class Sampler(nn.Module):
or sampled tokens (if sampled
or sampled tokens (if sampled
logprobs); 1D token ID tensor
logprobs); 1D token ID tensor
with (num tokens) elements
with (num tokens) elements
Must be int64.
Returns:
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
Sampled token rank tensor, (num tokens)
"""
"""
assert
token_ids
.
dtype
==
torch
.
int64
# Find the topK values.
# Find the topK values.
topk_logprobs
,
topk_indices
=
torch
.
topk
(
logprobs
,
topk_logprobs
,
topk_indices
=
torch
.
topk
(
logprobs
,
num_logprobs
,
num_logprobs
,
...
...
vllm/v1/spec_decode/metadata.py
0 → 100644
View file @
ca796e19
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
numpy
as
np
import
torch
@
dataclass
class
SpecDecodeMetadata
:
# [num_tokens]
draft_token_ids
:
torch
.
Tensor
# [batch_size]
num_draft_tokens
:
list
[
int
]
# [batch_size]
cu_num_draft_tokens
:
torch
.
Tensor
# [num_tokens]
target_logits_indices
:
torch
.
Tensor
# [batch_size]
bonus_logits_indices
:
torch
.
Tensor
# [num_tokens + batch_size]
logits_indices
:
torch
.
Tensor
def
__post_init__
(
self
):
self
.
max_spec_len
=
max
(
self
.
num_draft_tokens
)
@
classmethod
def
make_dummy
(
cls
,
draft_token_ids
:
list
[
list
[
int
]],
device
:
torch
.
device
,
)
->
"SpecDecodeMetadata"
:
batch_size
=
len
(
draft_token_ids
)
num_draft_tokens
=
[
len
(
ids
)
for
ids
in
draft_token_ids
]
flattened_draft_token_ids
=
sum
(
draft_token_ids
,
[])
num_tokens
=
len
(
flattened_draft_token_ids
)
draft_token_ids_tensor
=
torch
.
tensor
(
flattened_draft_token_ids
,
dtype
=
torch
.
int32
,
device
=
device
)
cu_num_draft_tokens
=
np
.
cumsum
(
num_draft_tokens
,
dtype
=
np
.
int32
)
cu_num_draft_tokens_tensor
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
to
(
device
)
target_logits_indices
=
torch
.
zeros
(
num_tokens
,
dtype
=
torch
.
int32
,
device
=
device
)
bonus_logits_indices
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
device
)
logits_indices
=
torch
.
zeros
(
num_tokens
+
batch_size
,
dtype
=
torch
.
int32
,
device
=
device
)
return
cls
(
draft_token_ids
=
draft_token_ids_tensor
,
num_draft_tokens
=
num_draft_tokens
,
cu_num_draft_tokens
=
cu_num_draft_tokens_tensor
,
target_logits_indices
=
target_logits_indices
,
bonus_logits_indices
=
bonus_logits_indices
,
logits_indices
=
logits_indices
,
)
vllm/v1/spec_decode/utils.py
View file @
ca796e19
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
vllm.v1.sample.ops.topk_topp_sampler
import
random_sample
# noqa
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
...
...
vllm/v1/structured_output/__init__.py
View file @
ca796e19
...
@@ -7,75 +7,27 @@ from typing import TYPE_CHECKING, Optional
...
@@ -7,75 +7,27 @@ from typing import TYPE_CHECKING, Optional
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.v1.structured_output.backend_types
import
(
StructuredOutputBackend
,
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
StructuredOutputGrammar
)
from
vllm.utils
import
LazyLoader
from
vllm.v1.structured_output.backend_xgrammar
import
XgrammarBackend
from
vllm.v1.structured_output.grammar
import
Grammar
,
StructuredOutputOptions
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
import
numpy
as
np
import
numpy
as
np
import
numpy.typing
as
npt
import
numpy.typing
as
npt
import
xgrammar
as
xgr
import
torch
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
else
:
xgr
=
LazyLoader
(
"xgr"
,
globals
(),
"xgrammar"
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
StructuredOutputManager
:
class
StructuredOutputManager
:
"""Engine-level manager for structured output requests."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
self
.
backend
:
Optional
[
StructuredOutputBackend
]
=
None
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
init_complete
=
False
self
.
_grammar_bitmask
:
Optional
[
torch
.
Tensor
]
=
None
def
_delayed_init
(
self
):
"""Initialization delayed until we know it is needed."""
tokenizer_group
=
init_tokenizer_from_configs
(
model_config
=
self
.
vllm_config
.
model_config
,
scheduler_config
=
self
.
vllm_config
.
scheduler_config
,
parallel_config
=
self
.
vllm_config
.
parallel_config
,
lora_config
=
self
.
vllm_config
.
lora_config
)
# type: ignore[arg-type]
tokenizer_group
.
ping
()
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
None
)
self
.
vocab_size
=
self
.
vllm_config
.
model_config
.
get_vocab_size
()
if
isinstance
(
tokenizer
,
MistralTokenizer
):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try
:
encoded_vocab
=
[
token
for
token
,
_
in
sorted
(
tokenizer
.
get_vocab
().
items
(),
key
=
lambda
x
:
x
[
1
],
)
]
stop_token_ids
=
None
if
hasattr
(
tokenizer
,
"eos_token_id"
,
)
and
tokenizer
.
eos_token_id
is
not
None
:
stop_token_ids
=
[
tokenizer
.
eos_token_id
]
except
AttributeError
as
e
:
raise
ValueError
(
f
"Cannot get the vocabulary of the tokenizer "
f
"
{
type
(
tokenizer
)
}
. The tokenizer should have a "
"get_vocab method."
)
from
e
tokenizer_info
=
xgr
.
TokenizerInfo
(
encoded_vocab
=
encoded_vocab
,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type
=
xgr
.
VocabType
.
BYTE_FALLBACK
,
vocab_size
=
self
.
vocab_size
,
stop_token_ids
=
stop_token_ids
,
add_prefix_space
=
True
,
)
else
:
tokenizer_info
=
xgr
.
TokenizerInfo
.
from_huggingface
(
tokenizer
,
vocab_size
=
self
.
vocab_size
,
)
self
.
compiler
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
8
)
# The default max_workers if not specified is the number of CPUs * 5,
# The default max_workers if not specified is the number of CPUs * 5,
# which is way too high since these tasks are CPU-bound, not I/O bound.
# which is way too high since these tasks are CPU-bound, not I/O bound.
...
@@ -83,28 +35,30 @@ class StructuredOutputManager:
...
@@ -83,28 +35,30 @@ class StructuredOutputManager:
# compilation, so we set it to half the number of CPUs.
# compilation, so we set it to half the number of CPUs.
max_workers
=
max
(
1
,
(
multiprocessing
.
cpu_count
()
+
1
)
//
2
)
max_workers
=
max
(
1
,
(
multiprocessing
.
cpu_count
()
+
1
)
//
2
)
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
max_workers
)
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
max_workers
)
self
.
_grammar_bitmask
=
xgr
.
allocate_token_bitmask
(
self
.
vllm_config
.
scheduler_config
.
max_num_seqs
,
self
.
vocab_size
,
)
self
.
init_complete
=
True
def
grammar_init
(
self
,
request
:
Request
)
->
None
:
def
grammar_init
(
self
,
request
:
Request
)
->
None
:
if
request
.
structured_output_request
is
None
:
if
request
.
structured_output_request
is
None
:
return
return
# The first time this is called, we need to finish initialization
# Initialize the backend the first time it is needed.
# of xgrammar. We defer it to avoid the import of xgrammar and
#
# initialization cost if it is not going to be used.
# NOTE: We only support a single backend. We do NOT support different
if
not
self
.
init_complete
:
# backends on a per-request basis in V1 (for now, anyway...).
self
.
_delayed_init
()
if
self
.
backend
is
None
:
backend_name
=
request
.
sampling_params
.
guided_decoding
.
backend_name
if
backend_name
==
"xgrammar"
:
self
.
backend
=
XgrammarBackend
(
self
.
vllm_config
)
else
:
raise
ValueError
(
f
"Unsupported structured output backend:
{
backend_name
}
"
)
grammar
:
Future
[
Grammar
]
=
self
.
executor
.
submit
(
grammar
:
Future
[
StructuredOutput
Grammar
]
=
self
.
executor
.
submit
(
self
.
_async_create_grammar
,
request
)
self
.
_async_create_grammar
,
request
,
self
.
backend
)
request
.
structured_output_request
.
grammar
=
grammar
# type: ignore[assignment]
request
.
structured_output_request
.
grammar
=
grammar
# type: ignore[assignment]
def
_async_create_grammar
(
self
,
request
:
Request
)
->
Grammar
:
def
_async_create_grammar
(
self
,
request
:
Request
,
backend
:
StructuredOutputBackend
)
->
StructuredOutputGrammar
:
key
=
request
.
structured_output_request
.
structured_output_key
# type: ignore[union-attr]
key
=
request
.
structured_output_request
.
structured_output_key
# type: ignore[union-attr]
# Note that the request was validated in the engine core client,
# Note that the request was validated in the engine core client,
...
@@ -114,28 +68,8 @@ class StructuredOutputManager:
...
@@ -114,28 +68,8 @@ class StructuredOutputManager:
# though it should be unlikely as we test that up front as well.
# though it should be unlikely as we test that up front as well.
request_type
,
grammar_spec
=
key
request_type
,
grammar_spec
=
key
if
request_type
==
StructuredOutputOptions
.
JSON
:
assert
self
.
backend
is
not
None
# TODO -- allow any_whitespace to be configurable
return
self
.
backend
.
compile_grammar
(
request_type
,
grammar_spec
)
# pending merge of https://github.com/vllm-project/vllm/pull/12744
ctx
=
self
.
compiler
.
compile_json_schema
(
grammar_spec
,
any_whitespace
=
False
)
elif
request_type
==
StructuredOutputOptions
.
JSON_OBJECT
:
ctx
=
self
.
compiler
.
compile_builtin_json_grammar
()
elif
request_type
==
StructuredOutputOptions
.
GRAMMAR
:
ctx
=
self
.
compiler
.
compile_grammar
(
grammar_spec
)
elif
request_type
==
StructuredOutputOptions
.
REGEX
:
ctx
=
self
.
compiler
.
compile_regex
(
grammar_spec
)
else
:
logger
.
error
(
"Validation should have already occurred. "
"Please file an issue."
)
raise
ValueError
(
f
"grammar is not of valid supported types. (
{
request_type
!
s
}
)"
)
return
Grammar
(
matcher
=
xgr
.
GrammarMatcher
(
ctx
),
vocab_size
=
self
.
vocab_size
,
ctx
=
ctx
,
)
def
grammar_bitmask
(
def
grammar_bitmask
(
self
,
self
,
...
@@ -147,6 +81,11 @@ class StructuredOutputManager:
...
@@ -147,6 +81,11 @@ class StructuredOutputManager:
if
not
structured_output_request_ids
:
if
not
structured_output_request_ids
:
return
None
return
None
if
self
.
_grammar_bitmask
is
None
:
assert
self
.
backend
is
not
None
self
.
_grammar_bitmask
=
self
.
backend
.
allocate_token_bitmask
(
self
.
vllm_config
.
scheduler_config
.
max_num_seqs
)
# Fill the bitmask using the index of each request equal to its
# Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of
# position in the batch. Resize the bitmask down to the size of
# the batch.
# the batch.
...
@@ -154,7 +93,7 @@ class StructuredOutputManager:
...
@@ -154,7 +93,7 @@ class StructuredOutputManager:
for
req_id
,
batch_index
in
structured_output_request_ids
.
items
():
for
req_id
,
batch_index
in
structured_output_request_ids
.
items
():
request
=
requests
[
req_id
].
structured_output_request
request
=
requests
[
req_id
].
structured_output_request
assert
request
is
not
None
and
request
.
grammar
is
not
None
assert
request
is
not
None
and
request
.
grammar
is
not
None
if
not
request
.
grammar
.
matcher
.
is_terminated
():
if
not
request
.
grammar
.
is_terminated
():
request
.
grammar
.
fill_bitmask
(
bitmask_tensor
,
batch_index
)
request
.
grammar
.
fill_bitmask
(
bitmask_tensor
,
batch_index
)
if
batch_len
<
self
.
_grammar_bitmask
.
shape
[
0
]:
if
batch_len
<
self
.
_grammar_bitmask
.
shape
[
0
]:
bitmask_tensor
=
self
.
_grammar_bitmask
[:
batch_len
]
bitmask_tensor
=
self
.
_grammar_bitmask
[:
batch_len
]
...
...
vllm/v1/structured_output/backend_types.py
0 → 100644
View file @
ca796e19
# SPDX-License-Identifier: Apache-2.0
import
enum
from
abc
import
ABC
,
abstractmethod
import
torch
class
StructuredOutputOptions
(
enum
.
Enum
):
JSON
=
enum
.
auto
()
JSON_OBJECT
=
enum
.
auto
()
REGEX
=
enum
.
auto
()
GRAMMAR
=
enum
.
auto
()
CHOICE
=
enum
.
auto
()
StructuredOutputKey
=
tuple
[
StructuredOutputOptions
,
str
]
class
StructuredOutputGrammar
(
ABC
):
"""Request-level backend for structured output requests."""
@
abstractmethod
def
accept_tokens
(
self
,
request_id
:
str
,
tokens
:
list
[
int
])
->
bool
:
"""
Determines whether the provided tokens are accepted for the
given request.
Args:
request_id (str): The unique identifier for the request.
tokens (list[int]): A list of token IDs to evaluate.
Returns:
bool: True if the tokens are accepted, False otherwise.
"""
@
abstractmethod
def
fill_bitmask
(
self
,
bitmask
:
torch
.
Tensor
,
batch_index
:
int
)
->
None
:
"""
Fills the bitmask for a specific batch index.
Args:
bitmask (torch.Tensor): The bitmask to fill
batch_index (int): The index in the bitmask to fill
"""
@
abstractmethod
def
is_terminated
(
self
)
->
bool
:
"""
Checks whether the structured output process has terminated.
Returns:
bool: True if the process is terminated, False otherwise.
"""
@
abstractmethod
def
reset
(
self
):
"""
Resets the state of the structured output grammar.
"""
class
StructuredOutputBackend
(
ABC
):
"""Engine-level backend for structured output requests."""
@
abstractmethod
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
"""
Compiles a grammar specification into a structured output grammar.
Args:
request_type (StructuredOutputOptions): The type of structured
output request.
grammar_spec (str): The grammar specification to compile.
Returns:
StructuredOutputGrammar: The compiled structured output grammar.
"""
@
abstractmethod
def
allocate_token_bitmask
(
self
,
max_num_seqs
:
int
):
"""
Allocates a token bitmask for the specified maximum number of sequences.
Args:
max_num_seqs (int): The maximum number of sequences for which
to allocate the bitmask.
"""
vllm/v1/structured_output/backend_xgrammar.py
0 → 100644
View file @
ca796e19
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
import
torch
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
from
vllm.utils
import
LazyLoader
from
vllm.v1.structured_output.backend_types
import
(
StructuredOutputBackend
,
StructuredOutputGrammar
,
StructuredOutputOptions
)
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
else
:
xgr
=
LazyLoader
(
"xgr"
,
globals
(),
"xgrammar"
)
logger
=
init_logger
(
__name__
)
class
XgrammarBackend
(
StructuredOutputBackend
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
self
.
vllm_config
=
vllm_config
tokenizer_group
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
parallel_config
=
vllm_config
.
parallel_config
,
lora_config
=
vllm_config
.
lora_config
)
# type: ignore[arg-type]
tokenizer_group
.
ping
()
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
None
)
self
.
vocab_size
=
vllm_config
.
model_config
.
get_vocab_size
()
if
isinstance
(
tokenizer
,
MistralTokenizer
):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try
:
encoded_vocab
=
[
token
for
token
,
_
in
sorted
(
tokenizer
.
get_vocab
().
items
(),
key
=
lambda
x
:
x
[
1
],
)
]
stop_token_ids
=
None
if
hasattr
(
tokenizer
,
"eos_token_id"
,
)
and
tokenizer
.
eos_token_id
is
not
None
:
stop_token_ids
=
[
tokenizer
.
eos_token_id
]
except
AttributeError
as
e
:
raise
ValueError
(
f
"Cannot get the vocabulary of the tokenizer "
f
"
{
type
(
tokenizer
)
}
. The tokenizer should have a "
"get_vocab method."
)
from
e
tokenizer_info
=
xgr
.
TokenizerInfo
(
# type: ignore
encoded_vocab
=
encoded_vocab
,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type
=
xgr
.
VocabType
.
BYTE_FALLBACK
,
vocab_size
=
self
.
vocab_size
,
stop_token_ids
=
stop_token_ids
,
add_prefix_space
=
True
,
)
else
:
tokenizer_info
=
xgr
.
TokenizerInfo
.
from_huggingface
(
tokenizer
,
vocab_size
=
self
.
vocab_size
,
)
self
.
compiler
=
xgr
.
GrammarCompiler
(
tokenizer_info
,
max_threads
=
8
)
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
if
request_type
==
StructuredOutputOptions
.
JSON
:
ctx
=
self
.
compiler
.
compile_json_schema
(
grammar_spec
,
any_whitespace
=
False
)
elif
request_type
==
StructuredOutputOptions
.
JSON_OBJECT
:
ctx
=
self
.
compiler
.
compile_builtin_json_grammar
()
elif
request_type
==
StructuredOutputOptions
.
GRAMMAR
:
ctx
=
self
.
compiler
.
compile_grammar
(
grammar_spec
)
elif
request_type
==
StructuredOutputOptions
.
REGEX
:
ctx
=
self
.
compiler
.
compile_regex
(
grammar_spec
)
else
:
logger
.
error
(
"Validation should have already occurred. Please file an issue."
)
raise
ValueError
(
f
"grammar is not of valid supported types. (
{
request_type
!
s
}
)"
)
return
XgrammarGrammar
(
matcher
=
xgr
.
GrammarMatcher
(
ctx
),
vocab_size
=
self
.
vocab_size
,
ctx
=
ctx
,
)
def
allocate_token_bitmask
(
self
,
max_num_seqs
:
int
):
return
xgr
.
allocate_token_bitmask
(
max_num_seqs
,
self
.
vocab_size
)
@
dataclass
class
XgrammarGrammar
(
StructuredOutputGrammar
):
# NOTE: This would be a generic-enough class for
# supporting different backends, in the future.
# For now, just xgrammar.
#
# TODO: support max_rollback_tokens
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding
vocab_size
:
int
matcher
:
xgr
.
GrammarMatcher
=
field
(
hash
=
False
)
ctx
:
xgr
.
CompiledGrammar
=
field
(
hash
=
False
)
num_processed_tokens
:
int
=
field
(
default_factory
=
lambda
:
0
,
repr
=
False
,
hash
=
False
,
init
=
False
)
def
accept_tokens
(
self
,
request_id
:
str
,
tokens
:
list
[
int
])
->
bool
:
"""Accepts a list of tokens and advances the FSM.
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
for
token
in
tokens
:
if
not
self
.
matcher
.
accept_token
(
token
):
logger
.
error
(
"Failed to advance FSM for request %s "
"for tokens %s. Please file an issue."
,
request_id
,
token
)
return
False
self
.
num_processed_tokens
+=
1
return
True
def
fill_bitmask
(
self
,
bitmask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
self
.
matcher
.
fill_next_token_bitmask
(
bitmask
,
idx
)
def
is_terminated
(
self
)
->
bool
:
return
self
.
matcher
.
is_terminated
()
def
reset
(
self
):
self
.
num_processed_tokens
=
0
self
.
matcher
.
reset
()
vllm/v1/structured_output/grammar.py
deleted
100644 → 0
View file @
e983c804
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
enum
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
import
torch
from
vllm.logger
import
init_logger
from
vllm.utils
import
LazyLoader
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
else
:
xgr
=
LazyLoader
(
"xgr"
,
globals
(),
"xgrammar"
)
logger
=
init_logger
(
__name__
)
class
StructuredOutputOptions
(
enum
.
Enum
):
JSON
=
enum
.
auto
()
JSON_OBJECT
=
enum
.
auto
()
REGEX
=
enum
.
auto
()
GRAMMAR
=
enum
.
auto
()
CHOICE
=
enum
.
auto
()
StructuredOutputKey
=
tuple
[
StructuredOutputOptions
,
str
]
@
dataclass
class
Grammar
:
# NOTE: This would be a generic-enough class for
# supporting different backends, in the future.
# For now, just xgrammar.
#
# TODO: support max_rollback_tokens
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding
vocab_size
:
int
matcher
:
xgr
.
GrammarMatcher
=
field
(
hash
=
False
)
ctx
:
xgr
.
CompiledGrammar
=
field
(
hash
=
False
)
num_processed_tokens
:
int
=
field
(
default_factory
=
lambda
:
0
,
repr
=
False
,
hash
=
False
,
init
=
False
)
def
accept_tokens
(
self
,
request_id
:
str
,
tokens
:
list
[
int
])
->
bool
:
"""Accepts a list of tokens and advances the FSM.
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
for
token
in
tokens
:
if
not
self
.
matcher
.
accept_token
(
token
):
logger
.
error
(
"Failed to advance FSM for request %s "
"for tokens %s. Please file an issue."
,
request_id
,
token
)
return
False
self
.
num_processed_tokens
+=
1
return
True
def
fill_bitmask
(
self
,
bitmask
:
torch
.
Tensor
,
idx
:
int
)
->
bool
:
return
self
.
matcher
.
fill_next_token_bitmask
(
bitmask
,
idx
)
def
reset
(
self
):
self
.
num_processed_tokens
=
0
self
.
matcher
.
reset
()
def
__copy__
(
self
):
return
Grammar
(
matcher
=
xgr
.
GrammarMatcher
(
self
.
ctx
),
vocab_size
=
self
.
vocab_size
,
ctx
=
self
.
ctx
,
)
vllm/v1/structured_output/request.py
View file @
ca796e19
...
@@ -9,15 +9,17 @@ from concurrent.futures._base import TimeoutError
...
@@ -9,15 +9,17 @@ from concurrent.futures._base import TimeoutError
from
typing
import
Optional
,
Union
,
cast
from
typing
import
Optional
,
Union
,
cast
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.structured_output.grammar
import
(
Grammar
,
StructuredOutputKey
,
from
vllm.v1.structured_output.backend_types
import
(
StructuredOutputGrammar
,
StructuredOutputOptions
)
StructuredOutputKey
,
StructuredOutputOptions
)
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
StructuredOutputRequest
:
class
StructuredOutputRequest
:
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
_grammar
:
Optional
[
Union
[
Future
[
Grammar
],
Grammar
]]
=
None
_grammar
:
Optional
[
Union
[
Future
[
StructuredOutputGrammar
],
StructuredOutputGrammar
]]
=
None
def
_check_grammar_completion
(
self
)
->
bool
:
def
_check_grammar_completion
(
self
)
->
bool
:
# NOTE: We have to lazy import to gate circular imports
# NOTE: We have to lazy import to gate circular imports
...
@@ -37,12 +39,16 @@ class StructuredOutputRequest:
...
@@ -37,12 +39,16 @@ class StructuredOutputRequest:
return
self
.
_check_grammar_completion
()
return
self
.
_check_grammar_completion
()
@
property
@
property
def
grammar
(
self
)
->
Optional
[
Grammar
]:
def
grammar
(
self
)
->
Optional
[
StructuredOutput
Grammar
]:
completed
=
self
.
_check_grammar_completion
()
completed
=
self
.
_check_grammar_completion
()
return
cast
(
Optional
[
Grammar
],
self
.
_grammar
)
if
completed
else
None
return
cast
(
Optional
[
StructuredOutputGrammar
],
self
.
_grammar
)
if
completed
else
None
@
grammar
.
setter
@
grammar
.
setter
def
grammar
(
self
,
grammar
:
Union
[
Grammar
,
Future
[
Grammar
]])
->
None
:
def
grammar
(
self
,
grammar
:
Union
[
StructuredOutputGrammar
,
Future
[
StructuredOutputGrammar
]]
)
->
None
:
self
.
_grammar
=
grammar
self
.
_grammar
=
grammar
@
functools
.
cached_property
@
functools
.
cached_property
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
ca796e19
...
@@ -34,7 +34,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
...
@@ -34,7 +34,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
)
ModelRunnerOutput
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
INVALID_TOKEN_ID
,
RejectionSampler
from
vllm.v1.sample.rejection_sampler
import
RejectionSampler
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.utils
import
is_spec_decode_supported
from
vllm.v1.spec_decode.utils
import
is_spec_decode_supported
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.utils
import
bind_kv_cache
...
@@ -149,7 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -149,7 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
use_spec_decode
=
False
self
.
use_spec_decode
=
False
if
self
.
speculative_config
:
if
self
.
speculative_config
:
self
.
use_spec_decode
=
True
self
.
use_spec_decode
=
True
self
.
rejection_sampler
=
RejectionSampler
()
# TODO: find a better way to check if we are using ngram.
# TODO: find a better way to check if we are using ngram.
assert
self
.
speculative_config
.
ngram_prompt_lookup_min
,
\
assert
self
.
speculative_config
.
ngram_prompt_lookup_min
,
\
"Currently, only ngram spec decode is supported in V1."
"Currently, only ngram spec decode is supported in V1."
...
@@ -162,6 +162,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -162,6 +162,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
speculative_config
.
ngram_prompt_lookup_min
,
self
.
speculative_config
.
ngram_prompt_lookup_min
,
self
.
speculative_config
.
num_speculative_tokens
,
self
.
speculative_config
.
num_speculative_tokens
,
)
)
self
.
rejection_sampler
=
RejectionSampler
()
# Request states.
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
...
@@ -452,7 +453,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -452,7 +453,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_prepare_inputs
(
def
_prepare_inputs
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
tuple
[
FlashAttentionMetadata
,
torch
.
Tensor
]:
)
->
tuple
[
FlashAttentionMetadata
,
torch
.
Tensor
,
Optional
[
SpecDecodeMetadata
]]:
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
assert
total_num_scheduled_tokens
>
0
assert
total_num_scheduled_tokens
>
0
num_reqs
=
self
.
input_batch
.
num_reqs
num_reqs
=
self
.
input_batch
.
num_reqs
...
@@ -577,22 +579,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -577,22 +579,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_spec_decode
=
len
(
use_spec_decode
=
len
(
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
scheduler_output
.
scheduled_spec_decode_tokens
)
>
0
if
use_spec_decode
:
if
not
use_spec_decode
:
logits_indices
=
self
.
_calc_spec_decode_metadata
(
scheduler_output
,
cu_num_tokens
)
else
:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
# TODO: Support prompt logprobs.
logits_indices
=
attn_metadata
.
query_start_loc
[
1
:]
-
1
logits_indices
=
attn_metadata
.
query_start_loc
[
1
:]
-
1
spec_decode_metadata
=
None
else
:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens
=
np
.
zeros
(
num_reqs
,
dtype
=
np
.
int32
)
for
req_id
,
draft_token_ids
in
(
scheduler_output
.
scheduled_spec_decode_tokens
.
items
()):
req_idx
=
self
.
input_batch
.
req_id_to_index
[
req_id
]
num_draft_tokens
[
req_idx
]
=
len
(
draft_token_ids
)
spec_decode_metadata
=
self
.
_calc_spec_decode_metadata
(
num_draft_tokens
,
cu_num_tokens
)
logits_indices
=
spec_decode_metadata
.
logits_indices
# Hot-Swap lora model
# Hot-Swap lora model
if
self
.
lora_config
:
if
self
.
lora_config
:
self
.
set_active_loras
(
self
.
input_batch
,
num_scheduled_tokens
)
self
.
set_active_loras
(
self
.
input_batch
,
num_scheduled_tokens
)
return
attn_metadata
,
logits_indices
return
attn_metadata
,
logits_indices
,
spec_decode_metadata
def
_compute_cascade_attn_prefix_len
(
def
_compute_cascade_attn_prefix_len
(
self
,
self
,
...
@@ -732,50 +745,79 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -732,50 +745,79 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
_calc_spec_decode_metadata
(
def
_calc_spec_decode_metadata
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
num_draft_tokens
:
np
.
ndarray
,
cu_num_tokens
:
np
.
ndarray
,
cu_num_scheduled_tokens
:
np
.
ndarray
,
)
->
torch
.
Tensor
:
)
->
SpecDecodeMetadata
:
# Get the number of spec decode tokens for each request.
# Inputs:
num_reqs
=
self
.
input_batch
.
num_reqs
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
num_spec_decode_tokens
=
np
.
empty
(
num_reqs
,
dtype
=
np
.
int32
)
# num_draft_tokens: [ 3, 0, 2, 0, 1]
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
# Outputs:
num_spec_decode_tokens
[
i
]
=
len
(
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
()))
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
# 206, 207, 208]
# Get spec decode logits indices.
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
# cu_num_tokens: [4, 104, 107, 207, 209]
# num_spec_tokens_list: [3, 0, 2, 0, 1]
# Compute the logits indices.
# num_sampled_tokens: [4, 1, 3, 1, 2]
# [4, 1, 3, 1, 2]
# spec_decode_logits_indices:
num_sampled_tokens
=
num_draft_tokens
+
1
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
# Step 1. [4, 5, 8, 9, 11]
num_sampled_tokens
=
num_spec_decode_tokens
+
1
cu_num_sampled_tokens
=
np
.
cumsum
(
num_sampled_tokens
,
dtype
=
np
.
int32
)
# logits_start_loc: [0, 103, 104, 206, 207]
total_num_sampled_tokens
=
cu_num_sampled_tokens
[
-
1
]
logits_start_loc
=
cu_num_tokens
-
num_sampled_tokens
# Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# [0, 103, 104, 206, 207] ->
cumsums_offsets
=
np
.
repeat
(
cu_num_sampled_tokens
-
num_sampled_tokens
,
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
num_sampled_tokens
)
logits_start_loc
=
np
.
repeat
(
logits_start_loc
,
num_sampled_tokens
)
# Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# The following three lines:
arange
=
self
.
arange_np
[:
total_num_sampled_tokens
]
-
cumsums_offsets
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
logits_indices
=
np
.
repeat
(
cu_num_sampled_tokens
=
np
.
cumsum
(
num_sampled_tokens
)
cu_num_scheduled_tokens
-
num_sampled_tokens
,
num_sampled_tokens
)
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
logits_indices
+=
arange
cumsums_sampled_offsets
=
np
.
repeat
(
cu_num_sampled_tokens
-
num_sampled_tokens
,
num_sampled_tokens
)
# Compute the bonus logits indices.
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
bonus_logits_indices
=
cu_num_sampled_tokens
-
1
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# Compute the draft logits indices.
total_num_sampled_tokens
=
num_sampled_tokens
.
sum
()
# [3, 3, 5, 5, 6]
sampled_arange
=
(
self
.
arange_np
[:
total_num_sampled_tokens
]
-
cu_num_draft_tokens
=
np
.
cumsum
(
num_draft_tokens
,
dtype
=
np
.
int32
)
cumsums_sampled_offsets
)
total_num_draft_tokens
=
cu_num_draft_tokens
[
-
1
]
# [0, 0, 0, 3, 3, 5]
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
cumsums_offsets
=
np
.
repeat
(
cu_num_draft_tokens
-
num_draft_tokens
,
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
num_draft_tokens
)
spec_decode_logits_indices
=
logits_start_loc
+
sampled_arange
# [0, 1, 2, 0, 1, 0]
return
torch
.
from_numpy
(
spec_decode_logits_indices
).
to
(
arange
=
self
.
arange_np
[:
total_num_draft_tokens
]
-
cumsums_offsets
# [0, 0, 0, 5, 5, 9]
target_logits_indices
=
np
.
repeat
(
cu_num_sampled_tokens
-
num_sampled_tokens
,
num_draft_tokens
)
# [0, 1, 2, 5, 6, 9]
target_logits_indices
+=
arange
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
to
(
self
.
device
,
non_blocking
=
True
)
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
device
,
non_blocking
=
True
)
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
draft_token_ids
=
self
.
input_ids
[
logits_indices
]
draft_token_ids
=
draft_token_ids
[
target_logits_indices
+
1
]
metadata
=
SpecDecodeMetadata
(
draft_token_ids
=
draft_token_ids
,
num_draft_tokens
=
num_draft_tokens
.
tolist
(),
cu_num_draft_tokens
=
cu_num_draft_tokens
,
target_logits_indices
=
target_logits_indices
,
bonus_logits_indices
=
bonus_logits_indices
,
logits_indices
=
logits_indices
,
)
return
metadata
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
if
not
scheduled_encoder_inputs
:
if
not
scheduled_encoder_inputs
:
...
@@ -931,7 +973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -931,7 +973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs
=
[]
encoder_outputs
=
[]
# Prepare the decoder inputs.
# Prepare the decoder inputs.
attn_metadata
,
logits_indices
=
self
.
_prepare_inputs
(
scheduler_output
)
attn_metadata
,
logits_indices
,
spec_decode_metadata
=
(
self
.
_prepare_inputs
(
scheduler_output
))
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
if
(
self
.
use_cuda_graph
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
...
@@ -1006,31 +1049,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1006,31 +1049,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Sample the next token and get logprobs if needed.
# Sample the next token and get logprobs if needed.
sampling_metadata
=
self
.
input_batch
.
sampling_metadata
sampling_metadata
=
self
.
input_batch
.
sampling_metadata
if
not
self
.
use_spec_decod
e
:
if
spec_decode_metadata
is
Non
e
:
sampler_output
=
self
.
model
.
sample
(
sampler_output
=
self
.
model
.
sample
(
logits
=
logits
,
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
)
)
else
:
else
:
draft_token_ids
=
[
# TODO(woosuk): Optimize the memory usage.
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
[])
bonus_logits
=
logits
[
spec_decode_metadata
.
bonus_logits_indices
]
for
req_id
in
self
.
input_batch
.
req_ids
]
sample_lens
=
[
len
(
tokens
)
+
1
for
tokens
in
draft_token_ids
]
recover_logits_idx
=
np
.
cumsum
(
sample_lens
)
-
1
target_probs
=
self
.
rejection_sampler
.
compute_probs
(
logits
,
sampling_metadata
,
sample_lens
)
sampler_output
=
self
.
model
.
sample
(
sampler_output
=
self
.
model
.
sample
(
logits
=
logits
[
recover_logits_idx
,
:]
,
logits
=
bonus_
logits
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
)
)
bonus_token_ids
=
sampler_output
.
sampled_token_ids
bonus_token_ids
=
sampler_output
.
sampled_token_ids
# TODO(woosuk): Optimize the memory usage.
target_logits
=
logits
[
spec_decode_metadata
.
target_logits_indices
]
output_token_ids
=
self
.
rejection_sampler
(
output_token_ids
=
self
.
rejection_sampler
(
draft_token_ids
,
spec_decode_metadata
,
None
,
# draft_probs
None
,
# draft_probs
target_logits
,
bonus_token_ids
,
bonus_token_ids
,
target_probs
,
sampling_metadata
,
sampling_metadata
)
)
sampler_output
.
sampled_token_ids
=
output_token_ids
sampler_output
.
sampled_token_ids
=
output_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over
# TODO(woosuk): The following loop can be slow since it iterates over
...
@@ -1066,13 +1107,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1066,13 +1107,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
else
:
else
:
# Includes spec decode tokens.
# Includes spec decode tokens.
valid_mask
=
sampled_token_ids
!=
INVALID_TOKEN_ID
valid_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
gen_lens
=
valid_mask
.
sum
(
dim
=
1
).
tolist
()
sampled_token_ids
,
self
.
input_batch
.
vocab_size
)
# TODO(woosuk): Optimize this.
valid_sampled_token_ids
=
[
seq
.
tolist
()
for
seq
in
sampled_token_ids
[
valid_mask
].
split
(
gen_lens
)
]
if
not
self
.
use_spec_decode
:
if
not
self
.
use_spec_decode
:
spec_token_ids
=
None
spec_token_ids
=
None
...
@@ -1316,6 +1352,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1316,6 +1352,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"initializing the engine."
)
from
e
"initializing the engine."
)
from
e
else
:
else
:
raise
e
raise
e
if
self
.
use_spec_decode
:
draft_token_ids
=
[[
0
]
for
_
in
range
(
num_reqs
)]
dummy_spec_decode_metadata
=
SpecDecodeMetadata
.
make_dummy
(
draft_token_ids
,
self
.
device
)
num_tokens
=
sum
(
len
(
ids
)
for
ids
in
draft_token_ids
)
# draft_probs = torch.randn(
# num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype)
draft_probs
=
None
target_logits
=
torch
.
randn
(
num_tokens
,
logits
.
shape
[
-
1
],
device
=
self
.
device
,
dtype
=
logits
.
dtype
)
# NOTE(woosuk): Here, we should use int32 because the sampler uses
# int32 for bonus_token_ids. If the dtype mismatches, re-compilation
# will occur at runtime.
bonus_token_ids
=
torch
.
zeros
(
num_reqs
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
self
.
rejection_sampler
(
dummy_spec_decode_metadata
,
draft_probs
,
target_logits
,
bonus_token_ids
,
dummy_metadata
,
)
return
sampler_output
return
sampler_output
def
profile_run
(
self
)
->
None
:
def
profile_run
(
self
)
->
None
:
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
ca796e19
...
@@ -410,6 +410,9 @@ class TPUModelRunner:
...
@@ -410,6 +410,9 @@ class TPUModelRunner:
# Do the padding and copy the tensors to the TPU.
# Do the padding and copy the tensors to the TPU.
padded_total_num_scheduled_tokens
=
_get_padded_token_len
(
padded_total_num_scheduled_tokens
=
_get_padded_token_len
(
total_num_scheduled_tokens
)
total_num_scheduled_tokens
)
# Zero out to avoid spurious values from prev iteration (last cp chunk)
self
.
input_ids_cpu
[
total_num_scheduled_tokens
:
padded_total_num_scheduled_tokens
]
=
0
self
.
input_ids
=
self
.
input_ids_cpu
[:
self
.
input_ids
=
self
.
input_ids_cpu
[:
padded_total_num_scheduled_tokens
].
to
(
padded_total_num_scheduled_tokens
].
to
(
self
.
device
)
self
.
device
)
...
...
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment