Unverified Commit a1cc9f33 authored by Chengji Yao's avatar Chengji Yao Committed by GitHub
Browse files

[TPU] remove transpose ops in moe kernel (#18923)


Signed-off-by: default avatarChengji Yao <chengjiyao@google.com>
parent a521ef06
...@@ -18,9 +18,9 @@ setuptools==78.1.0 ...@@ -18,9 +18,9 @@ setuptools==78.1.0
--find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.8.0.dev20250518 torch==2.8.0.dev20250529
torchvision==0.22.0.dev20250518 torchvision==0.22.0.dev20250529
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
...@@ -26,7 +26,7 @@ TOP_KS = [2, 6] ...@@ -26,7 +26,7 @@ TOP_KS = [2, 6]
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16 # The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
@pytest.mark.parametrize("m", [8, 16, 64, 2048]) @pytest.mark.parametrize("m", [8, 16, 64, 2048])
@pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("ep_size", EP_SIZE)
......
...@@ -67,15 +67,10 @@ def fused_moe( ...@@ -67,15 +67,10 @@ def fused_moe(
token_indices = token_indices[topk_argsort_indices] token_indices = token_indices[topk_argsort_indices]
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1) group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
# NOTE(woosuk): The GMM Pallas kernel requires a different weight layout
# from HF Transformers.
w1 = w1.transpose(1, 2)
w2 = w2.transpose(1, 2)
x = hidden_states[token_indices] x = hidden_states[token_indices]
x = torch.ops.xla.gmm(x, w1, group_sizes) x = torch.ops.xla.gmm(x, w1, group_sizes, transpose_rhs=True)
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:] x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
x = torch.ops.xla.gmm(x, w2, group_sizes) x = torch.ops.xla.gmm(x, w2, group_sizes, transpose_rhs=True)
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
x = x * topk_weights.unsqueeze(dim=-1) x = x * topk_weights.unsqueeze(dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment