Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b01c8270
Commit
b01c8270
authored
Apr 27, 2025
by
lizhigong
Browse files
delete triton kernel ,use tensor indices
parent
1ed30424
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
43 deletions
+25
-43
vllm/zero_overhead/v0/llm_engine.py
vllm/zero_overhead/v0/llm_engine.py
+1
-2
vllm/zero_overhead/v0/model_runner.py
vllm/zero_overhead/v0/model_runner.py
+12
-3
vllm/zero_overhead/v0/sampler.py
vllm/zero_overhead/v0/sampler.py
+3
-3
vllm/zero_overhead/v0/update_input.py
vllm/zero_overhead/v0/update_input.py
+0
-34
vllm/zero_overhead/v0/utils.py
vllm/zero_overhead/v0/utils.py
+9
-1
No files found.
vllm/zero_overhead/v0/llm_engine.py
View file @
b01c8270
...
@@ -263,7 +263,6 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -263,7 +263,6 @@ class ZeroOverheadEngine(LLMEngine):
self
.
_skip_scheduling_next_step
=
False
self
.
_skip_scheduling_next_step
=
False
self
.
async_d2h
=
None
self
.
async_d2h
=
None
self
.
last_record
=
None
self
.
last_record
=
None
assert
os
.
environ
.
get
(
'HIP_ALLOC_INITIALIZE'
)
==
'0'
self
.
async_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
async_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
thread_running
=
False
self
.
thread_running
=
False
self
.
q_recorder
=
queue
.
Queue
()
self
.
q_recorder
=
queue
.
Queue
()
...
@@ -410,7 +409,7 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -410,7 +409,7 @@ class ZeroOverheadEngine(LLMEngine):
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list
=
self
.
async_d2h
.
tolist
()
sample_out_list
=
self
.
async_d2h
.
tolist
()
sample_out_ids
=
last_sampler
.
seq_id
.
tolist
()
sample_out_ids
=
last_sampler
.
seq_id
s
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
in
\
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
in
\
zip
(
seq_group_metadata_list
,
output
[
0
],
scheduled_seq_groups
):
zip
(
seq_group_metadata_list
,
output
[
0
],
scheduled_seq_groups
):
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
...
...
vllm/zero_overhead/v0/model_runner.py
View file @
b01c8270
...
@@ -36,11 +36,20 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
...
@@ -36,11 +36,20 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
model_input
=
super
().
build
()
model_input
=
super
().
build
()
last_sampler
=
get_last_sampler
()
last_sampler
=
get_last_sampler
()
if
last_sampler
is
not
None
:
if
last_sampler
is
not
None
:
input_ids
=
async_tensor_h2d
(
self
.
req_ids
,
torch
.
long
,
update_indices
=
[]
select_indices
=
[]
for
i
,
seq_id
in
enumerate
(
self
.
req_ids
):
for
j
,
seq_id_
in
enumerate
(
last_sampler
.
seq_ids
):
if
seq_id
==
seq_id_
:
select_indices
.
append
(
j
)
update_indices
.
append
(
i
)
break
select_indices
=
async_tensor_h2d
(
select_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
self
.
runner
.
pin_memory
)
last_id
s
=
async_tensor_h2d
(
last_sampler
.
seq_id
.
tolist
()
,
torch
.
long
,
update_indice
s
=
async_tensor_h2d
(
update_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
self
.
runner
.
pin_memory
)
UpdateInputTokens
(
model_input
.
input_tokens
,
input_ids
,
last_sampler
.
sampled_token_ids_tensor
,
last_ids
)
if
len
(
select_indices
)
>
0
:
model_input
.
input_tokens
[
update_indices
]
=
last_sampler
.
sampled_token_ids_tensor
[
select_indices
,
0
]
return
model_input
return
model_input
vllm/zero_overhead/v0/sampler.py
View file @
b01c8270
...
@@ -22,7 +22,7 @@ else:
...
@@ -22,7 +22,7 @@ else:
class
SampleRecorder
:
class
SampleRecorder
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
seq_id
:
torch
.
Tensor
=
None
self
.
seq_id
s
:
torch
.
Tensor
=
None
self
.
sampled_token_ids_tensor
:
torch
.
Tensor
=
None
self
.
sampled_token_ids_tensor
:
torch
.
Tensor
=
None
last_sampler
=
None
last_sampler
=
None
...
@@ -275,10 +275,10 @@ def _sample_with_torch(
...
@@ -275,10 +275,10 @@ def _sample_with_torch(
t
:
[]
t
:
[]
for
t
in
SamplingType
for
t
in
SamplingType
}
}
last_sampler
.
seq_id
=
torch
.
zeros
(
len
(
sampling_metadata
.
seq_groups
),
dtype
=
torch
.
int32
)
last_sampler
.
seq_id
s
=
[]
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
last_sampler
.
seq_id
[
i
]
=
seq_group
.
seq_ids
[
0
]
last_sampler
.
seq_id
s
.
append
(
seq_group
.
seq_ids
[
0
]
)
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
...
...
vllm/zero_overhead/v0/update_input.py
deleted
100644 → 0
View file @
1ed30424
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_update_input_tokens
(
sample_output
,
seq_ids
,
input_tokens
,
input_seq_ids
,
BATCH_SIZE1
,
BATCH_SIZE2
,
):
pid
=
tl
.
program_id
(
0
)
if
pid
>=
BATCH_SIZE2
:
return
output_token
=
tl
.
load
(
input_tokens
+
pid
)
_input_seq_id
=
tl
.
load
(
input_seq_ids
+
pid
)
for
i
in
range
(
BATCH_SIZE1
):
_seq_ids
=
tl
.
load
(
seq_ids
+
i
)
if
_seq_ids
==
_input_seq_id
:
output_token
=
tl
.
load
(
sample_output
+
i
)
tl
.
store
(
input_tokens
+
pid
,
output_token
)
_update_input_tokens_ptr
=
None
def
UpdateInputTokens
(
input_tokens
,
input_seq_ids
,
last_sample
,
last_ids
):
global
_update_input_tokens_ptr
grid
=
[
input_seq_ids
.
shape
[
0
],
1
,
1
]
if
_update_input_tokens_ptr
is
None
:
_update_input_tokens_ptr
=
_update_input_tokens
[
grid
](
last_sample
,
last_ids
,
input_tokens
,
input_seq_ids
,
last_ids
.
shape
[
0
],
input_seq_ids
.
shape
[
0
])
else
:
_update_input_tokens_ptr
[
grid
](
last_sample
,
last_ids
,
input_tokens
,
input_seq_ids
,
last_ids
.
shape
[
0
],
input_seq_ids
.
shape
[
0
])
\ No newline at end of file
vllm/zero_overhead/v0/utils.py
View file @
b01c8270
...
@@ -10,3 +10,11 @@ def is_zero_overhead():
...
@@ -10,3 +10,11 @@ def is_zero_overhead():
def
is_zero_no_thread
():
def
is_zero_no_thread
():
return
zero_no_thread
and
zero_overhead
return
zero_no_thread
and
zero_overhead
def
UpdateInputTokens
(
input_tokens
,
last_sample
,
indices
):
global
_update_input_tokens_ptr
grid
=
[
input_tokens
.
shape
[
0
],
1
,
1
]
if
_update_input_tokens_ptr
is
None
:
_update_input_tokens_ptr
=
_update_input_tokens
[
grid
](
last_sample
,
input_tokens
,
indices
,
input_tokens
.
shape
[
0
])
else
:
_update_input_tokens_ptr
[
grid
](
last_sample
,
input_tokens
,
indices
,
input_tokens
.
shape
[
0
])
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment