Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
17dc9c7f
Unverified
Commit
17dc9c7f
authored
Mar 04, 2026
by
Harry Mellor
Committed by
GitHub
Mar 04, 2026
Browse files
[CI] Bump `mypy` version (#34950)
Signed-off-by:
Harry Mellor
<
19981378+hmellor@users.noreply.github.com
>
parent
7eca8591
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
90 additions
and
61 deletions
+90
-61
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
tests/kernels/core/test_pos_encoding.py
tests/kernels/core/test_pos_encoding.py
+2
-5
tests/kernels/core/test_rotary_embedding.py
tests/kernels/core/test_rotary_embedding.py
+2
-2
tests/kernels/mamba/test_mamba_ssm.py
tests/kernels/mamba/test_mamba_ssm.py
+7
-7
tests/kernels/quantization/test_fp8_quant.py
tests/kernels/quantization/test_fp8_quant.py
+3
-3
vllm/config/parallel.py
vllm/config/parallel.py
+13
-3
vllm/distributed/elastic_ep/elastic_state.py
vllm/distributed/elastic_ep/elastic_state.py
+25
-11
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+1
-0
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+2
-0
vllm/v1/attention/backends/gdn_attn.py
vllm/v1/attention/backends/gdn_attn.py
+12
-11
vllm/v1/attention/backends/mamba_attn.py
vllm/v1/attention/backends/mamba_attn.py
+6
-6
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+6
-6
vllm/v1/sample/logits_processor/__init__.py
vllm/v1/sample/logits_processor/__init__.py
+10
-6
No files found.
.pre-commit-config.yaml
View file @
17dc9c7f
...
@@ -55,7 +55,7 @@ repos:
...
@@ -55,7 +55,7 @@ repos:
language
:
python
language
:
python
types_or
:
[
python
,
pyi
]
types_or
:
[
python
,
pyi
]
require_serial
:
true
require_serial
:
true
additional_dependencies
:
[
mypy==1.1
1.1
,
regex
,
types-cachetools
,
types-setuptools
,
types-PyYAML
,
types-requests
,
types-torch
,
pydantic
]
additional_dependencies
:
[
"
mypy
[faster-cache]
==1.1
5.0"
,
regex
,
types-cachetools
,
types-setuptools
,
types-PyYAML
,
types-requests
,
types-torch
,
pydantic
]
-
id
:
mypy-3.10
# TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
-
id
:
mypy-3.10
# TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name
:
Run mypy for Python
3.10
name
:
Run mypy for Python
3.10
entry
:
python tools/pre_commit/mypy.py 1 "3.10"
entry
:
python tools/pre_commit/mypy.py 1 "3.10"
...
...
tests/kernels/core/test_pos_encoding.py
View file @
17dc9c7f
...
@@ -94,12 +94,9 @@ def test_rotary_embedding(
...
@@ -94,12 +94,9 @@ def test_rotary_embedding(
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query_shape
=
tensor_shape_fn
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
query_shape
=
tensor_shape_fn
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
# slice tensor if required, noop otherwise
# slice tensor if required, noop otherwise
query
=
query
[...,
:
head_size
]
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
[...,
:
head_size
]
key
=
key
[...,
:
head_size
]
if
use_key
else
None
key
=
torch
.
randn_like
(
query
)
[...,
:
head_size
]
if
use_key
else
None
# NOTE(woosuk): The reference implementation should be executed first
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
# because the custom kernel is in-place.
...
...
tests/kernels/core/test_rotary_embedding.py
View file @
17dc9c7f
...
@@ -62,7 +62,7 @@ def test_rotary_embedding_opcheck(
...
@@ -62,7 +62,7 @@ def test_rotary_embedding_opcheck(
)
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
query
=
query
[...,
:
head_size
]
query
=
query
[...,
:
head_size
]
key
=
key
[...,
:
head_size
]
if
use_key
else
None
key
=
key
[...,
:
head_size
]
if
key
is
not
None
else
None
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
)
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
)
...
@@ -73,5 +73,5 @@ def test_rotary_embedding_opcheck(
...
@@ -73,5 +73,5 @@ def test_rotary_embedding_opcheck(
rot
,
rot
,
positions
,
positions
,
query
.
flatten
(
start_dim
=-
2
),
query
.
flatten
(
start_dim
=-
2
),
key
.
flatten
(
start_dim
=-
2
)
if
use_key
else
None
,
key
.
flatten
(
start_dim
=-
2
)
if
key
is
not
None
else
None
,
)
)
tests/kernels/mamba/test_mamba_ssm.py
View file @
17dc9c7f
...
@@ -298,13 +298,13 @@ def test_selective_scan(
...
@@ -298,13 +298,13 @@ def test_selective_scan(
C
=
torch
.
randn
(
C_shape
,
device
=
device
,
dtype
=
wtype
if
not
is_variable_C
else
itype
)
C
=
torch
.
randn
(
C_shape
,
device
=
device
,
dtype
=
wtype
if
not
is_variable_C
else
itype
)
C_ref
=
C
.
clone
()
C_ref
=
C
.
clone
()
D
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
torch
.
float32
)
if
has_D
else
None
D
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
torch
.
float32
)
if
has_D
else
None
D_ref
=
D
.
clone
()
D_ref
=
D
.
clone
()
if
D
is
not
None
else
None
z
=
(
z
=
(
torch
.
randn
(
batch_size
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
)
torch
.
randn
(
batch_size
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
)
if
has_z
if
has_z
else
None
else
None
)
)
z_ref
=
z
.
clone
()
if
has_z
else
None
z_ref
=
z
.
clone
()
if
z
is
not
None
else
None
delta_bias
=
(
delta_bias
=
(
(
0.5
*
torch
.
rand
(
dim
,
device
=
device
,
dtype
=
torch
.
float32
))
(
0.5
*
torch
.
rand
(
dim
,
device
=
device
,
dtype
=
torch
.
float32
))
if
has_delta_bias
if
has_delta_bias
...
@@ -493,7 +493,7 @@ def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len):
...
@@ -493,7 +493,7 @@ def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len):
B
[
idx
:
idx
+
1
],
B
[
idx
:
idx
+
1
],
C
[
idx
:
idx
+
1
],
C
[
idx
:
idx
+
1
],
D
=
D
,
D
=
D
,
z
=
z
[
idx
:
idx
+
1
]
if
has_z
else
None
,
z
=
z
[
idx
:
idx
+
1
]
if
z
is
not
None
else
None
,
dt_bias
=
dt_bias
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
dt_softplus
=
True
,
)
)
...
@@ -578,7 +578,7 @@ def test_selective_scan_varlen(
...
@@ -578,7 +578,7 @@ def test_selective_scan_varlen(
C
=
torch
.
randn
(
C_shape
,
device
=
device
,
dtype
=
wtype
if
not
is_variable_C
else
itype
)
C
=
torch
.
randn
(
C_shape
,
device
=
device
,
dtype
=
wtype
if
not
is_variable_C
else
itype
)
C_ref
=
C
.
clone
()
C_ref
=
C
.
clone
()
D
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
torch
.
float32
)
if
has_D
else
None
D
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
torch
.
float32
)
if
has_D
else
None
D_ref
=
D
.
clone
()
D_ref
=
D
.
clone
()
if
D
is
not
None
else
None
z
=
torch
.
randn
(
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
)
z
=
torch
.
randn
(
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
)
z_ref
=
z
.
clone
()
z_ref
=
z
.
clone
()
delta_bias
=
(
delta_bias
=
(
...
@@ -750,7 +750,7 @@ def test_selective_state_update_with_batch_indices(
...
@@ -750,7 +750,7 @@ def test_selective_state_update_with_batch_indices(
B
[:
batch_size
],
B
[:
batch_size
],
C
[:
batch_size
],
C
[:
batch_size
],
D
=
D
,
D
=
D
,
z
=
z
[:
batch_size
],
z
=
z
[:
batch_size
]
if
z
is
not
None
else
None
,
dt_bias
=
dt_bias
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
dt_softplus
=
True
,
)
)
...
@@ -934,7 +934,7 @@ def test_selective_state_update_with_num_accepted_tokens(
...
@@ -934,7 +934,7 @@ def test_selective_state_update_with_num_accepted_tokens(
B
[
global_idx
:
global_idx
+
1
],
B
[
global_idx
:
global_idx
+
1
],
C
[
global_idx
:
global_idx
+
1
],
C
[
global_idx
:
global_idx
+
1
],
D
=
D
,
D
=
D
,
z
=
z
[
global_idx
:
global_idx
+
1
]
if
has_z
else
None
,
z
=
z
[
global_idx
:
global_idx
+
1
]
if
z
is
not
None
else
None
,
dt_bias
=
dt_bias
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
dt_softplus
=
True
,
)
)
...
@@ -1061,7 +1061,7 @@ def test_selective_state_update_varlen_with_num_accepted(
...
@@ -1061,7 +1061,7 @@ def test_selective_state_update_varlen_with_num_accepted(
B
[
global_idx
:
global_idx
+
1
],
B
[
global_idx
:
global_idx
+
1
],
C
[
global_idx
:
global_idx
+
1
],
C
[
global_idx
:
global_idx
+
1
],
D
=
D
,
D
=
D
,
z
=
z
[
global_idx
:
global_idx
+
1
]
if
has_z
else
None
,
z
=
z
[
global_idx
:
global_idx
+
1
]
if
z
is
not
None
else
None
,
dt_bias
=
dt_bias
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
dt_softplus
=
True
,
)
)
...
...
tests/kernels/quantization/test_fp8_quant.py
View file @
17dc9c7f
...
@@ -57,11 +57,11 @@ def opcheck_fp8_quant(
...
@@ -57,11 +57,11 @@ def opcheck_fp8_quant(
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"scale_ub"
,
SCALE_UBS
)
@
pytest
.
mark
.
parametrize
(
"
do_
scale_ub"
,
SCALE_UBS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_dynamic_per_token_fp8_quant
(
def
test_dynamic_per_token_fp8_quant
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
scale_ub
:
bool
,
seed
:
int
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
do_
scale_ub
:
bool
,
seed
:
int
)
->
None
:
)
->
None
:
set_random_seed
(
seed
)
set_random_seed
(
seed
)
...
@@ -70,7 +70,7 @@ def test_dynamic_per_token_fp8_quant(
...
@@ -70,7 +70,7 @@ def test_dynamic_per_token_fp8_quant(
)
# avoid nans
)
# avoid nans
scale_ub
=
(
scale_ub
=
(
torch
.
mean
(
x
).
to
(
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
scale_ub
else
None
torch
.
mean
(
x
).
to
(
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
do_
scale_ub
else
None
)
)
ref_out
,
ref_scales
=
ref_dynamic_per_token_quant
(
x
,
FP8_DTYPE
,
scale_ub
)
ref_out
,
ref_scales
=
ref_dynamic_per_token_quant
(
x
,
FP8_DTYPE
,
scale_ub
)
ops_out
,
ops_scales
=
ops
.
scaled_fp8_quant
(
ops_out
,
ops_scales
=
ops
.
scaled_fp8_quant
(
...
...
vllm/config/parallel.py
View file @
17dc9c7f
...
@@ -3,11 +3,11 @@
...
@@ -3,11 +3,11 @@
import
os
import
os
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
overload
import
torch
import
torch
from
pydantic
import
Field
,
field_validator
,
model_validator
from
pydantic
import
Field
,
field_validator
,
model_validator
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
torch.distributed
import
ProcessGroup
,
ReduceOp
,
Store
from
typing_extensions
import
Self
from
typing_extensions
import
Self
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -507,7 +507,17 @@ class ParallelConfig:
...
@@ -507,7 +507,17 @@ class ParallelConfig:
def
get_next_stateless_eplb_group_port
(
self
)
->
list
[
int
]:
def
get_next_stateless_eplb_group_port
(
self
)
->
list
[
int
]:
return
self
.
_stateless_eplb_group_port_list
.
pop
()
return
self
.
_stateless_eplb_group_port_list
.
pop
()
def
stateless_init_dp_group
(
self
,
return_store
:
bool
=
False
)
->
ProcessGroup
:
@
overload
def
stateless_init_dp_group
(
self
,
return_store
:
Literal
[
False
]
=
...
)
->
ProcessGroup
:
...
@
overload
def
stateless_init_dp_group
(
self
,
return_store
:
Literal
[
True
]
=
...
)
->
tuple
[
ProcessGroup
,
Store
]:
...
def
stateless_init_dp_group
(
self
,
return_store
:
bool
=
False
)
->
ProcessGroup
|
tuple
[
ProcessGroup
,
Store
]:
# NOTE: In high-concurrency scenarios multiple processes
# NOTE: In high-concurrency scenarios multiple processes
# can pick the same (currently free) port through a race
# can pick the same (currently free) port through a race
# condition when calling `get_open_port()`. When the first
# condition when calling `get_open_port()`. When the first
...
...
vllm/distributed/elastic_ep/elastic_state.py
View file @
17dc9c7f
...
@@ -4,7 +4,7 @@ import enum
...
@@ -4,7 +4,7 @@ import enum
import
time
import
time
import
weakref
import
weakref
from
datetime
import
timedelta
from
datetime
import
timedelta
from
typing
import
TYPE_CHECKING
,
Literal
from
typing
import
TYPE_CHECKING
,
Literal
,
TypeAlias
import
torch.distributed
import
torch.distributed
...
@@ -61,6 +61,14 @@ class ScaleDownRemovingEngineState(enum.IntEnum):
...
@@ -61,6 +61,14 @@ class ScaleDownRemovingEngineState(enum.IntEnum):
COMPLETE
=
2
COMPLETE
=
2
EngineState
:
TypeAlias
=
(
ScaleUpExistingEngineState
|
ScaleUpNewEngineState
|
ScaleDownRemainingEngineState
|
ScaleDownRemovingEngineState
)
class
_BarrierTimeoutError
(
RuntimeError
):
class
_BarrierTimeoutError
(
RuntimeError
):
"""
"""
Exception raised for timeout
Exception raised for timeout
...
@@ -87,14 +95,13 @@ class ElasticEPScalingState:
...
@@ -87,14 +95,13 @@ class ElasticEPScalingState:
self
.
old_dp_group
=
self
.
engine_core
.
dp_group
if
worker_type
!=
"new"
else
None
self
.
old_dp_group
=
self
.
engine_core
.
dp_group
if
worker_type
!=
"new"
else
None
self
.
old_dp_store
=
self
.
engine_core
.
dp_store
if
worker_type
!=
"new"
else
None
self
.
old_dp_store
=
self
.
engine_core
.
dp_store
if
worker_type
!=
"new"
else
None
self
.
new_parallel_config
:
ParallelConfig
=
new_parallel_config
self
.
new_parallel_config
:
ParallelConfig
=
new_parallel_config
self
.
new_dp_group
:
torch
.
distributed
.
ProcessGroup
|
None
=
(
self
.
new_dp_group
=
self
.
engine_core
.
dp_group
if
worker_type
==
"new"
else
None
self
.
engine_core
.
dp_group
if
worker_type
==
"new"
else
None
)
self
.
new_dp_store
=
self
.
engine_core
.
dp_store
if
worker_type
==
"new"
else
None
self
.
new_dp_store
=
self
.
engine_core
.
dp_store
if
worker_type
==
"new"
else
None
self
.
worker_type
=
worker_type
self
.
worker_type
=
worker_type
self
.
scale_type
=
scale_type
self
.
scale_type
=
scale_type
self
.
reconfig_request
=
reconfig_request
self
.
reconfig_request
=
reconfig_request
self
.
state
:
EngineState
if
scale_type
==
"scale_up"
:
if
scale_type
==
"scale_up"
:
self
.
state
=
(
self
.
state
=
(
ScaleUpNewEngineState
.
PREPARE
ScaleUpNewEngineState
.
PREPARE
...
@@ -182,9 +189,9 @@ class ElasticEPScalingState:
...
@@ -182,9 +189,9 @@ class ElasticEPScalingState:
engine step, and will synchronize with the other EngineCores in the
engine step, and will synchronize with the other EngineCores in the
next step with a barrier without timeout.
next step with a barrier without timeout.
"""
"""
dp_store
=
self
.
new_dp_store
if
use_new_group
else
self
.
old_dp_store
dp_group
=
self
.
new_dp_group
if
use_new_group
else
self
.
old_dp_group
dp_group
=
self
.
new_dp_group
if
use_new_group
else
self
.
old_dp_group
assert
dp_group
is
not
None
dp_store
=
self
.
new_dp_store
if
use_new_group
else
self
.
old_dp_store
assert
dp_group
is
not
None
and
dp_store
is
not
None
group_rank
=
dp_group
.
rank
()
group_rank
=
dp_group
.
rank
()
group_size
=
dp_group
.
size
()
group_size
=
dp_group
.
size
()
...
@@ -212,6 +219,7 @@ class ElasticEPScalingState:
...
@@ -212,6 +219,7 @@ class ElasticEPScalingState:
def
_progress_existing_engine
(
self
)
->
bool
:
def
_progress_existing_engine
(
self
)
->
bool
:
state
=
self
.
state
state
=
self
.
state
assert
self
.
old_dp_group
is
not
None
and
self
.
old_dp_store
is
not
None
if
state
==
ScaleUpExistingEngineState
.
WAIT_NEW_CORE_ENGINES_INIT
:
if
state
==
ScaleUpExistingEngineState
.
WAIT_NEW_CORE_ENGINES_INIT
:
return
False
return
False
...
@@ -265,11 +273,12 @@ class ElasticEPScalingState:
...
@@ -265,11 +273,12 @@ class ElasticEPScalingState:
elif
state
==
ScaleUpExistingEngineState
.
SWITCH_AND_PREPARE
:
elif
state
==
ScaleUpExistingEngineState
.
SWITCH_AND_PREPARE
:
self
.
_switch_and_prepare
()
self
.
_switch_and_prepare
()
self
.
state
=
ScaleUpExistingEngineState
.
EPLB_RESHUFFLE
self
.
state
=
ScaleUpExistingEngineState
.
EPLB_RESHUFFLE
assert
self
.
new_dp_store
is
not
None
self
.
new_dp_store
.
add
(
"eep_barrier_engine_count"
,
1
)
self
.
new_dp_store
.
add
(
"eep_barrier_engine_count"
,
1
)
return
True
return
True
elif
state
==
ScaleUpExistingEngineState
.
EPLB_RESHUFFLE
:
elif
state
==
ScaleUpExistingEngineState
.
EPLB_RESHUFFLE
:
assert
self
.
new_dp_group
is
not
None
assert
self
.
new_dp_group
is
not
None
and
self
.
new_dp_store
is
not
None
if
(
if
(
int
(
self
.
new_dp_store
.
get
(
"eep_barrier_engine_count"
))
int
(
self
.
new_dp_store
.
get
(
"eep_barrier_engine_count"
))
<
self
.
new_dp_group
.
size
()
<
self
.
new_dp_group
.
size
()
...
@@ -292,7 +301,7 @@ class ElasticEPScalingState:
...
@@ -292,7 +301,7 @@ class ElasticEPScalingState:
def
_progress_new_engine
(
self
)
->
bool
:
def
_progress_new_engine
(
self
)
->
bool
:
state
=
self
.
state
state
=
self
.
state
assert
self
.
new_dp_group
is
not
None
assert
self
.
new_dp_group
is
not
None
and
self
.
new_dp_store
is
not
None
if
state
==
ScaleUpNewEngineState
.
PREPARE
:
if
state
==
ScaleUpNewEngineState
.
PREPARE
:
tensor
=
torch
.
tensor
([
0
,
0
,
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
tensor
=
torch
.
tensor
([
0
,
0
,
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
...
@@ -330,6 +339,7 @@ class ElasticEPScalingState:
...
@@ -330,6 +339,7 @@ class ElasticEPScalingState:
def
_progress_remaining_engine
(
self
)
->
bool
:
def
_progress_remaining_engine
(
self
)
->
bool
:
state
=
self
.
state
state
=
self
.
state
assert
self
.
old_dp_group
is
not
None
and
self
.
old_dp_store
is
not
None
if
state
==
ScaleDownRemainingEngineState
.
PREPARE
:
if
state
==
ScaleDownRemainingEngineState
.
PREPARE
:
self
.
state
=
ScaleDownRemainingEngineState
.
EPLB_RESHUFFLE
self
.
state
=
ScaleDownRemainingEngineState
.
EPLB_RESHUFFLE
...
@@ -369,6 +379,7 @@ class ElasticEPScalingState:
...
@@ -369,6 +379,7 @@ class ElasticEPScalingState:
def
_progress_removing_engine
(
self
)
->
bool
:
def
_progress_removing_engine
(
self
)
->
bool
:
state
=
self
.
state
state
=
self
.
state
assert
self
.
old_dp_group
is
not
None
and
self
.
old_dp_store
is
not
None
if
state
==
ScaleDownRemovingEngineState
.
PREPARE
:
if
state
==
ScaleDownRemovingEngineState
.
PREPARE
:
self
.
state
=
ScaleDownRemovingEngineState
.
EPLB_RESHUFFLE
self
.
state
=
ScaleDownRemovingEngineState
.
EPLB_RESHUFFLE
...
@@ -401,6 +412,7 @@ class ElasticEPScalingState:
...
@@ -401,6 +412,7 @@ class ElasticEPScalingState:
def
handle_notification
(
self
,
notification_type
:
EEPNotificationType
):
def
handle_notification
(
self
,
notification_type
:
EEPNotificationType
):
assert
self
.
worker_type
!=
"new"
assert
self
.
worker_type
!=
"new"
assert
self
.
old_dp_store
is
not
None
if
(
if
(
notification_type
==
EEPNotificationType
.
NEW_CORE_ENGINES_INIT_READY
notification_type
==
EEPNotificationType
.
NEW_CORE_ENGINES_INIT_READY
and
self
.
state
==
ScaleUpExistingEngineState
.
WAIT_NEW_CORE_ENGINES_INIT
and
self
.
state
==
ScaleUpExistingEngineState
.
WAIT_NEW_CORE_ENGINES_INIT
...
@@ -429,6 +441,7 @@ class ElasticEPScalingState:
...
@@ -429,6 +441,7 @@ class ElasticEPScalingState:
)
)
def
_create_standby_groups
(
self
):
def
_create_standby_groups
(
self
):
assert
self
.
old_dp_group
is
not
None
self
.
new_dp_group
,
self
.
new_dp_store
=
(
self
.
new_dp_group
,
self
.
new_dp_store
=
(
self
.
new_parallel_config
.
stateless_init_dp_group
(
return_store
=
True
)
self
.
new_parallel_config
.
stateless_init_dp_group
(
return_store
=
True
)
)
)
...
@@ -439,7 +452,7 @@ class ElasticEPScalingState:
...
@@ -439,7 +452,7 @@ class ElasticEPScalingState:
logger
.
info
(
"[Elastic EP] Created standby communication groups"
)
logger
.
info
(
"[Elastic EP] Created standby communication groups"
)
def
_transfer_weights
(
self
):
def
_transfer_weights
(
self
):
assert
self
.
reconfig_request
is
not
None
assert
self
.
reconfig_request
is
not
None
and
self
.
old_dp_group
is
not
None
old_dp_size
=
self
.
old_dp_group
.
size
()
old_dp_size
=
self
.
old_dp_group
.
size
()
new_dp_size
=
self
.
reconfig_request
.
new_data_parallel_size
new_dp_size
=
self
.
reconfig_request
.
new_data_parallel_size
...
@@ -450,6 +463,7 @@ class ElasticEPScalingState:
...
@@ -450,6 +463,7 @@ class ElasticEPScalingState:
logger
.
info
(
"[Elastic EP] Transferred weights to new workers"
)
logger
.
info
(
"[Elastic EP] Transferred weights to new workers"
)
def
_transfer_expert_mapping
(
self
):
def
_transfer_expert_mapping
(
self
):
assert
self
.
old_dp_group
is
not
None
self
.
model_executor
.
collective_rpc
(
self
.
model_executor
.
collective_rpc
(
"elastic_ep_execute"
,
args
=
(
"broadcast_expert_mapping"
,)
"elastic_ep_execute"
,
args
=
(
"broadcast_expert_mapping"
,)
)
)
...
@@ -458,7 +472,7 @@ class ElasticEPScalingState:
...
@@ -458,7 +472,7 @@ class ElasticEPScalingState:
def
_sync_kv_cache_memory_size
(
self
):
def
_sync_kv_cache_memory_size
(
self
):
assert
self
.
engine_core
.
available_gpu_memory_for_kv_cache
>
0
assert
self
.
engine_core
.
available_gpu_memory_for_kv_cache
>
0
assert
self
.
new_dp_group
is
not
None
assert
self
.
new_dp_group
is
not
None
and
self
.
old_dp_group
is
not
None
ParallelConfig
.
sync_kv_cache_memory_size
(
ParallelConfig
.
sync_kv_cache_memory_size
(
self
.
new_dp_group
,
self
.
new_dp_group
,
self
.
engine_core
.
available_gpu_memory_for_kv_cache
,
self
.
engine_core
.
available_gpu_memory_for_kv_cache
,
...
@@ -507,7 +521,7 @@ class ElasticEPScalingState:
...
@@ -507,7 +521,7 @@ class ElasticEPScalingState:
logger
.
info
(
"[Elastic EP] EPLB reshuffle completed"
)
logger
.
info
(
"[Elastic EP] EPLB reshuffle completed"
)
def
_eplb_reshuffle_before_scale_down
(
self
):
def
_eplb_reshuffle_before_scale_down
(
self
):
assert
self
.
reconfig_request
is
not
None
assert
self
.
reconfig_request
is
not
None
and
self
.
old_dp_group
is
not
None
self
.
model_executor
.
collective_rpc
(
self
.
model_executor
.
collective_rpc
(
"elastic_ep_execute"
,
"elastic_ep_execute"
,
args
=
(
args
=
(
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
17dc9c7f
...
@@ -336,6 +336,7 @@ class TpKVTopology:
...
@@ -336,6 +336,7 @@ class TpKVTopology:
self
.
_cross_layers_blocks
=
(
self
.
_cross_layers_blocks
=
(
len
(
self
.
tensor_shape
)
==
len
(
kv_cache_shape
)
+
1
len
(
self
.
tensor_shape
)
==
len
(
kv_cache_shape
)
+
1
)
)
self
.
tensor_shape
:
torch
.
Size
if
self
.
_cross_layers_blocks
:
if
self
.
_cross_layers_blocks
:
logger
.
debug
(
"Using cross-layer KV cache"
)
logger
.
debug
(
"Using cross-layer KV cache"
)
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
17dc9c7f
...
@@ -972,6 +972,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -972,6 +972,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Early-out for cascade attention
# Early-out for cascade attention
if
use_cascade
:
if
use_cascade
:
assert
num_blocks_np
is
not
None
# Grab the blocks of the shared prefix from the first request.
# Grab the blocks of the shared prefix from the first request.
num_common_kv_blocks
=
common_prefix_len
//
page_size
num_common_kv_blocks
=
common_prefix_len
//
page_size
...
@@ -1117,6 +1118,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -1117,6 +1118,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
)
)
else
:
else
:
assert
seq_lens_cpu
is
not
None
pure_decode
=
num_prefills
==
0
pure_decode
=
num_prefills
==
0
use_cudagraph
=
(
use_cudagraph
=
(
self
.
enable_cuda_graph
self
.
enable_cuda_graph
...
...
vllm/v1/attention/backends/gdn_attn.py
View file @
17dc9c7f
...
@@ -88,14 +88,14 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
...
@@ -88,14 +88,14 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
self
.
num_spec
:
int
=
self
.
speculative_config
.
num_speculative_tokens
self
.
num_spec
:
int
=
self
.
speculative_config
.
num_speculative_tokens
else
:
else
:
self
.
num_spec
=
0
self
.
num_spec
=
0
self
.
use_spec_decode
=
self
.
num_spec
>
0
self
.
use_spec_decode
:
bool
=
self
.
num_spec
>
0
self
.
_init_reorder_batch_threshold
(
1
,
self
.
use_spec_decode
)
self
.
_init_reorder_batch_threshold
(
1
,
self
.
use_spec_decode
)
self
.
use_full_cuda_graph
=
(
self
.
use_full_cuda_graph
:
bool
=
(
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
)
)
self
.
decode_cudagraph_max_bs
=
(
self
.
decode_cudagraph_max_bs
:
int
=
(
self
.
vllm_config
.
scheduler_config
.
max_num_seqs
*
(
self
.
num_spec
+
1
)
self
.
vllm_config
.
scheduler_config
.
max_num_seqs
*
(
self
.
num_spec
+
1
)
)
)
if
self
.
compilation_config
.
max_cudagraph_capture_size
is
not
None
:
if
self
.
compilation_config
.
max_cudagraph_capture_size
is
not
None
:
...
@@ -104,42 +104,42 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
...
@@ -104,42 +104,42 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
self
.
compilation_config
.
max_cudagraph_capture_size
,
self
.
compilation_config
.
max_cudagraph_capture_size
,
)
)
self
.
spec_state_indices_tensor
=
torch
.
empty
(
self
.
spec_state_indices_tensor
:
torch
.
Tensor
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
,
self
.
num_spec
+
1
),
(
self
.
decode_cudagraph_max_bs
,
self
.
num_spec
+
1
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
self
.
non_spec_state_indices_tensor
=
torch
.
empty
(
self
.
non_spec_state_indices_tensor
:
torch
.
Tensor
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
,),
(
self
.
decode_cudagraph_max_bs
,),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
self
.
spec_sequence_masks
=
torch
.
empty
(
self
.
spec_sequence_masks
:
torch
.
Tensor
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
,),
(
self
.
decode_cudagraph_max_bs
,),
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
device
,
device
=
device
,
)
)
self
.
spec_token_indx
=
torch
.
empty
(
self
.
spec_token_indx
:
torch
.
Tensor
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
*
(
self
.
num_spec
+
1
),),
(
self
.
decode_cudagraph_max_bs
*
(
self
.
num_spec
+
1
),),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
self
.
non_spec_token_indx
=
torch
.
empty
(
self
.
non_spec_token_indx
:
torch
.
Tensor
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
*
(
self
.
num_spec
+
1
),),
(
self
.
decode_cudagraph_max_bs
*
(
self
.
num_spec
+
1
),),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
self
.
spec_query_start_loc
=
torch
.
empty
(
self
.
spec_query_start_loc
:
torch
.
Tensor
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
+
1
,),
(
self
.
decode_cudagraph_max_bs
+
1
,),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
self
.
non_spec_query_start_loc
=
torch
.
empty
(
self
.
non_spec_query_start_loc
:
torch
.
Tensor
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
+
1
,),
(
self
.
decode_cudagraph_max_bs
+
1
,),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
self
.
num_accepted_tokens
=
torch
.
empty
(
self
.
num_accepted_tokens
:
torch
.
Tensor
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
,),
(
self
.
decode_cudagraph_max_bs
,),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
...
@@ -322,6 +322,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
...
@@ -322,6 +322,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
and
num_spec_decodes
<=
self
.
decode_cudagraph_max_bs
and
num_spec_decodes
<=
self
.
decode_cudagraph_max_bs
and
num_spec_decode_tokens
<=
self
.
decode_cudagraph_max_bs
and
num_spec_decode_tokens
<=
self
.
decode_cudagraph_max_bs
):
):
assert
spec_sequence_masks
is
not
None
self
.
spec_state_indices_tensor
[:
num_spec_decodes
].
copy_
(
self
.
spec_state_indices_tensor
[:
num_spec_decodes
].
copy_
(
spec_state_indices_tensor
,
non_blocking
=
True
spec_state_indices_tensor
,
non_blocking
=
True
)
)
...
...
vllm/v1/attention/backends/mamba_attn.py
View file @
17dc9c7f
...
@@ -98,8 +98,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
...
@@ -98,8 +98,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
self
.
use_spec_decode
=
self
.
num_spec_tokens
>
0
self
.
use_spec_decode
=
self
.
num_spec_tokens
>
0
assert
isinstance
(
kv_cache_spec
,
MambaSpec
)
assert
isinstance
(
kv_cache_spec
,
MambaSpec
)
s
elf
.
compilation
_config
=
vllm_config
.
compilation
_config
s
cheduler
_config
=
vllm_config
.
scheduler
_config
self
.
decode_cudagraph_max_bs
=
self
.
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
decode_cudagraph_max_bs
:
int
=
scheduler_config
.
max_num_seqs
if
self
.
compilation_config
.
max_cudagraph_capture_size
is
not
None
:
if
self
.
compilation_config
.
max_cudagraph_capture_size
is
not
None
:
self
.
decode_cudagraph_max_bs
=
min
(
self
.
decode_cudagraph_max_bs
=
min
(
self
.
decode_cudagraph_max_bs
,
self
.
decode_cudagraph_max_bs
,
...
@@ -114,7 +114,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
...
@@ -114,7 +114,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
# Speculative decoding not supported with prefix caching,
# Speculative decoding not supported with prefix caching,
# so keep shape consistent with prefill buffer
# so keep shape consistent with prefill buffer
# TODO: reduce this size as needed for decode-only cudagraph capture
# TODO: reduce this size as needed for decode-only cudagraph capture
self
.
state_indices_tensor_d
=
torch
.
empty
(
self
.
state_indices_tensor_d
:
torch
.
Tensor
=
torch
.
empty
(
(
(
self
.
decode_cudagraph_max_bs
,
self
.
decode_cudagraph_max_bs
,
max_num_blocks
,
max_num_blocks
,
...
@@ -122,12 +122,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
...
@@ -122,12 +122,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
self
.
block_idx_last_scheduled_token
=
torch
.
empty
(
self
.
block_idx_last_scheduled_token
:
torch
.
Tensor
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
,),
(
self
.
decode_cudagraph_max_bs
,),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
)
)
self
.
block_idx_last_computed_token
=
torch
.
empty
(
self
.
block_idx_last_computed_token
:
torch
.
Tensor
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
,),
(
self
.
decode_cudagraph_max_bs
,),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
...
@@ -142,7 +142,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
...
@@ -142,7 +142,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
# For speculative decoding, we need to store the following buffers
# For speculative decoding, we need to store the following buffers
# for CUDA graph capture during decode
# for CUDA graph capture during decode
if
self
.
num_spec_tokens
>
0
:
if
self
.
num_spec_tokens
>
0
:
self
.
decode_num_accepted_tokens
=
torch
.
empty
(
self
.
decode_num_accepted_tokens
:
torch
.
Tensor
=
torch
.
empty
(
(
self
.
decode_cudagraph_max_bs
,),
(
self
.
decode_cudagraph_max_bs
,),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
...
...
vllm/v1/engine/core.py
View file @
17dc9c7f
...
@@ -1539,18 +1539,18 @@ class DPEngineCoreProc(EngineCoreProc):
...
@@ -1539,18 +1539,18 @@ class DPEngineCoreProc(EngineCoreProc):
def
_init_data_parallel
(
self
,
vllm_config
:
VllmConfig
):
def
_init_data_parallel
(
self
,
vllm_config
:
VllmConfig
):
# Configure GPUs and stateless process group for data parallel.
# Configure GPUs and stateless process group for data parallel.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
parallel_config
=
vllm_config
.
parallel_config
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
dp_rank
=
parallel_config
.
data_parallel_rank
local_dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank_local
dp_size
=
parallel_config
.
data_parallel_size
local_dp_rank
=
parallel_config
.
data_parallel_rank_local
assert
dp_size
>
1
assert
dp_size
>
1
assert
local_dp_rank
is
not
None
assert
local_dp_rank
is
not
None
assert
0
<=
local_dp_rank
<=
dp_rank
<
dp_size
assert
0
<=
local_dp_rank
<=
dp_rank
<
dp_size
self
.
dp_rank
=
dp_rank
self
.
dp_rank
=
dp_rank
self
.
dp_group
,
self
.
dp_store
=
(
dp_group
,
dp_store
=
parallel_config
.
stateless_init_dp_group
(
return_store
=
True
)
vllm_config
.
parallel_config
.
stateless_init_dp_group
(
return_store
=
True
)
self
.
dp_group
,
self
.
dp_store
=
dp_group
,
dp_store
)
def
shutdown
(
self
):
def
shutdown
(
self
):
super
().
shutdown
()
super
().
shutdown
()
...
...
vllm/v1/sample/logits_processor/__init__.py
View file @
17dc9c7f
...
@@ -309,12 +309,16 @@ class AdapterLogitsProcessor(LogitsProcessor):
...
@@ -309,12 +309,16 @@ class AdapterLogitsProcessor(LogitsProcessor):
"""
"""
if
req_lp
:
=
self
.
new_req_logits_processor
(
params
):
if
req_lp
:
=
self
.
new_req_logits_processor
(
params
):
args
=
(
if
len
(
inspect
.
signature
(
req_lp
).
parameters
)
==
3
:
[
prompt_ids
,
output_ids
]
if
prompt_ids
is
None
:
if
(
len
(
inspect
.
signature
(
req_lp
).
parameters
)
==
3
)
raise
ValueError
(
else
[
output_ids
]
"Prompt token ids are required for this "
"logits processor but were not provided."
)
)
return
partial
(
req_lp
,
*
args
)
# type: ignore[misc]
args
=
[
prompt_ids
,
output_ids
]
else
:
args
=
[
output_ids
]
return
partial
(
req_lp
,
*
args
)
return
None
return
None
def
update_state
(
self
,
batch_update
:
BatchUpdate
|
None
):
def
update_state
(
self
,
batch_update
:
BatchUpdate
|
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment