Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
03e822d1
Commit
03e822d1
authored
Nov 27, 2025
by
王敏
Browse files
去掉宽松mtp中的隐式同步
parent
deae0a22
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
6 deletions
+11
-6
vllm/v1/sample/rejection_sampler_opt.py
vllm/v1/sample/rejection_sampler_opt.py
+4
-5
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+7
-1
No files found.
vllm/v1/sample/rejection_sampler_opt.py
View file @
03e822d1
...
@@ -87,12 +87,8 @@ class OptRejectionSampler(nn.Module):
...
@@ -87,12 +87,8 @@ class OptRejectionSampler(nn.Module):
assert
metadata
.
max_spec_len
<=
MAX_SPEC_LEN
assert
metadata
.
max_spec_len
<=
MAX_SPEC_LEN
target_probs
=
target_logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
target_probs
=
target_logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
draft_token_ids
=
metadata
.
draft_token_ids
mask
=
draft_token_ids
.
eq
(
-
1
).
to
(
torch
.
bool
)
draft_token_ids
=
torch
.
where
(
mask
,
0
,
draft_token_ids
).
to
(
torch
.
long
)
# 兼容第一次decode
output_token_ids
=
rejection_sample
(
output_token_ids
=
rejection_sample
(
draft_token_ids
,
metadata
.
draft_token_ids
,
metadata
.
num_draft_tokens
,
metadata
.
num_draft_tokens
,
metadata
.
max_spec_len
,
metadata
.
max_spec_len
,
metadata
.
cu_num_draft_tokens
,
metadata
.
cu_num_draft_tokens
,
...
@@ -225,6 +221,8 @@ def rejection_random_sample_kernel(
...
@@ -225,6 +221,8 @@ def rejection_random_sample_kernel(
for
pos
in
range
(
num_draft_tokens
):
for
pos
in
range
(
num_draft_tokens
):
if
not
rejected
:
if
not
rejected
:
draft_token_id
=
tl
.
load
(
draft_token_ids_ptr
+
start_idx
+
pos
)
draft_token_id
=
tl
.
load
(
draft_token_ids_ptr
+
start_idx
+
pos
)
if
draft_token_id
<
0
:
draft_token_id
=
0
if
NO_DRAFT_PROBS
:
if
NO_DRAFT_PROBS
:
draft_prob
=
1
draft_prob
=
1
else
:
else
:
...
@@ -235,6 +233,7 @@ def rejection_random_sample_kernel(
...
@@ -235,6 +233,7 @@ def rejection_random_sample_kernel(
(
start_idx
+
pos
)
*
vocab_size
+
(
start_idx
+
pos
)
*
vocab_size
+
draft_token_id
)
draft_token_id
)
draft_token_id
=
draft_token_id
.
to
(
tl
.
int64
)
target_token_id
=
tl
.
load
(
target_token_ids_ptr
+
(
start_idx
+
pos
))
target_token_id
=
tl
.
load
(
target_token_ids_ptr
+
(
start_idx
+
pos
))
target_token_id
=
target_token_id
.
to
(
tl
.
int64
)
target_token_id
=
target_token_id
.
to
(
tl
.
int64
)
uniform_prob
=
tl
.
load
(
uniform_probs_ptr
+
start_idx
+
pos
)
uniform_prob
=
tl
.
load
(
uniform_probs_ptr
+
start_idx
+
pos
)
...
...
vllm/v1/spec_decode/utils.py
View file @
03e822d1
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
async_tensor_h2d
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
...
@@ -80,5 +81,10 @@ class DraftProbs(ABC): # type: ignore[call-arg]
...
@@ -80,5 +81,10 @@ class DraftProbs(ABC): # type: ignore[call-arg]
def
get_probs
(
self
,
req_ids
:
list
[
str
]):
def
get_probs
(
self
,
req_ids
:
list
[
str
]):
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
req_ids
]
index
=
[
self
.
_req_ids
.
index
(
req_id
)
for
req_id
in
req_ids
]
return
self
.
draft_probs
[
index
]
index_tensor
=
async_tensor_h2d
(
index
,
dtype
=
torch
.
int32
,
target_device
=
self
.
draft_probs
.
device
,
pin_memory
=
True
)
return
self
.
draft_probs
[
index_tensor
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment