Commit 10400c58 authored by 王敏's avatar 王敏
Browse files

[feat]优化deepep高吞吐模式

parent 0acf61d6
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.utils import cdiv from vllm.utils import cdiv, async_tensor_h2d
# #
# This file defines a set of base classes used to make MoE kernels more modular. # This file defines a set of base classes used to make MoE kernels more modular.
...@@ -97,6 +97,8 @@ class FusedMoEActivationFormat(Enum): ...@@ -97,6 +97,8 @@ class FusedMoEActivationFormat(Enum):
BatchedExperts = "batched_experts", BatchedExperts = "batched_experts",
@dataclass @dataclass
class ExpertTokensMetadata: class ExpertTokensMetadata:
""" """
...@@ -110,11 +112,16 @@ class ExpertTokensMetadata: ...@@ -110,11 +112,16 @@ class ExpertTokensMetadata:
def make_from_list( def make_from_list(
expert_num_tokens_list: list[int], device: str expert_num_tokens_list: list[int], device: str
) -> "ExpertTokensMetadata": ) -> "ExpertTokensMetadata":
# expert_num_tokens_cpu = torch.tensor(
# expert_num_tokens_list, device="cpu", dtype=torch.int32
# )
expert_num_tokens_cpu = torch.tensor( expert_num_tokens_cpu = torch.tensor(
expert_num_tokens_list, device="cpu", dtype=torch.int32 expert_num_tokens_list, device="cpu", dtype=torch.int32, pin_memory=True
) )
expert_num_tokens = expert_num_tokens_cpu.to(device=device, non_blocking=True)
return ExpertTokensMetadata( return ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens_cpu.to(device, non_blocking=True), expert_num_tokens=expert_num_tokens,
expert_num_tokens_cpu=expert_num_tokens_cpu, expert_num_tokens_cpu=expert_num_tokens_cpu,
) )
......
...@@ -784,6 +784,7 @@ def deepgemm_moe_permute( ...@@ -784,6 +784,7 @@ def deepgemm_moe_permute(
expert_num_tokens: Optional[torch.Tensor] = None, expert_num_tokens: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None, expert_num_tokens_cpu: Optional[torch.Tensor] = None,
aq_out: torch.Tensor | None = None, aq_out: torch.Tensor | None = None,
M_sum: int | None = None,
): ):
assert aq.ndim == 2 assert aq.ndim == 2
assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids" assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids"
...@@ -792,13 +793,14 @@ def deepgemm_moe_permute( ...@@ -792,13 +793,14 @@ def deepgemm_moe_permute(
block_m = block_shape[0] block_m = block_shape[0]
M_sum = compute_aligned_M( if M_sum is None:
M=topk_ids.size(0), M_sum = compute_aligned_M(
num_topk=topk_ids.size(1), M=topk_ids.size(0),
local_num_experts=local_num_experts, num_topk=topk_ids.size(1),
alignment=block_m, local_num_experts=local_num_experts,
expert_num_tokens_cpu=expert_num_tokens_cpu, alignment=block_m,
) expert_num_tokens_cpu=expert_num_tokens_cpu,
)
expert_start_loc = torch.empty( expert_start_loc = torch.empty(
(local_num_experts), device=device, dtype=torch.int32 (local_num_experts), device=device, dtype=torch.int32
......
...@@ -346,6 +346,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -346,6 +346,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
expert_num_tokens=expert_num_tokens, expert_num_tokens=expert_num_tokens,
expert_num_tokens_cpu=expert_num_tokens_cpu, expert_num_tokens_cpu=expert_num_tokens_cpu,
aq_out=a1q_perm, aq_out=a1q_perm,
M_sum=M_sum
) )
# if expert_map is not None: # if expert_map is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment