Unverified Commit 6a7988c5 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

Refactor pplx init logic to make it modular (prepare for deepep) (#18200)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 022d8abe
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import TYPE_CHECKING
import torch import torch
import torch.distributed as dist
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from .base_device_communicator import All2AllManagerBase, Cache
class All2AllBase: logger = init_logger(__name__)
def __init__(self, cpu_group, model):
self.cpu_group = cpu_group
# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_ep_group,
get_tp_group,
in_the_same_node_as)
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
self.ep_group = get_ep_group()
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
self.internode = not self.intranode
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self): if TYPE_CHECKING:
pass from vllm.model_executor.layers.fused_moe.layer import FusedMoE
else:
FusedMoE = None
class NaiveAll2All(All2AllBase): class NaiveAll2AllManager(All2AllManagerBase):
""" """
A naive implementation of all2all communication. A naive implementation of all2all communication.
It uses all-reduce under the hood, which is not It uses all-reduce under the hood, which is not
...@@ -46,8 +26,8 @@ class NaiveAll2All(All2AllBase): ...@@ -46,8 +26,8 @@ class NaiveAll2All(All2AllBase):
debugging. debugging.
""" """
def __init__(self, cpu_group, model): def __init__(self, cpu_group):
super().__init__(cpu_group, model) super().__init__(cpu_group)
def naive_multicast(self, x: torch.Tensor, def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor): cu_tokens_across_dp_cpu: torch.Tensor):
...@@ -91,3 +71,56 @@ class NaiveAll2All(All2AllBase): ...@@ -91,3 +71,56 @@ class NaiveAll2All(All2AllBase):
def destroy(self): def destroy(self):
pass pass
class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.
"""
def __init__(self, cpu_group):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
super().__init__(cpu_group)
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: "
"rank=%d, world size=%d", self.rank, self.world_size)
uid = nvshmem_get_unique_id(
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
dist.broadcast(uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.handle_cache = Cache()
def get_handle(self, kwargs):
import pplx_kernels as pplx
return self.handle_cache.get_or_create(
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import threading
from typing import Optional from typing import Optional
from weakref import WeakValueDictionary
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
class Cache:
def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety
def get_or_create(self, kwargs, func):
# Create a hashable key from the kwargs
key = tuple(sorted((k, v) for k, v in kwargs.items()))
with self._lock:
instance = self._cache.get(key)
if instance is None:
instance = func(**kwargs)
self._cache[key] = instance
return instance
class All2AllManagerBase:
def __init__(self, cpu_group):
self.cpu_group = cpu_group
# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_tp_group,
in_the_same_node_as)
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
# no self.ep_group since self.ep_group is still in construction
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
self.internode = not self.intranode
def get_handle(self, kwargs):
# get a handle for the all2all communication,
# based on the kwargs.
# different layers can have different configs,
# e.g. one layer has hidden size 1024, another has 2048.
# usually the underlying implementation caches the handle
# and reuse it for the same config.
raise NotImplementedError
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
pass
class DeviceCommunicatorBase: class DeviceCommunicatorBase:
""" """
Base class for device-specific communicator. Base class for device-specific communicator.
...@@ -31,6 +96,18 @@ class DeviceCommunicatorBase: ...@@ -31,6 +96,18 @@ class DeviceCommunicatorBase:
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank) self.global_rank)
use_ep = False
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
# as long as we use data parallel (coupled data parallel
# where all data parallel ranks execute forward together),
# we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1
self.use_all2all = "ep" in unique_name and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group) dist.all_reduce(input_, group=self.device_group)
return input_ return input_
...@@ -154,9 +231,17 @@ class DeviceCommunicatorBase: ...@@ -154,9 +231,17 @@ class DeviceCommunicatorBase:
model: torch.nn.Module) -> None: model: torch.nn.Module) -> None:
""" """
Prepare the communication buffer for the model. Prepare the communication buffer for the model.
This is a no-op in the base class.
""" """
pass if not self.use_all2all:
return
moe_modules = [
module for module in model.modules()
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config,
module.quant_config)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
......
...@@ -6,10 +6,12 @@ import torch ...@@ -6,10 +6,12 @@ import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger
from .all2all import All2AllBase
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
class CudaCommunicator(DeviceCommunicatorBase): class CudaCommunicator(DeviceCommunicatorBase):
...@@ -31,8 +33,6 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -31,8 +33,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
use_pynccl = "ep" not in unique_name use_pynccl = "ep" not in unique_name
self.use_pynccl = use_pynccl self.use_pynccl = use_pynccl
self.use_all2all = "ep" in unique_name
self.all2all_impl: Optional[All2AllBase] = None
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
# lazy import to avoid documentation build error # lazy import to avoid documentation build error
...@@ -56,6 +56,19 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -56,6 +56,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
device=self.device, device=self.device,
) )
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
elif all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
logger.info("Using PPLX all2all manager.")
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
def all_reduce(self, input_): def all_reduce(self, input_):
# always try custom allreduce first, # always try custom allreduce first,
# and then pynccl. # and then pynccl.
...@@ -136,31 +149,19 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -136,31 +149,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm = None self.pynccl_comm = None
if self.ca_comm is not None: if self.ca_comm is not None:
self.ca_comm = None self.ca_comm = None
if self.all2all_impl is not None: if self.all2all_manager is not None:
self.all2all_impl.destroy() self.all2all_manager.destroy()
self.all2all_impl = None self.all2all_manager = None
def prepare_communication_buffer_for_model(self,
model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
"""
if not self.use_all2all:
return
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2All
self.all2all_impl = NaiveAll2All(self.cpu_group, model)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_impl is not None assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_impl.dispatch( hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits) hidden_states, router_logits)
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_impl is not None assert self.all2all_manager is not None
hidden_states = self.all2all_impl.combine(hidden_states) hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states return hidden_states
...@@ -23,7 +23,6 @@ If you only need to use the distributed environment without model/pipeline ...@@ -23,7 +23,6 @@ If you only need to use the distributed environment without model/pipeline
""" """
import contextlib import contextlib
import gc import gc
import importlib.util
import pickle import pickle
import weakref import weakref
from collections import namedtuple from collections import namedtuple
...@@ -43,7 +42,7 @@ from vllm.distributed.device_communicators.base_device_communicator import ( ...@@ -43,7 +42,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
run_once, supports_custom_op) supports_custom_op)
@dataclass @dataclass
...@@ -791,10 +790,14 @@ class GroupCoordinator: ...@@ -791,10 +790,14 @@ class GroupCoordinator:
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states, return self.device_communicator.dispatch(hidden_states,
router_logits) router_logits)
else:
return hidden_states, router_logits
def combine(self, hidden_states) -> torch.Tensor: def combine(self, hidden_states) -> torch.Tensor:
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.combine(hidden_states) return self.device_communicator.combine(hidden_states)
else:
return hidden_states
_WORLD: Optional[GroupCoordinator] = None _WORLD: Optional[GroupCoordinator] = None
...@@ -959,49 +962,9 @@ def init_distributed_environment( ...@@ -959,49 +962,9 @@ def init_distributed_environment(
"world group already initialized with a different world size") "world group already initialized with a different world size")
PPLX_DID_INIT: bool = False
@run_once
def pplx_init(rank, world_size):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
if has_pplx and world_size > 1:
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id, nvshmem_init)
try:
global PPLX_DID_INIT
logger.debug(
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
"world size=%d", rank, world_size)
uid = nvshmem_get_unique_id(
) if rank == 0 else nvshmem_alloc_empty_unique_id()
uid_gpu = uid.cuda()
get_world_group().broadcast(uid_gpu, src=0)
uid = uid_gpu.to(device='cpu')
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, rank, world_size)
PPLX_DID_INIT = True
except Exception as ex:
logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex)
@run_once
def pplx_finalize():
global PPLX_DID_INIT
if PPLX_DID_INIT:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
from vllm.model_executor.layers.fused_moe.layer import (
_all_to_all_cache)
_all_to_all_cache.destroy()
nvshmem_finalize()
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
enable_expert_parallel: bool = False,
backend: Optional[str] = None, backend: Optional[str] = None,
) -> None: ) -> None:
""" """
...@@ -1104,14 +1067,10 @@ def initialize_model_parallel( ...@@ -1104,14 +1067,10 @@ def initialize_model_parallel(
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
_EP.rank_in_group) _EP.rank_in_group)
if enable_expert_parallel:
pplx_init(rank, world_size)
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
tensor_model_parallel_size: int, tensor_model_parallel_size: int,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
enable_expert_parallel: bool = False,
backend: Optional[str] = None, backend: Optional[str] = None,
) -> None: ) -> None:
"""Helper to initialize model parallel groups if they are not initialized, """Helper to initialize model parallel groups if they are not initialized,
...@@ -1122,8 +1081,7 @@ def ensure_model_parallel_initialized( ...@@ -1122,8 +1081,7 @@ def ensure_model_parallel_initialized(
get_world_group().device_group) get_world_group().device_group)
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size, initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size, pipeline_model_parallel_size, backend)
enable_expert_parallel, backend)
return return
assert ( assert (
...@@ -1202,8 +1160,6 @@ def destroy_model_parallel(): ...@@ -1202,8 +1160,6 @@ def destroy_model_parallel():
"""Set the groups to none and destroy them.""" """Set the groups to none and destroy them."""
global _TP global _TP
pplx_finalize()
if _TP: if _TP:
_TP.destroy() _TP.destroy()
_TP = None _TP = None
......
...@@ -809,6 +809,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -809,6 +809,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
# all2all backend for vllm's expert parallel communication # all2all backend for vllm's expert parallel communication
# Available options:
# - "naive": naive all2all implementation using all-reduce
# - "pplx": use pplx kernels
"VLLM_ALL2ALL_BACKEND": "VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
} }
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import importlib import importlib
import threading
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
from weakref import WeakValueDictionary
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -73,7 +71,8 @@ class FusedMoEParallelConfig: ...@@ -73,7 +71,8 @@ class FusedMoEParallelConfig:
@property @property
def use_pplx_kernels(self): def use_pplx_kernels(self):
return self.dp_size > 1 and self.use_ep and has_pplx return self.dp_size > 1 and self.use_ep and \
envs.VLLM_ALL2ALL_BACKEND == "pplx"
@staticmethod @staticmethod
def make(tp_size_: int, dp_size_: int, def make(tp_size_: int, dp_size_: int,
...@@ -196,6 +195,8 @@ class MoEConfig: ...@@ -196,6 +195,8 @@ class MoEConfig:
# TODO: add more quantization params, blocked, per-token, etc. # TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128 block_size: int = 128
max_num_tokens: int = MOE_DP_CHUNK_SIZE
@property @property
def tp_size(self): def tp_size(self):
return self.moe_parallel_config.tp_size return self.moe_parallel_config.tp_size
...@@ -244,13 +245,59 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -244,13 +245,59 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError raise NotImplementedError
def set_prepare_finalize( def init_prepare_finalize(self, moe: MoEConfig,
self, quant_config: Optional[QuantizationConfig]):
dp_size: int, all2all_manager = get_ep_group().device_communicator.all2all_manager
world_size: int, assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize,
) -> bool: prepare_finalize = None
return False if moe.use_pplx_kernels:
all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=all2all_manager.rank,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)),
group_name=all2all_manager.cpu_group.group_name,
)
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = PplxPrepareAndFinalize(
handle,
max_num_tokens=moe.max_num_tokens,
world_size=all2all_manager.world_size,
rank=all2all_manager.rank,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
quant_dtype=moe.in_dtype,
)
if prepare_finalize is not None:
experts = self.select_gemm_impl(prepare_finalize)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
def select_gemm_impl(
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise NotImplementedError(
"Subclass must select appropriate gemm implementation"
" based on the prepare_finalize")
@abstractmethod @abstractmethod
def apply( def apply(
...@@ -274,53 +321,13 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -274,53 +321,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError raise NotImplementedError
class AllToAllCache:
def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety
def destroy(self):
with self._lock:
# TODO: can we do del self._cache?
for _, a2a in self._cache.items():
a2a.destroy()
def get_or_create(self, **kwargs):
assert has_pplx
import pplx_kernels as pplx
# Create a hashable key from the kwargs
key = tuple(sorted((k, v) for k, v in kwargs.items()))
with self._lock:
instance = self._cache.get(key)
if instance is None:
# TODO (varun): Add support to switch to intranode
# when all communications are within the same
# node.
logger.debug("Create AllToAll %s", kwargs)
instance = pplx.AllToAll.internode(**kwargs)
self._cache[key] = instance
return instance
# Global singleton
_all_to_all_cache = AllToAllCache()
# Factory function as a cleaner interface
def get_all_to_all(**kwargs):
return _all_to_all_cache.get_or_create(**kwargs)
@CustomOp.register("unquantized_fused_moe") @CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization.""" """MoE method without quantization."""
def __init__(self, moe: MoEConfig): def __init__(self, moe: MoEConfig):
super().__init__() super().__init__()
self.fused_experts = fused_experts self.fused_experts = fused_experts # type: ignore
self.moe = moe self.moe = moe
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
...@@ -330,6 +337,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -330,6 +337,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else: else:
self.rocm_aiter_fused_experts = None # type: ignore self.rocm_aiter_fused_experts = None # type: ignore
def select_gemm_impl(
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]):
assert self.fused_experts == fused_experts
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
return experts
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
...@@ -429,47 +472,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -429,47 +472,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input)
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: FusedMoEPrepareAndFinalize,
) -> bool:
assert self.fused_experts == fused_experts
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
return True
def forward_cuda( def forward_cuda(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -679,45 +681,6 @@ def determine_expert_map( ...@@ -679,45 +681,6 @@ def determine_expert_map(
return (local_num_experts, expert_map) return (local_num_experts, expert_map)
def _construct_prepare_finalize(
moe: MoEConfig, quant_config: Optional[QuantizationConfig]
) -> Optional[FusedMoEPrepareAndFinalize]:
max_num_tokens = MOE_DP_CHUNK_SIZE
world_size = moe.ep_size
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
rank = moe.ep_rank
if moe.use_pplx_kernels:
logger.debug("using PplxPrepareAndFinalize")
all_to_all = get_all_to_all(
max_num_tokens=max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) //
moe.block_size * torch.float32.itemsize)))
return PplxPrepareAndFinalize(
all_to_all,
max_num_tokens=max_num_tokens,
world_size=world_size,
rank=rank,
dp_size=dp_size,
quant_dtype=moe.in_dtype,
)
return None
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
...@@ -831,7 +794,10 @@ class FusedMoE(torch.nn.Module): ...@@ -831,7 +794,10 @@ class FusedMoE(torch.nn.Module):
moe_parallel_config=self.moe_parallel_config, moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types. # TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype, in_dtype=params_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
) )
self.moe_config = moe
self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts # Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first. # for heuristic purposes, so it must be initialized first.
...@@ -839,25 +805,13 @@ class FusedMoE(torch.nn.Module): ...@@ -839,25 +805,13 @@ class FusedMoE(torch.nn.Module):
if quant_config is None: if quant_config is None:
quant_method = UnquantizedFusedMoEMethod(moe) quant_method = UnquantizedFusedMoEMethod(moe)
prepare_finalize = _construct_prepare_finalize(moe, quant_config)
else: else:
quant_method = quant_config.get_quant_method(self, prefix) quant_method = quant_config.get_quant_method(self, prefix)
# No pplx for quantized types yet.
prepare_finalize = None
assert quant_method is not None assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase) assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method self.quant_method = quant_method
if prepare_finalize is not None:
world_size = moe.ep_size
dp_size = int(moe.ep_size // moe.dp_size)
success = self.quant_method.set_prepare_finalize(
dp_size, world_size, prepare_finalize)
if not success:
logger.warning("DP+EP not supported for %s.",
type(self.quant_method))
moe_quant_params = { moe_quant_params = {
"num_experts": self.local_num_experts, "num_experts": self.local_num_experts,
"hidden_size": hidden_size, "hidden_size": hidden_size,
......
...@@ -9,7 +9,6 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -9,7 +9,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input) moe_kernel_quantize_input)
# Note use: layer.get_all_to_all() to get an AllToAll instance
# The max_num_tokens, world_size and dp_size must be the same # The max_num_tokens, world_size and dp_size must be the same
# as the ones used to create the AllToAll. # as the ones used to create the AllToAll.
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
......
...@@ -10,7 +10,6 @@ from torch.nn import Module ...@@ -10,7 +10,6 @@ from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -461,7 +460,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -461,7 +460,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.warning_once( logger.warning_once(
"DeepGemm not supported on the current platform.") "DeepGemm not supported on the current platform.")
self.fused_experts = functools.partial( self.fused_experts = functools.partial( # type: ignore
fused_experts, fused_experts,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm) allow_deep_gemm=self.allow_deep_gemm)
...@@ -791,17 +790,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -791,17 +790,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
def set_prepare_finalize( def select_gemm_impl(self, prepare_finalize):
self,
dp_size: int,
world_size: int,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
) -> bool:
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts) TritonOrDeepGemmExperts)
if self.use_marlin or self.rocm_aiter_moe_enabled: assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
return False "Marlin and ROCm AITER are not supported with all2all yet.")
experts = TritonOrDeepGemmExperts( experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True, use_fp8_w8a8=True,
...@@ -809,12 +803,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -809,12 +803,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_deep_gemm=self.allow_deep_gemm, allow_deep_gemm=self.allow_deep_gemm,
) )
self.fused_experts = mk.FusedMoEModularKernel( return experts
prepare_finalize,
experts,
)
return True
def apply( def apply(
self, self,
......
...@@ -158,6 +158,7 @@ class CudaPlatformBase(Platform): ...@@ -158,6 +158,7 @@ class CudaPlatformBase(Platform):
"currently not supported with CUDA Graphs.") "currently not supported with CUDA Graphs.")
vllm_config.model_config.enforce_eager = True vllm_config.model_config.enforce_eager = True
compilation_config.use_cudagraph = False compilation_config.use_cudagraph = False
# FIXME: inductor breaks cudagraph (from @bnell)
compilation_config.use_inductor = False compilation_config.use_inductor = False
@classmethod @classmethod
......
...@@ -348,8 +348,7 @@ def init_worker_distributed_environment( ...@@ -348,8 +348,7 @@ def init_worker_distributed_environment(
distributed_init_method, local_rank) distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size)
parallel_config.enable_expert_parallel)
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)
......
...@@ -265,8 +265,7 @@ def init_tpu_worker_distributed_environment( ...@@ -265,8 +265,7 @@ def init_tpu_worker_distributed_environment(
backend="gloo", backend="gloo",
) )
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size)
parallel_config.enable_expert_parallel)
try: try:
......
...@@ -390,8 +390,7 @@ class CPUWorker(LocalOrDistributedWorkerBase): ...@@ -390,8 +390,7 @@ class CPUWorker(LocalOrDistributedWorkerBase):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size)
parallel_config.enable_expert_parallel)
def get_cache_block_size_bytes(self) -> int: def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block. """Return the size in bytes of a single KV cache block.
......
...@@ -415,8 +415,7 @@ def init_worker_distributed_environment( ...@@ -415,8 +415,7 @@ def init_worker_distributed_environment(
backend='hccl') backend='hccl')
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size)
parallel_config.enable_expert_parallel)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size() torch_world_size = torch.distributed.get_world_size()
...@@ -442,8 +441,7 @@ def init_worker_distributed_environment( ...@@ -442,8 +441,7 @@ def init_worker_distributed_environment(
torch.distributed.all_reduce(dummy_tensor_hpu) torch.distributed.all_reduce(dummy_tensor_hpu)
assert dummy_tensor_hpu.item() == parallel_config.world_size assert dummy_tensor_hpu.item() == parallel_config.world_size
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size)
parallel_config.enable_expert_parallel)
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len, def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,
......
...@@ -76,8 +76,7 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -76,8 +76,7 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
) )
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size, self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size, self.parallel_config.pipeline_parallel_size)
self.parallel_config.enable_expert_parallel)
# Device initialization should happen after initializing the distributed # Device initialization should happen after initializing the distributed
# runtime. # runtime.
......
...@@ -529,8 +529,7 @@ def init_worker_distributed_environment( ...@@ -529,8 +529,7 @@ def init_worker_distributed_environment(
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank) distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size)
parallel_config.enable_expert_parallel)
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)
......
...@@ -175,8 +175,7 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker): ...@@ -175,8 +175,7 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size)
parallel_config.enable_expert_parallel)
# global all_reduce needed for overall oneccl warm up # global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(torch.zeros(1).xpu()) torch.distributed.all_reduce(torch.zeros(1).xpu())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment