Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a34bff19
Commit
a34bff19
authored
Mar 12, 2026
by
wangmin6
Browse files
Merge branch 'v0.9.2-dev-tx-cpp' into 'v0.9.2-dev'
V0.9.2 dev tx cpp See merge request dcutoolkit/deeplearing/vllm!474
parents
d761561a
badaff2d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
161 additions
and
63 deletions
+161
-63
vllm/envs.py
vllm/envs.py
+4
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+71
-53
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+51
-4
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+35
-6
No files found.
vllm/envs.py
View file @
a34bff19
...
@@ -219,6 +219,7 @@ if TYPE_CHECKING:
...
@@ -219,6 +219,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_SHARED_EXPERTS_FUSION
:
bool
=
False
VLLM_ENABLE_SHARED_EXPERTS_FUSION
:
bool
=
False
VLLM_USE_MOE_W16A16_TRITON
:
bool
=
False
VLLM_USE_MOE_W16A16_TRITON
:
bool
=
False
VLLM_USE_FUSED_DTBMM
:
bool
=
False
VLLM_USE_FUSED_DTBMM
:
bool
=
False
VLLM_FUSE_CAT_AND_CAST_FP8
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
...
@@ -1404,6 +1405,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1404,6 +1405,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_DTBMM"
:
"VLLM_USE_FUSED_DTBMM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_DTBMM"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_DTBMM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
"VLLM_FUSE_CAT_AND_CAST_FP8"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_FUSE_CAT_AND_CAST_FP8"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/v1/attention/backends/mla/common.py
View file @
a34bff19
...
@@ -1036,33 +1036,44 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1036,33 +1036,44 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
use_flash_fp8_arch
=
(
\
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
\
if
envs
.
VLLM_USE_OPT_CAT
:
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
if
k_nope
.
shape
[
0
]
>
1024
:
)
from
vllm.v1.attention.backends.mla.test_concat
import
lightop_concat_prefill_helper
use_fused_fp8_op
=
use_flash_fp8_arch
and
envs
.
VLLM_FUSE_CAT_AND_CAST_FP8
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
dim
=
2
)
k_pe_expanded
=
k_pe
.
expand
(
k_pe
.
shape
[
0
],
self
.
num_heads
,
k_pe
.
shape
[
-
1
])
else
:
if
use_fused_fp8_op
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
from
lightop
import
op
dim
=-
1
)
q
,
k
,
v
=
op
.
ds_fused_qkv_cast_fp8
(
q
,
kv_nope
,
k_pe_expanded
,
self
.
qk_nope_head_dim
,
self
.
v_head_dim
)
else
:
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
k_nope
,
v
=
kv_nope
\
dim
=-
1
)
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
:
if
envs
.
VLLM_USE_OPT_CAT
:
q_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
if
k_nope
.
shape
[
0
]
>
1024
:
k_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
from
vllm.v1.attention.backends.mla.test_concat
import
lightop_concat_prefill_helper
v_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe_expanded
,
dim
=
2
)
descale_shape
=
(
attn_metadata
.
prefill
.
query_start_loc
.
numel
()
-
1
,
q
.
shape
[
1
])
else
:
q_descale
=
q_descale
.
expand
(
descale_shape
)
k
=
torch
.
cat
((
k_nope
,
k_pe_expanded
),
dim
=-
1
)
k_descale
=
k_descale
.
expand
(
descale_shape
)
else
:
v_descale
=
v_descale
.
expand
(
descale_shape
)
k
=
torch
.
cat
((
k_nope
,
k_pe_expanded
),
dim
=-
1
)
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
if
use_flash_fp8_arch
:
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
q_descale
=
None
k_descale
=
None
v_descale
=
None
if
not
use_fused_fp8_op
:
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
attn_output
,
attn_softmax_lse
=
\
attn_output
,
attn_softmax_lse
=
\
self
.
_flash_attn_varlen_diff_headdims
(
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
q
=
q
,
...
@@ -1134,32 +1145,41 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1134,32 +1145,41 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
use_flash_fp8_arch
=
(
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
\
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
if
envs
.
VLLM_USE_OPT_CAT
:
)
if
k_nope
.
shape
[
0
]
>
1024
:
use_fused_fp8_op
=
use_flash_fp8_arch
and
envs
.
VLLM_FUSE_CAT_AND_CAST_FP8
from
vllm.v1.attention.backends.mla.test_concat
import
lightop_concat_prefill_helper
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
if
use_fused_fp8_op
:
dim
=
2
)
from
lightop
import
op
else
:
k_pe_expanded
=
k_pe
.
expand
(
k_pe
.
shape
[
0
],
self
.
num_heads
,
k_pe
.
shape
[
-
1
])
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
q
,
k
,
v
=
op
.
ds_fused_qkv_cast_fp8
(
dim
=-
1
)
q
,
kv_nope
,
k_pe_expanded
,
self
.
qk_nope_head_dim
,
self
.
v_head_dim
)
else
:
else
:
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
k_nope
,
v
=
kv_nope
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
if
envs
.
VLLM_USE_OPT_CAT
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_ATTN_FP8
:
if
k_nope
.
shape
[
0
]
>
1024
:
q_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
from
vllm.v1.attention.backends.mla.test_concat
import
lightop_concat_prefill_helper
k_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
k
=
lightop_concat_prefill_helper
(
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
)),
dim
=
2
)
v_descale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
else
:
descale_shape
=
(
attn_metadata
.
prefill
.
query_start_loc
.
numel
()
-
1
,
q
.
shape
[
1
])
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
q_descale
=
q_descale
.
expand
(
descale_shape
)
else
:
k_descale
=
k_descale
.
expand
(
descale_shape
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
use_flash_fp8_arch
:
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
q_descale
=
None
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
k_descale
=
None
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
v_descale
=
None
if
not
use_fused_fp8_op
:
q
=
q
.
to
(
torch
.
float8_e4m3fn
)
k
=
k
.
to
(
torch
.
float8_e4m3fn
)
v
=
v
.
to
(
torch
.
float8_e4m3fn
)
output
=
self
.
_flash_attn_varlen_diff_headdims
(
output
=
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
...
@@ -1270,7 +1290,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1270,7 +1290,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
has_decode
=
attn_metadata
.
num_decodes
>
0
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
prefill_k_pe
=
k_pe
[
num_decode_tokens
:]
prefill_k_pe
=
k_pe
[
num_decode_tokens
:]
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
decode_q
=
q
[:
num_decode_tokens
]
decode_q
=
q
[:
num_decode_tokens
]
...
@@ -1356,7 +1375,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1356,7 +1375,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False
,
False
,
1e-6
,
1e-6
,
)
)
if
has_prefill
:
if
has_prefill
:
if
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
if
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
prefill_k_c_normed
=
key_normed
[:
num_actual_toks
,
...]
prefill_k_c_normed
=
key_normed
[:
num_actual_toks
,
...]
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
a34bff19
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
import
os
import
os
import
copy
import
copy
import
queue
import
threading
import
gc
import
gc
import
time
import
time
import
weakref
import
weakref
...
@@ -323,6 +325,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -323,6 +325,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device
=
"cpu"
,
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
pin_memory
=
self
.
pin_memory
)
self
.
seq_lens_np
=
self
.
seq_lens_cpu
.
numpy
()
self
.
seq_lens_np
=
self
.
seq_lens_cpu
.
numpy
()
self
.
_recv_queue
:
queue
.
Queue
[
IntermediateTensors
]
=
queue
.
Queue
()
self
.
_recv_thread
:
Optional
[
threading
.
Thread
]
=
None
self
.
_recv_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
self
.
_recv_event
:
Optional
[
torch
.
cuda
.
Event
]
=
None
# Layer pairings for cross-layer KV sharing.
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
# If an Attention layer `layer_name` is in the keys of this dict, it
...
@@ -730,8 +736,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -730,8 +736,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self
.
positions_cpu
[:
total_num_scheduled_tokens
],
self
.
positions_cpu
[:
total_num_scheduled_tokens
],
non_blocking
=
True
)
non_blocking
=
True
)
self
.
query_start_loc
[:
num_reqs
+
1
].
c
opy_
(
current_cpu_slice_clone
=
self
.
query_start_loc
_cpu
[:
num_reqs
+
1
].
c
lone
()
self
.
query_start_loc
_cpu
[:
num_reqs
+
1
],
non_blocking
=
True
)
self
.
query_start_loc
[:
num_reqs
+
1
]
.
copy_
(
current_cpu_slice_clone
,
non_blocking
=
True
)
self
.
seq_lens
[:
num_reqs
].
copy_
(
self
.
seq_lens_cpu
[:
num_reqs
],
self
.
seq_lens
[:
num_reqs
].
copy_
(
self
.
seq_lens_cpu
[:
num_reqs
],
non_blocking
=
True
)
non_blocking
=
True
)
...
@@ -740,7 +746,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -740,7 +746,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Note: pad query_start_loc to be non-decreasing, as kernels
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
# like FlashAttention requires that
self
.
query_start_loc
[
num_reqs
+
1
:].
fill_
(
self
.
query_start_loc
[
num_reqs
+
1
:].
fill_
(
self
.
query_start_loc_cpu
[
num_reqs
].
item
())
current_cpu_slice_clone
[
num_reqs
].
item
())
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
...
@@ -1337,6 +1343,22 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1337,6 +1343,22 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
finished_recving
=
finished_recving
,
finished_recving
=
finished_recving
,
)
)
def
_recv_tensor_dict
(
self
):
with
torch
.
cuda
.
stream
(
self
.
_recv_stream
):
intermediate_tensors
=
IntermediateTensors
(
get_pp_group
().
recv_tensor_dict
(
all_gather_group
=
get_tp_group
(),
)
)
self
.
_recv_event
.
record
(
self
.
_recv_stream
)
self
.
_recv_queue
.
put
(
intermediate_tensors
)
def
_tensor_dict_recv_thread
(
self
):
torch
.
cuda
.
set_device
(
self
.
device
)
self_rank
=
get_pp_group
().
rank_in_group
while
True
:
self
.
_recv_tensor_dict
()
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
...
@@ -1426,9 +1448,18 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1426,9 +1448,18 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
None
intermediate_tensors
=
None
else
:
else
:
def
_recv_tensor_dict
():
return
self
.
_recv_queue
.
get
()
if
self
.
_recv_thread
is
None
:
self
.
_recv_stream
=
torch
.
cuda
.
Stream
()
self
.
_recv_event
=
torch
.
cuda
.
Event
()
self
.
_recv_thread
=
threading
.
Thread
(
target
=
self
.
_tensor_dict_recv_thread
,
daemon
=
True
,
name
=
"pp_recv_thread"
)
self
.
_recv_thread
.
start
()
intermediate_tensors
=
_recv_tensor_dict
()
torch
.
cuda
.
current_stream
().
wait_event
(
self
.
_recv_event
)
intermediate_tensors
=
self
.
sync_and_slice_intermediate_tensors
(
intermediate_tensors
=
self
.
sync_and_slice_intermediate_tensors
(
num_input_tokens
,
intermediate_tensors
,
True
)
num_input_tokens
,
intermediate_tensors
,
True
)
# Some attention backends only support CUDA Graphs in pure decode.
# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, but we
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
# compiled with full CUDA graphs, we have to skip them entirely.
...
@@ -1494,6 +1525,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1494,6 +1525,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
hidden_states
,
aux_hidden_states
=
model_output
hidden_states
,
aux_hidden_states
=
model_output
else
:
else
:
hidden_states
=
model_output
hidden_states
=
model_output
if
isinstance
(
model_output
,
IntermediateTensors
):
residual_clone
=
model_output
.
tensors
[
"residual"
].
clone
()
hidden_states
.
tensors
[
"residual"
]
=
residual_clone
aux_hidden_states
=
None
aux_hidden_states
=
None
# Broadcast PP output for external_launcher (torchrun)
# Broadcast PP output for external_launcher (torchrun)
...
@@ -3290,6 +3324,16 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
...
@@ -3290,6 +3324,16 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
None
intermediate_tensors
=
None
else
:
else
:
def
_recv_tensor_dict
():
return
self
.
_recv_queue
.
get
()
if
self
.
_recv_thread
is
None
:
self
.
_recv_stream
=
torch
.
cuda
.
Stream
()
self
.
_recv_event
=
torch
.
cuda
.
Event
()
self
.
_recv_thread
=
threading
.
Thread
(
target
=
self
.
_tensor_dict_recv_thread
,
daemon
=
True
,
name
=
"pp_recv_thread"
)
self
.
_recv_thread
.
start
()
intermediate_tensors
=
_recv_tensor_dict
()
torch
.
cuda
.
current_stream
().
wait_event
(
self
.
_recv_event
)
intermediate_tensors
=
self
.
sync_and_slice_intermediate_tensors
(
intermediate_tensors
=
self
.
sync_and_slice_intermediate_tensors
(
num_input_tokens
,
intermediate_tensors
,
True
)
num_input_tokens
,
intermediate_tensors
,
True
)
...
@@ -3358,6 +3402,9 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
...
@@ -3358,6 +3402,9 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
hidden_states
,
aux_hidden_states
=
model_output
hidden_states
,
aux_hidden_states
=
model_output
else
:
else
:
hidden_states
=
model_output
hidden_states
=
model_output
if
isinstance
(
model_output
,
IntermediateTensors
):
residual_clone
=
model_output
.
tensors
[
"residual"
].
clone
()
hidden_states
.
tensors
[
"residual"
]
=
residual_clone
aux_hidden_states
=
None
aux_hidden_states
=
None
# Broadcast PP output for external_launcher (torchrun)
# Broadcast PP output for external_launcher (torchrun)
...
...
vllm/v1/worker/gpu_worker.py
View file @
a34bff19
...
@@ -2,7 +2,9 @@
...
@@ -2,7 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""
"""A GPU worker class."""
import
gc
import
gc
import
queue
import
os
import
os
import
threading
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
...
@@ -81,6 +83,10 @@ class Worker(WorkerBase):
...
@@ -81,6 +83,10 @@ class Worker(WorkerBase):
torch_profiler_trace_dir
,
use_gzip
=
True
))
torch_profiler_trace_dir
,
use_gzip
=
True
))
else
:
else
:
self
.
profiler
=
None
self
.
profiler
=
None
self
.
_send_queue
:
queue
.
Queue
[
tuple
[
IntermediateTensors
,
SchedulerOutput
,
torch
.
cuda
.
Event
]]
=
queue
.
Queue
(
1
)
self
.
_send_thread
:
Optional
[
threading
.
Thread
]
=
None
self
.
_send_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
self
.
_send_event
:
Optional
[
torch
.
cuda
.
Event
]
=
None
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
free_bytes_before_sleep
=
torch
.
cuda
.
mem_get_info
()[
0
]
free_bytes_before_sleep
=
torch
.
cuda
.
mem_get_info
()[
0
]
...
@@ -305,16 +311,33 @@ class Worker(WorkerBase):
...
@@ -305,16 +311,33 @@ class Worker(WorkerBase):
def
get_model
(
self
)
->
nn
.
Module
:
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model_runner
.
get_model
()
return
self
.
model_runner
.
get_model
()
def
_send_tensor_dict
(
self
):
intermediate_tensors
,
scheduler_output
,
event
=
self
.
_send_queue
.
get
()
assert
event
is
not
None
# 等待event在GPU执行完成
event
.
synchronize
()
def
_send_tensor_dict
():
get_pp_group
().
send_tensor_dict
(
intermediate_tensors
.
tensors
,
all_gather_group
=
get_tp_group
(),
)
_send_tensor_dict
()
def
_tensor_dict_send_thread
(
self
):
torch
.
cuda
.
set_device
(
self
.
device
)
torch
.
cuda
.
set_stream
(
self
.
_send_stream
)
while
True
:
self
.
_send_tensor_dict
()
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
Optional
[
ModelRunnerOutput
]:
)
->
Optional
[
ModelRunnerOutput
]:
intermediate_tensors
=
None
intermediate_tensors
=
None
if
not
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
IntermediateTensors
(
get_pp_group
().
recv_tensor_dict
(
all_gather_group
=
get_tp_group
()))
if
envs
.
VLLM_ZERO_OVERHEAD
:
if
envs
.
VLLM_ZERO_OVERHEAD
:
use_stream
=
zero_overhead_stream
(
self
.
device
)
use_stream
=
zero_overhead_stream
(
self
.
device
)
with
torch
.
cuda
.
stream
(
use_stream
):
with
torch
.
cuda
.
stream
(
use_stream
):
...
@@ -327,8 +350,14 @@ class Worker(WorkerBase):
...
@@ -327,8 +350,14 @@ class Worker(WorkerBase):
if
parallel_config
.
distributed_executor_backend
!=
"external_launcher"
\
if
parallel_config
.
distributed_executor_backend
!=
"external_launcher"
\
and
not
get_pp_group
().
is_last_rank
:
and
not
get_pp_group
().
is_last_rank
:
assert
isinstance
(
output
,
IntermediateTensors
)
assert
isinstance
(
output
,
IntermediateTensors
)
get_pp_group
().
send_tensor_dict
(
output
.
tensors
,
if
self
.
_send_thread
is
None
:
all_gather_group
=
get_tp_group
())
self
.
_send_stream
=
torch
.
cuda
.
Stream
()
self
.
_send_thread
=
threading
.
Thread
(
target
=
self
.
_tensor_dict_send_thread
,
daemon
=
True
,
name
=
"pp_send_thread"
)
self
.
_send_thread
.
start
()
self
.
_send_event
=
torch
.
cuda
.
Event
()
send_event
=
self
.
_send_event
send_event
.
record
()
self
.
_send_queue
.
put
((
output
,
scheduler_output
,
send_event
))
return
None
return
None
assert
isinstance
(
output
,
ModelRunnerOutput
)
assert
isinstance
(
output
,
ModelRunnerOutput
)
return
output
if
self
.
is_driver_worker
else
None
return
output
if
self
.
is_driver_worker
else
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment