Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
faf71bcd
Unverified
Commit
faf71bcd
authored
Jun 05, 2024
by
Nick Hill
Committed by
GitHub
Jun 05, 2024
Browse files
[Speculative Decoding] Add `ProposerWorkerBase` abstract class (#5252)
parent
f270a395
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
91 additions
and
60 deletions
+91
-60
tests/spec_decode/test_dynamic_spec_decode.py
tests/spec_decode/test_dynamic_spec_decode.py
+2
-2
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+12
-9
tests/spec_decode/test_ngram_worker.py
tests/spec_decode/test_ngram_worker.py
+12
-9
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+1
-1
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+9
-6
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+5
-28
vllm/spec_decode/proposer_worker_base.py
vllm/spec_decode/proposer_worker_base.py
+44
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+3
-2
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+3
-3
No files found.
tests/spec_decode/test_dynamic_spec_decode.py
View file @
faf71bcd
...
@@ -68,13 +68,13 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
...
@@ -68,13 +68,13 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
if
queue_size
<
disable_by_batch_size
:
if
queue_size
<
disable_by_batch_size
:
# Should raise exception when executing the mocked draft model.
# Should raise exception when executing the mocked draft model.
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposer
.
get_
spec_
proposals
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
num_lookahead_slots
=
k
),
)
else
:
else
:
# Should not execute the draft model because spec decode is disabled
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
# for all requests. Accordingly, the proposal length should be 0.
proposals
=
proposer
.
get_proposals
(
proposals
=
proposer
.
get_
spec_
proposals
(
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
num_lookahead_slots
=
k
),
)
...
...
tests/spec_decode/test_multi_step_worker.py
View file @
faf71bcd
...
@@ -307,9 +307,10 @@ def test_draft_proposals_full_speculation_len():
...
@@ -307,9 +307,10 @@ def test_draft_proposals_full_speculation_len():
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_spec_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
),
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -344,9 +345,10 @@ def test_draft_proposals_no_speculations():
...
@@ -344,9 +345,10 @@ def test_draft_proposals_no_speculations():
k
,
k
,
prompt_len
=
prompt_len
)
prompt_len
=
prompt_len
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_spec_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
),
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -415,9 +417,10 @@ def test_draft_proposals_mixed_k():
...
@@ -415,9 +417,10 @@ def test_draft_proposals_mixed_k():
prev_output_token_len
=
prev_output_token_len
,
prev_output_token_len
=
prev_output_token_len
,
)
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_spec_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
k
),
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
tests/spec_decode/test_ngram_worker.py
View file @
faf71bcd
...
@@ -50,9 +50,10 @@ def test_ngram_algo_correctness_for_single_no_match():
...
@@ -50,9 +50,10 @@ def test_ngram_algo_correctness_for_single_no_match():
block_size
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_spec_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
proposal_len
),
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -117,9 +118,10 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
...
@@ -117,9 +118,10 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
block_size
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_spec_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
proposal_len
),
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
@@ -188,9 +190,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
...
@@ -188,9 +190,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
block_size
,
block_size
,
final_prompt_lens
=
final_prompt_lens
)
final_prompt_lens
=
final_prompt_lens
)
proposals
=
proposer
.
get_proposals
(
execute_model_req
=
ExecuteModelRequest
(
proposals
=
proposer
.
get_spec_proposals
(
seq_group_metadata_list
=
seq_group_metadata_list
,
execute_model_req
=
ExecuteModelRequest
(
num_lookahead_slots
=
proposal_len
),
)
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
proposal_len
),
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
...
...
vllm/spec_decode/interfaces.py
View file @
faf71bcd
...
@@ -55,7 +55,7 @@ class SpeculativeScores:
...
@@ -55,7 +55,7 @@ class SpeculativeScores:
class
SpeculativeProposer
(
ABC
):
class
SpeculativeProposer
(
ABC
):
@
abstractmethod
@
abstractmethod
def
get_proposals
(
def
get_
spec_
proposals
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
)
->
SpeculativeProposals
:
...
...
vllm/spec_decode/multi_step_worker.py
View file @
faf71bcd
...
@@ -7,11 +7,12 @@ import torch
...
@@ -7,11 +7,12 @@ import torch
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
class
MultiStepWorker
(
Worker
):
class
MultiStepWorker
(
Worker
,
ProposerWorkerBase
):
"""The MultiStepWorker is equivalent to a Worker except that it allows
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
allocated enough space to store the additional KV. This reduces overhead
...
@@ -33,7 +34,7 @@ class MultiStepWorker(Worker):
...
@@ -33,7 +34,7 @@ class MultiStepWorker(Worker):
super
().
init_device
()
super
().
init_device
()
self
.
_proposer
=
Top1Proposer
(
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
self
.
device
,
self
.
device
,
self
.
vocab_size
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
max_proposal_len
=
self
.
max_model_len
,
...
@@ -92,11 +93,12 @@ class MultiStepWorker(Worker):
...
@@ -92,11 +93,12 @@ class MultiStepWorker(Worker):
speculative tokens per sequence is determined by max_proposal_len.
speculative tokens per sequence is determined by max_proposal_len.
"""
"""
return
self
.
_proposer
.
get_proposals
(
execute_model_req
)
return
self
.
_proposer
.
get_
spec_
proposals
(
execute_model_req
)
@
staticmethod
def
_append_new_tokens
(
def
_append_new_tokens
(
self
,
model_output
:
SamplerOutput
,
model_output
:
List
[
SamplerOutput
]
,
seq_group_metadata_list
:
SequenceGroupMetadata
)
->
None
:
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
)
->
None
:
"""Given model output from a single run, append the tokens to the
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
required if the worker is to perform multiple forward passes.
...
@@ -116,8 +118,9 @@ class MultiStepWorker(Worker):
...
@@ -116,8 +118,9 @@ class MultiStepWorker(Worker):
seq
.
append_token_id
(
token_id
,
token_logprob
.
logprob
)
seq
.
append_token_id
(
token_id
,
token_logprob
.
logprob
)
seq
.
update_num_computed_tokens
(
1
)
seq
.
update_num_computed_tokens
(
1
)
@
staticmethod
def
_shallow_copy_inputs
(
def
_shallow_copy_inputs
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
)
->
List
[
SequenceGroupMetadata
]:
)
->
List
[
SequenceGroupMetadata
]:
"""Copy input data structures to remove side-effects when input data
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
structures are shared with other modules.
...
...
vllm/spec_decode/ngram_worker.py
View file @
faf71bcd
...
@@ -5,15 +5,16 @@ import torch
...
@@ -5,15 +5,16 @@ import torch
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
class
NGramWorker
(
LoraNotSupportedWorkerBase
):
class
NGramWorker
(
NonLLMProposerWorkerBase
,
LoraNotSupportedWorkerBase
):
"""NGramWorker provides a light drafter without need for model.
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implement prompt lookup decoding,
Current NGramWorker only implement prompt lookup decoding,
and in future we may also do RAG type drafter and other scen
e
rios
and in future we may also do RAG type drafter and other scen
a
rios
which don't rely on LLM model to give proposals.
which don't rely on LLM model to give proposals.
"""
"""
...
@@ -38,34 +39,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
...
@@ -38,34 +39,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
# Current only support Top1Proposer
# Current only support Top1Proposer
self
.
_proposer
=
Top1Proposer
(
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
device
=
self
.
device
,
device
=
self
.
device
,
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
)
)
def
set_include_gpu_probs_tensor
(
self
):
# NGram don't need gpu sampler
pass
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
None
:
"""NGram doesn't depend on model execution, just pass this function"""
pass
def
determine_num_available_blocks
(
self
)
->
None
:
"""NGram doesn't depend on model execution, no need to check blocks"""
pass
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""As there is no cache need to handle, just pass this function"""
pass
def
get_cache_block_size_bytes
(
self
):
"""Return the size of a cache block in bytes."""
return
0
def
sampler_output
(
def
sampler_output
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
,
execute_model_req
:
ExecuteModelRequest
,
...
@@ -97,7 +75,6 @@ class NGramWorker(LoraNotSupportedWorkerBase):
...
@@ -97,7 +75,6 @@ class NGramWorker(LoraNotSupportedWorkerBase):
-
1
,
-
1
,
):
):
ngram_tensor
=
input_ids
[
-
ngram_size
:]
ngram_tensor
=
input_ids
[
-
ngram_size
:]
proposal_start_idx
=
None
if
ngram_size
==
1
:
if
ngram_size
==
1
:
# Do not match itself and do not use unfold and all
# Do not match itself and do not use unfold and all
matches
=
(
input_ids
[:
-
1
]
==
ngram_tensor
)
matches
=
(
input_ids
[:
-
1
]
==
ngram_tensor
)
...
@@ -161,7 +138,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
...
@@ -161,7 +138,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
speculative tokens per sequence is determined by max_proposal_len.
speculative tokens per sequence is determined by max_proposal_len.
"""
"""
return
self
.
_proposer
.
get_proposals
(
execute_model_req
)
return
self
.
_proposer
.
get_
spec_
proposals
(
execute_model_req
)
def
_raise_if_unsupported
(
def
_raise_if_unsupported
(
self
,
self
,
...
...
vllm/spec_decode/proposer_worker_base.py
0 → 100644
View file @
faf71bcd
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Tuple
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposer
from
vllm.worker.worker_base
import
WorkerBase
class
ProposerWorkerBase
(
WorkerBase
,
SpeculativeProposer
):
"""Interface for proposer workers"""
@
abstractmethod
def
sampler_output
(
self
,
execute_model_req
:
ExecuteModelRequest
,
sample_len
:
int
,
)
->
Tuple
[
Optional
[
List
[
SamplerOutput
]],
bool
]:
raise
NotImplementedError
def
set_include_gpu_probs_tensor
(
self
):
"""Implementation optional"""
pass
class
NonLLMProposerWorkerBase
(
ProposerWorkerBase
,
ABC
):
"""Proposer worker which does not use a model with kvcache"""
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
"""get_spec_proposals is used to get the proposals"""
return
[]
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""This is never called on the proposer, only the target model"""
raise
NotImplementedError
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
pass
def
get_cache_block_size_bytes
(
self
)
->
int
:
return
0
vllm/spec_decode/spec_decode_worker.py
View file @
faf71bcd
...
@@ -14,6 +14,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
...
@@ -14,6 +14,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.util
import
(
create_sequence_group_output
,
from
vllm.spec_decode.util
import
(
create_sequence_group_output
,
get_all_num_logprobs
,
get_all_seq_ids
,
get_all_num_logprobs
,
get_all_seq_ids
,
get_sampled_token_logprobs
,
nvtx_range
,
get_sampled_token_logprobs
,
nvtx_range
,
...
@@ -117,7 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -117,7 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def
__init__
(
def
__init__
(
self
,
self
,
proposer_worker
:
WorkerBase
,
proposer_worker
:
Proposer
WorkerBase
,
scorer_worker
:
WorkerBase
,
scorer_worker
:
WorkerBase
,
rejection_sampler
:
RejectionSampler
,
rejection_sampler
:
RejectionSampler
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
...
@@ -260,7 +261,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -260,7 +261,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# This is required as if the number of draft model runs changes
# This is required as if the number of draft model runs changes
# dynamically, the non-driver workers won't know unless we perform a
# dynamically, the non-driver workers won't know unless we perform a
# communication to inform the
n
.
# communication to inform the
m
.
broadcast_dict
=
dict
(
broadcast_dict
=
dict
(
num_lookahead_slots
=
num_lookahead_slots
,
num_lookahead_slots
=
num_lookahead_slots
,
disable_all_speculation
=
disable_all_speculation
,
disable_all_speculation
=
disable_all_speculation
,
...
...
vllm/spec_decode/top1_proposer.py
View file @
faf71bcd
...
@@ -6,8 +6,8 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
...
@@ -6,8 +6,8 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.worker.worker_base
import
WorkerBase
class
Top1Proposer
(
SpeculativeProposer
):
class
Top1Proposer
(
SpeculativeProposer
):
...
@@ -29,7 +29,7 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -29,7 +29,7 @@ class Top1Proposer(SpeculativeProposer):
def
__init__
(
def
__init__
(
self
,
self
,
worker
:
WorkerBase
,
worker
:
Proposer
WorkerBase
,
device
:
str
,
device
:
str
,
vocab_size
:
int
,
vocab_size
:
int
,
max_proposal_len
:
Optional
[
int
]
=
None
,
max_proposal_len
:
Optional
[
int
]
=
None
,
...
@@ -39,7 +39,7 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -39,7 +39,7 @@ class Top1Proposer(SpeculativeProposer):
self
.
max_proposal_len
=
max_proposal_len
self
.
max_proposal_len
=
max_proposal_len
self
.
_vocab_size
=
vocab_size
self
.
_vocab_size
=
vocab_size
def
get_proposals
(
def
get_
spec_
proposals
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
,
execute_model_req
:
ExecuteModelRequest
,
)
->
SpeculativeProposals
:
)
->
SpeculativeProposals
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment