Unverified Commit 4bd18ec0 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Minor] Fix type annotation in fused moe (#3045)

parent 2410e320
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import functools import functools
import json import json
import os import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Tuple
import torch import torch
import triton import triton
...@@ -137,7 +137,7 @@ def fused_moe_kernel( ...@@ -137,7 +137,7 @@ def fused_moe_kernel(
def moe_align_block_size( def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int, topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block size for matrix multiplication. Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
...@@ -185,7 +185,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -185,7 +185,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, config: dict): mul_routed_weight: bool, top_k: int,
config: Dict[str, Any]) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment