Unverified Commit c6d7f8d3 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add some fused elementwise kernels for grok-1 (#4398)


Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarHanming Lu <69857889+hanming-lu@users.noreply.github.com>
parent a5a892ff
from typing import Tuple
import torch
import triton
import triton.language as tl
fused_softcap_autotune = triton.autotune(
configs=[
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=4),
triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=4),
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 32768}, num_warps=32),
],
key=["n_ele"],
)
@triton.jit
def fused_softcap_kernel(
output_ptr,
input_ptr,
n_ele,
softcap_const: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_ele
x = tl.load(input_ptr + offsets, mask=mask)
fx = x.to(tl.float32)
fxs = fx / softcap_const
exped = tl.exp(2 * fxs)
top = exped - 1
bottom = exped + 1
output = top / bottom * softcap_const
tl.store(output_ptr + offsets, output, mask=mask)
fused_softcap_kernel_autotuned = fused_softcap_autotune(fused_softcap_kernel)
def fused_softcap(x, softcap_const, autotune=False):
output = torch.empty_like(x, dtype=torch.float32)
n_elements = output.numel()
if autotune:
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
fused_softcap_kernel_autotuned[grid](output, x, n_elements, softcap_const)
else:
fused_softcap_kernel[(triton.cdiv(n_elements, 128),)](
output, x, n_elements, softcap_const, BLOCK_SIZE=128, num_warps=8
)
return output
# cast to float + softcap
class Softcap:
def __init__(self, softcap_const: float):
self.softcap_const = softcap_const
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.is_cuda:
return self.forward_cuda(x)
else:
return self.forward_native(x)
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return torch.tanh(x.float() / self.softcap_const) * self.softcap_const
def forward_cuda(self, x: torch.Tensor, autotune=False) -> torch.Tensor:
return fused_softcap(x, self.softcap_const, autotune=autotune)
rmsnorm_autotune = triton.autotune(
configs=[
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=8),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=8),
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4),
],
key=["hidden_dim"],
)
@triton.jit
def fused_dual_residual_rmsnorm_kernel(
output_ptr,
mid_ptr,
activ_ptr,
residual_ptr,
weight1_ptr,
weight2_ptr,
eps: tl.constexpr,
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
input_start = pid * hidden_dim
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_dim
a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
a = a_.to(tl.float32)
rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
r = tl.load(residual_ptr + input_start + offsets, mask=mask, other=0.0)
w1_ = tl.load(weight1_ptr + offsets, mask=mask, other=0.0)
w1 = w1_.to(tl.float32)
a2r = r + (a / rms * w1).to(r.dtype)
tl.store(
mid_ptr + input_start + offsets,
a2r,
mask=mask,
)
a2r = a2r.to(tl.float32)
rms2 = tl.sqrt(tl.sum(a2r * a2r, axis=0) / hidden_dim + eps)
w2_ = tl.load(weight2_ptr + offsets, mask=mask, other=0.0)
w2 = w2_.to(tl.float32)
tl.store(
output_ptr + input_start + offsets,
a2r / rms2 * w2, # implicitly casts to output dtype here
mask=mask,
)
fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
fused_dual_residual_rmsnorm_kernel
)
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
assert len(x.shape) == 2
assert x.shape == residual.shape and x.dtype == residual.dtype
output, mid = torch.empty_like(x), torch.empty_like(x)
bs, hidden_dim = x.shape
if autotune:
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
)
else:
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
),
}
fused_dual_residual_rmsnorm_kernel[(bs,)](
output,
mid,
x,
residual,
weight1,
weight2,
eps=eps,
hidden_dim=hidden_dim,
**config,
)
return output, mid
@triton.jit
def fused_rmsnorm_kernel(
output_ptr,
activ_ptr,
weight_ptr,
eps: tl.constexpr,
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
input_start = pid * hidden_dim
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_dim
a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
a = a_.to(tl.float32)
rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
w1_ = tl.load(weight_ptr + offsets, mask=mask, other=0.0)
w1 = w1_.to(tl.float32)
a_rms = a / rms * w1
tl.store(
output_ptr + input_start + offsets,
a_rms, # implicitly casts to output dtype here
mask=mask,
)
def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
assert len(x.shape) == 2
if inplace:
output = x
else:
output = torch.empty_like(x)
bs, hidden_dim = x.shape
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
),
}
fused_rmsnorm_kernel[(bs,)](
output, x, weight, eps=eps, hidden_dim=hidden_dim, **config
)
return output
class FusedDualResidualRMSNorm:
"""
Fused implementation of
y = RMSNorm2(RMSNorm1(x) + residual))
"""
def __init__(self, rmsnorm1, rmsnorm2) -> None: # the one after rmsnorm1
self.rmsnorm1 = rmsnorm1
self.rmsnorm2 = rmsnorm2
self.variance_epsilon = self.rmsnorm1.variance_epsilon
assert self.rmsnorm1.variance_epsilon == self.rmsnorm2.variance_epsilon
assert self.rmsnorm1.weight.shape == self.rmsnorm2.weight.shape
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(
self, x: torch.Tensor, residual: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if x.is_cuda:
return self.forward_cuda(x, residual)
else:
return self.forward_flashinfer(x, residual)
def forward_cuda(
self, x: torch.Tensor, residual: torch.Tensor, autotune=False
) -> Tuple[torch.Tensor, torch.Tensor]:
return fused_dual_residual_rmsnorm(
x,
residual,
self.rmsnorm1.weight,
self.rmsnorm2.weight,
self.variance_epsilon,
autotune=autotune,
)
def forward_flashinfer(
self,
x: torch.Tensor,
residual: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
normed1 = self.rmsnorm1(x)
residual = normed1 + residual
return self.rmsnorm2(residual), residual
def forward_native(
self,
x: torch.Tensor,
residual: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
normed1 = self.rmsnorm1.forward_native(x)
residual = normed1 + residual
return self.rmsnorm2.forward_native(residual), residual
# gelu on first half of vector
@triton.jit
def gelu_and_mul_kernel(
out_hidden_states_ptr, # (bs, hidden_dim)
out_scales_ptr, # (bs,)
hidden_states_ptr, # (bs, hidden_dim * 2)
quant_max: tl.constexpr,
static_scale: tl.constexpr,
hidden_dim: tl.constexpr, # the output hidden_dim
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
input_start = pid * hidden_dim * 2
output_start = pid * hidden_dim
input1_offs = tl.arange(0, BLOCK_SIZE)
mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
output_offs = tl.arange(0, BLOCK_SIZE)
x1 = tl.load(
hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
).to(tl.float32)
x3 = tl.load(
hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
).to(tl.float32)
# gelu
# cast down before mul to better match training?
gelu_x1 = 0.5 * (1.0 + tl.erf(x1 * 0.7071067811865475)) * x1
out = x3 * gelu_x1.to(hidden_states_ptr.dtype.element_ty)
if quant_max is not None:
raise NotImplementedError()
tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
def gelu_and_mul_triton(
hidden_states,
scales=None,
quantize=None, # dtype to quantize to
out=None,
):
bs, in_hidden_dim = hidden_states.shape
hidden_dim = in_hidden_dim // 2
if out is None:
out_hidden_states = torch.empty(
(bs, hidden_dim),
dtype=quantize or hidden_states.dtype,
device=hidden_states.device,
)
else:
assert out.shape == (bs, hidden_dim)
assert out.dtype == (quantize or hidden_states.dtype)
out_hidden_states = out
out_scales = None
static_scale = False
if quantize is not None:
if scales is None:
out_scales = torch.empty(
(bs,), dtype=torch.float32, device=hidden_states.device
)
else:
out_scales = scales
static_scale = True
config = {
# 8 ele per thread (not tuned)
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
),
}
gelu_and_mul_kernel[(bs,)](
out_hidden_states,
out_scales,
hidden_states,
quant_max=torch.finfo(quantize).max if quantize is not None else None,
static_scale=static_scale,
hidden_dim=hidden_dim,
BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
**config,
)
if quantize is not None:
return out_hidden_states, out_scales
else:
return out_hidden_states, None
from typing import Tuple
import torch
import triton
import triton.language as tl
from sglang.srt.layers.moe.topk import fused_topk
@triton.jit
def fused_moe_router_kernel(
input_ptr, # input (bs, hidden_dim)
moe_router_weight_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk)
topk_ids_ptr, # output (bs, topk)
num_experts: tl.constexpr,
topk: tl.constexpr,
moe_softcapping: tl.constexpr,
moe_renormalize: tl.constexpr, # not supported
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_dim
# moe_router_weight is k major
expert_offsets = tl.arange(0, num_experts)[:, None]
router_mask = mask[None, :]
w_router = tl.load(
moe_router_weight_ptr + expert_offsets * hidden_dim + offsets[None, :],
mask=router_mask,
other=0.0,
)
x = tl.load(input_ptr + pid * hidden_dim + offsets, mask=mask, other=0.0)
# todo: tl.dot?
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
# logit softcap
logits_scaled = logits / moe_softcapping
exped = tl.exp(2 * logits_scaled)
top = exped - 1
bottom = exped + 1
logits_softcapped = top / bottom * moe_softcapping
# topk
# assert 1 <= topk <= num_experts
# 5.38 us
top1 = tl.argmax(logits_softcapped, axis=0)
tl.store(topk_ids_ptr + pid * topk + 0, top1) # 5.63 us
top1_v = tl.max(logits_softcapped, axis=0)
invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0)
tl.store(
topk_weights_ptr + pid * topk + 0,
invsumexp,
) # 5.73 us
if topk >= 2:
top2 = tl.argmax(
tl.where(
tl.arange(0, num_experts) != top1, logits_softcapped, float("-inf")
),
axis=0,
)
tl.store(topk_ids_ptr + pid * topk + 1, top2)
top2_v = tl.sum(logits_softcapped * (tl.arange(0, num_experts) == top2), axis=0)
tl.store(
topk_weights_ptr + pid * topk + 1,
tl.exp(top2_v - top1_v) * invsumexp,
) # 5.95us
# probably slow
if topk > 2:
topk_mask = tl.full(logits_softcapped.shape, 1.0, dtype=logits_softcapped.dtype)
topk_mask = tl.where(
tl.arange(0, num_experts) != top1, topk_mask, float("-inf")
)
topk_mask = tl.where(
tl.arange(0, num_experts) != top2, topk_mask, float("-inf")
)
for i in range(2, topk):
topi = tl.argmax(logits_softcapped + topk_mask, axis=0)
topk_mask = tl.where(
tl.arange(0, num_experts) != topi, topk_mask, float("-inf")
)
tl.store(topk_ids_ptr + pid * topk + i, topi)
topi_v = tl.sum(
logits_softcapped * (tl.arange(0, num_experts) == topi), axis=0
)
tl.store(
topk_weights_ptr + pid * topk + i,
tl.exp(topi_v - top1_v) * invsumexp,
)
# assert not moe_renormalize, "moe weight renormalization not implemented"
def fused_moe_router_impl(
x: torch.Tensor,
router_weight: torch.Tensor,
topk: int,
moe_softcapping: float,
):
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
bs, hidden_dim = x.shape
num_experts = router_weight.shape[0]
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
grid = lambda meta: (bs,)
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
),
}
fused_moe_router_kernel[grid](
x,
router_weight,
topk_weights,
topk_ids,
num_experts=num_experts,
topk=topk,
moe_softcapping=moe_softcapping,
moe_renormalize=False,
hidden_dim=hidden_dim,
**config,
)
return topk_weights, topk_ids
@triton.jit
def fused_moe_router_large_bs_kernel(
a_ptr, # input (bs, hidden_dim)
b_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk)
topk_ids_ptr, # output (bs, topk)
bs,
num_experts: tl.constexpr,
topk: tl.constexpr, # only support topk == 1
moe_softcapping: tl.constexpr,
moe_renormalize: tl.constexpr, # not supported
K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
stride_am: tl.constexpr,
stride_bn: tl.constexpr,
):
# 1. get block id
pid = tl.program_id(axis=0)
# 2. create pointers for the first block of A and B
# 2.1. setup a_ptrs with offsets in m and k
offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]
bs_mask = offs_m < bs
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
a_ptrs = a_ptr + (offs_m * stride_am + offs_k)
# 2.2. setup b_ptrs with offsets in k and n.
# Note: b matrix is k-major.
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
offs_n = tl.arange(0, BLOCK_SIZE_N)[:, None]
expert_mask = offs_n < num_experts
b_ptrs = b_ptr + (offs_n * stride_bn + offs_k)
# 3. Create an accumulator of float32 of size [BLOCK_SIZE_M, BLOCK_SIZE_N]
# 3.1. iterate in K dimension
# 3.2. transpose tile B
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K // BLOCK_SIZE_K): # hidden_dim % BLOCK_SIZE_K == 0
a = tl.load(
a_ptrs,
mask=bs_mask,
other=0.0,
).to(tl.float32)
b = tl.load(b_ptrs, mask=expert_mask, other=0.0).to(tl.float32).T
acc += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
# 4. logit softcap
logits_scaled = acc / moe_softcapping
exped = tl.exp(2 * logits_scaled)
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
# 5. top1
cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
top1_v = tl.max(
tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
)
invsumexp = 1.0 / tl.sum(
tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
)
# 6. store to output
offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
topk_mask = offs_topk < bs
tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
tl.store(
topk_weights_ptr + offs_topk,
invsumexp,
mask=topk_mask,
)
def fused_moe_router_large_bs_impl(
x: torch.Tensor,
router_weight: torch.Tensor,
topk: int,
moe_softcapping: float,
BLOCK_SIZE_M: int,
BLOCK_SIZE_N: int,
BLOCK_SIZE_K: int,
):
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
bs, hidden_dim = x.shape
num_experts = router_weight.shape[0]
assert num_experts <= BLOCK_SIZE_N
assert hidden_dim % BLOCK_SIZE_K == 0
assert topk == 1
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
fused_moe_router_large_bs_kernel[grid](
a_ptr=x,
b_ptr=router_weight,
topk_weights_ptr=topk_weights,
topk_ids_ptr=topk_ids,
bs=bs,
num_experts=num_experts,
topk=topk,
moe_softcapping=moe_softcapping,
moe_renormalize=False,
K=hidden_dim,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
stride_am=hidden_dim,
stride_bn=hidden_dim,
)
return topk_weights, topk_ids
def fused_moe_router_shim(
moe_softcapping,
hidden_states,
gating_output,
topk,
renormalize,
):
assert not renormalize
assert (
len(hidden_states.shape) == 2
and hidden_states.shape[1] == gating_output.shape[1]
)
bs, hidden_dim = hidden_states.shape
num_experts = gating_output.shape[0]
BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 16
BLOCK_SIZE_K = 256
if (
bs >= 512
and topk == 1
and num_experts <= BLOCK_SIZE_N
and hidden_dim % BLOCK_SIZE_K == 0
):
return fused_moe_router_large_bs_impl(
x=hidden_states,
router_weight=gating_output,
topk=topk,
moe_softcapping=moe_softcapping,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
else:
return fused_moe_router_impl(
x=hidden_states,
router_weight=gating_output,
topk=topk,
moe_softcapping=moe_softcapping,
)
class FusedMoeRouter:
def __init__(self, router_linear, topk, moe_softcapping) -> None:
self.router_linear = router_linear
self.topk = topk
self.moe_softcapping = moe_softcapping
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(
self, x: torch.Tensor, residual: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if x.is_cuda:
return self.forward_cuda(x, residual)
else:
return self.forward_vllm(x, residual)
def forward_cuda(
self, x: torch.Tensor, autotune=False
) -> Tuple[torch.Tensor, torch.Tensor]:
return fused_moe_router_shim(
moe_softcapping=self.moe_softcapping,
hidden_states=x,
gating_output=self.router_linear.weight,
topk=self.topk,
renormalize=False,
)
def forward_vllm(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# g, _ = self.router_linear.forward(x)
g = x.float() @ self.router_linear.weight.T.float()
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
return fused_topk(x, g, self.topk, False)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment