Commit 10400c58 authored by 王敏's avatar 王敏
Browse files

[feat]优化deepep高吞吐模式

parent 0acf61d6
......@@ -12,7 +12,7 @@ import torch
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.utils import cdiv
from vllm.utils import cdiv, async_tensor_h2d
#
# This file defines a set of base classes used to make MoE kernels more modular.
......@@ -97,6 +97,8 @@ class FusedMoEActivationFormat(Enum):
BatchedExperts = "batched_experts",
@dataclass
class ExpertTokensMetadata:
"""
......@@ -110,11 +112,16 @@ class ExpertTokensMetadata:
def make_from_list(
expert_num_tokens_list: list[int], device: str
) -> "ExpertTokensMetadata":
# expert_num_tokens_cpu = torch.tensor(
# expert_num_tokens_list, device="cpu", dtype=torch.int32
# )
expert_num_tokens_cpu = torch.tensor(
expert_num_tokens_list, device="cpu", dtype=torch.int32
expert_num_tokens_list, device="cpu", dtype=torch.int32, pin_memory=True
)
expert_num_tokens = expert_num_tokens_cpu.to(device=device, non_blocking=True)
return ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens_cpu.to(device, non_blocking=True),
expert_num_tokens=expert_num_tokens,
expert_num_tokens_cpu=expert_num_tokens_cpu,
)
......
......@@ -784,6 +784,7 @@ def deepgemm_moe_permute(
expert_num_tokens: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
aq_out: torch.Tensor | None = None,
M_sum: int | None = None,
):
assert aq.ndim == 2
assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids"
......@@ -792,13 +793,14 @@ def deepgemm_moe_permute(
block_m = block_shape[0]
M_sum = compute_aligned_M(
M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=block_m,
expert_num_tokens_cpu=expert_num_tokens_cpu,
)
if M_sum is None:
M_sum = compute_aligned_M(
M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=block_m,
expert_num_tokens_cpu=expert_num_tokens_cpu,
)
expert_start_loc = torch.empty(
(local_num_experts), device=device, dtype=torch.int32
......
......@@ -346,6 +346,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
expert_num_tokens=expert_num_tokens,
expert_num_tokens_cpu=expert_num_tokens_cpu,
aq_out=a1q_perm,
M_sum=M_sum
)
# if expert_map is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment