Unverified Commit 17dc9c7f authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

[CI] Bump `mypy` version (#34950)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 7eca8591
...@@ -55,7 +55,7 @@ repos: ...@@ -55,7 +55,7 @@ repos:
language: python language: python
types_or: [python, pyi] types_or: [python, pyi]
require_serial: true require_serial: true
additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] additional_dependencies: ["mypy[faster-cache]==1.15.0", regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.10 name: Run mypy for Python 3.10
entry: python tools/pre_commit/mypy.py 1 "3.10" entry: python tools/pre_commit/mypy.py 1 "3.10"
......
...@@ -94,12 +94,9 @@ def test_rotary_embedding( ...@@ -94,12 +94,9 @@ def test_rotary_embedding(
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query) if use_key else None
# slice tensor if required, noop otherwise # slice tensor if required, noop otherwise
query = query[..., :head_size] query = torch.randn(query_shape, dtype=dtype)[..., :head_size]
key = key[..., :head_size] if use_key else None key = torch.randn_like(query)[..., :head_size] if use_key else None
# NOTE(woosuk): The reference implementation should be executed first # NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place. # because the custom kernel is in-place.
......
...@@ -62,7 +62,7 @@ def test_rotary_embedding_opcheck( ...@@ -62,7 +62,7 @@ def test_rotary_embedding_opcheck(
) )
key = torch.randn_like(query) if use_key else None key = torch.randn_like(query) if use_key else None
query = query[..., :head_size] query = query[..., :head_size]
key = key[..., :head_size] if use_key else None key = key[..., :head_size] if key is not None else None
rotary_embedding_opcheck(rot, positions, query, key) rotary_embedding_opcheck(rot, positions, query, key)
...@@ -73,5 +73,5 @@ def test_rotary_embedding_opcheck( ...@@ -73,5 +73,5 @@ def test_rotary_embedding_opcheck(
rot, rot,
positions, positions,
query.flatten(start_dim=-2), query.flatten(start_dim=-2),
key.flatten(start_dim=-2) if use_key else None, key.flatten(start_dim=-2) if key is not None else None,
) )
...@@ -298,13 +298,13 @@ def test_selective_scan( ...@@ -298,13 +298,13 @@ def test_selective_scan(
C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype) C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
C_ref = C.clone() C_ref = C.clone()
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
D_ref = D.clone() D_ref = D.clone() if D is not None else None
z = ( z = (
torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
if has_z if has_z
else None else None
) )
z_ref = z.clone() if has_z else None z_ref = z.clone() if z is not None else None
delta_bias = ( delta_bias = (
(0.5 * torch.rand(dim, device=device, dtype=torch.float32)) (0.5 * torch.rand(dim, device=device, dtype=torch.float32))
if has_delta_bias if has_delta_bias
...@@ -493,7 +493,7 @@ def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len): ...@@ -493,7 +493,7 @@ def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len):
B[idx : idx + 1], B[idx : idx + 1],
C[idx : idx + 1], C[idx : idx + 1],
D=D, D=D,
z=z[idx : idx + 1] if has_z else None, z=z[idx : idx + 1] if z is not None else None,
dt_bias=dt_bias, dt_bias=dt_bias,
dt_softplus=True, dt_softplus=True,
) )
...@@ -578,7 +578,7 @@ def test_selective_scan_varlen( ...@@ -578,7 +578,7 @@ def test_selective_scan_varlen(
C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype) C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
C_ref = C.clone() C_ref = C.clone()
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
D_ref = D.clone() D_ref = D.clone() if D is not None else None
z = torch.randn(dim, seqlen, device=device, dtype=itype) z = torch.randn(dim, seqlen, device=device, dtype=itype)
z_ref = z.clone() z_ref = z.clone()
delta_bias = ( delta_bias = (
...@@ -750,7 +750,7 @@ def test_selective_state_update_with_batch_indices( ...@@ -750,7 +750,7 @@ def test_selective_state_update_with_batch_indices(
B[:batch_size], B[:batch_size],
C[:batch_size], C[:batch_size],
D=D, D=D,
z=z[:batch_size], z=z[:batch_size] if z is not None else None,
dt_bias=dt_bias, dt_bias=dt_bias,
dt_softplus=True, dt_softplus=True,
) )
...@@ -934,7 +934,7 @@ def test_selective_state_update_with_num_accepted_tokens( ...@@ -934,7 +934,7 @@ def test_selective_state_update_with_num_accepted_tokens(
B[global_idx : global_idx + 1], B[global_idx : global_idx + 1],
C[global_idx : global_idx + 1], C[global_idx : global_idx + 1],
D=D, D=D,
z=z[global_idx : global_idx + 1] if has_z else None, z=z[global_idx : global_idx + 1] if z is not None else None,
dt_bias=dt_bias, dt_bias=dt_bias,
dt_softplus=True, dt_softplus=True,
) )
...@@ -1061,7 +1061,7 @@ def test_selective_state_update_varlen_with_num_accepted( ...@@ -1061,7 +1061,7 @@ def test_selective_state_update_varlen_with_num_accepted(
B[global_idx : global_idx + 1], B[global_idx : global_idx + 1],
C[global_idx : global_idx + 1], C[global_idx : global_idx + 1],
D=D, D=D,
z=z[global_idx : global_idx + 1] if has_z else None, z=z[global_idx : global_idx + 1] if z is not None else None,
dt_bias=dt_bias, dt_bias=dt_bias,
dt_softplus=True, dt_softplus=True,
) )
......
...@@ -57,11 +57,11 @@ def opcheck_fp8_quant( ...@@ -57,11 +57,11 @@ def opcheck_fp8_quant(
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("scale_ub", SCALE_UBS) @pytest.mark.parametrize("do_scale_ub", SCALE_UBS)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def test_dynamic_per_token_fp8_quant( def test_dynamic_per_token_fp8_quant(
num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int num_tokens: int, hidden_size: int, dtype: torch.dtype, do_scale_ub: bool, seed: int
) -> None: ) -> None:
set_random_seed(seed) set_random_seed(seed)
...@@ -70,7 +70,7 @@ def test_dynamic_per_token_fp8_quant( ...@@ -70,7 +70,7 @@ def test_dynamic_per_token_fp8_quant(
) # avoid nans ) # avoid nans
scale_ub = ( scale_ub = (
torch.mean(x).to(dtype=torch.float32, device="cuda") if scale_ub else None torch.mean(x).to(dtype=torch.float32, device="cuda") if do_scale_ub else None
) )
ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub) ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
ops_out, ops_scales = ops.scaled_fp8_quant( ops_out, ops_scales = ops.scaled_fp8_quant(
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
import os import os
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal, overload
import torch import torch
from pydantic import Field, field_validator, model_validator from pydantic import Field, field_validator, model_validator
from torch.distributed import ProcessGroup, ReduceOp from torch.distributed import ProcessGroup, ReduceOp, Store
from typing_extensions import Self from typing_extensions import Self
import vllm.envs as envs import vllm.envs as envs
...@@ -507,7 +507,17 @@ class ParallelConfig: ...@@ -507,7 +507,17 @@ class ParallelConfig:
def get_next_stateless_eplb_group_port(self) -> list[int]: def get_next_stateless_eplb_group_port(self) -> list[int]:
return self._stateless_eplb_group_port_list.pop() return self._stateless_eplb_group_port_list.pop()
def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup: @overload
def stateless_init_dp_group(
self, return_store: Literal[False] = ...
) -> ProcessGroup: ...
@overload
def stateless_init_dp_group(
self, return_store: Literal[True] = ...
) -> tuple[ProcessGroup, Store]: ...
def stateless_init_dp_group(
self, return_store: bool = False
) -> ProcessGroup | tuple[ProcessGroup, Store]:
# NOTE: In high-concurrency scenarios multiple processes # NOTE: In high-concurrency scenarios multiple processes
# can pick the same (currently free) port through a race # can pick the same (currently free) port through a race
# condition when calling `get_open_port()`. When the first # condition when calling `get_open_port()`. When the first
......
...@@ -4,7 +4,7 @@ import enum ...@@ -4,7 +4,7 @@ import enum
import time import time
import weakref import weakref
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal, TypeAlias
import torch.distributed import torch.distributed
...@@ -61,6 +61,14 @@ class ScaleDownRemovingEngineState(enum.IntEnum): ...@@ -61,6 +61,14 @@ class ScaleDownRemovingEngineState(enum.IntEnum):
COMPLETE = 2 COMPLETE = 2
EngineState: TypeAlias = (
ScaleUpExistingEngineState
| ScaleUpNewEngineState
| ScaleDownRemainingEngineState
| ScaleDownRemovingEngineState
)
class _BarrierTimeoutError(RuntimeError): class _BarrierTimeoutError(RuntimeError):
""" """
Exception raised for timeout Exception raised for timeout
...@@ -87,14 +95,13 @@ class ElasticEPScalingState: ...@@ -87,14 +95,13 @@ class ElasticEPScalingState:
self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None
self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None
self.new_parallel_config: ParallelConfig = new_parallel_config self.new_parallel_config: ParallelConfig = new_parallel_config
self.new_dp_group: torch.distributed.ProcessGroup | None = ( self.new_dp_group = self.engine_core.dp_group if worker_type == "new" else None
self.engine_core.dp_group if worker_type == "new" else None
)
self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None
self.worker_type = worker_type self.worker_type = worker_type
self.scale_type = scale_type self.scale_type = scale_type
self.reconfig_request = reconfig_request self.reconfig_request = reconfig_request
self.state: EngineState
if scale_type == "scale_up": if scale_type == "scale_up":
self.state = ( self.state = (
ScaleUpNewEngineState.PREPARE ScaleUpNewEngineState.PREPARE
...@@ -182,9 +189,9 @@ class ElasticEPScalingState: ...@@ -182,9 +189,9 @@ class ElasticEPScalingState:
engine step, and will synchronize with the other EngineCores in the engine step, and will synchronize with the other EngineCores in the
next step with a barrier without timeout. next step with a barrier without timeout.
""" """
dp_store = self.new_dp_store if use_new_group else self.old_dp_store
dp_group = self.new_dp_group if use_new_group else self.old_dp_group dp_group = self.new_dp_group if use_new_group else self.old_dp_group
assert dp_group is not None dp_store = self.new_dp_store if use_new_group else self.old_dp_store
assert dp_group is not None and dp_store is not None
group_rank = dp_group.rank() group_rank = dp_group.rank()
group_size = dp_group.size() group_size = dp_group.size()
...@@ -212,6 +219,7 @@ class ElasticEPScalingState: ...@@ -212,6 +219,7 @@ class ElasticEPScalingState:
def _progress_existing_engine(self) -> bool: def _progress_existing_engine(self) -> bool:
state = self.state state = self.state
assert self.old_dp_group is not None and self.old_dp_store is not None
if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT: if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT:
return False return False
...@@ -265,11 +273,12 @@ class ElasticEPScalingState: ...@@ -265,11 +273,12 @@ class ElasticEPScalingState:
elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE: elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE:
self._switch_and_prepare() self._switch_and_prepare()
self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE
assert self.new_dp_store is not None
self.new_dp_store.add("eep_barrier_engine_count", 1) self.new_dp_store.add("eep_barrier_engine_count", 1)
return True return True
elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE: elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE:
assert self.new_dp_group is not None assert self.new_dp_group is not None and self.new_dp_store is not None
if ( if (
int(self.new_dp_store.get("eep_barrier_engine_count")) int(self.new_dp_store.get("eep_barrier_engine_count"))
< self.new_dp_group.size() < self.new_dp_group.size()
...@@ -292,7 +301,7 @@ class ElasticEPScalingState: ...@@ -292,7 +301,7 @@ class ElasticEPScalingState:
def _progress_new_engine(self) -> bool: def _progress_new_engine(self) -> bool:
state = self.state state = self.state
assert self.new_dp_group is not None assert self.new_dp_group is not None and self.new_dp_store is not None
if state == ScaleUpNewEngineState.PREPARE: if state == ScaleUpNewEngineState.PREPARE:
tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu") tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
...@@ -330,6 +339,7 @@ class ElasticEPScalingState: ...@@ -330,6 +339,7 @@ class ElasticEPScalingState:
def _progress_remaining_engine(self) -> bool: def _progress_remaining_engine(self) -> bool:
state = self.state state = self.state
assert self.old_dp_group is not None and self.old_dp_store is not None
if state == ScaleDownRemainingEngineState.PREPARE: if state == ScaleDownRemainingEngineState.PREPARE:
self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE
...@@ -369,6 +379,7 @@ class ElasticEPScalingState: ...@@ -369,6 +379,7 @@ class ElasticEPScalingState:
def _progress_removing_engine(self) -> bool: def _progress_removing_engine(self) -> bool:
state = self.state state = self.state
assert self.old_dp_group is not None and self.old_dp_store is not None
if state == ScaleDownRemovingEngineState.PREPARE: if state == ScaleDownRemovingEngineState.PREPARE:
self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
...@@ -401,6 +412,7 @@ class ElasticEPScalingState: ...@@ -401,6 +412,7 @@ class ElasticEPScalingState:
def handle_notification(self, notification_type: EEPNotificationType): def handle_notification(self, notification_type: EEPNotificationType):
assert self.worker_type != "new" assert self.worker_type != "new"
assert self.old_dp_store is not None
if ( if (
notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY
and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
...@@ -429,6 +441,7 @@ class ElasticEPScalingState: ...@@ -429,6 +441,7 @@ class ElasticEPScalingState:
) )
def _create_standby_groups(self): def _create_standby_groups(self):
assert self.old_dp_group is not None
self.new_dp_group, self.new_dp_store = ( self.new_dp_group, self.new_dp_store = (
self.new_parallel_config.stateless_init_dp_group(return_store=True) self.new_parallel_config.stateless_init_dp_group(return_store=True)
) )
...@@ -439,7 +452,7 @@ class ElasticEPScalingState: ...@@ -439,7 +452,7 @@ class ElasticEPScalingState:
logger.info("[Elastic EP] Created standby communication groups") logger.info("[Elastic EP] Created standby communication groups")
def _transfer_weights(self): def _transfer_weights(self):
assert self.reconfig_request is not None assert self.reconfig_request is not None and self.old_dp_group is not None
old_dp_size = self.old_dp_group.size() old_dp_size = self.old_dp_group.size()
new_dp_size = self.reconfig_request.new_data_parallel_size new_dp_size = self.reconfig_request.new_data_parallel_size
...@@ -450,6 +463,7 @@ class ElasticEPScalingState: ...@@ -450,6 +463,7 @@ class ElasticEPScalingState:
logger.info("[Elastic EP] Transferred weights to new workers") logger.info("[Elastic EP] Transferred weights to new workers")
def _transfer_expert_mapping(self): def _transfer_expert_mapping(self):
assert self.old_dp_group is not None
self.model_executor.collective_rpc( self.model_executor.collective_rpc(
"elastic_ep_execute", args=("broadcast_expert_mapping",) "elastic_ep_execute", args=("broadcast_expert_mapping",)
) )
...@@ -458,7 +472,7 @@ class ElasticEPScalingState: ...@@ -458,7 +472,7 @@ class ElasticEPScalingState:
def _sync_kv_cache_memory_size(self): def _sync_kv_cache_memory_size(self):
assert self.engine_core.available_gpu_memory_for_kv_cache > 0 assert self.engine_core.available_gpu_memory_for_kv_cache > 0
assert self.new_dp_group is not None assert self.new_dp_group is not None and self.old_dp_group is not None
ParallelConfig.sync_kv_cache_memory_size( ParallelConfig.sync_kv_cache_memory_size(
self.new_dp_group, self.new_dp_group,
self.engine_core.available_gpu_memory_for_kv_cache, self.engine_core.available_gpu_memory_for_kv_cache,
...@@ -507,7 +521,7 @@ class ElasticEPScalingState: ...@@ -507,7 +521,7 @@ class ElasticEPScalingState:
logger.info("[Elastic EP] EPLB reshuffle completed") logger.info("[Elastic EP] EPLB reshuffle completed")
def _eplb_reshuffle_before_scale_down(self): def _eplb_reshuffle_before_scale_down(self):
assert self.reconfig_request is not None assert self.reconfig_request is not None and self.old_dp_group is not None
self.model_executor.collective_rpc( self.model_executor.collective_rpc(
"elastic_ep_execute", "elastic_ep_execute",
args=( args=(
......
...@@ -336,6 +336,7 @@ class TpKVTopology: ...@@ -336,6 +336,7 @@ class TpKVTopology:
self._cross_layers_blocks = ( self._cross_layers_blocks = (
len(self.tensor_shape) == len(kv_cache_shape) + 1 len(self.tensor_shape) == len(kv_cache_shape) + 1
) )
self.tensor_shape: torch.Size
if self._cross_layers_blocks: if self._cross_layers_blocks:
logger.debug("Using cross-layer KV cache") logger.debug("Using cross-layer KV cache")
......
...@@ -972,6 +972,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -972,6 +972,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Early-out for cascade attention # Early-out for cascade attention
if use_cascade: if use_cascade:
assert num_blocks_np is not None
# Grab the blocks of the shared prefix from the first request. # Grab the blocks of the shared prefix from the first request.
num_common_kv_blocks = common_prefix_len // page_size num_common_kv_blocks = common_prefix_len // page_size
...@@ -1117,6 +1118,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -1117,6 +1118,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
) )
else: else:
assert seq_lens_cpu is not None
pure_decode = num_prefills == 0 pure_decode = num_prefills == 0
use_cudagraph = ( use_cudagraph = (
self.enable_cuda_graph self.enable_cuda_graph
......
...@@ -88,14 +88,14 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -88,14 +88,14 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
self.num_spec: int = self.speculative_config.num_speculative_tokens self.num_spec: int = self.speculative_config.num_speculative_tokens
else: else:
self.num_spec = 0 self.num_spec = 0
self.use_spec_decode = self.num_spec > 0 self.use_spec_decode: bool = self.num_spec > 0
self._init_reorder_batch_threshold(1, self.use_spec_decode) self._init_reorder_batch_threshold(1, self.use_spec_decode)
self.use_full_cuda_graph = ( self.use_full_cuda_graph: bool = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs() self.compilation_config.cudagraph_mode.has_full_cudagraphs()
) )
self.decode_cudagraph_max_bs = ( self.decode_cudagraph_max_bs: int = (
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1) self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1)
) )
if self.compilation_config.max_cudagraph_capture_size is not None: if self.compilation_config.max_cudagraph_capture_size is not None:
...@@ -104,42 +104,42 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -104,42 +104,42 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
self.compilation_config.max_cudagraph_capture_size, self.compilation_config.max_cudagraph_capture_size,
) )
self.spec_state_indices_tensor = torch.empty( self.spec_state_indices_tensor: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1), (self.decode_cudagraph_max_bs, self.num_spec + 1),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
self.non_spec_state_indices_tensor = torch.empty( self.non_spec_state_indices_tensor: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs,), (self.decode_cudagraph_max_bs,),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
self.spec_sequence_masks = torch.empty( self.spec_sequence_masks: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs,), (self.decode_cudagraph_max_bs,),
dtype=torch.bool, dtype=torch.bool,
device=device, device=device,
) )
self.spec_token_indx = torch.empty( self.spec_token_indx: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),), (self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
self.non_spec_token_indx = torch.empty( self.non_spec_token_indx: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),), (self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
self.spec_query_start_loc = torch.empty( self.spec_query_start_loc: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs + 1,), (self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
self.non_spec_query_start_loc = torch.empty( self.non_spec_query_start_loc: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs + 1,), (self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
self.num_accepted_tokens = torch.empty( self.num_accepted_tokens: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs,), (self.decode_cudagraph_max_bs,),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
...@@ -322,6 +322,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -322,6 +322,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
and num_spec_decodes <= self.decode_cudagraph_max_bs and num_spec_decodes <= self.decode_cudagraph_max_bs
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
): ):
assert spec_sequence_masks is not None
self.spec_state_indices_tensor[:num_spec_decodes].copy_( self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True spec_state_indices_tensor, non_blocking=True
) )
......
...@@ -98,8 +98,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -98,8 +98,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
self.use_spec_decode = self.num_spec_tokens > 0 self.use_spec_decode = self.num_spec_tokens > 0
assert isinstance(kv_cache_spec, MambaSpec) assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config scheduler_config = vllm_config.scheduler_config
self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs self.decode_cudagraph_max_bs: int = scheduler_config.max_num_seqs
if self.compilation_config.max_cudagraph_capture_size is not None: if self.compilation_config.max_cudagraph_capture_size is not None:
self.decode_cudagraph_max_bs = min( self.decode_cudagraph_max_bs = min(
self.decode_cudagraph_max_bs, self.decode_cudagraph_max_bs,
...@@ -114,7 +114,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -114,7 +114,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
# Speculative decoding not supported with prefix caching, # Speculative decoding not supported with prefix caching,
# so keep shape consistent with prefill buffer # so keep shape consistent with prefill buffer
# TODO: reduce this size as needed for decode-only cudagraph capture # TODO: reduce this size as needed for decode-only cudagraph capture
self.state_indices_tensor_d = torch.empty( self.state_indices_tensor_d: torch.Tensor = torch.empty(
( (
self.decode_cudagraph_max_bs, self.decode_cudagraph_max_bs,
max_num_blocks, max_num_blocks,
...@@ -122,12 +122,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -122,12 +122,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
self.block_idx_last_scheduled_token = torch.empty( self.block_idx_last_scheduled_token: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs,), (self.decode_cudagraph_max_bs,),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
self.block_idx_last_computed_token = torch.empty( self.block_idx_last_computed_token: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs,), (self.decode_cudagraph_max_bs,),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
...@@ -142,7 +142,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -142,7 +142,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
# For speculative decoding, we need to store the following buffers # For speculative decoding, we need to store the following buffers
# for CUDA graph capture during decode # for CUDA graph capture during decode
if self.num_spec_tokens > 0: if self.num_spec_tokens > 0:
self.decode_num_accepted_tokens = torch.empty( self.decode_num_accepted_tokens: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs,), (self.decode_cudagraph_max_bs,),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
......
...@@ -1539,18 +1539,18 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1539,18 +1539,18 @@ class DPEngineCoreProc(EngineCoreProc):
def _init_data_parallel(self, vllm_config: VllmConfig): def _init_data_parallel(self, vllm_config: VllmConfig):
# Configure GPUs and stateless process group for data parallel. # Configure GPUs and stateless process group for data parallel.
dp_rank = vllm_config.parallel_config.data_parallel_rank parallel_config = vllm_config.parallel_config
dp_size = vllm_config.parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local dp_size = parallel_config.data_parallel_size
local_dp_rank = parallel_config.data_parallel_rank_local
assert dp_size > 1 assert dp_size > 1
assert local_dp_rank is not None assert local_dp_rank is not None
assert 0 <= local_dp_rank <= dp_rank < dp_size assert 0 <= local_dp_rank <= dp_rank < dp_size
self.dp_rank = dp_rank self.dp_rank = dp_rank
self.dp_group, self.dp_store = ( dp_group, dp_store = parallel_config.stateless_init_dp_group(return_store=True)
vllm_config.parallel_config.stateless_init_dp_group(return_store=True) self.dp_group, self.dp_store = dp_group, dp_store
)
def shutdown(self): def shutdown(self):
super().shutdown() super().shutdown()
......
...@@ -309,12 +309,16 @@ class AdapterLogitsProcessor(LogitsProcessor): ...@@ -309,12 +309,16 @@ class AdapterLogitsProcessor(LogitsProcessor):
""" """
if req_lp := self.new_req_logits_processor(params): if req_lp := self.new_req_logits_processor(params):
args = ( if len(inspect.signature(req_lp).parameters) == 3:
[prompt_ids, output_ids] if prompt_ids is None:
if (len(inspect.signature(req_lp).parameters) == 3) raise ValueError(
else [output_ids] "Prompt token ids are required for this "
) "logits processor but were not provided."
return partial(req_lp, *args) # type: ignore[misc] )
args = [prompt_ids, output_ids]
else:
args = [output_ids]
return partial(req_lp, *args)
return None return None
def update_state(self, batch_update: BatchUpdate | None): def update_state(self, batch_update: BatchUpdate | None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment