Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
89639c96
Commit
89639c96
authored
Dec 23, 2025
by
jujl1
Browse files
feat: triton kernel 实现 update_input
parent
0936ee97
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
15 deletions
+71
-15
vllm/zero_overhead/v1/core.py
vllm/zero_overhead/v1/core.py
+1
-1
vllm/zero_overhead/v1/gpu_model_runner.py
vllm/zero_overhead/v1/gpu_model_runner.py
+70
-14
No files found.
vllm/zero_overhead/v1/core.py
View file @
89639c96
...
@@ -205,7 +205,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -205,7 +205,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_new
=
len
(
generated_token_ids
)
num_new
=
len
(
generated_token_ids
)
if
(
model_runner_output
.
fix_req_ids
and
req_id
in
model_runner_output
.
fix_req_ids
if
(
model_runner_output
.
fix_req_ids
and
req_id
in
model_runner_output
.
fix_req_ids
and
request
.
num_computed_tokens
<
=
request
.
num_prompt_tokens
+
num_new
):
and
request
.
num_computed_tokens
>
=
request
.
num_prompt_tokens
+
num_new
):
req_idx
=
model_runner_output
.
fix_req_ids
.
index
(
req_id
)
req_idx
=
model_runner_output
.
fix_req_ids
.
index
(
req_id
)
num_new
=
len
(
model_runner_output
.
fix_sampled_token_ids
[
req_idx
])
num_new
=
len
(
model_runner_output
.
fix_sampled_token_ids
[
req_idx
])
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
num_new
)
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
num_new
)
...
...
vllm/zero_overhead/v1/gpu_model_runner.py
View file @
89639c96
...
@@ -23,6 +23,46 @@ from vllm.profiler.prof import profile
...
@@ -23,6 +23,46 @@ from vllm.profiler.prof import profile
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.v1.spec_decode.utils
import
DraftProbs
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
fused_last_valid_scatter_kernel
(
last_ids_ptr
,
# [B, T]
input_ids_ptr
,
# [N]
update_req_ptr
,
# [U]
input_pos_ptr
,
# [U]
stride0
,
stride1
,
T
,
BLOCK_T
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
# indices
req_idx
=
tl
.
load
(
update_req_ptr
+
pid
)
input_pos
=
tl
.
load
(
input_pos_ptr
+
pid
)
# load row
offs
=
tl
.
arange
(
0
,
BLOCK_T
)
mask
=
offs
<
T
row_ptr
=
last_ids_ptr
+
req_idx
*
stride0
+
offs
*
stride1
vals
=
tl
.
load
(
row_ptr
,
mask
=
mask
,
other
=-
1
)
# ✅ 正确做法:index reduction
idx
=
tl
.
where
(
vals
!=
-
1
,
offs
,
-
1
)
last_idx
=
tl
.
max
(
idx
,
axis
=
0
)
# load last token
last_val
=
tl
.
load
(
last_ids_ptr
+
req_idx
*
stride0
+
last_idx
*
stride1
,
mask
=
last_idx
>=
0
,
other
=
0
,
)
# scatter
tl
.
store
(
input_ids_ptr
+
input_pos
,
last_val
)
class
V1ZeroModelRunner
(
GPUModelRunner
):
class
V1ZeroModelRunner
(
GPUModelRunner
):
def
__init__
(
self
,
vllm_config
,
device
):
def
__init__
(
self
,
vllm_config
,
device
):
...
@@ -302,18 +342,30 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -302,18 +342,30 @@ class V1ZeroModelRunner(GPUModelRunner):
True
)
True
)
last_draft_token_ids
=
self
.
last_draft_token_ids
.
flatten
().
to
(
torch
.
int
)
last_draft_token_ids
=
self
.
last_draft_token_ids
.
flatten
().
to
(
torch
.
int
)
input_ids
[
input_ids_indices_tensor
]
=
last_draft_token_ids
[
update_req_indices_tensor
]
input_ids
[
input_ids_indices_tensor
]
=
last_draft_token_ids
[
update_req_indices_tensor
]
def
find_last_valid_vectorized
(
tensor
):
"""
def
fused_update_input_ids
(
向量化方法找到每行最后一个非-1元素
last_sampled_token_ids
,
"""
input_ids
,
mask
=
tensor
!=
-
1
update_req_indices
,
reversed_mask
=
mask
.
flip
(
dims
=
[
1
])
# 沿着列方向反转
input_ids_indices
,
_
,
col_indices
=
torch
.
max
(
reversed_mask
.
int
(),
dim
=
1
)
):
original_col_indices
=
tensor
.
size
(
1
)
-
1
-
col_indices
B
,
T
=
last_sampled_token_ids
.
shape
result
=
tensor
[
torch
.
arange
(
tensor
.
size
(
0
)),
original_col_indices
]
U
=
update_req_indices
.
numel
()
all_invalid
=
~
mask
.
any
(
dim
=
1
)
result
[
all_invalid
]
=
-
1
# 或者设置为其他默认值
BLOCK_T
=
1024
return
result
assert
T
<=
BLOCK_T
grid
=
(
U
,)
fused_last_valid_scatter_kernel
[
grid
](
last_sampled_token_ids
,
input_ids
,
update_req_indices
,
input_ids_indices
,
last_sampled_token_ids
.
stride
(
0
),
last_sampled_token_ids
.
stride
(
1
),
T
,
BLOCK_T
=
BLOCK_T
,
)
update_req_indices
=
[]
update_req_indices
=
[]
input_ids_indices
=
[]
input_ids_indices
=
[]
...
@@ -332,8 +384,12 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -332,8 +384,12 @@ class V1ZeroModelRunner(GPUModelRunner):
input_ids_indices_tensor
=
async_tensor_h2d
(
input_ids_indices
,
torch
.
int32
,
input_ids_indices_tensor
=
async_tensor_h2d
(
input_ids_indices
,
torch
.
int32
,
self
.
device
,
self
.
device
,
True
)
True
)
last_sampled_token_ids
=
find_last_valid_vectorized
(
self
.
last_sampled_token_ids
).
flatten
()
fused_update_input_ids
(
input_ids
[
input_ids_indices_tensor
]
=
last_sampled_token_ids
[
update_req_indices_tensor
]
self
.
last_sampled_token_ids
,
input_ids
,
update_req_indices_tensor
,
input_ids_indices_tensor
)
def
propose_draft_token_ids
(
def
propose_draft_token_ids
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment