Commit ab485158 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-ds-deepep-wm' into 'v0.9.2-dev-ds'

整合mori和deepep相关代码

See merge request dcutoolkit/deeplearing/vllm!239
parents db2c32b0 24fed7be
......@@ -4755,7 +4755,7 @@ class VllmConfig:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
if self.model_config.use_mla and self.compilation_config.full_cuda_graph and self.scheduler_config.max_num_seqs<=512:
if self.model_config.use_mla and self.scheduler_config.max_num_seqs<=512:
cuda_graph_sizes = [self.scheduler_config.max_num_seqs]
else:
cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes
......
......@@ -140,7 +140,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
# This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling.
self.num_sms = 20
self.num_sms = 24#20
def get_handle(self, kwargs):
raise NotImplementedError
......@@ -166,13 +166,21 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = 1024 * 1024 * 1024
num_nvl_bytes = int(2e9/2)#1024 * 1024 * 1024
num_rdma_bytes = None
num_qps_per_rank = None
if self.internode:
num_rdma_bytes = 1024 * 1024 * 1024
num_qps_per_rank = self.num_sms // 2
num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024
num_qps_per_rank = 30 #self.num_sms // 2
# import deep_ep
# num_nvl_bytes, num_rdma_bytes = 0, 0
# hidden_size = 7168
# hidden_bytes = hidden_size * 2
# for config in (deep_ep.Buffer.get_dispatch_config(self.cpu_group.size()), deep_ep.Buffer.get_combine_config(self.cpu_group.size())):
# num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_nvl_bytes)
# num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_rdma_bytes)
else:
num_rdma_bytes = 0
num_qps_per_rank = 1
......@@ -183,7 +191,9 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank)
num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=False,
use_default_stream_as_comm_stream=False)
def get_handle(self, kwargs):
......
......@@ -87,6 +87,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Low-Latency all2all manager.")
elif all2all_backend == "mori":
pass
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
......
......@@ -951,7 +951,7 @@ def init_distributed_environment(
parallel_config = config.parallel_config
data_parallel_size = parallel_config.data_parallel_size
use_mori_ep = envs.VLLM_USE_MORI_EP and data_parallel_size > 1 and parallel_config.enable_expert_parallel
use_mori_ep = envs.VLLM_ALL2ALL_BACKEND == 'mori' and data_parallel_size > 1 and parallel_config.enable_expert_parallel
if use_mori_ep:
backend="cpu:gloo,cuda:nccl"
torch.distributed.init_process_group(
......
......@@ -173,9 +173,9 @@ if TYPE_CHECKING:
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_MORI_EP: bool = False
VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_ENABLE_MOE_GROUP_GEMM: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -945,6 +945,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "pplx": use pplx kernels
# - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_low_latency", use deepep low-latency kernels
# - "mori", use mori kernels
"VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
......@@ -1144,11 +1145,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")),
# vLLM will use all_to_all ep mode
"VLLM_USE_MORI_EP":
lambda: (os.environ.get("VLLM_USE_MORI_EP", "True").lower() in
("true", "1")),
# vllm pd separation will be used async
"VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
......@@ -1156,6 +1152,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
# pd separation p2p async buf tokens
"VLLM_ENABLE_MOE_GROUP_GEMM":
lambda: (os.environ.get("VLLM_ENABLE_MOE_GROUP_GEMM", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -136,8 +136,8 @@ def set_forward_context(
forward_start_time = time.perf_counter()
dp_metadata: Optional[DPMetadata] = None
dp_size = vllm_config.parallel_config.data_parallel_size
use_mori_ep = envs.VLLM_USE_MORI_EP and dp_size > 1 and vllm_config.parallel_config.enable_expert_parallel
if not use_mori_ep and dp_size > 1 and (
use_navie_ep = envs.VLLM_ALL2ALL_BACKEND == 'naive' and dp_size > 1 and vllm_config.parallel_config.enable_expert_parallel
if use_navie_ep and dp_size > 1 and (
attn_metadata is not None or num_tokens is not None) :
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
attn_metadata, num_tokens or 0,
......
......@@ -59,6 +59,8 @@ if HAS_TRITON:
get_config_file_name, grouped_topk)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.triton_group_gemm_moe import (
TritonOrGroupGemmExperts)
__all__ += [
"fused_moe",
......@@ -75,4 +77,5 @@ if HAS_TRITON:
"BatchedDeepGemmExperts",
"TritonOrDeepGemmExperts",
"BatchedTritonOrDeepGemmExperts",
"TritonOrGroupGemmExperts",
]
......@@ -4,12 +4,15 @@ from typing import Optional
import deep_ep
import torch
import torch.distributed as dist
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.distributed.parallel_state import get_ep_group
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
......@@ -54,6 +57,10 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
if self.dp_size not in self.available_rank_configs:
return None
return deep_ep.Buffer.get_combine_config(self.dp_size)
def sync(self):
# torch.cuda.synchronize()
dist.barrier()
def _do_dispatch(self, tokens: torch.Tensor,
token_scales: Optional[torch.Tensor],
......@@ -205,13 +212,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> None:
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True) -> None:
assert self.handle is not None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
if fused_expert_output.numel() != 0:
if fused_expert_output.numel() != 0 and apply_weights_and_reduce:
fused_expert_output = self._apply_weights_and_reduce(
num_tokens=topk_ids.size(0),
fused_expert_output=fused_expert_output,
......@@ -227,5 +235,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False)
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
......@@ -162,7 +162,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> None:
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True) -> None:
assert self.handle is not None
......
......@@ -596,6 +596,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None:
num_tokens = topk_ids.size(0)
num_local_experts = fused_expert_output.size(0)
......
......@@ -28,8 +28,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig)
# yapf: enable
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEModularKernel,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
FusedMoEActivationFormat, FusedMoEModularKernel,
DeepGemmBannedFusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
# is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.quantization.base_config import (
......@@ -40,7 +41,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx, has_deep_gemm
from vllm import _custom_ops as ops
......@@ -184,10 +185,17 @@ class FusedMoEMethodBase(QuantizeMethodBase):
logger.debug("%s", prepare_finalize.__class__.__name__)
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, moe)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
if has_deep_gemm():
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
else:
self.fused_experts = DeepGemmBannedFusedMoEModularKernel(
prepare_finalize,
experts,
)
def select_gemm_impl(
self,
......
......@@ -149,6 +149,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None:
"""
Perform any combine plus apply weights and perform a reduction on the
......@@ -355,6 +356,168 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
assigned to each expert when using batched experts format input.
"""
raise NotImplementedError
class CustomizedFusedMoEPermuteExpertsUnpermute(ABC):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
above.
"""
def __init__(
self,
quant_config: Optional[FusedMoEQuantConfig],
):
if quant_config is not None:
self.quant_config = quant_config
else:
self.quant_config = FusedMoEQuantConfig()
@property
@abstractmethod
def activation_formats(
self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
"""
A property which is a tuple of the input and output activation formats
for the 'apply' method.
"""
raise NotImplementedError
@property
def quant_dtype(self) -> Optional[torch.dtype]:
return self.quant_config.quant_dtype
@property
def block_shape(self) -> Optional[list[int]]:
return self.quant_config.block_shape
@property
def per_act_token_quant(self) -> bool:
return self.quant_config.per_act_token_quant
@property
def per_out_ch_quant(self) -> bool:
return self.quant_config.per_out_ch_quant
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod
def supports_chunking(self) -> bool:
"""
A flag indicating whether or not this class supports activation
chunking.
"""
raise NotImplementedError
@abstractmethod
def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps
"""
raise NotImplementedError
@abstractmethod
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
raise NotImplementedError
def activation(self, activation: str, output: torch.Tensor,
input: torch.Tensor) -> None:
assert output.size(-1) * 2 == input.size(-1)
if activation == "silu":
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(output, input)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
def enable_chunking(self):
return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
self.supports_chunking()
@abstractmethod
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
):
"""
This function computes the intermediate result of a Mixture of Experts
(MoE) layer using two sets of weights, w1 and w2.
Parameters:
- output: (torch.Tensor): The unweighted, unreduced output tensor.
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation
function.
- expert_num_tokens: An optional tensor containing the number of tokens
assigned to each expert when using batched experts format input.
"""
raise NotImplementedError
def _chunk_scales(scales: Optional[torch.Tensor], start: int,
......@@ -596,3 +759,145 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_ids, apply_router_weight_on_input)
return output
@final
class DeepGemmBannedFusedMoEModularKernel(torch.nn.Module):
"""
This class combines a FusedMoEPrepareAndFinalize instance and
a FusedMoEPermuteExpertsUnpermute to provide an interface that
is compatible with the `fused_experts` function in fused_moe.py.
It takes care of managing any required scratch space.
Note: Instances of this class should only be used for a single model
layer due to any layer specific state that may be used by the component
objects.
"""
def __init__(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: CustomizedFusedMoEPermuteExpertsUnpermute,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
assert prepare_finalize.activation_format == \
fused_experts.activation_formats[0], (
f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}")
def forward(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
apply_router_weight_on_input: bool = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
topk_weights = (topk_weights if _expert_topk_weights is None else
_expert_topk_weights)
fused_out = self.fused_experts.apply(
None,
a1q,
w1,
w2,
topk_ids,
topk_weights=topk_weights,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=None,
workspace2=None,
use_nn_moe=use_nn_moe,
expert_num_tokens=expert_num_tokens,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=False)
return output
import os
import logging
from typing import Callable, List, Optional, Tuple
from dataclasses import dataclass
from typing import Callable, Optional
from collections.abc import Iterable
import torch
import torch.nn.functional as F
import torch.distributed as dist
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -18,10 +15,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis import EpMoeConfig
from vllm.model_executor.layers.fused_moe.mori_moe.ep_moe_utlis import EpMoeConfig
from vllm.utils import direct_register_custom_op
import torch.distributed as dist
try:
import mori
......@@ -35,8 +30,8 @@ logger = init_logger(__name__)
_MORI_OP = None
@CustomOp.register("unquantized_ep_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
@CustomOp.register("unquantized_mori_moe")
class UnquantizedMoriMoeMethod(UnquantizedFusedMoEMethod):
"""MoE method without quantization."""
def __init__(self, moe: FusedMoEConfig):
......@@ -44,9 +39,9 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.topk_indices_dtype = None
self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = False
def apply_ep(
def apply_mori_ep(
self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
......@@ -162,7 +157,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
forward_native = forward_cuda
class EPMoE(FusedMoE):
class MoriMoE(FusedMoE):
"""
dp+ep MoE Expert Parallel Impl
......@@ -194,7 +189,6 @@ class EPMoE(FusedMoE):
enable_eplb: bool = False,
num_redundant_experts: int = 0,
moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False
):
super().__init__(num_experts, top_k, hidden_size,
intermediate_size, params_dtype,
......@@ -215,7 +209,6 @@ class EPMoE(FusedMoE):
moe_router_topk=self.top_k,
# TODO: support fusion permute
moe_permute_fusion=moe_permute_fusion,
moe_shared_expert_overlap=moe_shared_expert_overlap,
ep_size=self.ep_size,
num_moe_experts=self.global_num_experts,
routed_scaling_factor=self.routed_scaling_factor,
......@@ -228,23 +221,15 @@ class EPMoE(FusedMoE):
self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.local_num_experts)
]
self.use_shared_expert = False
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.local_num_experts, self.local_expert_indices,
config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
)
self.shared_expert_overlap = moe_shared_expert_overlap
self.shared_experts = None
self.scales = None
self.use_int8_dispatch = True
vllm_config = get_current_vllm_config()
self.max_num_inp_token_per_rank = vllm_config.scheduler_config.max_num_seqs
self.max_num_inp_token_per_rank = 1024 #vllm_config.scheduler_config.max_num_seqs
self.mori_op = self.get_mori_op()
self.first = True
def get_mori_op(self):
global _MORI_OP
......@@ -253,10 +238,6 @@ class EPMoE(FusedMoE):
assert world_group is not None
torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
mori.shmem.shmem_torch_process_group_init("mori_ep")
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
# torch._C._distributed_c10d._register_process_group("default", world_group)
# mori.shmem.shmem_torch_process_group_init("default")
vllm_config = get_current_vllm_config()
multi_node = self.ep_size / 8 > 1
......@@ -278,8 +259,7 @@ class EPMoE(FusedMoE):
num_experts_per_token=self.top_k,
max_token_type_size=2,
block_num=80,
warp_num_per_block=16,
# kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
warp_num_per_block=4,
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \
mori.ops.EpDispatchCombineKernelType.IntraNode
)
......@@ -291,14 +271,11 @@ class EPMoE(FusedMoE):
if self.shared_experts is None:
self.shared_experts = shared_experts
if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(self.shared_experts)
def create_quant_method(self, moe, quant_config, prefix):
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
quant_method = (UnquantizedEPGroupedGemmMethod(moe) if quant_config is None
quant_method = (UnquantizedMoriMoeMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))
assert quant_method is not None
......@@ -311,7 +288,7 @@ class EPMoE(FusedMoE):
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
return torch.ops.vllm.mori_moe_forward(hidden_states, router_logits,
self.layer_name)
def get_expert_weights(self) -> Iterable[torch.Tensor]:
......@@ -351,7 +328,7 @@ class EPMoE(FusedMoE):
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
if self.use_int8_dispatch:
......@@ -378,33 +355,10 @@ class EPMoE(FusedMoE):
hidden_states,
topk_weights,
scales,
topk_ids,
topk_ids
)
# self.sync()
# expect_m = topk_ids.shape[0] * self.ep_size
# dispatch_output_clip = dispatch_output[:expect_m]
# dispatch_weights_clip = dispatch_weights[:expect_m]
# dispatch_indices_clip = dispatch_indices[:expect_m]
# dispatch_scales_clip = dispatch_scales[:expect_m]
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output_clip,
# topk_weights=dispatch_weights_clip,
# topk_ids=dispatch_indices_clip,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales_clip if self.use_int8_dispatch else None
# #routed_scaling_factor=self.routed_scaling_factor,
# )
expert_output = self.quant_method.apply_ep(
expert_output = self.quant_method.apply_mori_ep(
layer=self,
x=dispatch_output,
topk_weights=dispatch_weights,
......@@ -415,10 +369,10 @@ class EPMoE(FusedMoE):
apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe,
num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0],
expect_m=hidden_states.shape[0],
scales=dispatch_scales if self.use_int8_dispatch else None
# routed_scaling_factor=self.routed_scaling_factor,
)
# self.sync()
combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids)
......@@ -426,13 +380,7 @@ class EPMoE(FusedMoE):
# self.sync()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
# shared_output = (
# self.maybe_all_reduce_tensor_model_parallel(
# shared_output))
if self.shared_experts is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
......@@ -444,7 +392,7 @@ class EPMoE(FusedMoE):
return final_hidden_states
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
def mori_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
......@@ -453,16 +401,16 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
return self.forward_impl(hidden_states, router_logits)
def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
def mori_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="ep_moe_forward",
op_func=ep_moe_forward,
op_name="mori_moe_forward",
op_func=mori_moe_forward,
mutates_args=["hidden_states", "router_logits"],
fake_impl=ep_moe_forward_fake,
fake_impl=mori_moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order,),
)
\ No newline at end of file
......@@ -207,6 +207,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None:
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
......
......@@ -61,6 +61,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
apply_weights_and_reduce: bool = True
) -> None:
_moe_unpermute_and_reduce(output, fused_expert_output, None,
topk_weights, apply_router_weight_on_input)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
allow_group_gemm: bool = False,
fused_experts = None
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
self.fused_experts = fused_experts
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
raise NotImplementedError
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
topk_weights: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
):
assert self.fused_experts is not None
return self.fused_experts(
x=hidden_states,
w1=w1,
w2=w2,
topk_ids=topk_ids,
topk_weights=topk_weights,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=False,
activation=activation,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1q_scale,
a2_scale=a2_scale,
expert_num_tokens=expert_num_tokens,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor
)
......@@ -4,10 +4,12 @@ import os
import torch
from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size, get_dp_group
from vllm.logger import init_logger
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig,
......@@ -125,6 +127,10 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
def get_scaled_act_names(self) -> List[str]:
return []
@property
def weight_block_size(self):
return [128,128]
class SlimQuantW4A8Int8MarlinMoEMethod:
......@@ -154,6 +160,15 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
def __init__(self, quant_config):
self.quant_config = quant_config
self.fused_experts = self.w4a8_marlin_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.use_deepep = parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.enable_moe_group_gemm = parallel_config.enable_expert_parallel and envs.VLLM_ENABLE_MOE_GROUP_GEMM
def create_weights(
self,
......@@ -218,7 +233,55 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False)
layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False)
def apply_ep( #dp+ep
def w4a8_marlin_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
**_ ):
if not self.enable_moe_group_gemm:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
w1,
w2,
topk_ids=topk_ids,
topk_weights=topk_weights,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
else:
# TODO:
return None
def apply_mori_ep(
self,
layer: torch.nn.Module,
x: torch.Tensor,
......@@ -230,7 +293,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
config_select_bs: Optional[int] = None,
expect_m: Optional[int] = None,
routed_scaling_factor: Optional[float] = None,
scales: Optional[torch.Tensor] = None,
**_
......@@ -253,12 +316,11 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a1_scale=scales,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
num_local_tokens=num_local_tokens,
config_select_bs=config_select_bs,
q_scales=scales
expect_m=expect_m,
)
def apply(
......@@ -301,29 +363,25 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=torch.int64 if self.use_deepep else None,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
......@@ -335,10 +393,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
BatchedGroupedGemmExperts, GroupedGemmGemmExperts)
assert not self.rocm_aiter_moe_enabled, (
"ROCm AITER are not supported with all2all yet.")
TritonOrGroupGemmExperts)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
......@@ -350,21 +405,16 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
self.quant_config.weight_block_size, False)
return BatchedGroupedGemmExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=False,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=True,
allow_deep_gemm=False,
)
return None
else:
logger.debug(
"GroupedGemmGemmExperts(%s): block_size=%s, per_act_token=%s",
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, self.quant_config.weight_block_size,
False)
return GroupedGemmGemmExperts(
return TritonOrGroupGemmExperts(
use_fp8_w8a8=False,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=False,
allow_group_gemm=False,
fused_experts=self.w4a8_marlin_forward
)
......@@ -178,7 +178,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
self.use_mori_ep = envs.VLLM_USE_MORI_EP and dp_size > 1 and parallel_config.enable_expert_parallel
self.use_mori_ep = envs.VLLM_ALL2ALL_BACKEND == 'mori' and dp_size > 1 and parallel_config.enable_expert_parallel
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment