Unverified Commit e0613702 authored by Chuan (Richard) Li's avatar Chuan (Richard) Li Committed by GitHub
Browse files

[ROCm] Fix AITER ops fake impl and minor bugs (#36092)


Signed-off-by: default avatarLi <chuali@amd.com>
parent 9853a3c1
......@@ -336,9 +336,13 @@ def _rocm_aiter_fused_topk_fake(
router_logits: torch.Tensor,
top_k: int,
gate_up: bool,
) -> None:
# tuple[torch.Tensor, torch.Tensor]:
pass
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = x.shape[0]
topk_weights = torch.empty(
(num_tokens, top_k), dtype=torch.float32, device=x.device
)
topk_indices = torch.empty((num_tokens, top_k), dtype=torch.int32, device=x.device)
return topk_weights, topk_indices
# Cache whether aiter supports FP8 MLA parameters
......@@ -1918,7 +1922,7 @@ class rocm_aiter_ops:
@staticmethod
def shuffle_weight(
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> torch.Tensor:
from aiter.ops.shuffle import shuffle_weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment