Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
215f33b0
Commit
215f33b0
authored
Nov 20, 2024
by
王敏
Browse files
[fix]修复单测test_medusa_correctness.py中的错误
parent
0ec3aa4b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
6 deletions
+7
-6
tests/spec_decode/e2e/test_medusa_correctness.py
tests/spec_decode/e2e/test_medusa_correctness.py
+1
-1
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+5
-4
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+1
-1
No files found.
tests/spec_decode/e2e/test_medusa_correctness.py
View file @
215f33b0
...
@@ -36,7 +36,7 @@ SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
...
@@ -36,7 +36,7 @@ SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
MAX_SPEC_TOKENS
=
5
MAX_SPEC_TOKENS
=
5
# precision
# precision
PRECISION
=
"float
32
"
PRECISION
=
"float
16
"
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
vllm/spec_decode/medusa_worker.py
View file @
215f33b0
...
@@ -40,9 +40,10 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
...
@@ -40,9 +40,10 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
self
.
medusa_buffers
=
None
self
.
medusa_buffers
=
None
if
hasattr
(
self
.
model_runner
.
model
,
'medusa_choices'
):
if
hasattr
(
self
.
model_runner
.
model
,
'medusa_choices'
):
self
.
medusa_choices
=
self
.
model_runner
.
model
.
medusa_choices
self
.
medusa_choices
=
self
.
model_runner
.
model
.
medusa_choices
self
.
medusa_buffers
=
self
.
generate_medusa_buffers
(
if
self
.
medusa_choices
is
not
None
:
self
.
medusa_choices
,
device
=
self
.
device
self
.
medusa_buffers
=
self
.
generate_medusa_buffers
(
)
self
.
medusa_choices
,
device
=
self
.
device
)
if
self
.
medusa_buffers
is
None
:
if
self
.
medusa_buffers
is
None
:
self
.
_proposer
=
Top1Proposer
(
self
.
_proposer
=
Top1Proposer
(
...
@@ -342,7 +343,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
...
@@ -342,7 +343,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
# Move the tensors in the dictionary to the specified device
# Move the tensors in the dictionary to the specified device
medusa_buffers
=
{
medusa_buffers
=
{
k
:
v
.
clone
().
to
(
device
)
k
:
(
v
.
clone
().
to
(
device
)
if
k
!=
"tree_position_ids"
else
v
.
clone
())
if
isinstance
(
v
,
torch
.
Tensor
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
torch
.
tensor
(
v
,
device
=
device
)
else
torch
.
tensor
(
v
,
device
=
device
)
for
k
,
v
in
medusa_buffers
.
items
()
for
k
,
v
in
medusa_buffers
.
items
()
...
...
vllm/worker/model_runner.py
View file @
215f33b0
...
@@ -512,7 +512,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -512,7 +512,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data
.
input_positions
[
seq_idx
]
=
list
(
range
(
context_len
,
seq_len
))
inter_data
.
input_positions
[
seq_idx
]
=
list
(
range
(
context_len
,
seq_len
))
if
seq_group_metadata
.
tree_position_ids
is
not
None
:
if
seq_group_metadata
.
tree_position_ids
is
not
None
:
inter_data
.
input_positions
[
seq_idx
]
=
seq_group_metadata
.
tree_position_ids
.
contiguous
().
cpu
().
tolist
()
inter_data
.
input_positions
[
seq_idx
]
=
seq_group_metadata
.
tree_position_ids
.
contiguous
().
tolist
()
inter_data
.
tree_attn_masks
[
seq_idx
]
=
seq_group_metadata
.
tree_attn_masks
inter_data
.
tree_attn_masks
[
seq_idx
]
=
seq_group_metadata
.
tree_attn_masks
inter_data
.
query_lens
[
inter_data
.
query_lens
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment