Unverified Commit ce498a6d authored by Sage Moore's avatar Sage Moore Committed by GitHub
Browse files

Change the type signature of MixtureOfExperts.expert_weights to...


Change the type signature of MixtureOfExperts.expert_weights to MutableSequence[Sequence[Tensor]] (#33573)
Signed-off-by: default avatarSage Moore <sagmoore@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 9f14c922
......@@ -6,7 +6,7 @@ The actual execution of the rearrangement.
This involves the exchange of expert weights between GPUs.
"""
from collections.abc import Iterable, Sequence
from collections.abc import Sequence
from dataclasses import dataclass
import numpy as np
......@@ -153,7 +153,7 @@ def move_to_buffer(
num_local_experts: int,
old_indices: np.ndarray,
new_indices: np.ndarray,
expert_weights: Iterable[torch.Tensor],
expert_weights: Sequence[torch.Tensor],
expert_weights_buffers: Sequence[torch.Tensor],
cuda_stream: torch.cuda.Stream | None,
ep_group: ProcessGroup,
......@@ -355,7 +355,7 @@ def move_to_buffer(
def move_from_buffer(
expert_weights: Iterable[torch.Tensor],
expert_weights: Sequence[torch.Tensor],
expert_weights_buffers: list[torch.Tensor],
is_unchanged: np.ndarray,
is_received_locally: np.ndarray,
......@@ -436,7 +436,7 @@ def move_from_buffer(
async def transfer_layer(
old_global_expert_indices: torch.Tensor,
new_global_expert_indices: torch.Tensor,
expert_weights: Sequence[Iterable[torch.Tensor]],
expert_weights: Sequence[Sequence[torch.Tensor]],
expert_weights_buffer: Sequence[torch.Tensor],
ep_group: ProcessGroup,
is_profile: bool = False,
......@@ -488,7 +488,8 @@ async def transfer_layer(
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
assert len(expert_weights) == num_moe_layers
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
assert len(expert_weights[0]) >= 1
num_local_physical_experts = expert_weights[0][0].shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
assert num_physical_experts == ep_size * num_local_physical_experts
......@@ -510,7 +511,7 @@ async def transfer_layer(
def rearrange_expert_weights_inplace(
old_global_expert_indices: torch.Tensor,
new_global_expert_indices: torch.Tensor,
expert_weights: Sequence[Iterable[torch.Tensor]],
expert_weights: Sequence[Sequence[torch.Tensor]],
ep_group: ProcessGroup,
is_profile: bool = False,
rank_mapping: dict[int, int] | None = None,
......@@ -553,8 +554,9 @@ def rearrange_expert_weights_inplace(
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
assert len(expert_weights) == num_moe_layers
assert len(expert_weights[0]) >= 1
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
num_local_physical_experts = expert_weights[0][0].shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
ep_size = ep_group.size()
......
......@@ -2,7 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, MutableSequence
from collections.abc import (
AsyncGenerator,
Callable,
Iterable,
Mapping,
MutableSequence,
Sequence,
)
from contextlib import ExitStack, contextmanager, nullcontext
from typing import (
TYPE_CHECKING,
......@@ -818,7 +825,7 @@ class MixtureOfExperts(Protocol):
Check if the model is a mixture of experts (MoE) model.
"""
expert_weights: MutableSequence[Iterable[Tensor]]
expert_weights: MutableSequence[Sequence[Tensor]]
"""
Expert weights saved in this rank.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment