Unverified Commit 17dc9c7f authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

[CI] Bump `mypy` version (#34950)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 7eca8591
......@@ -55,7 +55,7 @@ repos:
language: python
types_or: [python, pyi]
require_serial: true
additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
additional_dependencies: ["mypy[faster-cache]==1.15.0", regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.10
entry: python tools/pre_commit/mypy.py 1 "3.10"
......
......@@ -94,12 +94,9 @@ def test_rotary_embedding(
positions = torch.randint(0, max_position, (batch_size, seq_len))
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query) if use_key else None
# slice tensor if required, noop otherwise
query = query[..., :head_size]
key = key[..., :head_size] if use_key else None
query = torch.randn(query_shape, dtype=dtype)[..., :head_size]
key = torch.randn_like(query)[..., :head_size] if use_key else None
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
......
......@@ -62,7 +62,7 @@ def test_rotary_embedding_opcheck(
)
key = torch.randn_like(query) if use_key else None
query = query[..., :head_size]
key = key[..., :head_size] if use_key else None
key = key[..., :head_size] if key is not None else None
rotary_embedding_opcheck(rot, positions, query, key)
......@@ -73,5 +73,5 @@ def test_rotary_embedding_opcheck(
rot,
positions,
query.flatten(start_dim=-2),
key.flatten(start_dim=-2) if use_key else None,
key.flatten(start_dim=-2) if key is not None else None,
)
......@@ -298,13 +298,13 @@ def test_selective_scan(
C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
C_ref = C.clone()
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
D_ref = D.clone()
D_ref = D.clone() if D is not None else None
z = (
torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
if has_z
else None
)
z_ref = z.clone() if has_z else None
z_ref = z.clone() if z is not None else None
delta_bias = (
(0.5 * torch.rand(dim, device=device, dtype=torch.float32))
if has_delta_bias
......@@ -493,7 +493,7 @@ def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len):
B[idx : idx + 1],
C[idx : idx + 1],
D=D,
z=z[idx : idx + 1] if has_z else None,
z=z[idx : idx + 1] if z is not None else None,
dt_bias=dt_bias,
dt_softplus=True,
)
......@@ -578,7 +578,7 @@ def test_selective_scan_varlen(
C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
C_ref = C.clone()
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
D_ref = D.clone()
D_ref = D.clone() if D is not None else None
z = torch.randn(dim, seqlen, device=device, dtype=itype)
z_ref = z.clone()
delta_bias = (
......@@ -750,7 +750,7 @@ def test_selective_state_update_with_batch_indices(
B[:batch_size],
C[:batch_size],
D=D,
z=z[:batch_size],
z=z[:batch_size] if z is not None else None,
dt_bias=dt_bias,
dt_softplus=True,
)
......@@ -934,7 +934,7 @@ def test_selective_state_update_with_num_accepted_tokens(
B[global_idx : global_idx + 1],
C[global_idx : global_idx + 1],
D=D,
z=z[global_idx : global_idx + 1] if has_z else None,
z=z[global_idx : global_idx + 1] if z is not None else None,
dt_bias=dt_bias,
dt_softplus=True,
)
......@@ -1061,7 +1061,7 @@ def test_selective_state_update_varlen_with_num_accepted(
B[global_idx : global_idx + 1],
C[global_idx : global_idx + 1],
D=D,
z=z[global_idx : global_idx + 1] if has_z else None,
z=z[global_idx : global_idx + 1] if z is not None else None,
dt_bias=dt_bias,
dt_softplus=True,
)
......
......@@ -57,11 +57,11 @@ def opcheck_fp8_quant(
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
@pytest.mark.parametrize("do_scale_ub", SCALE_UBS)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_token_fp8_quant(
num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int
num_tokens: int, hidden_size: int, dtype: torch.dtype, do_scale_ub: bool, seed: int
) -> None:
set_random_seed(seed)
......@@ -70,7 +70,7 @@ def test_dynamic_per_token_fp8_quant(
) # avoid nans
scale_ub = (
torch.mean(x).to(dtype=torch.float32, device="cuda") if scale_ub else None
torch.mean(x).to(dtype=torch.float32, device="cuda") if do_scale_ub else None
)
ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
ops_out, ops_scales = ops.scaled_fp8_quant(
......
......@@ -3,11 +3,11 @@
import os
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, overload
import torch
from pydantic import Field, field_validator, model_validator
from torch.distributed import ProcessGroup, ReduceOp
from torch.distributed import ProcessGroup, ReduceOp, Store
from typing_extensions import Self
import vllm.envs as envs
......@@ -507,7 +507,17 @@ class ParallelConfig:
def get_next_stateless_eplb_group_port(self) -> list[int]:
return self._stateless_eplb_group_port_list.pop()
def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
@overload
def stateless_init_dp_group(
self, return_store: Literal[False] = ...
) -> ProcessGroup: ...
@overload
def stateless_init_dp_group(
self, return_store: Literal[True] = ...
) -> tuple[ProcessGroup, Store]: ...
def stateless_init_dp_group(
self, return_store: bool = False
) -> ProcessGroup | tuple[ProcessGroup, Store]:
# NOTE: In high-concurrency scenarios multiple processes
# can pick the same (currently free) port through a race
# condition when calling `get_open_port()`. When the first
......
......@@ -4,7 +4,7 @@ import enum
import time
import weakref
from datetime import timedelta
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, TypeAlias
import torch.distributed
......@@ -61,6 +61,14 @@ class ScaleDownRemovingEngineState(enum.IntEnum):
COMPLETE = 2
EngineState: TypeAlias = (
ScaleUpExistingEngineState
| ScaleUpNewEngineState
| ScaleDownRemainingEngineState
| ScaleDownRemovingEngineState
)
class _BarrierTimeoutError(RuntimeError):
"""
Exception raised for timeout
......@@ -87,14 +95,13 @@ class ElasticEPScalingState:
self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None
self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None
self.new_parallel_config: ParallelConfig = new_parallel_config
self.new_dp_group: torch.distributed.ProcessGroup | None = (
self.engine_core.dp_group if worker_type == "new" else None
)
self.new_dp_group = self.engine_core.dp_group if worker_type == "new" else None
self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None
self.worker_type = worker_type
self.scale_type = scale_type
self.reconfig_request = reconfig_request
self.state: EngineState
if scale_type == "scale_up":
self.state = (
ScaleUpNewEngineState.PREPARE
......@@ -182,9 +189,9 @@ class ElasticEPScalingState:
engine step, and will synchronize with the other EngineCores in the
next step with a barrier without timeout.
"""
dp_store = self.new_dp_store if use_new_group else self.old_dp_store
dp_group = self.new_dp_group if use_new_group else self.old_dp_group
assert dp_group is not None
dp_store = self.new_dp_store if use_new_group else self.old_dp_store
assert dp_group is not None and dp_store is not None
group_rank = dp_group.rank()
group_size = dp_group.size()
......@@ -212,6 +219,7 @@ class ElasticEPScalingState:
def _progress_existing_engine(self) -> bool:
state = self.state
assert self.old_dp_group is not None and self.old_dp_store is not None
if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT:
return False
......@@ -265,11 +273,12 @@ class ElasticEPScalingState:
elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE:
self._switch_and_prepare()
self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE
assert self.new_dp_store is not None
self.new_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE:
assert self.new_dp_group is not None
assert self.new_dp_group is not None and self.new_dp_store is not None
if (
int(self.new_dp_store.get("eep_barrier_engine_count"))
< self.new_dp_group.size()
......@@ -292,7 +301,7 @@ class ElasticEPScalingState:
def _progress_new_engine(self) -> bool:
state = self.state
assert self.new_dp_group is not None
assert self.new_dp_group is not None and self.new_dp_store is not None
if state == ScaleUpNewEngineState.PREPARE:
tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
......@@ -330,6 +339,7 @@ class ElasticEPScalingState:
def _progress_remaining_engine(self) -> bool:
state = self.state
assert self.old_dp_group is not None and self.old_dp_store is not None
if state == ScaleDownRemainingEngineState.PREPARE:
self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE
......@@ -369,6 +379,7 @@ class ElasticEPScalingState:
def _progress_removing_engine(self) -> bool:
state = self.state
assert self.old_dp_group is not None and self.old_dp_store is not None
if state == ScaleDownRemovingEngineState.PREPARE:
self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
......@@ -401,6 +412,7 @@ class ElasticEPScalingState:
def handle_notification(self, notification_type: EEPNotificationType):
assert self.worker_type != "new"
assert self.old_dp_store is not None
if (
notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY
and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
......@@ -429,6 +441,7 @@ class ElasticEPScalingState:
)
def _create_standby_groups(self):
assert self.old_dp_group is not None
self.new_dp_group, self.new_dp_store = (
self.new_parallel_config.stateless_init_dp_group(return_store=True)
)
......@@ -439,7 +452,7 @@ class ElasticEPScalingState:
logger.info("[Elastic EP] Created standby communication groups")
def _transfer_weights(self):
assert self.reconfig_request is not None
assert self.reconfig_request is not None and self.old_dp_group is not None
old_dp_size = self.old_dp_group.size()
new_dp_size = self.reconfig_request.new_data_parallel_size
......@@ -450,6 +463,7 @@ class ElasticEPScalingState:
logger.info("[Elastic EP] Transferred weights to new workers")
def _transfer_expert_mapping(self):
assert self.old_dp_group is not None
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("broadcast_expert_mapping",)
)
......@@ -458,7 +472,7 @@ class ElasticEPScalingState:
def _sync_kv_cache_memory_size(self):
assert self.engine_core.available_gpu_memory_for_kv_cache > 0
assert self.new_dp_group is not None
assert self.new_dp_group is not None and self.old_dp_group is not None
ParallelConfig.sync_kv_cache_memory_size(
self.new_dp_group,
self.engine_core.available_gpu_memory_for_kv_cache,
......@@ -507,7 +521,7 @@ class ElasticEPScalingState:
logger.info("[Elastic EP] EPLB reshuffle completed")
def _eplb_reshuffle_before_scale_down(self):
assert self.reconfig_request is not None
assert self.reconfig_request is not None and self.old_dp_group is not None
self.model_executor.collective_rpc(
"elastic_ep_execute",
args=(
......
......@@ -336,6 +336,7 @@ class TpKVTopology:
self._cross_layers_blocks = (
len(self.tensor_shape) == len(kv_cache_shape) + 1
)
self.tensor_shape: torch.Size
if self._cross_layers_blocks:
logger.debug("Using cross-layer KV cache")
......
......@@ -972,6 +972,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Early-out for cascade attention
if use_cascade:
assert num_blocks_np is not None
# Grab the blocks of the shared prefix from the first request.
num_common_kv_blocks = common_prefix_len // page_size
......@@ -1117,6 +1118,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
max_seq_len=max_seq_len,
)
else:
assert seq_lens_cpu is not None
pure_decode = num_prefills == 0
use_cudagraph = (
self.enable_cuda_graph
......
......@@ -88,14 +88,14 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
self.num_spec: int = self.speculative_config.num_speculative_tokens
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self.use_spec_decode: bool = self.num_spec > 0
self._init_reorder_batch_threshold(1, self.use_spec_decode)
self.use_full_cuda_graph = (
self.use_full_cuda_graph: bool = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.decode_cudagraph_max_bs = (
self.decode_cudagraph_max_bs: int = (
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1)
)
if self.compilation_config.max_cudagraph_capture_size is not None:
......@@ -104,42 +104,42 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
self.compilation_config.max_cudagraph_capture_size,
)
self.spec_state_indices_tensor = torch.empty(
self.spec_state_indices_tensor: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),
dtype=torch.int32,
device=device,
)
self.non_spec_state_indices_tensor = torch.empty(
self.non_spec_state_indices_tensor: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.spec_sequence_masks = torch.empty(
self.spec_sequence_masks: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.bool,
device=device,
)
self.spec_token_indx = torch.empty(
self.spec_token_indx: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.int32,
device=device,
)
self.non_spec_token_indx = torch.empty(
self.non_spec_token_indx: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.int32,
device=device,
)
self.spec_query_start_loc = torch.empty(
self.spec_query_start_loc: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.non_spec_query_start_loc = torch.empty(
self.non_spec_query_start_loc: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.num_accepted_tokens = torch.empty(
self.num_accepted_tokens: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
......@@ -322,6 +322,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
and num_spec_decodes <= self.decode_cudagraph_max_bs
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
):
assert spec_sequence_masks is not None
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True
)
......
......@@ -98,8 +98,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
self.use_spec_decode = self.num_spec_tokens > 0
assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs
scheduler_config = vllm_config.scheduler_config
self.decode_cudagraph_max_bs: int = scheduler_config.max_num_seqs
if self.compilation_config.max_cudagraph_capture_size is not None:
self.decode_cudagraph_max_bs = min(
self.decode_cudagraph_max_bs,
......@@ -114,7 +114,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
# Speculative decoding not supported with prefix caching,
# so keep shape consistent with prefill buffer
# TODO: reduce this size as needed for decode-only cudagraph capture
self.state_indices_tensor_d = torch.empty(
self.state_indices_tensor_d: torch.Tensor = torch.empty(
(
self.decode_cudagraph_max_bs,
max_num_blocks,
......@@ -122,12 +122,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
dtype=torch.int32,
device=device,
)
self.block_idx_last_scheduled_token = torch.empty(
self.block_idx_last_scheduled_token: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.block_idx_last_computed_token = torch.empty(
self.block_idx_last_computed_token: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
......@@ -142,7 +142,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
# For speculative decoding, we need to store the following buffers
# for CUDA graph capture during decode
if self.num_spec_tokens > 0:
self.decode_num_accepted_tokens = torch.empty(
self.decode_num_accepted_tokens: torch.Tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
......
......@@ -1539,18 +1539,18 @@ class DPEngineCoreProc(EngineCoreProc):
def _init_data_parallel(self, vllm_config: VllmConfig):
# Configure GPUs and stateless process group for data parallel.
dp_rank = vllm_config.parallel_config.data_parallel_rank
dp_size = vllm_config.parallel_config.data_parallel_size
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
parallel_config = vllm_config.parallel_config
dp_rank = parallel_config.data_parallel_rank
dp_size = parallel_config.data_parallel_size
local_dp_rank = parallel_config.data_parallel_rank_local
assert dp_size > 1
assert local_dp_rank is not None
assert 0 <= local_dp_rank <= dp_rank < dp_size
self.dp_rank = dp_rank
self.dp_group, self.dp_store = (
vllm_config.parallel_config.stateless_init_dp_group(return_store=True)
)
dp_group, dp_store = parallel_config.stateless_init_dp_group(return_store=True)
self.dp_group, self.dp_store = dp_group, dp_store
def shutdown(self):
super().shutdown()
......
......@@ -309,12 +309,16 @@ class AdapterLogitsProcessor(LogitsProcessor):
"""
if req_lp := self.new_req_logits_processor(params):
args = (
[prompt_ids, output_ids]
if (len(inspect.signature(req_lp).parameters) == 3)
else [output_ids]
)
return partial(req_lp, *args) # type: ignore[misc]
if len(inspect.signature(req_lp).parameters) == 3:
if prompt_ids is None:
raise ValueError(
"Prompt token ids are required for this "
"logits processor but were not provided."
)
args = [prompt_ids, output_ids]
else:
args = [output_ids]
return partial(req_lp, *args)
return None
def update_state(self, batch_update: BatchUpdate | None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment