Unverified Commit 45232a45 authored by TJian's avatar TJian Committed by GitHub
Browse files

[FEAT] [Perf] [Gemma4] Fused Gemma4 Routing Function Triton (#39083)


Signed-off-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
parent 03ce1c6e
...@@ -170,6 +170,7 @@ eles = "eles" ...@@ -170,6 +170,7 @@ eles = "eles"
datas = "datas" datas = "datas"
ser = "ser" ser = "ser"
ure = "ure" ure = "ure"
VALU = "VALU"
# Walsh-Hadamard Transform # Walsh-Hadamard Transform
wht = "wht" wht = "wht"
WHT = "WHT" WHT = "WHT"
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.models.gemma4 import (
gemma4_fused_routing_kernel_triton,
gemma4_routing_function_torch,
)
def sort_by_id(w, ids):
order = ids.argsort(dim=-1)
return w.gather(1, order), ids.gather(1, order)
# Gemma4 MoE Model has context length of 250K
# the minus 1 is to ensure that edge cases are tested
@pytest.mark.parametrize("num_tokens", [1, 2, 2048, 250000])
@pytest.mark.parametrize("num_experts", [128]) # gemma4 moe experts
@pytest.mark.parametrize("topk", [8]) # gemma4 topk
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
def test_gemma4_routing_kernel_triton(
num_tokens: int,
num_experts: int,
topk: int,
dtype: torch.dtype,
):
torch.manual_seed(0)
gating = torch.randn(num_tokens, num_experts, dtype=dtype, device="cuda")
scales = torch.rand(num_experts, dtype=torch.float32, device="cuda")
ref_w, ref_ids = gemma4_routing_function_torch(gating, topk, scales)
tri_w, tri_ids = gemma4_fused_routing_kernel_triton(gating, topk, scales)
# Sort by expert id — to remove tie-breaking differences
ref_ws, ref_is = sort_by_id(ref_w, ref_ids)
tri_ws, tri_is = sort_by_id(tri_w, tri_ids)
ids_match = (ref_is == tri_is).all().item()
weights_match = torch.allclose(ref_ws, tri_ws, atol=1e-2, rtol=1e-2)
all_match = ids_match and weights_match
max_err = (ref_ws - tri_ws).abs().max().item()
print(
f"T={num_tokens:5d} E={num_experts:4d} K={topk} "
f"{str(dtype).split('.')[-1]:7s} ids={ids_match} max_Δweight={max_err:.2e}"
)
if not all_match:
bad = (ref_is != tri_is).any(dim=-1).nonzero(as_tuple=True)[0]
if len(bad):
r = bad[0].item()
print(
f" first bad row {r}: ref_ids={ref_ids[r].tolist()} "
f"tri_ids={tri_ids[r].tolist()}"
)
assert all_match
...@@ -57,7 +57,9 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -57,7 +57,9 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
from .interfaces import ( from .interfaces import (
...@@ -79,6 +81,120 @@ from .utils import ( ...@@ -79,6 +81,120 @@ from .utils import (
logger = init_logger(__name__) logger = init_logger(__name__)
@triton.jit
def _gemma4_routing_kernel(
gating_ptr,
per_expert_scale_ptr,
topk_weights_ptr,
topk_ids_ptr,
E: tl.constexpr,
K: tl.constexpr,
BLOCK_E: tl.constexpr,
):
pid = tl.program_id(0)
offs_e = tl.arange(0, BLOCK_E)
valid = offs_e < E
logits = tl.load(
gating_ptr + pid * E + offs_e,
mask=valid,
other=-float("inf"),
).to(tl.float32)
max_l = tl.max(logits, axis=0)
# Float32 → ascending-sortable bijection
MIN32 = -2147483648
logit_bits = logits.to(tl.int32, bitcast=True)
sign_b = logit_bits >> 31
key = tl.where(sign_b == 0, logit_bits ^ -1, logit_bits ^ MIN32)
key = tl.where(valid, key, 0x7FFFFFFF)
sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF
packed = (sk64 << 32) | offs_e.to(tl.int64)
sorted_p = tl.sort(packed, descending=False)
# Vectorized extraction of ALL sorted elements — no K-loop, no cross-lane reductions
all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32)
all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32)
# Inverse bijection: recover original logit bits
sign_k = all_keys >> 31
all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32)
all_logits = all_bits.to(tl.float32, bitcast=True)
# Compute raw_exp for ALL BLOCK_E elements — vectorized, ~2 VALU clocks
all_raw_exp = tl.math.exp2((all_logits - max_l) * 1.4426950408889634)
# Sum only top-K for renorm — ONE masked reduction
top_mask = offs_e < K
renorm_raw = tl.sum(tl.where(top_mask, all_raw_exp, 0.0), axis=0)
renorm_raw = tl.where(renorm_raw > 0.0, renorm_raw, 1.0)
inv_renorm = 1.0 / renorm_raw
# Load scales for top-K only (masked gather; scale array is tiny → L1 cached)
all_scales = tl.load(
per_expert_scale_ptr + all_ids.to(tl.int64),
mask=top_mask,
other=1.0,
).to(tl.float32)
# Final weights: vectorized multiply (only top-K will be stored)
all_weights = (all_raw_exp * inv_renorm * all_scales).to(tl.float32)
# Write results with TWO masked stores — replaces K × 2 serial scalar stores
base_off = pid * K + offs_e
tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask)
tl.store(topk_weights_ptr + base_off, all_weights, mask=top_mask)
def gemma4_fused_routing_kernel_triton(
gating_output: torch.Tensor,
topk: int,
per_expert_scale: torch.Tensor,
num_warps: int = 1,
) -> tuple[torch.Tensor, torch.Tensor]:
gating_output = gating_output.contiguous()
per_expert_scale = per_expert_scale.contiguous()
T, E = gating_output.shape
weights = torch.empty(T, topk, dtype=torch.float32, device=gating_output.device)
ids = torch.empty(T, topk, dtype=torch.int32, device=gating_output.device)
BLOCK_E = triton.next_power_of_2(E)
_gemma4_routing_kernel[(T,)](
gating_output,
per_expert_scale,
weights,
ids,
E,
topk,
BLOCK_E,
num_warps=num_warps,
)
return weights, ids
def gemma4_routing_function_torch(
gating_output: torch.Tensor,
topk: int,
per_expert_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
_, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1)
indicator = torch.nn.functional.one_hot(
topk_ids, num_classes=gating_output.size(-1)
).sum(dim=-2)
gate_weights = indicator * router_probabilities
renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True)
renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0)
dispatch_weights = gate_weights / renorm_factor
topk_weights = dispatch_weights.gather(1, topk_ids)
# Fold per_expert_scale into routing weights
expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype)
topk_weights = topk_weights * expert_scales
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def _get_text_config(config): def _get_text_config(config):
"""Dereference text_config if config is a nested Gemma4Config. """Dereference text_config if config is a nested Gemma4Config.
...@@ -216,22 +332,12 @@ class Gemma4MoE(nn.Module): ...@@ -216,22 +332,12 @@ class Gemma4MoE(nn.Module):
topk: int, topk: int,
renormalize: bool, renormalize: bool,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
_, topk_ids = torch.topk(gating_output, k=topk, dim=-1) if current_platform.is_cuda_alike() or current_platform.is_xpu():
router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1) return gemma4_fused_routing_kernel_triton(
indicator = torch.nn.functional.one_hot( gating_output, topk, per_expert_scale
topk_ids, num_classes=gating_output.size(-1) )
).sum(dim=-2)
gate_weights = indicator * router_probabilities
renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True)
renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0)
dispatch_weights = gate_weights / renorm_factor
topk_weights = dispatch_weights.gather(1, topk_ids)
# Fold per_expert_scale into routing weights return gemma4_routing_function_torch(gating_output, topk, per_expert_scale)
expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype)
topk_weights = topk_weights * expert_scales
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
# FusedMoE experts with custom Gemma4 routing # FusedMoE experts with custom Gemma4 routing
self.experts = FusedMoE( self.experts = FusedMoE(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment