Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cd4a72a2
Unverified
Commit
cd4a72a2
authored
Feb 17, 2025
by
Woosuk Kwon
Committed by
GitHub
Feb 17, 2025
Browse files
[V1][Spec decode] Move drafter to model runner (#13363)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
6ac485a9
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
84 additions
and
57 deletions
+84
-57
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+7
-0
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+4
-7
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+0
-30
vllm/v1/outputs.py
vllm/v1/outputs.py
+3
-0
vllm/v1/request.py
vllm/v1/request.py
+0
-12
vllm/v1/spec_decode/ngram_proposer.py
vllm/v1/spec_decode/ngram_proposer.py
+15
-8
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+7
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+47
-0
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+1
-0
No files found.
tests/v1/core/test_scheduler.py
View file @
cd4a72a2
...
@@ -203,6 +203,7 @@ def test_schedule_partial_requests():
...
@@ -203,6 +203,7 @@ def test_schedule_partial_requests():
req_ids
=
[
request
.
request_id
for
request
in
requests
],
req_ids
=
[
request
.
request_id
for
request
in
requests
],
req_id_to_index
=
req_to_index
,
req_id_to_index
=
req_to_index
,
sampled_token_ids
=
[[
0
]
for
_
in
range
(
len
(
requests
))],
sampled_token_ids
=
[[
0
]
for
_
in
range
(
len
(
requests
))],
spec_token_ids
=
None
,
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
prompt_logprobs_dict
=
{},
)
)
...
@@ -259,6 +260,7 @@ def test_stop_via_update_from_output():
...
@@ -259,6 +260,7 @@ def test_stop_via_update_from_output():
sampled_token_ids
=
[[
EOS_TOKEN_ID
],
sampled_token_ids
=
[[
EOS_TOKEN_ID
],
[
10
,
[
10
,
11
]],
# First request hits EOS, second continues
11
]],
# First request hits EOS, second continues
spec_token_ids
=
None
,
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{})
prompt_logprobs_dict
=
{})
...
@@ -307,6 +309,7 @@ def test_stop_via_update_from_output():
...
@@ -307,6 +309,7 @@ def test_stop_via_update_from_output():
},
},
sampled_token_ids
=
[[
10
,
42
,
12
],
sampled_token_ids
=
[[
10
,
42
,
12
],
[
13
,
14
]],
# First request hits stop token
[
13
,
14
]],
# First request hits stop token
spec_token_ids
=
None
,
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{})
prompt_logprobs_dict
=
{})
...
@@ -354,6 +357,7 @@ def test_stop_via_update_from_output():
...
@@ -354,6 +357,7 @@ def test_stop_via_update_from_output():
},
},
sampled_token_ids
=
[[
10
,
11
,
12
],
sampled_token_ids
=
[[
10
,
11
,
12
],
[
13
]],
# First request exceeds max_tokens
[
13
]],
# First request exceeds max_tokens
spec_token_ids
=
None
,
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{})
prompt_logprobs_dict
=
{})
...
@@ -394,6 +398,7 @@ def test_stop_via_update_from_output():
...
@@ -394,6 +398,7 @@ def test_stop_via_update_from_output():
req_ids
=
[
requests
[
0
].
request_id
],
req_ids
=
[
requests
[
0
].
request_id
],
req_id_to_index
=
{
requests
[
0
].
request_id
:
0
},
req_id_to_index
=
{
requests
[
0
].
request_id
:
0
},
sampled_token_ids
=
[[
EOS_TOKEN_ID
,
10
,
11
]],
sampled_token_ids
=
[[
EOS_TOKEN_ID
,
10
,
11
]],
spec_token_ids
=
None
,
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{})
prompt_logprobs_dict
=
{})
...
@@ -434,6 +439,7 @@ def test_schedule_concurrent_batches():
...
@@ -434,6 +439,7 @@ def test_schedule_concurrent_batches():
req_ids
=
[
requests
[
0
].
request_id
],
req_ids
=
[
requests
[
0
].
request_id
],
req_id_to_index
=
{
requests
[
0
].
request_id
:
0
},
req_id_to_index
=
{
requests
[
0
].
request_id
:
0
},
sampled_token_ids
=
[[
0
]],
sampled_token_ids
=
[[
0
]],
spec_token_ids
=
None
,
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
prompt_logprobs_dict
=
{},
)
)
...
@@ -450,6 +456,7 @@ def test_schedule_concurrent_batches():
...
@@ -450,6 +456,7 @@ def test_schedule_concurrent_batches():
req_ids
=
[
requests
[
1
].
request_id
],
req_ids
=
[
requests
[
1
].
request_id
],
req_id_to_index
=
{
requests
[
1
].
request_id
:
0
},
req_id_to_index
=
{
requests
[
1
].
request_id
:
0
},
sampled_token_ids
=
[[
0
]],
sampled_token_ids
=
[[
0
]],
spec_token_ids
=
None
,
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
prompt_logprobs_dict
=
{},
)
)
...
...
vllm/v1/core/scheduler.py
View file @
cd4a72a2
...
@@ -474,6 +474,7 @@ class Scheduler:
...
@@ -474,6 +474,7 @@ class Scheduler:
model_runner_output
:
"ModelRunnerOutput"
,
model_runner_output
:
"ModelRunnerOutput"
,
)
->
EngineCoreOutputs
:
)
->
EngineCoreOutputs
:
sampled_token_ids
=
model_runner_output
.
sampled_token_ids
sampled_token_ids
=
model_runner_output
.
sampled_token_ids
spec_token_ids
=
model_runner_output
.
spec_token_ids
logprobs
=
model_runner_output
.
logprobs
logprobs
=
model_runner_output
.
logprobs
prompt_logprobs_dict
=
model_runner_output
.
prompt_logprobs_dict
prompt_logprobs_dict
=
model_runner_output
.
prompt_logprobs_dict
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
...
@@ -530,13 +531,9 @@ class Scheduler:
...
@@ -530,13 +531,9 @@ class Scheduler:
self
.
encoder_cache_manager
.
free_encoder_input
(
self
.
encoder_cache_manager
.
free_encoder_input
(
request
,
input_id
)
request
,
input_id
)
if
request
.
num_computed_tokens
>=
request
.
num_tokens
:
# Add newly generated spec token ids to the request.
# Clear the spec tokens as the request has generated
if
spec_token_ids
is
not
None
:
# a new token. Here, We assume all spec tokens are verified
request
.
spec_token_ids
=
spec_token_ids
[
req_index
]
# if we perform speculative decoding for this request.
# Therefore, we can clear all spec tokens after
# the generation step.
request
.
clear_spec_tokens
()
# Get prompt logprobs for this request.
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
...
...
vllm/v1/engine/core.py
View file @
cd4a72a2
...
@@ -27,7 +27,6 @@ from vllm.v1.executor.abstract import Executor
...
@@ -27,7 +27,6 @@ from vllm.v1.executor.abstract import Executor
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -86,15 +85,6 @@ class EngineCore:
...
@@ -86,15 +85,6 @@ class EngineCore:
self
.
batch_queue_size
)
self
.
batch_queue_size
)
self
.
batch_queue
=
queue
.
Queue
(
self
.
batch_queue_size
)
self
.
batch_queue
=
queue
.
Queue
(
self
.
batch_queue_size
)
# Setup speculative decode.
# TODO: find a better way to check if we are using ngram.
self
.
use_spec_decode
=
False
if
self
.
scheduler
.
speculative_config
:
assert
self
.
scheduler
.
speculative_config
.
ngram_prompt_lookup_min
\
,
"Only ngram spec decode is supported in V1."
self
.
proposer
=
NgramProposer
()
self
.
use_spec_decode
=
True
def
_initialize_kv_caches
(
self
,
def
_initialize_kv_caches
(
self
,
vllm_config
:
VllmConfig
)
->
Tuple
[
int
,
int
]:
vllm_config
:
VllmConfig
)
->
Tuple
[
int
,
int
]:
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -158,9 +148,6 @@ class EngineCore:
...
@@ -158,9 +148,6 @@ class EngineCore:
return
EngineCoreOutputs
(
return
EngineCoreOutputs
(
outputs
=
[],
scheduler_stats
=
self
.
scheduler
.
make_stats
())
outputs
=
[],
scheduler_stats
=
self
.
scheduler
.
make_stats
())
if
self
.
use_spec_decode
:
self
.
propose_tokens
()
scheduler_output
=
self
.
scheduler
.
schedule
()
scheduler_output
=
self
.
scheduler
.
schedule
()
output
=
self
.
model_executor
.
execute_model
(
scheduler_output
)
output
=
self
.
model_executor
.
execute_model
(
scheduler_output
)
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
...
@@ -221,23 +208,6 @@ class EngineCore:
...
@@ -221,23 +208,6 @@ class EngineCore:
def
profile
(
self
,
is_start
:
bool
=
True
):
def
profile
(
self
,
is_start
:
bool
=
True
):
self
.
model_executor
.
profile
(
is_start
)
self
.
model_executor
.
profile
(
is_start
)
def
propose_tokens
(
self
):
assert
self
.
scheduler
.
speculative_config
is
not
None
for
req
in
self
.
scheduler
.
running
:
# Ignore requests that are doing chunked prefill.
if
req
.
num_computed_tokens
<
req
.
num_tokens
-
1
:
continue
# Ignore requests that already have spec tokens.
if
req
.
spec_token_ids
:
continue
spec_tokens
=
self
.
proposer
.
propose
(
req
.
all_token_ids
,
self
.
scheduler
.
speculative_config
.
ngram_prompt_lookup_min
,
self
.
scheduler
.
speculative_config
.
num_speculative_tokens
,
)
if
spec_tokens
:
req
.
append_spec_token_ids
(
spec_tokens
)
def
reset_prefix_cache
(
self
):
def
reset_prefix_cache
(
self
):
self
.
scheduler
.
reset_prefix_cache
()
self
.
scheduler
.
reset_prefix_cache
()
...
...
vllm/v1/outputs.py
View file @
cd4a72a2
...
@@ -67,6 +67,9 @@ class ModelRunnerOutput:
...
@@ -67,6 +67,9 @@ class ModelRunnerOutput:
# each request due to speculative/jump decoding.
# each request due to speculative/jump decoding.
sampled_token_ids
:
List
[
List
[
int
]]
sampled_token_ids
:
List
[
List
[
int
]]
# num_reqs x num_spec_tokens
spec_token_ids
:
Optional
[
List
[
List
[
int
]]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs]
# [num_reqs]
...
...
vllm/v1/request.py
View file @
cd4a72a2
...
@@ -104,18 +104,6 @@ class Request:
...
@@ -104,18 +104,6 @@ class Request:
self
.
_output_token_ids
.
extend
(
token_ids
)
self
.
_output_token_ids
.
extend
(
token_ids
)
self
.
_all_token_ids
.
extend
(
token_ids
)
self
.
_all_token_ids
.
extend
(
token_ids
)
def
append_spec_token_ids
(
self
,
token_ids
:
Union
[
int
,
List
[
int
]],
)
->
None
:
if
isinstance
(
token_ids
,
int
):
self
.
spec_token_ids
.
append
(
token_ids
)
else
:
self
.
spec_token_ids
.
extend
(
token_ids
)
def
clear_spec_tokens
(
self
)
->
None
:
self
.
spec_token_ids
.
clear
()
@
property
@
property
def
num_tokens
(
self
)
->
int
:
def
num_tokens
(
self
)
->
int
:
return
len
(
self
.
_all_token_ids
)
return
len
(
self
.
_all_token_ids
)
...
...
vllm/v1/spec_decode/ngram_proposer.py
View file @
cd4a72a2
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
vllm.v1.utils
import
ConstantList
import
numpy
as
np
class
NgramProposer
:
class
NgramProposer
:
...
@@ -9,8 +9,12 @@ class NgramProposer:
...
@@ -9,8 +9,12 @@ class NgramProposer:
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
def
propose
(
self
,
context_token_ids
:
ConstantList
[
int
],
n
:
int
,
def
propose
(
k
:
int
)
->
Optional
[
List
[
int
]]:
self
,
context_token_ids
:
np
.
ndarray
,
n
:
int
,
k
:
int
,
)
->
Optional
[
np
.
ndarray
]:
"""Proposes the next sequence of tokens based on n-gram pattern
"""Proposes the next sequence of tokens based on n-gram pattern
matching in the context. The function finds matches of the last n
matching in the context. The function finds matches of the last n
tokens in the previous context, and returns k tokens that followed
tokens in the previous context, and returns k tokens that followed
...
@@ -25,8 +29,8 @@ class NgramProposer:
...
@@ -25,8 +29,8 @@ class NgramProposer:
the maximum amount of tokens until the end.
the maximum amount of tokens until the end.
Returns:
Returns:
List[int]
: The sequence of tokens that followed
np.ndarray
: The sequence of tokens that followed
the matched n-gram in the context.
the matched n-gram in the context.
None: If no matching n-gram pattern is found.
None: If no matching n-gram pattern is found.
Example:
Example:
...
@@ -66,9 +70,12 @@ class NgramProposer:
...
@@ -66,9 +70,12 @@ class NgramProposer:
return
lps
return
lps
@
staticmethod
@
staticmethod
def
_find_subarray_kmp
(
context_token_ids
:
ConstantList
[
int
],
n
:
int
,
def
_find_subarray_kmp
(
k
:
int
)
->
Optional
[
List
[
int
]]:
context_token_ids
:
np
.
ndarray
,
context_len
=
len
(
context_token_ids
)
n
:
int
,
k
:
int
,
)
->
Optional
[
np
.
ndarray
]:
context_len
=
context_token_ids
.
shape
[
0
]
assert
n
>
0
assert
n
>
0
pattern
=
context_token_ids
[
-
n
:]
pattern
=
context_token_ids
[
-
n
:]
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
cd4a72a2
...
@@ -78,6 +78,7 @@ class InputBatch:
...
@@ -78,6 +78,7 @@ class InputBatch:
)
)
self
.
token_ids_cpu
=
self
.
token_ids_cpu_tensor
.
numpy
()
self
.
token_ids_cpu
=
self
.
token_ids_cpu_tensor
.
numpy
()
self
.
num_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_tokens_no_spec
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_prompt_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_prompt_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens_cpu
=
np
.
empty
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens_cpu
=
np
.
empty
(
max_num_reqs
,
dtype
=
np
.
int32
)
...
@@ -217,7 +218,11 @@ class InputBatch:
...
@@ -217,7 +218,11 @@ class InputBatch:
end_idx
=
start_idx
+
len
(
request
.
output_token_ids
)
end_idx
=
start_idx
+
len
(
request
.
output_token_ids
)
self
.
token_ids_cpu
[
req_index
,
self
.
token_ids_cpu
[
req_index
,
start_idx
:
end_idx
]
=
request
.
output_token_ids
start_idx
:
end_idx
]
=
request
.
output_token_ids
# Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens.
self
.
num_tokens
[
req_index
]
=
request
.
num_tokens
self
.
num_tokens
[
req_index
]
=
request
.
num_tokens
# Number of tokens without spec decode tokens.
self
.
num_tokens_no_spec
[
req_index
]
=
request
.
num_tokens
self
.
num_computed_tokens_cpu
[
req_index
]
=
request
.
num_computed_tokens
self
.
num_computed_tokens_cpu
[
req_index
]
=
request
.
num_computed_tokens
self
.
block_table
.
add_row
(
req_index
,
request
.
block_ids
)
self
.
block_table
.
add_row
(
req_index
,
request
.
block_ids
)
...
@@ -356,6 +361,8 @@ class InputBatch:
...
@@ -356,6 +361,8 @@ class InputBatch:
self
.
token_ids_cpu
[
empty_index
,
:
num_tokens
]
=
self
.
token_ids_cpu
[
self
.
token_ids_cpu
[
empty_index
,
:
num_tokens
]
=
self
.
token_ids_cpu
[
last_req_index
,
:
num_tokens
]
last_req_index
,
:
num_tokens
]
self
.
num_tokens
[
empty_index
]
=
num_tokens
self
.
num_tokens
[
empty_index
]
=
num_tokens
self
.
num_tokens_no_spec
[
empty_index
]
=
self
.
num_tokens_no_spec
[
last_req_index
]
self
.
num_prompt_tokens
[
empty_index
]
=
self
.
num_prompt_tokens
[
self
.
num_prompt_tokens
[
empty_index
]
=
self
.
num_prompt_tokens
[
last_req_index
]
last_req_index
]
self
.
num_computed_tokens_cpu
[
self
.
num_computed_tokens_cpu
[
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
cd4a72a2
...
@@ -33,6 +33,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
...
@@ -33,6 +33,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
INVALID_TOKEN_ID
from
vllm.v1.sample.rejection_sampler
import
INVALID_TOKEN_ID
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
...
@@ -117,6 +118,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -117,6 +118,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# req_id -> (input_id -> encoder_output)
# req_id -> (input_id -> encoder_output)
self
.
encoder_cache
:
Dict
[
str
,
Dict
[
int
,
torch
.
Tensor
]]
=
{}
self
.
encoder_cache
:
Dict
[
str
,
Dict
[
int
,
torch
.
Tensor
]]
=
{}
# Set up speculative decoding.
self
.
use_spec_decode
=
False
if
self
.
speculative_config
:
# TODO: find a better way to check if we are using ngram.
assert
self
.
speculative_config
.
ngram_prompt_lookup_min
,
\
"Currently, only ngram spec decode is supported in V1."
self
.
drafter
=
NgramProposer
()
self
.
use_spec_decode
=
True
# Request states.
# Request states.
self
.
requests
:
Dict
[
str
,
CachedRequestState
]
=
{}
self
.
requests
:
Dict
[
str
,
CachedRequestState
]
=
{}
# Persistent batch.
# Persistent batch.
...
@@ -367,6 +377,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -367,6 +377,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
input_batch
.
token_ids_cpu
[
self
.
input_batch
.
token_ids_cpu
[
req_index
,
req_index
,
start_token_index
:
end_token_index
]
=
req_data
.
new_token_ids
start_token_index
:
end_token_index
]
=
req_data
.
new_token_ids
self
.
input_batch
.
num_tokens_no_spec
[
req_index
]
=
end_token_index
# Add spec_token_ids to token_ids_cpu.
# Add spec_token_ids to token_ids_cpu.
spec_token_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
spec_token_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
,
[])
req_id
,
[])
...
@@ -1009,15 +1020,51 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1009,15 +1020,51 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
seq
in
sampled_token_ids
[
valid_mask
].
split
(
gen_lens
)
for
seq
in
sampled_token_ids
[
valid_mask
].
split
(
gen_lens
)
]
]
if
not
self
.
use_spec_decode
:
spec_token_ids
=
None
else
:
spec_token_ids
=
self
.
generate_draft_token_ids
(
valid_sampled_token_ids
)
model_runner_output
=
ModelRunnerOutput
(
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
req_ids
,
req_ids
=
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
valid_sampled_token_ids
,
sampled_token_ids
=
valid_sampled_token_ids
,
spec_token_ids
=
spec_token_ids
,
logprobs
=
logprobs_lists
,
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
)
)
return
model_runner_output
return
model_runner_output
def
generate_draft_token_ids
(
self
,
sampled_token_ids
:
List
[
List
[
int
]],
)
->
List
[
List
[
int
]]:
# TODO(woosuk): Optimize.
num_reqs
=
len
(
sampled_token_ids
)
draft_token_ids
:
List
[
List
[
int
]]
=
[]
for
i
in
range
(
num_reqs
):
if
len
(
sampled_token_ids
[
i
])
==
0
:
# Skip speculative decoding.
draft_token_ids
.
append
([])
continue
# Add sampled_token_ids to token_ids_cpu.
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
i
]
end_idx
=
start_idx
+
len
(
sampled_token_ids
[
i
])
self
.
input_batch
.
token_ids_cpu
[
i
,
start_idx
:
end_idx
]
=
sampled_token_ids
[
i
]
drafter_output
=
self
.
drafter
.
propose
(
self
.
input_batch
.
token_ids_cpu
[
i
,
:
end_idx
],
self
.
speculative_config
.
ngram_prompt_lookup_min
,
self
.
speculative_config
.
num_speculative_tokens
,
)
if
drafter_output
is
None
or
len
(
drafter_output
)
==
0
:
draft_token_ids
.
append
([])
else
:
draft_token_ids
.
append
(
drafter_output
.
tolist
())
return
draft_token_ids
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
with
DeviceMemoryProfiler
()
as
m
:
# noqa: SIM117
with
DeviceMemoryProfiler
()
as
m
:
# noqa: SIM117
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
cd4a72a2
...
@@ -696,6 +696,7 @@ class TPUModelRunner:
...
@@ -696,6 +696,7 @@ class TPUModelRunner:
req_ids
=
all_req_ids
,
req_ids
=
all_req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
[[
token_id
]
for
token_id
in
sampled_token_ids
],
sampled_token_ids
=
[[
token_id
]
for
token_id
in
sampled_token_ids
],
spec_token_ids
=
None
,
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
# type: ignore[arg-type]
prompt_logprobs_dict
=
prompt_logprobs_dict
,
# type: ignore[arg-type]
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment