Unverified Commit e0613702 authored by Chuan (Richard) Li's avatar Chuan (Richard) Li Committed by GitHub
Browse files

[ROCm] Fix AITER ops fake impl and minor bugs (#36092)


Signed-off-by: default avatarLi <chuali@amd.com>
parent 9853a3c1
...@@ -336,9 +336,13 @@ def _rocm_aiter_fused_topk_fake( ...@@ -336,9 +336,13 @@ def _rocm_aiter_fused_topk_fake(
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
gate_up: bool, gate_up: bool,
) -> None: ) -> tuple[torch.Tensor, torch.Tensor]:
# tuple[torch.Tensor, torch.Tensor]: num_tokens = x.shape[0]
pass topk_weights = torch.empty(
(num_tokens, top_k), dtype=torch.float32, device=x.device
)
topk_indices = torch.empty((num_tokens, top_k), dtype=torch.int32, device=x.device)
return topk_weights, topk_indices
# Cache whether aiter supports FP8 MLA parameters # Cache whether aiter supports FP8 MLA parameters
...@@ -1918,7 +1922,7 @@ class rocm_aiter_ops: ...@@ -1918,7 +1922,7 @@ class rocm_aiter_ops:
@staticmethod @staticmethod
def shuffle_weight( def shuffle_weight(
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16) tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> torch.Tensor: ) -> torch.Tensor:
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment