Commit bb596f6e authored by xiaowei.zhang's avatar xiaowei.zhang
Browse files

1. Update MOE; 2. Update sglang mHC; 3. Update test scripts; 4 Add new

   ops.
parent d9ebb683
{
"1": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"2": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"3": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"4": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"5": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"6": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"7": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"8": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"9": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"10": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"11": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"12": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"13": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"14": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"15": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"16": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"32": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"64": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"128": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"256": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"512": {
"BLOCK_SIZE_M": 16,
"MODE": 98
},
"1024": {
"BLOCK_SIZE_M": 32,
"MODE": 183
},
"2048": {
"BLOCK_SIZE_M": 32,
"MODE": 146
},
"4096": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"8192": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"16384": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"32768": {
"BLOCK_SIZE_M": 32,
"MODE": 160
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"2": {
"BLOCK_SIZE_M": 16,
"MODE": 42
},
"3": {
"BLOCK_SIZE_M": 16,
"MODE": 42
},
"4": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"5": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"6": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"7": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"8": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"9": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"10": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"11": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"12": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"13": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"14": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"15": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"16": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"32": {
"BLOCK_SIZE_M": 16,
"MODE": 43
},
"64": {
"BLOCK_SIZE_M": 16,
"MODE": 43
},
"128": {
"BLOCK_SIZE_M": 16,
"MODE": 46
},
"256": {
"BLOCK_SIZE_M": 16,
"MODE": 46
},
"512": {
"BLOCK_SIZE_M": 16,
"MODE": 43
},
"1024": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"2048": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"4096": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"8192": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"16384": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"32768": {
"BLOCK_SIZE_M": 32,
"MODE": 86
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"2": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"3": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"4": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"5": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"6": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"7": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"8": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"9": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"10": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"11": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"12": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"13": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"14": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"15": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"16": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"32": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"64": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"128": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"256": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"512": {
"BLOCK_SIZE_M": 16,
"MODE": 98
},
"1024": {
"BLOCK_SIZE_M": 32,
"MODE": 183
},
"2048": {
"BLOCK_SIZE_M": 32,
"MODE": 146
},
"4096": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"8192": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"16384": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"32768": {
"BLOCK_SIZE_M": 32,
"MODE": 160
}
}
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
# SPDX-License-Identifier: MIT
from typing import List, Optional, Tuple
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
from typing import List
import torch
......@@ -9,38 +10,61 @@ from ..jit.core import compile_ops
MD_NAME = "module_custom_all_reduce"
@compile_ops("module_custom_all_reduce")
@compile_ops("module_custom_all_reduce", develop=True)
def init_custom_ar(
meta: torch.Tensor,
rank_data: torch.Tensor,
handles: List[torch.Tensor],
meta_ptr: int,
rank_data_ptr: int,
rank_data_sz: int,
ipc_handle_ptrs: List[int],
offsets: List[int],
rank: int,
fully_connected: bool,
) -> int: ...
@compile_ops("module_custom_all_reduce")
@compile_ops("module_custom_all_reduce", develop=True)
def all_reduce(
_fa: int,
inp: torch.Tensor,
out: torch.Tensor,
use_new: bool,
open_fp8_quant: bool,
reg_buffer: Optional[torch.Tensor] = None,
reg_inp_ptr: int,
reg_inp_bytes: int,
) -> None: ...
@compile_ops("module_custom_all_reduce", develop=True)
def reduce_scatter(
_fa: int,
inp: torch.Tensor,
out: torch.Tensor,
reg_ptr: int,
reg_bytes: int,
) -> None: ...
@compile_ops("module_custom_all_reduce")
def all_gather_reg(_fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: ...
@compile_ops("module_custom_all_reduce", develop=True)
def all_gather_reg(
_fa: int,
inp: torch.Tensor,
out: torch.Tensor,
dim: int,
) -> None: ...
@compile_ops("module_custom_all_reduce")
@compile_ops("module_custom_all_reduce", develop=True)
def all_gather_unreg(
_fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
_fa: int,
inp: torch.Tensor,
reg_buffer: int,
out: torch.Tensor,
reg_bytes: int,
dim: int,
) -> None: ...
@compile_ops("module_custom_all_reduce")
@compile_ops("module_custom_all_reduce", develop=True)
def fused_allreduce_rmsnorm(
_fa: int,
inp: torch.Tensor,
......@@ -49,162 +73,102 @@ def fused_allreduce_rmsnorm(
out: torch.Tensor,
w: torch.Tensor,
eps: float,
reg_buffer: Optional[torch.Tensor] = None,
reg_ptr: int,
reg_bytes: int,
use_1stage: bool,
) -> None: ...
def all_reduce_asm_fake_tensor(
@compile_ops("module_custom_all_reduce", develop=True)
def fused_allreduce_rmsnorm_quant(
_fa: int,
inp: torch.Tensor,
ca: int,
reg_sig: torch.Tensor,
reg_buffer: torch.Tensor,
isGraph: bool,
) -> torch.Tensor:
return torch.empty_like(
inp,
dtype=inp.dtype,
device=inp.device,
)
res_inp: torch.Tensor,
res_out: torch.Tensor,
out: torch.Tensor,
scale_out: torch.Tensor,
w: torch.Tensor,
eps: float,
reg_ptr: int,
reg_bytes: int,
use_1stage: bool,
) -> None: ...
@compile_ops("module_custom_all_reduce", gen_fake=all_reduce_asm_fake_tensor)
def all_reduce_asm_(
@compile_ops("module_custom_all_reduce", develop=True)
def fused_allreduce_rmsnorm_quant_per_group(
_fa: int,
inp: torch.Tensor,
ca: int,
reg_sig: torch.Tensor,
reg_buffer: torch.Tensor,
isGraph: bool,
) -> torch.Tensor: ...
def all_reduce_rmsnorm_fake_tensors(
input: torch.Tensor,
residual_in: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
epsilon: float,
ca: int,
reg_sig: torch.Tensor,
reg_buffer: torch.Tensor,
isGraph: bool,
) -> List[torch.Tensor]:
output = torch.empty_like(
input, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad
)
residual_out = torch.empty_like(
input, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad
)
return [output, residual_out]
@compile_ops("module_custom_all_reduce", gen_fake=all_reduce_rmsnorm_fake_tensors)
def all_reduce_rmsnorm_(
input: torch.Tensor,
residual_in: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
epsilon: float,
ca: int,
reg_sig: torch.Tensor,
reg_buffer: torch.Tensor,
isGraph: bool,
) -> List[torch.Tensor]: ...
# def all_reduce_rmsnorm_quant_fake_tensors(
# input: torch.Tensor,
# residual_in: torch.Tensor,
# weight: torch.Tensor,
# xscale: torch.Tensor,
# bias: torch.Tensor,
# epsilon: float,
# ca: int,
# reg_sig: torch.Tensor,
# reg_buffer: torch.Tensor,
# isGraph: bool,
# ) -> List[torch.Tensor]:
# N = input.size(-1)
# M = input.numel() // N
# output = torch.empty_like(
# input, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad
# )
# residual_out = torch.empty_like(
# input, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad
# )
# y_scale = torch.empty((M, 1), dtype=torch.float32, device=input.device)
# return [output, residual_out, y_scale]
# @compile_ops("module_custom_all_reduce", gen_fake=all_reduce_rmsnorm_quant_fake_tensors)
# def all_reduce_rmsnorm_quant_(
# input: torch.Tensor,
# residual_in: torch.Tensor,
# weight: torch.Tensor,
# xscale: torch.Tensor,
# bias: torch.Tensor,
# epsilon: float,
# ca: int,
# reg_sig: torch.Tensor,
# reg_buffer: torch.Tensor,
# isGraph: bool,
# ) -> List[torch.Tensor]: ...
@compile_ops("module_custom_all_reduce")
res_inp: torch.Tensor,
res_out: torch.Tensor,
out: torch.Tensor,
scale_out: torch.Tensor,
w: torch.Tensor,
eps: float,
group_size: int,
reg_ptr: int,
reg_bytes: int,
use_1stage: bool,
bf16_out_ptr: int = 0,
) -> None: ...
@compile_ops("module_custom_all_reduce", develop=True)
def fused_qknorm_allreduce(
_fa: int,
qkv_in: torch.Tensor,
q_w: torch.Tensor,
k_w: torch.Tensor,
q_out: torch.Tensor,
k_out: torch.Tensor,
v_out: torch.Tensor,
eps: float,
reg_ptr: int,
reg_bytes: int,
) -> None: ...
@compile_ops("module_custom_all_reduce", develop=True)
def dispose(_fa: int) -> None: ...
@compile_ops("module_custom_all_reduce")
@compile_ops("module_custom_all_reduce", develop=True)
def meta_size() -> int: ...
@compile_ops("module_custom_all_reduce")
def register_buffer(
_fa: int, t: torch.Tensor, handles: List[torch.Tensor], offsets: List[int]
@compile_ops("module_custom_all_reduce", develop=True)
def register_input_buffer(
_fa: int, self_ptr: int, ipc_handle_ptrs: List[int], offsets: List[int]
) -> None: ...
# def gen_get_graph_buffer_ipc_meta_fake_tensors(_fa: int) -> List[torch.Tensor]:
# handle_sz = 64 # sizeof(hipIpcMemHandle_t) is 64 byte
# num_buffers = 4 # ???
# handles = torch.empty((handle_sz * num_buffers,), dtype=torch.uint8, device="cuda")
@compile_ops("module_custom_all_reduce", develop=True)
def register_output_buffer(
_fa: int, self_ptr: int, ipc_handle_ptrs: List[int], offsets: List[int]
) -> None: ...
# offset_tensor = torch.empty((num_buffers,), dtype=torch.int64, device="cuda")
# return [handles, offset_tensor]
@compile_ops("module_custom_all_reduce", develop=True)
def get_graph_buffer_count(_fa: int) -> int: ...
@compile_ops("module_custom_all_reduce")
def get_graph_buffer_ipc_meta(_fa: int) -> Tuple[torch.Tensor, torch.Tensor]: ...
@compile_ops("module_custom_all_reduce", develop=True)
def get_graph_buffer_ipc_meta(_fa: int, handle_out: int, offset_out: int) -> None: ...
@compile_ops("module_custom_all_reduce")
@compile_ops("module_custom_all_reduce", develop=True)
def register_graph_buffers(
_fa: int, handles: List[torch.Tensor], offsets: List[torch.Tensor]
_fa: int, handle_ptrs: List[int], offset_ptrs: List[int]
) -> None: ...
@compile_ops("module_custom_all_reduce")
def allocate_meta_buffer(size: int) -> torch.Tensor: ...
@compile_ops("module_custom_all_reduce", develop=True)
def allocate_meta_buffer(size: int) -> int: ...
# def get_meta_buffer_ipc_handle_fake(inp: torch.Tensor) -> torch.Tensor:
# handle_size = 64
# if not inp.is_cuda:
# raise RuntimeError("Input tensor must be on CUDA device")
# return torch.empty(handle_size, dtype=torch.uint8, device=inp.device)
@compile_ops("module_custom_all_reduce", develop=True)
def free_meta_buffer(ptr: int) -> None: ...
@compile_ops("module_custom_all_reduce")
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: ...
\ No newline at end of file
@compile_ops("module_custom_all_reduce", develop=True)
def get_meta_buffer_ipc_handle(inp_ptr: int, out_handle_ptr: int) -> None: ...
# SPDX-License-Identifier: MIT
from typing import List, Optional, Sequence, Tuple
import torch
from torch import Tensor
from ..jit.core import compile_ops
@compile_ops("module_grouped_gemm")
def ck_grouped_gemm(
a_tensors: List[Tensor],
b_tensors: List[Tensor],
) -> List[Tensor]: ...
@compile_ops("module_grouped_gemm")
def ck_grouped_gemm_out(
a_tensors: List[Tensor],
b_tensors: List[Tensor],
c_tensors: List[Tensor],
) -> List[Tensor]: ...
# CK tile alignment for the low-level kernel (see grouped_gemm_kernels.cu).
_MOE_M_ALIGN = {
torch.float16: 64,
torch.bfloat16: 64,
torch.float8_e4m3fn: 128,
torch.int8: 32,
}
_MOE_NK_ALIGN = {
torch.float16: dict(n=128, k=128),
torch.bfloat16: dict(n=128, k=128),
torch.float8_e4m3fn: dict(n=128, k=128),
torch.int8: dict(n=32, k=128),
}
def _moe_output_dtype(dtype: torch.dtype) -> torch.dtype:
if dtype is torch.int8:
return torch.int32
if dtype is torch.float8_e4m3fn:
return torch.float32
return dtype
def _align_up(x: int, align: int) -> int:
return ((x + align - 1) // align) * align
def _validate_moe_fixed_nk(b_tensors: Sequence[Tensor], dtype: torch.dtype) -> Tuple[int, int]:
if not b_tensors:
raise ValueError("ck_grouped_gemm_moe: b_tensors must not be empty")
n0, k0 = b_tensors[0].shape
nk = _MOE_NK_ALIGN[dtype]
if n0 % nk["n"] != 0 or k0 % nk["k"] != 0 or k0 < nk["k"]:
raise ValueError(
f"ck_grouped_gemm_moe: fixed N/K must satisfy N % {nk['n']} == 0, "
f"K % {nk['k']} == 0, K >= {nk['k']} for {dtype}, got N={n0}, K={k0}"
)
for i, b in enumerate(b_tensors):
if b.shape != (n0, k0):
raise ValueError(
f"ck_grouped_gemm_moe: all B tensors must share the same [N, K], "
f"group {i} has {tuple(b.shape)} vs expected ({n0}, {k0})"
)
return n0, k0
def _pad_a_rows(a: Tensor, m_align: int) -> Tuple[Tensor, int, int]:
m_orig = a.size(0)
m_pad = _align_up(m_orig, m_align)
if m_pad == m_orig:
return a, m_orig, m_pad
a_pad = a.new_zeros(m_pad, a.size(1))
a_pad[:m_orig].copy_(a)
return a_pad, m_orig, m_pad
def ck_grouped_gemm_moe(
a_tensors: List[Tensor],
b_tensors: List[Tensor],
) -> List[Tensor]:
"""
MOE-friendly grouped GEMM with per-group dynamic M and fixed N/K.
Each group computes C_i = A_i @ B_i^T. A_i may have arbitrary M_i >= 1;
rows are zero-padded to the CK M-tile boundary before launch, then outputs
are sliced back to the logical M_i.
"""
if len(a_tensors) != len(b_tensors):
raise ValueError("ck_grouped_gemm_moe: a and b tensor lists must have the same length")
dtype = a_tensors[0].dtype
m_align = _MOE_M_ALIGN[dtype]
_validate_moe_fixed_nk(b_tensors, dtype)
a_padded: List[Tensor] = []
m_orig_list: List[int] = []
for a, b in zip(a_tensors, b_tensors):
if a.dtype != dtype or b.dtype != dtype:
raise ValueError("ck_grouped_gemm_moe: all tensors must share the same dtype")
if a.size(1) != b.size(1):
raise ValueError("ck_grouped_gemm_moe: K mismatch between A and B")
if a.size(0) <= 0:
raise ValueError("ck_grouped_gemm_moe: M must be positive")
a_pad, m_orig, _ = _pad_a_rows(a, m_align)
a_padded.append(a_pad)
m_orig_list.append(m_orig)
c_padded = ck_grouped_gemm(a_padded, b_tensors)
n = b_tensors[0].size(0)
out_dtype = _moe_output_dtype(dtype)
return [
c[:m_orig, :n].to(out_dtype) if c.size(0) != m_orig else c
for c, m_orig in zip(c_padded, m_orig_list)
]
def ck_grouped_gemm_moe_out(
a_tensors: List[Tensor],
b_tensors: List[Tensor],
c_tensors: List[Tensor],
) -> List[Tensor]:
"""
MOE grouped GEMM writing into caller-provided logical C tensors [M_i, N].
Padded A/C buffers are allocated internally; only the valid M_i rows are
copied into c_tensors.
"""
if not (len(a_tensors) == len(b_tensors) == len(c_tensors)):
raise ValueError("ck_grouped_gemm_moe_out: a, b, c lists must have the same length")
dtype = a_tensors[0].dtype
m_align = _MOE_M_ALIGN[dtype]
n, _ = _validate_moe_fixed_nk(b_tensors, dtype)
out_dtype = _moe_output_dtype(dtype)
a_padded: List[Tensor] = []
c_padded: List[Tensor] = []
m_orig_list: List[int] = []
for a, b, c in zip(a_tensors, b_tensors, c_tensors):
if a.dtype != dtype or b.dtype != dtype:
raise ValueError("ck_grouped_gemm_moe_out: a/b dtype mismatch")
if c.dtype != out_dtype:
raise ValueError(f"ck_grouped_gemm_moe_out: c dtype must be {out_dtype}")
if a.size(1) != b.size(1):
raise ValueError("ck_grouped_gemm_moe_out: K mismatch between A and B")
m_orig = a.size(0)
if c.shape != (m_orig, n):
raise ValueError(
f"ck_grouped_gemm_moe_out: c shape {tuple(c.shape)} != ({m_orig}, {n})"
)
a_pad, m_orig, m_pad = _pad_a_rows(a, m_align)
a_padded.append(a_pad)
m_orig_list.append(m_orig)
if m_pad == m_orig:
c_padded.append(c)
else:
c_padded.append(c.new_empty(m_pad, n))
ck_grouped_gemm_out(a_padded, b_tensors, c_padded)
for c, c_pad, m_orig in zip(c_tensors, c_padded, m_orig_list):
if c_pad.data_ptr() != c.data_ptr():
c.copy_(c_pad[:m_orig])
return c_tensors
class GroupedGemmMoeBuffers:
"""
Reusable padded A/C buffers for MOE inference with fixed N/K per expert.
Avoids per-forward allocation when max tokens per expert is bounded.
"""
def __init__(
self,
num_groups: int,
n: int,
k: int,
dtype: torch.dtype,
max_m: int,
device: Optional[torch.device] = None,
):
if num_groups <= 0:
raise ValueError("GroupedGemmMoeBuffers: num_groups must be positive")
nk = _MOE_NK_ALIGN[dtype]
if n % nk["n"] != 0 or k % nk["k"] != 0 or k < nk["k"]:
raise ValueError(f"GroupedGemmMoeBuffers: invalid fixed N={n}, K={k} for {dtype}")
self.num_groups = num_groups
self.n = n
self.k = k
self.dtype = dtype
self.m_align = _MOE_M_ALIGN[dtype]
self.max_m_pad = _align_up(max_m, self.m_align)
self.out_dtype = _moe_output_dtype(dtype)
dev = device or torch.device("cuda")
self.a_bufs = [
torch.zeros(self.max_m_pad, k, device=dev, dtype=dtype)
for _ in range(num_groups)
]
self.c_bufs = [
torch.zeros(self.max_m_pad, n, device=dev, dtype=self.out_dtype)
for _ in range(num_groups)
]
def _ensure_capacity(self, m_orig: int) -> int:
m_pad = _align_up(m_orig, self.m_align)
if m_pad > self.max_m_pad:
raise ValueError(
f"GroupedGemmMoeBuffers: M={m_orig} exceeds configured max_m "
f"(padded max {self.max_m_pad})"
)
return m_pad
def run(
self,
a_tensors: Sequence[Tensor],
b_tensors: Sequence[Tensor],
c_tensors: Optional[Sequence[Tensor]] = None,
) -> List[Tensor]:
if len(a_tensors) != self.num_groups or len(b_tensors) != self.num_groups:
raise ValueError("GroupedGemmMoeBuffers: group count mismatch")
a_padded: List[Tensor] = []
c_padded: List[Tensor] = []
m_orig_list: List[int] = []
logical_c: List[Tensor] = []
for i, (a, b) in enumerate(zip(a_tensors, b_tensors)):
if b.shape != (self.n, self.k):
raise ValueError(f"GroupedGemmMoeBuffers: B[{i}] shape {tuple(b.shape)} != ({self.n}, {self.k})")
m_orig = a.size(0)
m_pad = self._ensure_capacity(m_orig)
m_orig_list.append(m_orig)
a_buf = self.a_bufs[i]
a_buf.zero_()
a_buf[:m_orig].copy_(a)
a_padded.append(a_buf[:m_pad])
if c_tensors is not None:
c = c_tensors[i]
if c.shape != (m_orig, self.n):
raise ValueError(f"GroupedGemmMoeBuffers: c[{i}] shape mismatch")
logical_c.append(c)
c_padded.append(self.c_bufs[i][:m_pad])
else:
c_padded.append(self.c_bufs[i][:m_pad])
if c_tensors is not None:
ck_grouped_gemm_out(a_padded, list(b_tensors), c_padded)
for c, c_pad, m_orig in zip(logical_c, c_padded, m_orig_list):
c.copy_(c_pad[:m_orig])
return list(logical_c)
c_full = ck_grouped_gemm_out(a_padded, list(b_tensors), c_padded)
return [c[:m_orig].clone() for c, m_orig in zip(c_full, m_orig_list)]
# SPDX-License-Identifier: MIT
import math
import os
import torch
from aiter import dtypes
from torch import Tensor
from ..jit.core import compile_ops
from ..jit.utils.chip_info import get_cu_num, get_gfx
from ..jit.utils.torch_guard import torch_compile_guard
def _truthy_env(name: str) -> bool:
v = os.environ.get(name, "").strip().lower()
return v in ("1", "true", "yes", "on")
def _round_to_tf32_like_tilekernels(x: torch.Tensor) -> torch.Tensor:
return (x.view(torch.int32) + 0x1000).view(torch.float32)
@compile_ops("module_mhc")
def mhc_pre_gemm_sqrsum(
out: Tensor,
sqrsum: Tensor,
x: Tensor,
fn: Tensor,
tile_k: int = 128, # 64 or 128
use_tf32: bool = False,
) -> None: ...
@compile_ops("module_mhc")
def mhc_pre_gemm_sqrsum_stage1_m128(
out: Tensor,
sqrsum: Tensor,
x: Tensor,
fn: Tensor,
use_tf32: bool = False,
) -> None: ...
@compile_ops("module_mhc")
def mhc_pre_reduce_splitk(
out_red: Tensor,
sqrsum_red: Tensor,
out: Tensor,
sqrsum: Tensor,
) -> None: ...
@compile_ops("module_mhc")
def mhc_pre_big_fuse(
post_mix: Tensor,
comb_mix: Tensor,
layer_input: Tensor,
gemm_out_mul: Tensor,
gemm_out_sqrsum: Tensor,
hc_scale: Tensor,
hc_base: Tensor,
residual: Tensor,
rms_eps: float = 1e-6,
hc_pre_eps: float = 1e-6,
hc_sinkhorn_eps: float = 1e-6,
hc_post_mult_value: float = 1.0,
sinkhorn_repeat: int = 20,
) -> None: ...
@compile_ops("module_mhc")
def mhc_pre_big_fuse_tlstyle(
post_mix: Tensor,
comb_mix: Tensor,
layer_input: Tensor,
gemm_out_mul: Tensor,
gemm_out_sqrsum: Tensor,
hc_scale: Tensor,
hc_base: Tensor,
residual: Tensor,
rms_eps: float = 1e-6,
hc_pre_eps: float = 1e-6,
hc_sinkhorn_eps: float = 1e-6,
hc_post_mult_value: float = 1.0,
sinkhorn_repeat: int = 20,
) -> None: ...
def mhc_pre_fake(
residual: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float = 1e-6,
hc_pre_eps: float = 1e-6,
hc_sinkhorn_eps: float = 1e-6,
hc_post_mult_value: float = 1.0,
sinkhorn_repeat: int = 20, # if 0, only do pre for hc_head
use_tf32: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
m = residual.size(0)
hc_mult = residual.size(1)
hidden_size = residual.size(2)
device = residual.device
post_mix = torch.empty(m, hc_mult, 1, dtype=dtypes.fp32, device=device)
comb_mix = torch.empty(m, hc_mult, hc_mult, dtype=dtypes.fp32, device=device)
layer_input = torch.empty(m, hidden_size, dtype=dtypes.bf16, device=device)
return post_mix, comb_mix, layer_input
@torch_compile_guard(gen_fake=mhc_pre_fake)
def mhc_pre(
residual: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float = 1e-6,
hc_pre_eps: float = 1e-6,
hc_sinkhorn_eps: float = 1e-6,
hc_post_mult_value: float = 1.0,
sinkhorn_repeat: int = 20, # if 0, only do pre for hc_head
use_tf32: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
m = residual.size(0)
hc_mult = residual.size(1)
hidden_size = residual.size(2)
hc_mult3 = fn.size(0)
assert hc_mult3 == hc_mult * 2 + hc_mult * hc_mult or (
hc_mult3 == hc_mult and sinkhorn_repeat == 0
)
hc_hidden_size = hc_mult * hidden_size
gfx = get_gfx()
stage1_variant = os.environ.get("AITER_MHC_PRE_STAGE1", "auto").strip().lower()
use_stage1_m128_auto = (
sinkhorn_repeat > 0
and hc_mult3 == hc_mult * (2 + hc_mult)
and gfx != "gfx936"
and not (hidden_size in (1280, 2560) and m <= 512)
)
if stage1_variant in ("", "auto"):
use_stage1_m128 = use_stage1_m128_auto
elif stage1_variant in ("aiter", "legacy"):
use_stage1_m128 = False
elif stage1_variant in ("m128", "tlstyle"):
use_stage1_m128 = True
else:
raise ValueError("AITER_MHC_PRE_STAGE1 must be 'auto' or 'm128' ('tlstyle' is accepted as an alias)")
env_kernel = os.environ.get("AITER_MHC_PRE_KERNEL", "auto").strip().lower()
use_tlstyle_auto = (
sinkhorn_repeat > 0
and hc_mult3 == hc_mult * (2 + hc_mult)
and m > 128
and not (hidden_size in (1280, 2560) and m <= 512)
)
if env_kernel in ("aiter", "legacy"):
use_tlstyle = False
elif env_kernel == "tlstyle":
use_tlstyle = True
elif env_kernel in ("", "auto"):
use_tlstyle = use_tlstyle_auto
else:
use_tlstyle = use_tlstyle_auto
prefetch_stages = 2
tile_m = 128 if use_stage1_m128 else 16 * 4
# tile_k → 估算 tg_per_cu (target groups per CU, 受 LDS/VGPR 占用约束):
# tile_k=64: tile_n*64*4*2 = 16KB/block → 4 blocks/CU
# tile_k=128: tile_n*128*4*2 = 32KB/block → 2 blocks/CU
tile_k_tg_dict = {128: 2} if use_stage1_m128 else {128: 2, 64: 4}
num_cu = get_cu_num()
selected_splitk = 1
selected_tile_k = 128 if use_stage1_m128 else 64
num_tg_m = (m + tile_m - 1) // tile_m
# Data-driven split-k window:
# - For small/medium M (num_tg_m < num_cu), keep broad search [1, 32].
# - Once M-side TGs already cover all CUs (num_tg_m >= num_cu), prefer split-k=2.
# This avoids the large regression observed with split-k=1 on large batches.
if num_tg_m >= num_cu:
min_splitk = 2
max_splitk = 2
else:
min_splitk = 1
max_splitk = 32
selected_score = num_tg_m / (num_cu * tile_k_tg_dict[selected_tile_k])
selected_score = selected_score / math.ceil(selected_score)
for tile_k, tg_per_cu in tile_k_tg_dict.items():
if (hc_hidden_size % tile_k) != 0:
continue
meanwhile_tg = num_cu * tg_per_cu
for splitk in range(min_splitk, max_splitk + 1):
if hc_hidden_size % (splitk * tile_k) != 0 or (hc_hidden_size // splitk) < (
tile_k * prefetch_stages
):
continue
num_tg = num_tg_m * splitk
score = num_tg / meanwhile_tg
score = score / math.ceil(score)
if selected_score < score:
selected_splitk = splitk
selected_tile_k = tile_k
selected_score = score
# print(f"{selected_score=} {selected_splitk=} {selected_tile_k=} {score=} {splitk=} {tile_k=}")
if num_tg > meanwhile_tg * 4:
break
# TileLang-style M128 stage1 still needs split-k parallelism when M-side
# CTAs under-fill DCU. Once M-side CTAs already cover CUs, keep split_k low
# to avoid excessive partial writes and stage2 reduction work.
if use_stage1_m128 and hc_hidden_size in (4 * 4096, 4 * 7168):
if num_tg_m >= num_cu:
candidate_splitk = 2
elif m >= 2048:
candidate_splitk = 8
else:
candidate_splitk = 32
if (
hc_hidden_size % (candidate_splitk * selected_tile_k) == 0
and (hc_hidden_size // candidate_splitk) >= selected_tile_k * prefetch_stages
):
selected_splitk = candidate_splitk
# Work-bound regime override:
# When num_tg_m >= num_cu the splitk window is already forced to {2}, and both
# (tile_k=64, splitk=2) and (tile_k=128, splitk=2) can land on score==1.0. The
# strict `<` update in the loop above lets whichever is iterated first win.
# Empirically on DCU gfx936/938 tile_k=64 is meaningfully faster in this regime
# because it halves per-block LDS occupancy (tile_n*64*4*2 vs tile_n*128*4*2),
# unlocking ~2x concurrent blocks per CU. Measured stage1 wins (auto vs forced
# tile_k=64) up to ~40% at m=8192,hidden=7168 and consistent ~10% at m=8192
# across hidden_size; large-m/large-hidden cases where auto already picks
# tile_k=64 are unchanged.
if not use_stage1_m128 and num_tg_m >= num_cu and selected_tile_k == 128:
candidate_tile_k = 64
candidate_splitk = 2
if (
hc_hidden_size % (candidate_splitk * candidate_tile_k) == 0
and (hc_hidden_size // candidate_splitk)
>= candidate_tile_k * prefetch_stages
):
selected_tile_k = candidate_tile_k
selected_splitk = candidate_splitk
# Small/medium DeepSeek MHC stage1 override:
# sweep data shows tile_k=64, splitk=32 wins for m<=1024 on hidden=4096/7168.
# For m=2048 it only wins on hidden=7168; hidden=4096 regresses from extra split-k work.
candidate_tile_k = 64
candidate_splitk = 32
if (
not use_stage1_m128
and hc_hidden_size in (4 * 4096, 4 * 7168)
and (m <= 1024 or (m == 2048 and hc_hidden_size == 4 * 7168))
and hc_hidden_size % (candidate_splitk * candidate_tile_k) == 0
and (hc_hidden_size // candidate_splitk) >= candidate_tile_k * prefetch_stages
):
selected_tile_k = candidate_tile_k
selected_splitk = candidate_splitk
# Optional manual overrides for stage1 launch search:
# AITER_MHC_PRE_TILE_K=64|128
# AITER_MHC_PRE_SPLITK=<positive int>
env_tile_k = os.environ.get("AITER_MHC_PRE_TILE_K", "").strip()
if env_tile_k:
forced_tile_k = int(env_tile_k)
if forced_tile_k not in tile_k_tg_dict:
msg = "AITER_MHC_PRE_TILE_K must be 128 when AITER_MHC_PRE_STAGE1=m128"
if not use_stage1_m128:
msg = "AITER_MHC_PRE_TILE_K must be 64 or 128"
raise ValueError(msg)
if (hc_hidden_size % forced_tile_k) != 0:
raise ValueError(
f"AITER_MHC_PRE_TILE_K={forced_tile_k} is incompatible with hc_hidden_size={hc_hidden_size}"
)
selected_tile_k = forced_tile_k
env_splitk = os.environ.get("AITER_MHC_PRE_SPLITK", "").strip()
if env_splitk:
forced_splitk = int(env_splitk)
if forced_splitk < 1:
raise ValueError("AITER_MHC_PRE_SPLITK must be >= 1")
if hc_hidden_size % (forced_splitk * selected_tile_k) != 0:
raise ValueError(
"AITER_MHC_PRE_SPLITK is incompatible with selected tile_k/hc_hidden_size"
)
if (hc_hidden_size // forced_splitk) < (selected_tile_k * prefetch_stages):
raise ValueError(
"AITER_MHC_PRE_SPLITK violates prefetch stage constraint for selected tile_k"
)
selected_splitk = forced_splitk
device = residual.device
out_pad = torch.empty(
selected_splitk, m, (hc_mult3 + 31) // 32 * 32, dtype=dtypes.fp32, device=device
)
out = out_pad[:, :, :hc_mult3]
sqrsum = torch.empty(selected_splitk, m, dtype=dtypes.fp32, device=device)
if use_stage1_m128:
mhc_pre_gemm_sqrsum_stage1_m128(out, sqrsum, residual, fn, use_tf32)
else:
stage1_fn = _round_to_tf32_like_tilekernels(fn) if use_tf32 else fn
mhc_pre_gemm_sqrsum(out, sqrsum, residual, stage1_fn, selected_tile_k, False)
# Optional path: reduce split-k outputs before big_fuse and run stage2 with n_splits=1.
# Keep stage2 input layout compatible with kernel assumptions (3D + padded stride),
# instead of passing compact 2D tensors from direct sum().
# Enable explicitly via AITER_MHC_PRE_REDUCE_SPLITK=1|true|yes|on.
# Current data shows the extra kernel cost outweighs the stage2 reduction win.
use_reduce_splitk = selected_splitk > 1 and _truthy_env("AITER_MHC_PRE_REDUCE_SPLITK")
if use_reduce_splitk:
out_red_pad = torch.empty(
1, m, (hc_mult3 + 31) // 32 * 32, dtype=dtypes.fp32, device=device
)
out_red = out_red_pad[:, :, :hc_mult3]
sqrsum_red = torch.empty(1, m, dtype=dtypes.fp32, device=device)
mhc_pre_reduce_splitk(out_red, sqrsum_red, out, sqrsum)
out = out_red
sqrsum = sqrsum_red
post_mix = torch.empty(m, hc_mult, 1, dtype=dtypes.fp32, device=device)
comb_mix = torch.empty(m, hc_mult, hc_mult, dtype=dtypes.fp32, device=device)
layer_input = torch.empty(m, hidden_size, dtype=dtypes.bf16, device=device)
big_fuse = mhc_pre_big_fuse_tlstyle if use_tlstyle else mhc_pre_big_fuse
big_fuse(
post_mix,
comb_mix,
layer_input,
out,
sqrsum,
hc_scale,
hc_base,
residual,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_mult_value,
sinkhorn_repeat,
)
return post_mix, comb_mix, layer_input
@compile_ops("module_mhc")
def mhc_post(
out: Tensor,
x: Tensor,
residual: Tensor,
post_layer_mix: Tensor,
comb_res_mix: Tensor,
) -> None: ...
......@@ -8,7 +8,84 @@ from ..jit.core import (
compile_ops,
)
from .enum import ActivationType, Enum, QuantType
import os
import json
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_SILU_CFG_DIR = os.path.join(SCRIPT_DIR, "../moe_c_configs/silu_configs")
_SILU_SUMMARY = os.path.join(_SILU_CFG_DIR, "silu_config_summary.json")
_SILU_CASES_CACHE = None
_SILU_INDEX_BY_N_CACHE = None
def _load_silu_summary():
global _SILU_CASES_CACHE, _SILU_INDEX_BY_N_CACHE
if _SILU_CASES_CACHE is not None and _SILU_INDEX_BY_N_CACHE is not None:
return _SILU_CASES_CACHE, _SILU_INDEX_BY_N_CACHE
if not os.path.exists(_SILU_SUMMARY):
_SILU_CASES_CACHE = {}
_SILU_INDEX_BY_N_CACHE = {}
return _SILU_CASES_CACHE, _SILU_INDEX_BY_N_CACHE
with open(_SILU_SUMMARY, "r", encoding="utf-8") as f:
data = json.load(f)
_SILU_CASES_CACHE = data.get("cases", {})
_SILU_INDEX_BY_N_CACHE = data.get("index_by_n", {})
return _SILU_CASES_CACHE, _SILU_INDEX_BY_N_CACHE
def load_silu_tune_config(M: int, N: int):
# 1) 只读 summary:优先精确命中 key
N = int(N)
cases, index_by_n = _load_silu_summary()
key = f"M={M},N={N}"
if key in cases:
return cases[key]["rows_per_block"], cases[key]["vec_size"]
# 2) 同 N 下按 |M - M_i| 找最接近配置(优先使用预构建索引)
n_key = str(N)
if n_key in index_by_n and index_by_n[n_key]:
entries = index_by_n[n_key] # sorted by M
m_list = [int(e["M"]) for e in entries]
# Manual lower_bound (avoid importing bisect).
left = 0
right = len(m_list)
while left < right:
mid = (left + right) // 2
if m_list[mid] < M:
left = mid + 1
else:
right = mid
pos = left
candidates = []
if pos < len(entries):
candidates.append(entries[pos])
if pos > 0:
candidates.append(entries[pos - 1])
if candidates:
best = min(candidates, key=lambda e: abs(int(e["M"]) - M))
return best["rows_per_block"], best["vec_size"]
# Backward-compatible slow path when old summary has no index_by_n.
nearest = None
for _, v in cases.items():
if int(v.get("N", -1)) != N:
continue
km = int(v.get("M", -1))
if km < 0:
continue
dist = abs(km - M)
if nearest is None or dist < nearest[0]:
nearest = (dist, v)
if nearest is not None:
return nearest[1]["rows_per_block"], nearest[1]["vec_size"]
# 3) fallback 默认值
return 1, 2
@compile_ops("module_moe_c_kernel")
......@@ -24,7 +101,8 @@ def moe_c_moe_gemm_marlin_w8a8(
num_tokens_post_pad: torch.Tensor,
top_k : int,
mode :int,
delta: int)-> torch.Tensor:
delta: int,
size_m: int)-> torch.Tensor:
"""
---------------------------------------------------------------
# MoE 场景下 8bit 量化的 GEMM 计算(Marlin 优化版)
......@@ -44,6 +122,30 @@ def moe_c_moe_gemm_marlin_w8a8(
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_gemm_marlin_w8a8_tensorwise(
input: torch.Tensor,
b_qweight : torch.Tensor,
output : torch.Tensor,
a_scale: torch.Tensor,
b_scale : torch.Tensor,
topk_weights : Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids : torch.Tensor,
num_tokens_post_pad: torch.Tensor,
top_k : int,
mode :int,
delta: int,
size_m: int)-> torch.Tensor:
"""
Marlin W8A8 MoE GEMM with tensorwise weight scales.
b_scale must contain one scale per expert and use shape (E, 1, 1).
"""
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_gemm_marlin_w4a8(
input: torch.Tensor,
......@@ -57,7 +159,8 @@ def moe_c_moe_gemm_marlin_w4a8(
num_tokens_post_pad: torch.Tensor,
top_k : int,
mode :int,
delta: int)-> torch.Tensor:
delta: int,
size_m: int)-> torch.Tensor:
"""
---------------------------------------------------------------
# MoE 场景下 8bit 量化的 GEMM 计算(Marlin 优化版)
......@@ -91,7 +194,8 @@ def moe_c_moe_gemm_marlin_w8a8_fp8(
num_tokens_post_pad: torch.Tensor,
top_k : int,
mode :int,
delta: int)-> torch.Tensor:
delta: int,
size_m: int)-> torch.Tensor:
"""
---------------------------------------------------------------
# MoE 场景下 8bit 量化的 GEMM 计算(Marlin 优化版)
......@@ -110,6 +214,29 @@ def moe_c_moe_gemm_marlin_w8a8_fp8(
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_gemm_marlin_w8a8_fp8_tensorwise(
input: torch.Tensor,
b_qweight : torch.Tensor,
output : torch.Tensor,
a_scale: torch.Tensor,
b_scale : torch.Tensor,
topk_weights : Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids : torch.Tensor,
num_tokens_post_pad: torch.Tensor,
top_k : int,
mode :int,
delta: int,
size_m: int)-> torch.Tensor:
"""
Marlin FP8 W8A8 MoE GEMM with tensorwise weight scales.
b_scale must contain one scale per expert and use shape (E, 1, 1).
"""
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_gemm_marlin_w4a16(
input: torch.Tensor,
......@@ -132,6 +259,32 @@ def moe_c_moe_gemm_marlin_w4a16(
必须配合对应的权重 Shuffle 函数使用,否则会导致计算结果完全错误:
---------------------------------------------------------------
"""
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_gemm_marlin_w8a16(
input: torch.Tensor,
b_qweight : torch.Tensor,
output : torch.Tensor,
b_scale: torch.Tensor,
topk_weights : Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids : torch.Tensor,
num_tokens_post_pad: torch.Tensor,
top_k : int,
mode :int,
delta: int)-> torch.Tensor:
"""
---------------------------------------------------------------
# MoE 场景下 4bit 量化的 GEMM 计算(Marlin 优化版)
## 关键前置条件
必须配合对应的权重 Shuffle 函数使用,否则会导致计算结果完全错误:
---------------------------------------------------------------
"""
......@@ -346,9 +499,15 @@ def moe_c_topk_softmax(
@compile_ops("module_moe_c_kernel")
def moe_c_silu_and_mul( out : torch.Tensor,
input : torch.Tensor) -> None:
input : torch.Tensor,
rows_per_block: int = 1,
vec_size: int = 2) -> None:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_sum(
input: torch.Tensor, # 移除 C++ 引用 &
......@@ -357,6 +516,13 @@ def moe_c_moe_sum(
) -> None:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_sum_opt_v2(input: torch.Tensor,output: torch.Tensor,
routed_scaling_factor: float = 1.0) -> torch.Tensor:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_align_block_size(
topk_ids: torch.Tensor,
......
......@@ -15,7 +15,7 @@ def moe_sorting_fwd(
sorted_expert_ids: torch.Tensor,
tokens_positions_per_expert: torch.Tensor,
num_valid_ids: torch.Tensor,
moe_buf: torch.Tensor,
moe_buf: Optional[torch.Tensor],
num_experts: int,
unit_size: int,
local_expert_mask: Optional[torch.Tensor] = None,
......
......@@ -413,3 +413,42 @@ def partial_transpose(
input: Tensor,
num_rows: Tensor,
) -> None: ...
@compile_ops("module_quant")
def moe_swiglu_dynamic_quant(
scatter_tokens: torch.Tensor,
smooth: torch.Tensor,
experts_tokens_count: torch.Tensor,
experts_tokens_start: torch.Tensor,
output: torch.Tensor,
scales: torch.Tensor,
beta: float,
) -> None:
...
def moe_swiglu_dynamic_quant_wrapper(
scatter_tokens: torch.Tensor,
smooth: torch.Tensor,
experts_tokens_count: torch.Tensor,
experts_tokens_start: torch.Tensor,
beta: float = 1.0,
):
leading, d2 = scatter_tokens.shape
d = d2 // 2
output = torch.empty((leading, d), dtype=torch.int8, device=scatter_tokens.device)
scales = torch.empty((leading,), dtype=torch.float32, device=scatter_tokens.device)
moe_swiglu_dynamic_quant(
scatter_tokens,
smooth,
experts_tokens_count,
experts_tokens_start,
output,
scales,
beta,
)
return output, scales
......@@ -120,3 +120,29 @@ def rmsnorm2d_fwd_with_add_dynamicquant(
weight: Tensor,
epsilon: float,
) -> None: ...
@compile_ops("module_rmsnorm", gen_fake=gen_rms_norm_fake_tensor)
def head_rms_norm(
input: Tensor, # [num_tokens, num_heads * head_dim]
weight: Tensor, # [num_heads * head_dim]
epsilon: float,
norm_head_dim: int, # head_dim (size of each head's normalization window)
) -> Tensor:
"""
Apply RMS normalization per head independently.
Unlike standard rms_norm which normalizes over the entire last dimension,
head_rms_norm normalizes each head's head_dim elements separately with
its own weight parameters.
Args:
input: shape [num_tokens, num_heads * head_dim]
weight: shape [num_heads * head_dim]
epsilon: small value for numerical stability
norm_head_dim: the dimension of each head (head_dim)
Returns:
Tensor with same shape as input
"""
...
......@@ -66,6 +66,39 @@ def _w8a8_marlin_weight_2(weight_input # [size_n, size_k// 2 ]
marlin_q_w = _marlin_weights_2(weight, k_tile=64, n_tile=16, pack_factor=8)
return marlin_q_w
#w8a16
def w8a16_marlin_weight_1(weight_input # [size_n, size_k]
):
w1_qweight = weight_input
e,n,k=w1_qweight.shape
# k = k * 2
w1_qweight_uint32 = w1_qweight.view(-1).view(torch.uint32)
new_shape = (e, n // 16, 16, k // 32, 8) # uint32张量的形状
w1_qweight_uint32_reshaped = w1_qweight_uint32.view(new_shape)
w1_qweight_uint32_transposed = w1_qweight_uint32_reshaped.transpose(2, 3).contiguous()
new_shape = (e, n // 16, k // 128, 4, 16, 8)
w1_new_trans = w1_qweight_uint32_transposed.view(new_shape)
w1_qweight_shuffle = w1_new_trans.transpose(1, 2).contiguous()
w1_new = w1_qweight_shuffle.view(-1).view(torch.uint8).view(*w1_qweight.shape)
return w1_new
def w8a16_marlin_weight_2(weight_input # [size_n, size_k]
):
w2_qweight = weight_input
e,k,n=w2_qweight.shape
# n = n * 2
w2_qweight_uint32 = w2_qweight.view(-1).view(torch.uint32)
new_shape = (e, k // 16, 16, n // 32, 8) # uint32张量的形状
w2_qweight_uint32_reshaped = w2_qweight_uint32.view(new_shape)
w2_qweight_uint32_transposed = w2_qweight_uint32_reshaped.transpose(2, 3).contiguous()
new_shape = (e, k // 16, n // 128, 4, 16, 8)
w2_new_trans = w2_qweight_uint32_transposed.view(new_shape)
w2_qweight_shuffle = w2_new_trans.transpose(1, 2).contiguous()
w2_new = w2_qweight_shuffle.view(-1).view(torch.uint8).view(*w2_qweight.shape)
return w2_new
def _marlin_weights(
......
# SPDX-License-Identifier: MIT
from .sparse_mla_fwd import tilelang_sparse_fwd, ref_sparse_mla_fwd_interface
from .mhc import hc_split_sinkhorn, mhc_fused_tilelang, mhc_post_fwd, mhc_pre_big_fuse, pre_big_fuse_tilelang
__all__ = ["tilelang_sparse_fwd", "ref_sparse_mla_fwd_interface"]
__all__ = [
"tilelang_sparse_fwd",
"ref_sparse_mla_fwd_interface",
"mhc_pre_big_fuse",
"pre_big_fuse_tilelang",
"mhc_post_fwd",
"hc_split_sinkhorn",
"mhc_fused_tilelang",
]
# SPDX-License-Identifier: MIT
from .hc_split_sinkhorn_kernel import hc_split_sinkhorn
from .mhc_fused_post_pre_kernel import mhc_fused_tilelang
from .post_kernel import mhc_post_fwd
from .pre_big_fuse import mhc_pre_big_fuse
from .pre_big_fuse_kernel import pre_big_fuse_tilelang
__all__ = ["mhc_pre_big_fuse", "pre_big_fuse_tilelang", "mhc_post_fwd", "hc_split_sinkhorn", "mhc_fused_tilelang"]
import tilelang
import torch
from tilelang import language as T
_PASS_CONFIGS = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
}
@tilelang.jit(pass_configs=_PASS_CONFIGS)
def _mhc_split_sinkhorn_fwd_orig(
hc: int,
sinkhorn_iters: int,
eps: float,
threads: int,
) -> tilelang.JITKernel:
n = T.dynamic('n')
mix_hc = (2 + hc) * hc
@T.prim_func
def _mhc_split_sinkhorn_fwd_orig_kernel(
mixes: T.Tensor[(n, mix_hc), T.float32],
hc_scale: T.Tensor[(3,), T.float32],
hc_base: T.Tensor[(mix_hc,), T.float32],
pre: T.Tensor[(n, hc), T.float32],
post: T.Tensor[(n, hc), T.float32],
comb: T.Tensor[(n, hc, hc), T.float32],
) -> None:
with T.Kernel(n, threads=threads) as i:
mixes_shared = T.alloc_shared(mix_hc, T.float32)
comb_frag = T.alloc_fragment((hc, hc), T.float32)
row_sum = T.alloc_fragment(hc, T.float32)
col_sum = T.alloc_fragment(hc, T.float32)
row_max = T.alloc_fragment(hc, T.float32)
T.copy(mixes[i, :], mixes_shared)
for j in T.Parallel(hc):
pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps
for j in T.Parallel(hc):
post[i, j] = 2 * T.sigmoid(mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc])
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = (
mixes_shared[j * hc + k + hc * 2] * hc_scale[2]
+ hc_base[j * hc + k + hc * 2]
)
T.reduce_max(comb_frag, row_max, dim=1)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j])
T.reduce_sum(comb_frag, row_sum, dim=1)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps
T.reduce_sum(comb_frag, col_sum, dim=0)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
for _ in T.serial(sinkhorn_iters - 1):
T.reduce_sum(comb_frag, row_sum, dim=1)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps)
T.reduce_sum(comb_frag, col_sum, dim=0)
for j, k in T.Parallel(hc, hc):
comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps)
T.copy(comb_frag, comb[i, :, :])
return _mhc_split_sinkhorn_fwd_orig_kernel
@tilelang.jit(pass_configs=_PASS_CONFIGS)
def _mhc_split_sinkhorn_fwd(
hc: int,
sinkhorn_iters: int,
eps: float,
token_block_size: int,
threads: int,
) -> tilelang.JITKernel:
n = T.dynamic('n')
mix_hc = (2 + hc) * hc
@T.prim_func
def _mhc_split_sinkhorn_fwd_kernel(
mixes: T.Tensor[(n, mix_hc), T.float32],
hc_scale: T.Tensor[(3,), T.float32],
hc_base: T.Tensor[(mix_hc,), T.float32],
pre: T.Tensor[(n, hc), T.float32],
post: T.Tensor[(n, hc), T.float32],
comb: T.Tensor[(n, hc, hc), T.float32],
) -> None:
with T.Kernel(T.ceildiv(n, token_block_size), threads=threads) as pid_x:
mixes_shared = T.alloc_shared((token_block_size, mix_hc), T.float32)
comb_frag = T.alloc_fragment((token_block_size, hc, hc), T.float32)
row_sum = T.alloc_fragment((token_block_size, hc), T.float32)
col_sum = T.alloc_fragment((token_block_size, hc), T.float32)
row_max = T.alloc_fragment((token_block_size, hc), T.float32)
T.copy(mixes[pid_x * token_block_size, 0], mixes_shared)
for i, j in T.Parallel(token_block_size, hc):
idx = pid_x * token_block_size + i
if idx < n:
pre[idx, j] = T.sigmoid(mixes_shared[i, j] * hc_scale[0] + hc_base[j]) + eps
for i, j in T.Parallel(token_block_size, hc):
idx = pid_x * token_block_size + i
if idx < n:
post[idx, j] = 2 * T.sigmoid(
mixes_shared[i, j + hc] * hc_scale[1] + hc_base[j + hc]
)
for i, j, k in T.Parallel(token_block_size, hc, hc):
comb_frag[i, j, k] = (
mixes_shared[i, j * hc + k + hc * 2] * hc_scale[2]
+ hc_base[j * hc + k + hc * 2]
)
T.reduce_max(comb_frag, row_max, dim=2)
for i, j, k in T.Parallel(token_block_size, hc, hc):
comb_frag[i, j, k] = T.exp(comb_frag[i, j, k] - row_max[i, j])
T.reduce_sum(comb_frag, row_sum, dim=2)
for i, j, k in T.Parallel(token_block_size, hc, hc):
comb_frag[i, j, k] = comb_frag[i, j, k] / row_sum[i, j] + eps
T.reduce_sum(comb_frag, col_sum, dim=1)
for i, j, k in T.Parallel(token_block_size, hc, hc):
comb_frag[i, j, k] = comb_frag[i, j, k] / (col_sum[i, k] + eps)
for _ in T.serial(sinkhorn_iters - 1):
T.reduce_sum(comb_frag, row_sum, dim=2)
for i, j, k in T.Parallel(token_block_size, hc, hc):
comb_frag[i, j, k] = comb_frag[i, j, k] / (row_sum[i, j] + eps)
T.reduce_sum(comb_frag, col_sum, dim=1)
for i, j, k in T.Parallel(token_block_size, hc, hc):
comb_frag[i, j, k] = comb_frag[i, j, k] / (col_sum[i, k] + eps)
for i, j, k in T.Parallel(token_block_size, hc, hc):
idx = pid_x * token_block_size + i
if idx < n:
comb[idx, j, k] = comb_frag[i, j, k]
return _mhc_split_sinkhorn_fwd_kernel
def mhc_split_sinkhorn(
mixes: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
hc_mult: int = 4,
sinkhorn_iters: int = 20,
eps: float = 1e-6,
token_block_size: int = 32,
threads: int = 128,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
b, s, _ = mixes.size()
n = b * s
pre = mixes.new_empty(b, s, hc_mult)
post = mixes.new_empty(b, s, hc_mult)
comb = mixes.new_empty(b, s, hc_mult, hc_mult)
if threads * token_block_size // 4 > n:
kernel = _mhc_split_sinkhorn_fwd_orig(hc_mult, sinkhorn_iters, eps, threads)
else:
kernel = _mhc_split_sinkhorn_fwd(hc_mult, sinkhorn_iters, eps, token_block_size, threads)
kernel(
mixes.contiguous().view(-1, (2 + hc_mult) * hc_mult),
hc_scale.contiguous(),
hc_base.contiguous(),
pre.view(-1, hc_mult),
post.view(-1, hc_mult),
comb.view(-1, hc_mult, hc_mult),
)
return pre, post, comb
# public alias
def hc_split_sinkhorn(*args, **kwargs):
return mhc_split_sinkhorn(*args, **kwargs)
import math
import tilelang
from tilelang import language as T
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
},
)
def mhc_post_tilelang(
a,
b,
c,
d,
x,
mhc: int,
hidden: int,
n_thr: int = 128,
h_blk: int = 1024,
) -> tilelang.JITKernel:
n = T.dynamic("num_tokens")
h = hidden
h_blk = math.gcd(hidden, h_blk)
a: T.Tensor((n, mhc, mhc), T.float32) # type: ignore[no-redef, valid-type]
b: T.Tensor((n, mhc, h), T.bfloat16) # type: ignore[no-redef, valid-type]
c: T.Tensor((n, mhc), T.float32) # type: ignore[no-redef, valid-type]
d: T.Tensor((n, h), T.bfloat16) # type: ignore[no-redef, valid-type]
x: T.Tensor((n, mhc, h), T.bfloat16) # type: ignore[no-redef, valid-type]
with T.Kernel(n, threads=n_thr) as i_n:
x_shared = T.alloc_shared((mhc, h_blk), T.bfloat16)
b_shared = T.alloc_shared((mhc, h_blk), T.bfloat16)
d_shared = T.alloc_shared(h_blk, T.bfloat16)
x_local = T.alloc_fragment((mhc, h_blk), T.float32)
b_local = T.alloc_fragment((mhc, h_blk), T.float32)
d_local = T.alloc_fragment(h_blk, T.float32)
a_local = T.alloc_fragment((mhc, mhc), T.float32)
c_local = T.alloc_fragment(mhc, T.float32)
T.copy(a[i_n, 0, 0], a_local)
T.copy(c[i_n, 0], c_local)
for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2):
T.copy(b[i_n, 0, i0_h * h_blk], b_shared)
T.copy(d[i_n, i0_h * h_blk], d_shared)
T.copy(b_shared, b_local)
T.copy(d_shared, d_local)
for i_hco, i1_h in T.Parallel(mhc, h_blk):
x_local[i_hco, i1_h] = c_local[i_hco] * d_local[i1_h]
for i_hci in T.serial(mhc):
x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h]
T.copy(x_local, x_shared)
T.copy(x_shared, x[i_n, 0, i0_h * h_blk])
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
},
)
def mhc_fused_tilelang(
comb_mix,
residual_in,
post_mix,
x_in,
weight_t,
yp_out,
rp_out,
residual_out,
mhc: int,
hidden: int,
n_out: int,
n_thr: int = 256,
h_blk: int = 256,
tile_n: int = 1,
split_k: int = 1,
) -> tilelang.JITKernel:
m = T.dynamic("num_tokens")
split_k = T.dynamic("split_k")
h = hidden
h_blk = math.gcd(hidden, h_blk)
h_per_split = h // split_k
n_tiles = n_out // tile_n
h_iters = h_per_split // n_thr
warp_size = 64
num_warps = n_thr // warp_size
comb_mix: T.Tensor((m, mhc, mhc), T.float32) # type: ignore[no-redef, valid-type]
residual_in: T.Tensor((m, mhc, h), T.bfloat16) # type: ignore[no-redef, valid-type]
post_mix: T.Tensor((m, mhc), T.float32) # type: ignore[no-redef, valid-type]
x_in: T.Tensor((m, h), T.bfloat16) # type: ignore[no-redef, valid-type]
weight_t: T.Tensor((n_out, mhc, h), T.float32) # type: ignore[no-redef, valid-type]
yp_out: T.Tensor((split_k, m, n_out), T.float32) # type: ignore[no-redef, valid-type]
rp_out: T.Tensor((split_k, m), T.float32) # type: ignore[no-redef, valid-type]
residual_out: T.Tensor((m, mhc, h), T.bfloat16) # type: ignore[no-redef, valid-type]
with T.Kernel(m, n_tiles, split_k, threads=n_thr) as (i_n, i_nt, i_ks):
tid = T.get_thread_binding()
# warp_id = tid // warp_size
# lane = tid % warp_size
warp_id = T.get_warp_idx()
lane = T.get_lane_idx()
h_split_start = i_ks * h_per_split
s_warp = T.alloc_shared((num_warps, tile_n + 1), T.float32)
s_post = T.alloc_shared((mhc,), T.float32)
s_comb = T.alloc_shared((mhc, mhc), T.float32)
pm = T.alloc_local((mhc,), T.float32)
cm = T.alloc_local((mhc, mhc), T.float32)
acc = T.alloc_local((tile_n,), T.float32)
sqr = T.alloc_local((1,), T.float32)
new_r = T.alloc_local((mhc,), T.float32)
T.clear(acc)
T.clear(sqr)
T.copy(post_mix[i_n, 0], s_post)
T.copy(comb_mix[i_n, 0, 0], s_comb)
for j in T.unroll(mhc):
pm[j] = s_post[j]
for j in T.unroll(mhc):
for k in T.unroll(mhc):
cm[k, j] = s_comb[k, j]
for it in T.serial(h_iters):
h_idx = h_split_start + it * n_thr + tid
for j in T.unroll(mhc):
new_r[j] = pm[j] * x_in[i_n, h_idx]
for k in T.unroll(mhc):
new_r[j] += cm[k, j] * residual_in[i_n, k, h_idx]
if i_nt == 0:
for j in T.unroll(mhc):
residual_out[i_n, j, h_idx] = new_r[j]
sqr[0] += new_r[j] * new_r[j]
for n in T.unroll(tile_n):
for j in T.unroll(mhc):
acc[n] += weight_t[i_nt * tile_n + n, j, h_idx] * new_r[j]
for n in T.unroll(tile_n):
acc[n] = T.warp_reduce_sum(acc[n])
if i_nt == 0:
sqr[0] = T.warp_reduce_sum(sqr[0])
if lane == 0:
for n in T.unroll(tile_n):
s_warp[warp_id, n] = acc[n]
if i_nt == 0:
s_warp[warp_id, tile_n] = sqr[0]
T.sync_threads()
if warp_id == 0:
if lane < tile_n:
v = T.alloc_var(T.float32, init=0.0)
for w in T.unroll(num_warps):
v += s_warp[w, lane]
yp_out[i_ks, i_n, i_nt * tile_n + lane] = v
if i_nt == 0 and lane == 0:
v2 = T.alloc_var(T.float32, init=0.0)
for w in T.unroll(num_warps):
v2 += s_warp[w, tile_n]
rp_out[i_ks, i_n] = v2
import tilelang
import torch
from tilelang import language as T
_PASS_CONFIGS = {
tilelang.PassConfigKey.TL_DISABLE_WGMMA: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
@tilelang.jit
def _mhc_fn_normw_merge_fwd(m: int, n: int, dtype: T.dtype = T.float32) -> tilelang.JITKernel:
n_blk = 256
@T.prim_func
def _mhc_fn_normw_merge_fwd_(
fn: T.Tensor[(m, n), dtype],
normw: T.Tensor[n, dtype],
out_fn: T.Tensor[(m, n), dtype],
) -> None:
_ = dtype
with T.Kernel(m, T.ceildiv(n, n_blk)) as (pid_m, pid_n):
for i1_n in T.Parallel(n_blk):
i_n = pid_n * n_blk + i1_n
if i_n < n:
out_fn[pid_m, i_n] = fn[pid_m, i_n] * normw[i_n]
return _mhc_fn_normw_merge_fwd_
@tilelang.jit
def _mhc_fn_normw_merge_bwd(m: int, n: int, dtype: T.dtype = T.float32) -> tilelang.JITKernel:
n_blk = 256
@T.prim_func
def _mhc_fn_normw_merge_bwd_(
fn: T.Tensor[(m, n), dtype],
normw: T.Tensor[n, dtype],
out_fn_grad: T.Tensor[(m, n), dtype],
fn_grad: T.Tensor[(m, n), dtype],
normw_grad: T.Tensor[n, dtype],
) -> None:
_ = dtype
with T.Kernel(T.ceildiv(n, n_blk)) as pid_n:
normw_frag = T.alloc_fragment(n_blk, dtype)
T.copy(normw[pid_n * n_blk], normw_frag)
normw_grad_frag = T.alloc_fragment(n_blk, dtype)
T.clear(normw_grad_frag)
for i_m in T.serial(m):
for i1_n in T.Parallel(n_blk):
i_n = pid_n * n_blk + i1_n
if i_n < n:
fn_grad[i_m, i_n] += out_fn_grad[i_m, i_n] * normw_frag[i1_n]
normw_grad_frag[i1_n] += out_fn_grad[i_m, i_n] * fn[i_m, i_n]
for i1_n in T.Parallel(n_blk):
normw_grad[pid_n * n_blk + i1_n] += normw_grad_frag[i1_n]
return _mhc_fn_normw_merge_bwd_
@tilelang.jit(pass_configs=_PASS_CONFIGS)
def _mhc_pre_norm_fn_fwd_mul(
mhc_mult3: int,
n_rms_group: int,
rms_group_size: int,
token_block: int = 64,
hidden_block: int = 256,
) -> tilelang.JITKernel:
assert mhc_mult3 <= 32
num_tokens = T.dynamic('num_tokens')
assert rms_group_size % hidden_block == 0
@T.prim_func
def _mhc_pre_norm_fn_fwd_mul_kernel(
x: T.Tensor[(num_tokens, n_rms_group * rms_group_size), T.bfloat16],
fn: T.Tensor[(mhc_mult3, n_rms_group * rms_group_size), T.float32],
out: T.Tensor[(num_tokens, n_rms_group, mhc_mult3), T.float32],
sqrsum: T.Tensor[(num_tokens, n_rms_group), T.float32],
) -> None:
_ = mhc_mult3
with T.Kernel(T.ceildiv(num_tokens, token_block), n_rms_group, threads=256) as (pid_x, pid_y):
out_frag = T.alloc_fragment((token_block, 32), T.float32)
sqrsum_part = T.alloc_fragment((token_block, 16), T.float32)
T.clear(out_frag)
T.clear(sqrsum_part)
for pz in T.Pipelined(rms_group_size // hidden_block, num_stages=0):
x_frag_pre = T.alloc_fragment((token_block, hidden_block), T.bfloat16)
fn_frag_pre = T.alloc_fragment((32, hidden_block), T.float32)
x_frag_16 = T.alloc_fragment((token_block, hidden_block), T.bfloat16)
x_frag = T.alloc_fragment((token_block, hidden_block), T.float32)
fn_frag = T.alloc_fragment((32, hidden_block), T.float32)
x_smem_16 = T.alloc_shared((token_block, hidden_block), T.bfloat16)
fn_smem = T.alloc_shared((32, hidden_block), T.float32)
T.annotate_layout({x_smem_16: tilelang.layout.make_hcu_swizzled_layout(x_smem_16, major_pack=2)})
T.annotate_layout({fn_smem: tilelang.layout.make_hcu_swizzled_layout(fn_smem, major_pack=2)})
T.copy(x[pid_x * token_block, pid_y * rms_group_size + pz * hidden_block], x_frag_pre)
T.copy(fn[0, pid_y * rms_group_size + pz * hidden_block], fn_frag_pre)
T.copy(x_frag_pre, x_smem_16)
T.copy(x_smem_16, x_frag_16)
T.copy(x_frag_16, x_frag)
T.copy(fn_frag_pre, fn_smem)
T.copy(fn_smem, fn_frag)
for jj in T.serial(hidden_block // 16):
for i, j in T.Parallel(token_block, 16):
sqrsum_part[i, j] += x_frag[i, jj * 16 + j] * x_frag[i, jj * 16 + j]
T.gemm(
x_frag,
fn_frag,
out_frag,
transpose_A=False,
transpose_B=True,
clear_accum=False,
k_pack=2,
policy=T.GemmWarpPolicy.FullRow,
use_tf32=True,
)
sqrsum_l = T.alloc_fragment(token_block, T.float32)
T.reduce_sum(sqrsum_part, sqrsum_l)
out_shared = T.alloc_shared((token_block, 32), T.float32)
T.annotate_layout({out_shared: tilelang.layout.make_hcu_swizzled_layout(out_shared, major_pack=2)})
T.copy(out_frag, out_shared)
for i in T.Parallel(token_block):
sqrsum[pid_x * token_block + i, pid_y] = sqrsum_l[i]
for i, j in T.Parallel(token_block, 32):
if j < 24:
out[pid_x * token_block + i, pid_y, j] = out_shared[i, j]
return _mhc_pre_norm_fn_fwd_mul_kernel
@tilelang.jit(pass_configs=_PASS_CONFIGS)
def _mhc_pre_norm_fn_fwd_norm(
mhc_mult3: int,
n_rms_group: int,
rms_group_size: int,
rms_eps: float,
n_splits: int,
) -> tilelang.JITKernel:
num_tokens = T.dynamic('num_tokens')
n_thr = 32
@T.prim_func
def _mhc_pre_norm_fn_fwd_norm_kernel(
out_mul_splitted: T.Tensor[(n_splits, num_tokens, n_rms_group, mhc_mult3), T.float32],
sqrsum_splitted: T.Tensor[(n_splits, num_tokens, n_rms_group), T.float32],
out_mul: T.Tensor[(num_tokens, n_rms_group, mhc_mult3), T.float32],
sqrsum: T.Tensor[(num_tokens, n_rms_group), T.float32],
out: T.Tensor[(num_tokens, mhc_mult3), T.float32],
) -> None:
with T.Kernel(num_tokens, threads=n_thr) as pid:
rms = T.alloc_fragment(1, T.float32)
out_l = T.alloc_fragment(mhc_mult3, T.float32)
out_l0 = T.alloc_fragment(mhc_mult3, T.float32)
T.clear(out_l)
for k in T.serial(n_rms_group):
rms[0] = 0
for i_split in T.serial(n_splits):
rms[0] += sqrsum_splitted[i_split, pid, k]
if T.get_thread_binding() == 0:
sqrsum[pid, k] = rms[0]
rms[0] = T.rsqrt(rms[0] / rms_group_size + rms_eps)
for j in T.Parallel(mhc_mult3):
out_l0[j] = 0
for i_split in T.serial(n_splits):
out_l0[j] += out_mul_splitted[i_split, pid, k, j]
out_l[j] += out_l0[j] * rms[0]
T.copy(out_l0, out_mul[pid, k, :])
T.copy(out_l[:], out[pid, :])
return _mhc_pre_norm_fn_fwd_norm_kernel
@tilelang.jit(pass_configs=_PASS_CONFIGS)
def _mhc_pre_norm_fn_bwd_norm(
mhc_mult3: int,
n_rms_group: int,
rms_group_size: int,
rms_eps: float,
) -> tilelang.JITKernel:
num_tokens = T.dynamic('num_tokens')
n_thr = 32
@T.prim_func
def _mhc_pre_norm_fn_bwd_norm_kernel(
# Gradient of output
out_grad: T.Tensor[(num_tokens, mhc_mult3), T.float32],
# Saved inputs
out_mul: T.Tensor[(num_tokens, n_rms_group, mhc_mult3), T.float32],
sqrsum: T.Tensor[(num_tokens, n_rms_group), T.float32],
# Computed gradient of inputs
out_mul_grad: T.Tensor[(num_tokens, n_rms_group, mhc_mult3), T.float32],
sqrsum_grad: T.Tensor[(num_tokens, n_rms_group), T.float32],
) -> None:
with T.Kernel(num_tokens, n_rms_group, threads=n_thr) as (pid_i, pid_k):
sqrsum_frag = T.alloc_fragment(1, T.float32)
sqrsum_frag[0] = sqrsum[pid_i, pid_k]
rms_frag = T.alloc_fragment(1, T.float32)
rms_frag[0] = T.rsqrt(sqrsum_frag[0] / rms_group_size + rms_eps)
rms_grad_frag = T.alloc_reducer(1, T.float32, replication='all')
T.clear(rms_grad_frag)
for j in T.Parallel(mhc_mult3):
out_mul_grad[pid_i, pid_k, j] = out_grad[pid_i, j] * rms_frag[0]
rms_grad_frag[0] += out_grad[pid_i, j] * out_mul[pid_i, pid_k, j]
T.finalize_reducer(rms_grad_frag)
for kk in T.Parallel(1):
sqrsum_grad[pid_i, pid_k + kk] = rms_grad_frag[kk] * rms_frag[kk] / (sqrsum_frag[kk] + rms_eps * rms_group_size) / -2
return _mhc_pre_norm_fn_bwd_norm_kernel
@tilelang.jit(pass_configs=_PASS_CONFIGS)
def _mhc_pre_norm_fn_bwd_mul(
mhc_mult3: int,
n_rms_group: int,
rms_group_size: int,
token_block: int = 128,
hidden_block: int = 128,
) -> tilelang.JITKernel:
assert mhc_mult3 <= 32
num_tokens = T.dynamic('num_tokens')
assert rms_group_size % hidden_block == 0
@T.prim_func
def _mhc_pre_norm_fn_bwd_mul_kernel(
# Gradient of output
out_mul_grad: T.Tensor[(num_tokens, n_rms_group, mhc_mult3), T.float32],
sqrsum_grad: T.Tensor[(num_tokens, n_rms_group), T.float32],
# Saved inputs
x: T.Tensor[(num_tokens, n_rms_group * rms_group_size), T.bfloat16],
fn: T.Tensor[(mhc_mult3, n_rms_group * rms_group_size), T.float32],
# Computed gradient of inputs
x_grad: T.Tensor[(num_tokens, n_rms_group * rms_group_size), T.bfloat16],
fn_grad: T.Tensor[(mhc_mult3, n_rms_group * rms_group_size), T.float32],
) -> None:
with T.Kernel(n_rms_group, T.ceildiv(rms_group_size, hidden_block)) as (pid_y, pid_z):
yz = pid_y * rms_group_size + pid_z * hidden_block
fn_smem = T.alloc_shared((32, hidden_block), T.float32)
for i, j in T.Parallel(32, hidden_block):
if i < mhc_mult3:
fn_smem[i, j] = fn[i, yz + j]
else:
fn_smem[i, j] = 0
fn_grad_frag = T.alloc_fragment((32, hidden_block), T.float32)
T.fill(fn_grad_frag, 0)
for px in T.serial(T.ceildiv(num_tokens, token_block)):
x_smem = T.alloc_shared((token_block, hidden_block), T.float32)
T.copy(x[px * token_block, yz], x_smem)
padded_grad = T.alloc_shared((token_block, 32), T.float32)
for i, j in T.Parallel(token_block, 32):
if j < mhc_mult3:
padded_grad[i, j] = out_mul_grad[px * token_block + i, pid_y, j]
else:
padded_grad[i, j] = 0
x_grad_frag = T.alloc_fragment((token_block, hidden_block), T.float32)
T.copy(x_grad[px * token_block, yz], x_grad_frag)
T.gemm(
padded_grad,
x_smem,
fn_grad_frag,
transpose_A=True,
transpose_B=False,
clear_accum=False,
)
T.gemm(
padded_grad,
fn_smem,
x_grad_frag,
transpose_A=False,
transpose_B=False,
clear_accum=False,
)
sqrsum_grad_frag = T.alloc_fragment((token_block, 1), T.float32)
T.copy(sqrsum_grad[px * token_block, pid_y], sqrsum_grad_frag)
for i, j in T.Parallel(token_block, hidden_block):
x_grad_frag[i, j] += 2 * x_smem[i, j] * sqrsum_grad_frag[i, 0]
T.copy(x_grad_frag, x_grad[px * token_block, yz])
T.copy(fn_grad_frag, fn_grad[0, yz])
return _mhc_pre_norm_fn_bwd_mul_kernel
def round_to_tf32(x: torch.Tensor) -> torch.Tensor:
return (x.view(torch.int32) + 0x1000).view(torch.float32)
# SPDX-License-Identifier: MIT
import math
import tilelang
import torch
from tilelang import language as T
# Global guards for validating split-k stage0/stage1 kernels.
cu_count = torch.cuda.get_device_properties("cuda").multi_processor_count
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
},
)
def _mhc_post_fwd(mhc: int, hidden: int, n_thr: int = 128, h_blk: int = 1024) -> tilelang.JITKernel:
n = T.dynamic("num_tokens")
h = hidden
h_blk = math.gcd(hidden, h_blk)
@T.prim_func
def _mhc_post_fwd_kernel(
a: T.Tensor[(n, mhc, mhc), T.float32],
b: T.Tensor[(n, mhc, h), T.bfloat16],
c: T.Tensor[(n, mhc), T.float32],
d: T.Tensor[(n, h), T.bfloat16],
x: T.Tensor[(n, mhc, h), T.bfloat16],
) -> None:
with T.Kernel(n, threads=n_thr) as pid_n:
b_shared = T.alloc_shared((mhc, h_blk), T.bfloat16)
x_local = T.alloc_fragment((mhc, h_blk), T.float32)
b_local = T.alloc_fragment((mhc, h_blk), T.float32)
d_local = T.alloc_fragment(h_blk, T.float32)
a_local = T.alloc_fragment((mhc, mhc), T.float32)
c_local = T.alloc_fragment(mhc, T.float32)
T.copy(a[pid_n, 0, 0], a_local)
T.copy(c[pid_n, 0], c_local)
for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=1):
T.copy(b[pid_n, 0, i0_h * h_blk], b_shared, disable_tma=True)
T.copy(b_shared, b_local)
T.copy(d[pid_n, i0_h * h_blk], d_local, disable_tma=True)
for i_mhco, i1_h in T.Parallel(mhc, h_blk):
x_local[i_mhco, i1_h] = c_local[i_mhco] * d_local[i1_h]
for i_mhci in T.serial(mhc):
x_local[i_mhco, i1_h] += a_local[i_mhci, i_mhco] * b_local[i_mhci, i1_h]
T.copy(x_local, x[pid_n, 0, i0_h * h_blk], disable_tma=True, coalesced_width=8)
return _mhc_post_fwd_kernel
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
},
)
def _mhc_post_fwd_split_h(
mhc: int,
hidden: int,
n_thr: int = 128,
h_blk: int = 1024,
) -> tilelang.JITKernel:
n = T.dynamic("num_tokens")
h = hidden
h_blk = math.gcd(hidden, h_blk)
@T.prim_func
def _mhc_post_fwd_split_h_kernel(
a: T.Tensor[(n, mhc, mhc), T.float32],
b: T.Tensor[(n, mhc, h), T.bfloat16],
c: T.Tensor[(n, mhc), T.float32],
d: T.Tensor[(n, h), T.bfloat16],
x: T.Tensor[(n, mhc, h), T.bfloat16],
) -> None:
with T.Kernel(n, T.ceildiv(h, h_blk), threads=n_thr) as (pid_n, pid_h):
b_shared = T.alloc_shared((mhc, h_blk), T.bfloat16)
x_local = T.alloc_fragment((mhc, h_blk), T.float32)
b_local = T.alloc_fragment((mhc, h_blk), T.float32)
d_local = T.alloc_fragment(h_blk, T.float32)
a_local = T.alloc_fragment((mhc, mhc), T.float32)
c_local = T.alloc_fragment(mhc, T.float32)
T.copy(a[pid_n, 0, 0], a_local)
T.copy(c[pid_n, 0], c_local)
h_start = pid_h * h_blk
T.copy(b[pid_n, 0, h_start], b_shared, disable_tma=True)
T.copy(b_shared, b_local)
T.copy(d[pid_n, h_start], d_local, disable_tma=True)
for i_mhco, i1_h in T.Parallel(mhc, h_blk):
x_local[i_mhco, i1_h] = c_local[i_mhco] * d_local[i1_h]
for i_mhci in T.serial(mhc):
x_local[i_mhco, i1_h] += a_local[i_mhci, i_mhco] * b_local[i_mhci, i1_h]
T.copy(x_local, x[pid_n, 0, h_start], disable_tma=True, coalesced_width=8)
return _mhc_post_fwd_split_h_kernel
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
},
)
def _mhc_post_fwd_wo_shmem(
mhc: int,
hidden: int,
n_thr: int = 128,
h_blk: int = 1024,
) -> tilelang.JITKernel:
n = T.dynamic("num_tokens")
h = hidden
h_blk = math.gcd(hidden, h_blk)
@T.prim_func
def _mhc_post_fwd_wo_shmem_kernel(
a: T.Tensor[(n, mhc, mhc), T.float32],
b: T.Tensor[(n, mhc, h), T.bfloat16],
c: T.Tensor[(n, mhc), T.float32],
d: T.Tensor[(n, h), T.bfloat16],
x: T.Tensor[(n, mhc, h), T.bfloat16],
) -> None:
with T.Kernel(n, threads=n_thr) as pid_n:
x_local = T.alloc_fragment((mhc, h_blk), T.float32)
b_local = T.alloc_fragment((mhc, h_blk), T.float32)
d_local = T.alloc_fragment(h_blk, T.float32)
a_local = T.alloc_fragment((mhc, mhc), T.float32)
c_local = T.alloc_fragment(mhc, T.float32)
T.copy(a[pid_n, 0, 0], a_local)
T.copy(c[pid_n, 0], c_local)
for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=0):
T.copy(b[pid_n, 0, i0_h * h_blk], b_local, disable_tma=True)
T.copy(d[pid_n, i0_h * h_blk], d_local, disable_tma=True)
for i_mhco, i1_h in T.Parallel(mhc, h_blk):
x_local[i_mhco, i1_h] = c_local[i_mhco] * d_local[i1_h]
for i_mhci in T.serial(mhc):
x_local[i_mhco, i1_h] += a_local[i_mhci, i_mhco] * b_local[i_mhci, i1_h]
T.copy(x_local, x[pid_n, 0, i0_h * h_blk], disable_tma=True, coalesced_width=8)
return _mhc_post_fwd_wo_shmem_kernel
def mhc_post_fwd(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
out: torch.Tensor | None = None,
) -> torch.Tensor:
num_tokens, mhc, hidden = residual.shape
assert x.dtype == torch.bfloat16, f"x.dtype={x.dtype}"
assert residual.dtype == torch.bfloat16, f"residual.dtype={residual.dtype}"
assert post_layer_mix.dtype == torch.float32, f"post_layer_mix.dtype={post_layer_mix.dtype}"
assert comb_res_mix.dtype == torch.float32, f"comb_res_mix.dtype={comb_res_mix.dtype}"
assert x.shape == (num_tokens, hidden), f"x.shape={x.shape}"
assert post_layer_mix.shape == (num_tokens, mhc), f"post_layer_mix.shape={post_layer_mix.shape}"
assert comb_res_mix.shape == (num_tokens, mhc, mhc), f"comb_res_mix.shape={comb_res_mix.shape}"
residual = residual.contiguous()
assert x.is_contiguous()
assert post_layer_mix.is_contiguous()
assert comb_res_mix.is_contiguous()
if out is None:
out = torch.empty_like(residual)
n = num_tokens
h_tiles = math.gcd(hidden, 1024)
h_tiles = hidden // h_tiles
n_thr = 128
if n < cu_count * 2 and h_tiles > 1:
# increase cu num usage by adding h_split
kernel = _mhc_post_fwd_split_h(mhc, hidden, n_thr=n_thr)
elif n < cu_count * 2:
# use shared mem and stage pipeline
kernel = _mhc_post_fwd(mhc, hidden, n_thr=n_thr)
else:
# only use registers and no pipeline
kernel = _mhc_post_fwd_wo_shmem(mhc, hidden, n_thr=n_thr)
kernel(
comb_res_mix,
residual,
post_layer_mix,
x,
out,
)
return out
import functools
import math
from typing import NamedTuple
import tilelang
import torch
from tilelang import language as T
from .norm_fn_kernel import _mhc_pre_norm_fn_fwd_mul
from .pre_norm_fn_splitk_kernel import mhc_pre_gemm_sqrsum_splitk_kernel
from .pre_big_fuse_kernel import _mhc_pre_big_fuse
# Global guards for validating split-k stage0/stage1 kernels.
cu_count = torch.cuda.get_device_properties("cuda").multi_processor_count
class PreBigFuseBlockInfo(NamedTuple):
token_block: int
hidden_block: int
hidden_loop: int
n_splits_pre: int
use_small_token_splitk: bool
@functools.lru_cache(maxsize=1024)
def get_block_info(num_tokens: int, mhc_hidden_size: int, cu_count: int) -> PreBigFuseBlockInfo:
token_block = 128 # use 128 for better performance
hidden_block = 128 # with hidden_block = 128, the occupancy is 2
hidden_loop = mhc_hidden_size // hidden_block
token_loop = (num_tokens + token_block - 1) // token_block
if token_loop <= 2:
if num_tokens > 128:
# for occupied 2
n_splits_pre = 64
if hidden_loop % n_splits_pre != 0:
hidden_block = 64
hidden_loop = mhc_hidden_size // hidden_block
elif num_tokens > 64:
# for occupied 2
token_block = 64
n_splits_pre = 64
if hidden_loop % n_splits_pre != 0:
hidden_block = 64
hidden_loop = mhc_hidden_size // hidden_block
elif num_tokens > 32:
# for occupied 2
token_block = 32
n_splits_pre = 64
if hidden_loop % n_splits_pre != 0:
hidden_block = 64
hidden_loop = mhc_hidden_size // hidden_block
else:
# occupied 1
token_block = 32
n_splits_pre = 64
if hidden_loop % n_splits_pre != 0:
hidden_block = 64
hidden_loop = mhc_hidden_size // hidden_block
elif token_loop <= 4:
n_splits_pre = 32
elif token_loop <= cu_count // 8:
n_splits_pre = 16
elif token_loop <= cu_count // 4:
n_splits_pre = 8
elif token_loop <= cu_count * 0.75:
n_splits_pre = 8
elif token_loop <= cu_count * 2:
n_splits_pre = 4
else:
n_splits_pre = 1
final_token_loop = (num_tokens + token_block - 1) // token_block
use_small_token_splitk = (
n_splits_pre > 1
and final_token_loop <= cu_count * 2
and hidden_loop > 0
and hidden_loop % n_splits_pre == 0
)
if not use_small_token_splitk:
token_block = 64
hidden_block = 128
# print(f"use_small_token_splitk={use_small_token_splitk}, num_tokens={num_tokens}, hidden_loop={hidden_loop}, "
# f"MHC_PRE_BIG_FUSE_N_SPLITS_PRE={MHC_PRE_BIG_FUSE_N_SPLITS_PRE}, token_block={token_block}, hidden_block={hidden_block}")
return PreBigFuseBlockInfo(
token_block=token_block,
hidden_block=hidden_block,
hidden_loop=hidden_loop,
n_splits_pre=n_splits_pre,
use_small_token_splitk=use_small_token_splitk,
)
@functools.lru_cache(maxsize=128)
def _round_to_tf32_kernel(n_elem: int) -> tilelang.JITKernel:
return _compile_round_to_tf32(n_elem)
@tilelang.jit # inp, out both passed in; out_idx would mean only inp is passed and out is allocated inside the adapter
def _compile_round_to_tf32(n_elem: int) -> tilelang.JITKernel:
"""Bitcast float32 -> int32, add 0x1000, bitcast back (1D linear scan for coalescing)."""
_TF32_ROUND_BITS = 0x1000
_ROUND_TO_TF32_BLK_MAX = 2048
n_blk = math.gcd(_ROUND_TO_TF32_BLK_MAX, n_elem)
@T.prim_func
def _round_to_tf32_prim(
inp: T.Tensor[(n_elem,), T.float32],
out: T.Tensor[(n_elem,), T.float32],
) -> None:
with T.Kernel(T.ceildiv(n_elem, n_blk)) as pid:
input_frag = T.alloc_fragment((n_blk,), T.float32)
output_frag = T.alloc_fragment((n_blk,), T.float32)
T.copy(inp[pid * n_blk], input_frag)
input_int = T.view(input_frag, (n_blk,), T.int32)
for t in T.Parallel(n_blk):
input_int[t] += T.int32(_TF32_ROUND_BITS)
output_frag[t] = T.reinterpret(input_int[t], T.float32)
T.copy(output_frag, out[pid * n_blk])
return _round_to_tf32_prim
def round_to_tf32(fn: torch.Tensor) -> torch.Tensor:
"""TF32 grid rounding via TileLang (flat numel; preserves original shape)."""
ne = int(fn.numel())
out = torch.empty_like(fn)
_round_to_tf32_kernel(ne)(fn.reshape(ne), out.reshape(ne))
return out
def mhc_pre_big_fuse(
residual: torch.Tensor,
fn: torch.Tensor,
mhc_scale: torch.Tensor,
mhc_base: torch.Tensor,
rms_eps: float,
mhc_pre_eps: float,
mhc_sinkhorn_eps: float,
mhc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 16,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert residual.dtype == torch.bfloat16
assert fn.dtype == torch.float32
assert mhc_scale.dtype == torch.float32
assert mhc_base.dtype == torch.float32
mhc_mult = residual.shape[-2]
hidden_size = residual.shape[-1]
mhc_mult2 = mhc_mult * mhc_mult
mhc_mult3 = mhc_mult * 2 + mhc_mult2
mhc_hidden_size = mhc_mult * hidden_size
assert fn.shape[0] == mhc_mult3
assert fn.shape[1] == mhc_hidden_size
assert mhc_scale.shape == (3,)
assert mhc_base.shape == (mhc_mult3,)
outer_shape = residual.shape[:-2]
residual_flat = residual.view(-1, mhc_mult, hidden_size)
num_tokens = residual_flat.shape[0]
fn_flat = fn
post_mix = torch.empty(num_tokens, mhc_mult, dtype=torch.float32, device=residual.device)
comb_mix = torch.empty(num_tokens, mhc_mult2, dtype=torch.float32, device=residual.device)
layer_input = torch.empty(num_tokens, hidden_size, dtype=torch.bfloat16, device=residual.device)
# Bucket by 32 so get_block_info cache keys align with common launch granularity; real buffers still use num_tokens.
num_tokens_align = (int(num_tokens) + 31) // 32 * 32
block_info = get_block_info(num_tokens_align, mhc_hidden_size, cu_count)
token_block = block_info.token_block
hidden_block = block_info.hidden_block
hidden_loop = block_info.hidden_loop
MHC_PRE_BIG_FUSE_N_SPLITS_PRE = block_info.n_splits_pre
use_small_token_splitk = block_info.use_small_token_splitk
fn = round_to_tf32(fn)
if use_small_token_splitk:
kernel_0, kernel_1 = mhc_pre_gemm_sqrsum_splitk_kernel(
mhc_mult3,
mhc_hidden_size,
split_k=MHC_PRE_BIG_FUSE_N_SPLITS_PRE,
token_block=token_block,
hidden_block=hidden_block,
)
partial_out = torch.empty(
MHC_PRE_BIG_FUSE_N_SPLITS_PRE, num_tokens, mhc_mult3, dtype=torch.float32, device=residual.device
)
partial_sqrsum = torch.empty(
MHC_PRE_BIG_FUSE_N_SPLITS_PRE, num_tokens, dtype=torch.float32, device=residual.device
)
# gemm_out_mul = torch.empty(
# 1, num_tokens, mhc_mult3, dtype=torch.float32, device=residual.device
# )
# gemm_out_sqrsum = torch.empty(1, num_tokens, dtype=torch.float32, device=residual.device)
kernel_0(
residual_flat.view(-1, mhc_hidden_size),
fn,
partial_out,
partial_sqrsum,
)
gemm_out_mul = partial_out
gemm_out_sqrsum = partial_sqrsum
# kernel_1(
# partial_out,
# partial_sqrsum,
# gemm_out_mul.squeeze(0),
# gemm_out_sqrsum.squeeze(0),
# )
n_splits = MHC_PRE_BIG_FUSE_N_SPLITS_PRE
else:
gemm_out_mul = torch.empty(
1, num_tokens, mhc_mult3, dtype=torch.float32, device=residual.device
)
gemm_out_sqrsum = torch.empty(1, num_tokens, dtype=torch.float32, device=residual.device)
n_splits = 1
fwd_mul_kernel = _mhc_pre_norm_fn_fwd_mul(mhc_mult3, 1, mhc_hidden_size, token_block=token_block, hidden_block=hidden_block)
fwd_mul_kernel(
residual_flat.view(-1, mhc_hidden_size),
fn,
gemm_out_mul.view(-1, 1, mhc_mult3),
gemm_out_sqrsum.view(-1, 1),
)
# END of TileLang implementation of pre-norm-fn forward matmul
_mhc_pre_big_fuse(
hidden_size,
rms_eps,
mhc_pre_eps,
mhc_sinkhorn_eps,
mhc_post_mult_value,
sinkhorn_repeat,
n_splits=n_splits,
mhc_mult=mhc_mult,
)(
gemm_out_mul,
gemm_out_sqrsum,
mhc_scale,
mhc_base,
residual_flat,
post_mix,
comb_mix,
layer_input,
)
post_mix = post_mix.view(*outer_shape, mhc_mult, 1)
comb_mix = comb_mix.view(*outer_shape, mhc_mult, mhc_mult)
layer_input = layer_input.view(*outer_shape, hidden_size)
return post_mix, comb_mix, layer_input
import math
import tilelang
import torch
from tilelang import language as T
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True,
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def _mhc_pre_big_fuse(
hidden_size: int,
rms_eps: float,
mhc_pre_eps: float,
mhc_sinkhorn_eps: float,
mhc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 16,
mhc_mult: int = 4,
):
num_tokens = T.dynamic('num_tokens')
mhc_mult3 = mhc_mult * (2 + mhc_mult)
hidden_block = math.gcd(512, hidden_size)
@T.prim_func
def mhc_pre_big_fuse(
gemm_out_mul: T.Tensor[(n_splits, num_tokens, mhc_mult3), T.float32],
gemm_out_sqrsum: T.Tensor[(n_splits, num_tokens), T.float32],
mhc_scale: T.Tensor[(3,), T.float32],
mhc_base: T.Tensor[(mhc_mult3,), T.float32],
residual: T.Tensor[(num_tokens, mhc_mult, hidden_size), T.bfloat16],
# outputs
post_mix: T.Tensor[(num_tokens, mhc_mult), T.float32],
comb_mix: T.Tensor[(num_tokens, mhc_mult * mhc_mult), T.float32],
layer_input: T.Tensor[(num_tokens, hidden_size), T.bfloat16],
) -> None:
threads = 128
n_splits_aligned = tilelang.math.next_power_of_2(n_splits)
if n_splits >= 4:
split_groups = threads // 32
# assert n_splits % split_groups == 0
group_rows = n_splits // split_groups
with T.Kernel(num_tokens, threads=threads) as pid:
##################################################################
# _mhc_pre_norm_fn_fwd_norm
tx = T.get_thread_binding()
mixes_shared = T.alloc_shared(mhc_mult3, T.float32)
rms = T.alloc_fragment(1, T.float32)
if n_splits >= 4 and n_splits % split_groups == 0:
sqrsum = T.alloc_fragment(n_splits_aligned, T.float32)
T.copy(gemm_out_sqrsum[:, pid], sqrsum)
T.reduce_sum(sqrsum, rms)
rms[0] = T.rsqrt(rms[0] / (mhc_mult * hidden_size) + rms_eps)
mixes_pre = T.alloc_fragment((split_groups, 32), T.float32)
mixes_aligned = T.alloc_fragment(32, T.float32)
T.clear(mixes_pre)
for r in T.serial(group_rows):
for i, j in T.Parallel(split_groups, 32):
if j < mhc_mult3:
mixes_pre[i, j] += gemm_out_mul[i * group_rows + r, pid, j]
T.reduce_sum(mixes_pre, mixes_aligned, dim=0)
for i in T.Parallel(32):
if i < mhc_mult3:
mixes_shared[i] = mixes_aligned[i] * rms[0]
elif n_splits >= 2:
sqrsum = T.alloc_fragment(n_splits_aligned, T.float32)
T.copy(gemm_out_sqrsum[:, pid], sqrsum)
T.reduce_sum(sqrsum, rms)
rms[0] = T.rsqrt(rms[0] / (mhc_mult * hidden_size) + rms_eps)
mixes = T.alloc_fragment(mhc_mult3, T.float32)
for j in T.Parallel(mhc_mult3):
mixes[j] = 0
for i in T.serial(n_splits):
mixes[j] += gemm_out_mul[i, pid, j]
mixes[j] *= rms[0]
T.copy(mixes, mixes_shared, disable_tma=True)
else:
rms[0] = gemm_out_sqrsum[0, pid]
rms[0] = T.rsqrt(rms[0] / (mhc_mult * hidden_size) + rms_eps)
mixes = T.alloc_fragment(mhc_mult3, T.float32)
for j in T.Parallel(mhc_mult3):
mixes[j] = gemm_out_mul[0, pid, j]
mixes[j] *= rms[0]
T.copy(mixes, mixes_shared, disable_tma=True)
if tx < 64:
##################################################################
# _mhc_pre_split_mixes_fwd (post & comb)
cm = T.alloc_fragment((mhc_mult, mhc_mult), T.float32)
for j in T.Parallel(mhc_mult):
post_mix[pid, j] = T.sigmoid(mixes_shared[j + mhc_mult] * mhc_scale[1] + mhc_base[j + mhc_mult]) * mhc_post_mult_value
for j, k in T.Parallel(mhc_mult, mhc_mult):
cm[j, k] = mixes_shared[j * mhc_mult + k + mhc_mult * 2] * mhc_scale[2] + mhc_base[j * mhc_mult + k + mhc_mult * 2]
##################################################################
# _mhc_sinkhorn_fwd
row_sum = T.alloc_fragment(mhc_mult, T.float32)
col_sum = T.alloc_fragment(mhc_mult, T.float32)
# comb = comb.softmax(-1) + eps
row_max = T.alloc_fragment(mhc_mult, T.float32)
T.reduce_max(cm, row_max, dim=1)
for j, k in T.Parallel(mhc_mult, mhc_mult):
cm[j, k] = T.exp(cm[j, k] - row_max[j])
T.reduce_sum(cm, row_sum, dim=1)
for j, k in T.Parallel(mhc_mult, mhc_mult):
cm[j, k] = cm[j, k] / row_sum[j] + mhc_sinkhorn_eps
# comb = comb / (comb.sum(-2) + eps)
T.reduce_sum(cm, col_sum, dim=0)
for j, k in T.Parallel(mhc_mult, mhc_mult):
cm[j, k] = cm[j, k] / (col_sum[k] + mhc_sinkhorn_eps)
for _ in T.serial(sinkhorn_repeat - 1):
# comb = comb / (comb.sum(-1) + eps)
T.reduce_sum(cm, row_sum, dim=1)
for j, k in T.Parallel(mhc_mult, mhc_mult):
cm[j, k] = cm[j, k] / (row_sum[j] + mhc_sinkhorn_eps)
# comb = comb / (comb.sum(-2) + eps)
T.reduce_sum(cm, col_sum, dim=0)
for j, k in T.Parallel(mhc_mult, mhc_mult):
cm[j, k] = cm[j, k] / (col_sum[k] + mhc_sinkhorn_eps)
# save comb_mix to global memory
for j, k in T.Parallel(mhc_mult, mhc_mult):
comb_mix[pid, j * mhc_mult + k] = cm[j, k]
else:
##################################################################
# _mhc_pre_split_mixes_fwd (pre)
pre_mix_shared = T.alloc_fragment(mhc_mult, T.float32)
for j in T.serial(mhc_mult):
pre_mix_shared[j] = (
T.sigmoid(
mixes_shared[j] * mhc_scale[0] + mhc_base[j],
)
+ mhc_pre_eps
)
###################################################################
# _mhc_pre_apply_mix_fwd
for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=0):
# xs = T.alloc_shared((mhc_mult, hidden_block), T.bfloat16)
xl = T.alloc_fragment((mhc_mult, hidden_block), T.float32)
T.copy(residual[pid, 0, i0_h * hidden_block], xl, disable_tma=True)
# T.copy(xs, xl, disable_tma=True)
ol = T.alloc_fragment(hidden_block, T.float32)
T.clear(ol)
for i_mhc in T.serial(mhc_mult):
pre = pre_mix_shared[i_mhc]
for i1_h in T.Parallel(hidden_block):
ol[i1_h] += pre * xl[i_mhc, i1_h]
T.copy(ol, layer_input[pid, i0_h * hidden_block], disable_tma=True)
return mhc_pre_big_fuse
def pre_big_fuse_tilelang(
gemm_out_mul: torch.Tensor,
gemm_out_sqrsum: torch.Tensor,
mhc_scale: torch.Tensor,
mhc_base: torch.Tensor,
residual: torch.Tensor,
post_mix: torch.Tensor,
comb_mix: torch.Tensor,
layer_input: torch.Tensor,
hidden_size: int,
rms_eps: float,
mhc_pre_eps: float,
mhc_sinkhorn_eps: float,
mhc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 16,
mhc_mult: int = 4,
) -> None:
_mhc_pre_big_fuse(
hidden_size,
rms_eps,
mhc_pre_eps,
mhc_sinkhorn_eps,
mhc_post_mult_value,
sinkhorn_repeat,
n_splits=n_splits,
mhc_mult=mhc_mult,
)(
gemm_out_mul,
gemm_out_sqrsum,
mhc_scale,
mhc_base,
residual,
post_mix,
comb_mix,
layer_input,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment