Unverified Commit c6d7f8d3 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add some fused elementwise kernels for grok-1 (#4398)


Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarHanming Lu <69857889+hanming-lu@users.noreply.github.com>
parent a5a892ff
from typing import Tuple
import torch
import triton
import triton.language as tl
fused_softcap_autotune = triton.autotune(
configs=[
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=4),
triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=4),
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 32768}, num_warps=32),
],
key=["n_ele"],
)
@triton.jit
def fused_softcap_kernel(
output_ptr,
input_ptr,
n_ele,
softcap_const: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_ele
x = tl.load(input_ptr + offsets, mask=mask)
fx = x.to(tl.float32)
fxs = fx / softcap_const
exped = tl.exp(2 * fxs)
top = exped - 1
bottom = exped + 1
output = top / bottom * softcap_const
tl.store(output_ptr + offsets, output, mask=mask)
fused_softcap_kernel_autotuned = fused_softcap_autotune(fused_softcap_kernel)
def fused_softcap(x, softcap_const, autotune=False):
output = torch.empty_like(x, dtype=torch.float32)
n_elements = output.numel()
if autotune:
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
fused_softcap_kernel_autotuned[grid](output, x, n_elements, softcap_const)
else:
fused_softcap_kernel[(triton.cdiv(n_elements, 128),)](
output, x, n_elements, softcap_const, BLOCK_SIZE=128, num_warps=8
)
return output
# cast to float + softcap
class Softcap:
def __init__(self, softcap_const: float):
self.softcap_const = softcap_const
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.is_cuda:
return self.forward_cuda(x)
else:
return self.forward_native(x)
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return torch.tanh(x.float() / self.softcap_const) * self.softcap_const
def forward_cuda(self, x: torch.Tensor, autotune=False) -> torch.Tensor:
return fused_softcap(x, self.softcap_const, autotune=autotune)
rmsnorm_autotune = triton.autotune(
configs=[
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=8),
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=8),
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=1),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=4),
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4),
],
key=["hidden_dim"],
)
@triton.jit
def fused_dual_residual_rmsnorm_kernel(
output_ptr,
mid_ptr,
activ_ptr,
residual_ptr,
weight1_ptr,
weight2_ptr,
eps: tl.constexpr,
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
input_start = pid * hidden_dim
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_dim
a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
a = a_.to(tl.float32)
rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
r = tl.load(residual_ptr + input_start + offsets, mask=mask, other=0.0)
w1_ = tl.load(weight1_ptr + offsets, mask=mask, other=0.0)
w1 = w1_.to(tl.float32)
a2r = r + (a / rms * w1).to(r.dtype)
tl.store(
mid_ptr + input_start + offsets,
a2r,
mask=mask,
)
a2r = a2r.to(tl.float32)
rms2 = tl.sqrt(tl.sum(a2r * a2r, axis=0) / hidden_dim + eps)
w2_ = tl.load(weight2_ptr + offsets, mask=mask, other=0.0)
w2 = w2_.to(tl.float32)
tl.store(
output_ptr + input_start + offsets,
a2r / rms2 * w2, # implicitly casts to output dtype here
mask=mask,
)
fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
fused_dual_residual_rmsnorm_kernel
)
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
assert len(x.shape) == 2
assert x.shape == residual.shape and x.dtype == residual.dtype
output, mid = torch.empty_like(x), torch.empty_like(x)
bs, hidden_dim = x.shape
if autotune:
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
)
else:
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
),
}
fused_dual_residual_rmsnorm_kernel[(bs,)](
output,
mid,
x,
residual,
weight1,
weight2,
eps=eps,
hidden_dim=hidden_dim,
**config,
)
return output, mid
@triton.jit
def fused_rmsnorm_kernel(
output_ptr,
activ_ptr,
weight_ptr,
eps: tl.constexpr,
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
input_start = pid * hidden_dim
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_dim
a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
a = a_.to(tl.float32)
rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
w1_ = tl.load(weight_ptr + offsets, mask=mask, other=0.0)
w1 = w1_.to(tl.float32)
a_rms = a / rms * w1
tl.store(
output_ptr + input_start + offsets,
a_rms, # implicitly casts to output dtype here
mask=mask,
)
def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
assert len(x.shape) == 2
if inplace:
output = x
else:
output = torch.empty_like(x)
bs, hidden_dim = x.shape
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
),
}
fused_rmsnorm_kernel[(bs,)](
output, x, weight, eps=eps, hidden_dim=hidden_dim, **config
)
return output
class FusedDualResidualRMSNorm:
"""
Fused implementation of
y = RMSNorm2(RMSNorm1(x) + residual))
"""
def __init__(self, rmsnorm1, rmsnorm2) -> None: # the one after rmsnorm1
self.rmsnorm1 = rmsnorm1
self.rmsnorm2 = rmsnorm2
self.variance_epsilon = self.rmsnorm1.variance_epsilon
assert self.rmsnorm1.variance_epsilon == self.rmsnorm2.variance_epsilon
assert self.rmsnorm1.weight.shape == self.rmsnorm2.weight.shape
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(
self, x: torch.Tensor, residual: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if x.is_cuda:
return self.forward_cuda(x, residual)
else:
return self.forward_flashinfer(x, residual)
def forward_cuda(
self, x: torch.Tensor, residual: torch.Tensor, autotune=False
) -> Tuple[torch.Tensor, torch.Tensor]:
return fused_dual_residual_rmsnorm(
x,
residual,
self.rmsnorm1.weight,
self.rmsnorm2.weight,
self.variance_epsilon,
autotune=autotune,
)
def forward_flashinfer(
self,
x: torch.Tensor,
residual: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
normed1 = self.rmsnorm1(x)
residual = normed1 + residual
return self.rmsnorm2(residual), residual
def forward_native(
self,
x: torch.Tensor,
residual: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
normed1 = self.rmsnorm1.forward_native(x)
residual = normed1 + residual
return self.rmsnorm2.forward_native(residual), residual
# gelu on first half of vector
@triton.jit
def gelu_and_mul_kernel(
out_hidden_states_ptr, # (bs, hidden_dim)
out_scales_ptr, # (bs,)
hidden_states_ptr, # (bs, hidden_dim * 2)
quant_max: tl.constexpr,
static_scale: tl.constexpr,
hidden_dim: tl.constexpr, # the output hidden_dim
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
input_start = pid * hidden_dim * 2
output_start = pid * hidden_dim
input1_offs = tl.arange(0, BLOCK_SIZE)
mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
output_offs = tl.arange(0, BLOCK_SIZE)
x1 = tl.load(
hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
).to(tl.float32)
x3 = tl.load(
hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
).to(tl.float32)
# gelu
# cast down before mul to better match training?
gelu_x1 = 0.5 * (1.0 + tl.erf(x1 * 0.7071067811865475)) * x1
out = x3 * gelu_x1.to(hidden_states_ptr.dtype.element_ty)
if quant_max is not None:
raise NotImplementedError()
tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
def gelu_and_mul_triton(
hidden_states,
scales=None,
quantize=None, # dtype to quantize to
out=None,
):
bs, in_hidden_dim = hidden_states.shape
hidden_dim = in_hidden_dim // 2
if out is None:
out_hidden_states = torch.empty(
(bs, hidden_dim),
dtype=quantize or hidden_states.dtype,
device=hidden_states.device,
)
else:
assert out.shape == (bs, hidden_dim)
assert out.dtype == (quantize or hidden_states.dtype)
out_hidden_states = out
out_scales = None
static_scale = False
if quantize is not None:
if scales is None:
out_scales = torch.empty(
(bs,), dtype=torch.float32, device=hidden_states.device
)
else:
out_scales = scales
static_scale = True
config = {
# 8 ele per thread (not tuned)
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
),
}
gelu_and_mul_kernel[(bs,)](
out_hidden_states,
out_scales,
hidden_states,
quant_max=torch.finfo(quantize).max if quantize is not None else None,
static_scale=static_scale,
hidden_dim=hidden_dim,
BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
**config,
)
if quantize is not None:
return out_hidden_states, out_scales
else:
return out_hidden_states, None
from typing import Tuple
import torch
import triton
import triton.language as tl
from sglang.srt.layers.moe.topk import fused_topk
@triton.jit
def fused_moe_router_kernel(
input_ptr, # input (bs, hidden_dim)
moe_router_weight_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk)
topk_ids_ptr, # output (bs, topk)
num_experts: tl.constexpr,
topk: tl.constexpr,
moe_softcapping: tl.constexpr,
moe_renormalize: tl.constexpr, # not supported
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_dim
# moe_router_weight is k major
expert_offsets = tl.arange(0, num_experts)[:, None]
router_mask = mask[None, :]
w_router = tl.load(
moe_router_weight_ptr + expert_offsets * hidden_dim + offsets[None, :],
mask=router_mask,
other=0.0,
)
x = tl.load(input_ptr + pid * hidden_dim + offsets, mask=mask, other=0.0)
# todo: tl.dot?
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
# logit softcap
logits_scaled = logits / moe_softcapping
exped = tl.exp(2 * logits_scaled)
top = exped - 1
bottom = exped + 1
logits_softcapped = top / bottom * moe_softcapping
# topk
# assert 1 <= topk <= num_experts
# 5.38 us
top1 = tl.argmax(logits_softcapped, axis=0)
tl.store(topk_ids_ptr + pid * topk + 0, top1) # 5.63 us
top1_v = tl.max(logits_softcapped, axis=0)
invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0)
tl.store(
topk_weights_ptr + pid * topk + 0,
invsumexp,
) # 5.73 us
if topk >= 2:
top2 = tl.argmax(
tl.where(
tl.arange(0, num_experts) != top1, logits_softcapped, float("-inf")
),
axis=0,
)
tl.store(topk_ids_ptr + pid * topk + 1, top2)
top2_v = tl.sum(logits_softcapped * (tl.arange(0, num_experts) == top2), axis=0)
tl.store(
topk_weights_ptr + pid * topk + 1,
tl.exp(top2_v - top1_v) * invsumexp,
) # 5.95us
# probably slow
if topk > 2:
topk_mask = tl.full(logits_softcapped.shape, 1.0, dtype=logits_softcapped.dtype)
topk_mask = tl.where(
tl.arange(0, num_experts) != top1, topk_mask, float("-inf")
)
topk_mask = tl.where(
tl.arange(0, num_experts) != top2, topk_mask, float("-inf")
)
for i in range(2, topk):
topi = tl.argmax(logits_softcapped + topk_mask, axis=0)
topk_mask = tl.where(
tl.arange(0, num_experts) != topi, topk_mask, float("-inf")
)
tl.store(topk_ids_ptr + pid * topk + i, topi)
topi_v = tl.sum(
logits_softcapped * (tl.arange(0, num_experts) == topi), axis=0
)
tl.store(
topk_weights_ptr + pid * topk + i,
tl.exp(topi_v - top1_v) * invsumexp,
)
# assert not moe_renormalize, "moe weight renormalization not implemented"
def fused_moe_router_impl(
x: torch.Tensor,
router_weight: torch.Tensor,
topk: int,
moe_softcapping: float,
):
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
bs, hidden_dim = x.shape
num_experts = router_weight.shape[0]
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
grid = lambda meta: (bs,)
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
),
}
fused_moe_router_kernel[grid](
x,
router_weight,
topk_weights,
topk_ids,
num_experts=num_experts,
topk=topk,
moe_softcapping=moe_softcapping,
moe_renormalize=False,
hidden_dim=hidden_dim,
**config,
)
return topk_weights, topk_ids
@triton.jit
def fused_moe_router_large_bs_kernel(
a_ptr, # input (bs, hidden_dim)
b_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk)
topk_ids_ptr, # output (bs, topk)
bs,
num_experts: tl.constexpr,
topk: tl.constexpr, # only support topk == 1
moe_softcapping: tl.constexpr,
moe_renormalize: tl.constexpr, # not supported
K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
stride_am: tl.constexpr,
stride_bn: tl.constexpr,
):
# 1. get block id
pid = tl.program_id(axis=0)
# 2. create pointers for the first block of A and B
# 2.1. setup a_ptrs with offsets in m and k
offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]
bs_mask = offs_m < bs
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
a_ptrs = a_ptr + (offs_m * stride_am + offs_k)
# 2.2. setup b_ptrs with offsets in k and n.
# Note: b matrix is k-major.
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
offs_n = tl.arange(0, BLOCK_SIZE_N)[:, None]
expert_mask = offs_n < num_experts
b_ptrs = b_ptr + (offs_n * stride_bn + offs_k)
# 3. Create an accumulator of float32 of size [BLOCK_SIZE_M, BLOCK_SIZE_N]
# 3.1. iterate in K dimension
# 3.2. transpose tile B
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K // BLOCK_SIZE_K): # hidden_dim % BLOCK_SIZE_K == 0
a = tl.load(
a_ptrs,
mask=bs_mask,
other=0.0,
).to(tl.float32)
b = tl.load(b_ptrs, mask=expert_mask, other=0.0).to(tl.float32).T
acc += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
# 4. logit softcap
logits_scaled = acc / moe_softcapping
exped = tl.exp(2 * logits_scaled)
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
# 5. top1
cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
top1_v = tl.max(
tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
)
invsumexp = 1.0 / tl.sum(
tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
)
# 6. store to output
offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
topk_mask = offs_topk < bs
tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
tl.store(
topk_weights_ptr + offs_topk,
invsumexp,
mask=topk_mask,
)
def fused_moe_router_large_bs_impl(
x: torch.Tensor,
router_weight: torch.Tensor,
topk: int,
moe_softcapping: float,
BLOCK_SIZE_M: int,
BLOCK_SIZE_N: int,
BLOCK_SIZE_K: int,
):
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
bs, hidden_dim = x.shape
num_experts = router_weight.shape[0]
assert num_experts <= BLOCK_SIZE_N
assert hidden_dim % BLOCK_SIZE_K == 0
assert topk == 1
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
fused_moe_router_large_bs_kernel[grid](
a_ptr=x,
b_ptr=router_weight,
topk_weights_ptr=topk_weights,
topk_ids_ptr=topk_ids,
bs=bs,
num_experts=num_experts,
topk=topk,
moe_softcapping=moe_softcapping,
moe_renormalize=False,
K=hidden_dim,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
stride_am=hidden_dim,
stride_bn=hidden_dim,
)
return topk_weights, topk_ids
def fused_moe_router_shim(
moe_softcapping,
hidden_states,
gating_output,
topk,
renormalize,
):
assert not renormalize
assert (
len(hidden_states.shape) == 2
and hidden_states.shape[1] == gating_output.shape[1]
)
bs, hidden_dim = hidden_states.shape
num_experts = gating_output.shape[0]
BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 16
BLOCK_SIZE_K = 256
if (
bs >= 512
and topk == 1
and num_experts <= BLOCK_SIZE_N
and hidden_dim % BLOCK_SIZE_K == 0
):
return fused_moe_router_large_bs_impl(
x=hidden_states,
router_weight=gating_output,
topk=topk,
moe_softcapping=moe_softcapping,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
else:
return fused_moe_router_impl(
x=hidden_states,
router_weight=gating_output,
topk=topk,
moe_softcapping=moe_softcapping,
)
class FusedMoeRouter:
def __init__(self, router_linear, topk, moe_softcapping) -> None:
self.router_linear = router_linear
self.topk = topk
self.moe_softcapping = moe_softcapping
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(
self, x: torch.Tensor, residual: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if x.is_cuda:
return self.forward_cuda(x, residual)
else:
return self.forward_vllm(x, residual)
def forward_cuda(
self, x: torch.Tensor, autotune=False
) -> Tuple[torch.Tensor, torch.Tensor]:
return fused_moe_router_shim(
moe_softcapping=self.moe_softcapping,
hidden_states=x,
gating_output=self.router_linear.weight,
topk=self.topk,
renormalize=False,
)
def forward_vllm(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# g, _ = self.router_linear.forward(x)
g = x.float() @ self.router_linear.weight.T.float()
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
return fused_topk(x, g, self.topk, False)
......@@ -15,28 +15,36 @@
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
from typing import Iterable, List, Optional, Tuple
import functools
import json
import logging
import math
import os
import warnings
from typing import Iterable, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.router import fused_moe_router_shim
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
......@@ -44,47 +52,17 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
from sglang.srt.utils import dump_to_file
logger = logging.getLogger(__name__)
class Grok1MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
reduce_results=True,
use_presharded_weights: bool = False,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
use_presharded_weights=use_presharded_weights,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
reduce_results=reduce_results,
use_presharded_weights=use_presharded_weights,
)
self.act_fn = GeluAndMul(approximate="tanh")
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
debug_tensor_dump_output_folder = None
debug_tensor_dump_inject = False
class Grok1MoE(nn.Module):
......@@ -108,51 +86,55 @@ class Grok1MoE(nn.Module):
tp_size: Optional[int] = None,
reduce_results=True,
use_presharded_weights: bool = False,
prefix: str = "",
inplace: bool = True,
no_combine: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
# Gate always runs at half / full precision for now.
# Gate always runs at full precision for stability (see https://arxiv.org/pdf/2101.03961)
self.gate = ReplicatedLinear(
hidden_size,
num_experts,
bias=False,
params_dtype=params_dtype,
params_dtype=torch.float32,
quant_config=None,
prefix=add_prefix("gate", prefix),
)
self.router_logit_softcapping = getattr(
config, "router_logit_softcapping", 30.0
)
self.experts = FusedMoE(
custom_routing_function = functools.partial(
fused_moe_router_shim, self.router_logit_softcapping
)
kwargs = {}
if global_server_args_dict["enable_ep_moe"]:
MoEImpl = EPMoE
else:
MoEImpl = FusedMoE
kwargs["reduce_results"] = reduce_results
kwargs["use_presharded_weights"] = use_presharded_weights
kwargs["inplace"] = inplace
kwargs["no_combine"] = no_combine
self.experts = MoEImpl(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=reduce_results,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
custom_routing_function=custom_routing_function,
activation="gelu",
use_presharded_weights=use_presharded_weights,
prefix=add_prefix("experts", prefix),
**kwargs,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
router_logits = 30.0 * F.tanh(router_logits / 30.0)
# need to assert self.gate.quant_method is unquantized
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)
return self.experts(hidden_states, self.gate.weight)
class Grok1Attention(nn.Module):
......@@ -167,31 +149,33 @@ class Grok1Attention(nn.Module):
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
load_presharded_attn: bool = False,
) -> None:
super().__init__()
self.config = config
self.layer_id = layer_id
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
attn_tp_rank = get_tensor_model_parallel_rank()
attn_tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
assert self.total_num_heads % attn_tp_size == 0
self.num_heads = self.total_num_heads // attn_tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
if self.total_num_kv_heads >= attn_tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
assert self.total_num_kv_heads % attn_tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
assert attn_tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
self.head_dim = getattr(config, "head_dim", 128)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.load_presharded_attn = load_presharded_attn
self.qkv_proj = QKVParallelLinear(
hidden_size,
......@@ -200,7 +184,9 @@ class Grok1Attention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
load_presharded_attn=self.load_presharded_attn,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
......@@ -208,7 +194,9 @@ class Grok1Attention(nn.Module):
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=add_prefix("o_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
use_presharded_weights=self.load_presharded_attn,
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -227,7 +215,6 @@ class Grok1Attention(nn.Module):
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
logit_cap=logit_cap,
prefix=add_prefix("attn", prefix),
)
def forward(
......@@ -236,10 +223,73 @@ class Grok1Attention(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results
), "short-circuiting allreduce will lead to hangs"
return hidden_states
if debug_tensor_dump_output_folder:
dump_to_file(
debug_tensor_dump_output_folder,
f"attn_input_{self.layer_id}",
hidden_states,
)
if debug_tensor_dump_inject:
name = os.path.join(
debug_tensor_dump_output_folder,
f"jax_dump_attn_input_{self.layer_id}.npy",
)
logger.info(f"Load {name} from jax.")
x = np.load(name)
hidden_states = torch.tensor(x[0, : hidden_states.shape[0]]).to(
hidden_states
)
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
if debug_tensor_dump_output_folder:
num_tokens = q.shape[0]
num_heads_q = self.num_heads
head_dim = self.head_dim
num_heads_kv = k.numel() // (num_tokens * head_dim)
dump_to_file(
debug_tensor_dump_output_folder,
f"q_{self.layer_id}",
tensor_model_parallel_all_gather(
q.reshape(num_tokens, num_heads_q, head_dim).contiguous(), dim=1
).contiguous(),
)
dump_to_file(
debug_tensor_dump_output_folder,
f"k_{self.layer_id}",
tensor_model_parallel_all_gather(
k.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1
).contiguous(),
)
dump_to_file(
debug_tensor_dump_output_folder,
f"v_{self.layer_id}",
tensor_model_parallel_all_gather(
v.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1
).contiguous(),
)
attn_output = self.attn(q, k, v, forward_batch)
if debug_tensor_dump_output_folder:
dump_to_file(
debug_tensor_dump_output_folder,
f"attn_output_{self.layer_id}",
tensor_model_parallel_all_gather(
attn_output.reshape(num_tokens, num_heads_q, head_dim).contiguous(),
dim=1,
).contiguous(),
)
output, _ = self.o_proj(attn_output)
return output
......@@ -250,8 +300,9 @@ class Grok1DecoderLayer(nn.Module):
config: PretrainedConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
use_presharded_weights: bool = False,
prefix: str = "",
load_presharded_moe: bool = False,
load_presharded_attn: bool = False,
load_presharded_mlp: bool = False,
) -> None:
super().__init__()
self.num_experts = config.num_local_experts
......@@ -268,7 +319,8 @@ class Grok1DecoderLayer(nn.Module):
layer_id=layer_id,
rope_theta=rope_theta,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
reduce_results=False,
load_presharded_attn=load_presharded_attn,
)
self.block_sparse_moe = Grok1MoE(
config=config,
......@@ -282,38 +334,68 @@ class Grok1DecoderLayer(nn.Module):
),
quant_config=quant_config,
reduce_results=True,
use_presharded_weights=use_presharded_weights,
prefix=add_prefix("block_sparse_moe", prefix),
use_presharded_weights=load_presharded_moe,
inplace=True,
no_combine=False, # just a suggestion to not combine topk
)
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.ffn = self.block_sparse_moe
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
residual: Optional[torch.Tensor] = None,
deferred_norm: Optional[RMSNorm] = None,
) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
# Self Attention
hidden_states = (
self.post_attn_norm(
self.self_attn(
positions=positions,
hidden_states=self.pre_attn_norm(hidden_states),
forward_batch=forward_batch,
)
if deferred_norm is not None:
assert residual is not None
# here hidden_states is output of ffn, residual is residual from after previous attn layer
hidden_states, residual = fused_dual_residual_rmsnorm(
hidden_states,
residual,
deferred_norm.weight,
self.pre_attn_norm.weight,
deferred_norm.variance_epsilon,
)
+ hidden_states
else:
# here hidden_states is the residual
hidden_states, residual = (
fused_rmsnorm(
hidden_states,
self.pre_attn_norm.weight,
self.pre_attn_norm.variance_epsilon,
),
hidden_states,
)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Fully Connected
hidden_states = (
self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
+ hidden_states
if get_tensor_model_parallel_world_size() > 1:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = fused_dual_residual_rmsnorm(
hidden_states,
residual,
self.post_attn_norm.weight,
self.pre_moe_norm.weight,
self.post_attn_norm.variance_epsilon,
)
return hidden_states
# Fully Connected
hidden_states = self.ffn(hidden_states)
return hidden_states, residual, self.post_moe_norm # defer layernorm
class Grok1Model(nn.Module):
......@@ -321,8 +403,10 @@ class Grok1Model(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
use_presharded_weights: bool = False,
prefix: str = "",
load_presharded_moe: bool = False,
load_presharded_embedding: bool = False,
load_presharded_attn: bool = False,
load_presharded_mlp: bool = False,
) -> None:
super().__init__()
self.config = config
......@@ -332,7 +416,7 @@ class Grok1Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
use_presharded_weights=load_presharded_embedding,
)
self.layers = nn.ModuleList(
[
......@@ -340,8 +424,9 @@ class Grok1Model(nn.Module):
config,
i,
quant_config=quant_config,
use_presharded_weights=use_presharded_weights,
prefix=add_prefix(f"layers.{i}", prefix),
load_presharded_moe=load_presharded_moe,
load_presharded_attn=load_presharded_attn,
load_presharded_mlp=load_presharded_mlp,
)
for i in range(config.num_hidden_layers)
]
......@@ -361,10 +446,48 @@ class Grok1Model(nn.Module):
else:
hidden_states = input_embeds
residual, deferred_norm = None, None
for i in range(len(self.layers)):
hidden_states = self.layers[i](positions, hidden_states, forward_batch)
hidden_states = self.norm(hidden_states)
hidden_states.mul_(self.config.output_multiplier_scale)
hidden_states, residual, deferred_norm = self.layers[i](
positions, hidden_states, forward_batch, residual, deferred_norm
)
if debug_tensor_dump_output_folder:
hidden_states = (
fused_rmsnorm(
hidden_states,
deferred_norm.weight,
deferred_norm.variance_epsilon,
)
+ residual
)
dump_to_file(
debug_tensor_dump_output_folder,
"last_hidden_before_norm",
hidden_states,
)
hidden_states = fused_rmsnorm(
hidden_states,
self.norm.weight,
self.norm.variance_epsilon,
)
dump_to_file(
debug_tensor_dump_output_folder,
"last_hidden_after_norm",
hidden_states,
)
else:
hidden_states, _ = fused_dual_residual_rmsnorm(
hidden_states,
residual,
deferred_norm.weight,
self.norm.weight,
deferred_norm.variance_epsilon,
)
return hidden_states
......@@ -373,31 +496,77 @@ class Grok1ForCausalLM(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
if (
# Get presharded weights.
self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
self.load_presharded_moe = (
self.config.num_local_experts > 0
and get_tensor_model_parallel_world_size() > 1
):
self.use_presharded_weights = True
)
self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
self.load_presharded_embedding = getattr(
config, "load_presharded_embedding", False
)
self.is_weights_presharded = (
self.load_presharded_mlp
or self.load_presharded_moe
or self.load_presharded_attn
or self.load_presharded_embedding
)
if self.is_weights_presharded:
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
else:
self.use_presharded_weights = False
default_replicate_lm_head = False
self.replicate_lm_head = getattr(
config, "replicate_lm_head", default_replicate_lm_head
)
self.model = Grok1Model(
config,
quant_config=quant_config,
use_presharded_weights=self.use_presharded_weights,
prefix=add_prefix("model", prefix),
load_presharded_moe=self.load_presharded_moe,
load_presharded_embedding=self.load_presharded_embedding,
load_presharded_attn=self.load_presharded_attn,
load_presharded_mlp=self.load_presharded_mlp,
)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
)
self.logits_processor = LogitsProcessor(config)
lm_head_params_dtype = None
if self.replicate_lm_head:
self.lm_head = ReplicatedLinear(
config.hidden_size,
config.vocab_size,
bias=False,
params_dtype=lm_head_params_dtype,
)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
use_presharded_weights=self.load_presharded_embedding,
params_dtype=lm_head_params_dtype,
)
self.logits_processor = LogitsProcessor(config)
# Dump tensors for debugging
global debug_tensor_dump_output_folder, debug_tensor_dump_inject
debug_tensor_dump_output_folder = global_server_args_dict[
"debug_tensor_dump_output_folder"
]
debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"]
warnings.filterwarnings("ignore", category=FutureWarning)
if get_tensor_model_parallel_rank() == 0:
logger.info(
f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
)
def forward(
self,
......@@ -406,6 +575,9 @@ class Grok1ForCausalLM(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if debug_tensor_dump_output_folder:
dump_to_file(debug_tensor_dump_output_folder, "input_ids", input_ids)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
......@@ -414,21 +586,28 @@ class Grok1ForCausalLM(nn.Module):
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
):
num_experts = self.config.num_local_experts
stacked_params_mapping = [
num_experts: Optional[int] = None,
ignore_parent_name: bool = False,
) -> dict[str, torch.Tensor]:
if num_experts is None:
num_experts = self.config.num_local_experts
stacked_params_mapping = []
stacked_params_mapping += [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
stacked_params_mapping += [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
......@@ -439,14 +618,25 @@ class Grok1ForCausalLM(nn.Module):
all_names = set(params_dict.keys())
hit_names = set()
def load_weight_wrapper(name, loaded_weight, *args, **kwargs):
def load_weight_wrapper(
name: str, loaded_weight: torch.Tensor, *args, **kwargs
):
if ignore_parent_name:
name = name.split(".")[-1]
if name not in params_dict:
return
# Fuse constant multipliers into the weights
if "lm_head" in name:
loaded_weight = (
loaded_weight.to(torch.float32)
* self.config.output_multiplier_scale
)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, *args, **kwargs)
hit_names.add(name)
for name, loaded_weight in weights:
......@@ -460,7 +650,6 @@ class Grok1ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
load_weight_wrapper(name, loaded_weight, shard_id)
break
else:
......@@ -487,13 +676,79 @@ class Grok1ForCausalLM(nn.Module):
load_weight_wrapper(name=name, loaded_weight=loaded_weight)
if len(hit_names) > 5:
missing = all_names - hit_names
missing_exclude_scales = {x for x in missing if "scale" not in x}
logger.info(
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
)
if len(missing_exclude_scales) > 0:
raise ValueError(
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
)
elif len(hit_names) == 0:
raise ValueError("load_weights failed because it did not hit any names.")
return hit_names
def get_num_params_analytical(self):
cfg = self.config
moe_intermediate_size = getattr(
cfg,
"moe_intermediate_size",
getattr(cfg, "intermediate_size", None),
)
num_experts = cfg.num_local_experts
wq = (
cfg.num_hidden_layers
* cfg.hidden_size
* cfg.num_attention_heads
* cfg.head_dim
)
wkv = (
cfg.num_hidden_layers
* cfg.hidden_size
* cfg.num_key_value_heads
* cfg.head_dim
* 2
)
out = (
cfg.num_hidden_layers
* cfg.hidden_size
* cfg.num_attention_heads
* cfg.head_dim
)
ffn1 = (
cfg.num_hidden_layers
* num_experts
* cfg.hidden_size
* moe_intermediate_size
* 2
)
ffn2 = (
cfg.num_hidden_layers
* num_experts
* cfg.hidden_size
* moe_intermediate_size
)
embed = cfg.hidden_size * cfg.vocab_size * 2
return wq + wkv + out + ffn1 + ffn2 + embed
def get_num_params_torch(self):
return (
sum(p.numel() for p in self.parameters())
* get_tensor_model_parallel_world_size()
)
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
def _prepare_presharded_weights(
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
) -> Tuple[str, List[str], bool]:
) -> Tuple[str, list[str], bool]:
import glob
import os
......@@ -522,7 +777,7 @@ def _prepare_presharded_weights(
# The new format
allow_patterns += [f"*-TP-{tp_rank:03d}.safetensors", "*-TP-common.safetensors"]
hf_weights_files: List[str] = []
hf_weights_files = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment