Unverified Commit 86d10d22 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Update grok.py and tiktoken tokenizer (#9532)

parent 83871aa1
......@@ -162,12 +162,16 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
):
super().__init__()
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
# This ensures consistency between what the model considers EOS and what XGrammar uses
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids
)
override_stop_tokens = None
if hasattr(tokenizer, "init_xgrammar"):
# For special tokenizer
tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar()
else:
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
# This ensures consistency between what the model considers EOS and what XGrammar uses
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids
)
override_stop_tokens = None
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size
......
......@@ -263,6 +263,11 @@ def get_tokenizer(
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
if tokenizer_name.endswith(".json"):
from sglang.srt.tokenizer.tiktoken_tokenizer import TiktokenTokenizer
return TiktokenTokenizer(tokenizer_name)
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
......
......@@ -20,6 +20,14 @@ if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
def logit_capping_mod(logit_capping_method, logit_cap):
# positive logit_cap -> tanh cap
if logit_capping_method == "tanh":
return logit_cap
else:
raise ValueError()
@dataclass
class ForwardMetadata:
attn_logits: torch.Tensor
......@@ -718,6 +726,8 @@ class TritonAttnBackend(AttentionBackend):
layer, forward_batch.out_cache_loc, k, v
)
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
causal = True
if layer.attn_type == AttentionType.ENCODER_ONLY:
causal = False
......@@ -750,10 +760,11 @@ class TritonAttnBackend(AttentionBackend):
self.forward_metadata.mask_indptr,
self.forward_metadata.max_extend_len,
layer.scaling,
layer.logit_cap,
logit_cap=logits_soft_cap,
sliding_window_size=sliding_window_size,
sinks=sinks,
window_kv_offsets=window_kv_offsets,
xai_temperature_len=layer.xai_temperature_len,
)
return o
......@@ -777,6 +788,8 @@ class TritonAttnBackend(AttentionBackend):
else:
o = torch.empty_like(q)
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
......@@ -801,8 +814,9 @@ class TritonAttnBackend(AttentionBackend):
self.forward_metadata.num_kv_splits,
self.max_kv_splits,
layer.scaling,
layer.logit_cap,
logit_cap=logits_soft_cap,
sinks=sinks,
xai_temperature_len=layer.xai_temperature_len,
)
return o
......
......@@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
xai_temperature_len: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
......@@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
kv_splits = tl.load(num_kv_splits + cur_batch)
if xai_temperature_len > 0:
offs_qidx = cur_batch_seq_len - 1
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
kv_len_per_split = (
......@@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
if xai_temperature_len > 0:
qk *= xai_temperature_reg
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
offs_buf_v = (
......@@ -181,6 +191,7 @@ def _decode_att_m_fwd(
max_kv_splits,
sm_scale,
logit_cap,
xai_temperature_len=-1,
):
BLOCK = 64
# [TODO] work around SGPR limit on MI3xx
......@@ -230,6 +241,7 @@ def _decode_att_m_fwd(
BLOCK_N=BLOCK,
MIN_BLOCK_KV=_MIN_BLOCK_KV,
logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
num_warps=num_warps,
num_stages=2,
Lk=Lk,
......@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
BLOCK_H: tl.constexpr,
MIN_BLOCK_KV: tl.constexpr,
logit_cap: tl.constexpr,
xai_temperature_len: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
):
......@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
kv_splits = tl.load(num_kv_splits + cur_batch)
if xai_temperature_len > 0:
offs_qidx = cur_batch_seq_len - 1
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
if BLOCK_DPE > 0:
......@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]
qk = tl.where(
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
)
......@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
max_kv_splits,
sm_scale,
logit_cap,
xai_temperature_len=-1,
):
BLOCK = 32
Lk = k_buffer.shape[-1]
......@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_H=BLOCK_H,
MIN_BLOCK_KV=_MIN_BLOCK_KV,
logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
num_warps=4,
num_stages=num_stages,
Lk=Lk,
......@@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
sm_scale,
logit_cap=0.0,
sinks=None,
xai_temperature_len=-1,
):
_decode_att_m_fwd(
q,
......@@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
max_kv_splits,
sm_scale,
logit_cap,
xai_temperature_len,
)
_decode_softmax_reducev_fwd(
attn_logits,
......@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
sm_scale,
logit_cap=0.0,
sinks=None,
xai_temperature_len=-1,
):
_decode_grouped_att_m_fwd(
q,
......@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
max_kv_splits,
sm_scale,
logit_cap,
xai_temperature_len,
)
_decode_softmax_reducev_fwd(
attn_logits,
......@@ -702,6 +730,7 @@ def decode_attention_fwd(
sm_scale,
logit_cap=0.0,
sinks=None,
xai_temperature_len=-1,
):
assert max_kv_splits == attn_logits.shape[2]
assert q.shape[0] <= kv_indptr.shape[0] - 1
......@@ -725,6 +754,7 @@ def decode_attention_fwd(
sm_scale,
logit_cap=logit_cap,
sinks=sinks,
xai_temperature_len=xai_temperature_len,
)
else:
# GQA/MQA/MLA
......@@ -742,4 +772,5 @@ def decode_attention_fwd(
sm_scale,
logit_cap=logit_cap,
sinks=sinks,
xai_temperature_len=xai_temperature_len,
)
......@@ -69,6 +69,7 @@ def _fwd_kernel(
stride_buf_vh,
SLIDING_WINDOW_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
xai_temperature_len: tl.constexpr,
Lq: tl.constexpr,
Lv: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
......@@ -109,6 +110,15 @@ def _fwd_kernel(
mask_d = offs_d < Lq
mask_dv = offs_dv < Lv
if xai_temperature_len > 0:
offs_qidx = cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
xai_temperature_reg = tl.where(
offs_qidx > xai_temperature_len,
tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale,
1.0,
)
offs_q = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
......@@ -203,6 +213,9 @@ def _fwd_kernel(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]
qk = tl.where(final_mask, qk, float("-inf"))
row_max = tl.max(qk, 1)
......@@ -306,6 +319,9 @@ def _fwd_kernel(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]
qk = tl.where(final_mask, qk, float("-inf"))
row_max = tl.max(qk, 1)
......@@ -373,6 +389,7 @@ def extend_attention_fwd(
sliding_window_size=-1,
sinks=None,
window_kv_offsets=None,
xai_temperature_len=-1,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
......@@ -477,6 +494,7 @@ def extend_attention_fwd(
v_buffer.stride(1),
SLIDING_WINDOW_SIZE=sliding_window_size,
logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
......
......@@ -486,3 +486,97 @@ def gelu_and_mul_triton(
return out_hidden_states, out_scales
else:
return out_hidden_states, None
# silu on first half of vector
@triton.jit
def silu_and_mul_kernel(
out_hidden_states_ptr, # (bs, hidden_dim)
out_scales_ptr, # (bs,)
hidden_states_ptr, # (bs, hidden_dim * 2)
quant_max: tl.constexpr,
static_scale: tl.constexpr,
hidden_dim: tl.constexpr, # the output hidden_dim
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
input_start = pid * hidden_dim * 2
output_start = pid * hidden_dim
input1_offs = tl.arange(0, BLOCK_SIZE)
mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
output_offs = tl.arange(0, BLOCK_SIZE)
x1 = tl.load(
hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
).to(tl.float32)
x3 = tl.load(
hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
).to(tl.float32)
# silu
# cast down before mul to better match training?
silu_x1 = x1 * tl.sigmoid(x1)
out = x3 * silu_x1.to(hidden_states_ptr.dtype.element_ty)
if quant_max is not None:
raise NotImplementedError()
tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
def silu_and_mul_triton(
hidden_states,
scales=None,
quantize=None, # dtype to quantize to
out=None,
):
bs, in_hidden_dim = hidden_states.shape
hidden_dim = in_hidden_dim // 2
if out is None:
out_hidden_states = torch.empty(
(bs, hidden_dim),
dtype=quantize or hidden_states.dtype,
device=hidden_states.device,
)
else:
assert out.shape == (bs, hidden_dim)
assert out.dtype == (quantize or hidden_states.dtype)
out_hidden_states = out
out_scales = None
static_scale = False
if quantize is not None:
if scales is None:
out_scales = torch.empty(
(bs,), dtype=torch.float32, device=hidden_states.device
)
else:
out_scales = scales
static_scale = True
max_warps = 16 if _is_hip else 32
config = {
# 8 ele per thread (not tuned)
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
),
}
silu_and_mul_kernel[(bs,)](
out_hidden_states,
out_scales,
hidden_states,
quant_max=torch.finfo(quantize).max if quantize is not None else None,
static_scale=static_scale,
hidden_dim=hidden_dim,
BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
**config,
)
if quantize is not None:
return out_hidden_states, out_scales
else:
return out_hidden_states, None
......@@ -45,11 +45,14 @@ def fused_moe_router_kernel(
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
# logit softcap
logits_scaled = logits / moe_softcapping
exped = tl.exp(2 * logits_scaled)
top = exped - 1
bottom = exped + 1
logits_softcapped = top / bottom * moe_softcapping
if moe_softcapping == 0:
logits_softcapped = logits
else:
logits_scaled = logits / moe_softcapping
exped = tl.exp(2 * logits_scaled)
top = exped - 1
bottom = exped + 1
logits_softcapped = top / bottom * moe_softcapping
# Add bias after softcapping
if is_correction_bias:
......@@ -207,9 +210,12 @@ def fused_moe_router_large_bs_kernel(
b_ptrs += BLOCK_SIZE_K
# 4. logit softcap
logits_scaled = acc / moe_softcapping
exped = tl.exp(2 * logits_scaled)
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
if moe_softcapping == 0:
logits_softcapped = acc
else:
logits_scaled = acc / moe_softcapping
exped = tl.exp(2 * logits_scaled)
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
# 5. top1
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
......@@ -234,7 +240,7 @@ def fused_moe_router_large_bs_kernel(
# 7. handle topk == 2
if topk == 2:
cond_top2 = (arange_block_size_n < num_experts) and (
cond_top2 = (arange_block_size_n < num_experts) & (
arange_block_size_n != top1[:, None]
)
top2 = tl.argmax(
......
......@@ -52,6 +52,8 @@ class RadixAttention(nn.Module):
v_head_dim: int = -1,
sliding_window_size: int = -1,
is_cross_attention: bool = False,
pos_encoding_mode: str = "NONE",
logit_capping_method: str = "tanh",
quant_config: Optional[QuantizationConfig] = None,
attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False,
......@@ -81,6 +83,10 @@ class RadixAttention(nn.Module):
self.quant_method.create_weights(self)
self.attn_type = attn_type
self.pos_encoding_mode = pos_encoding_mode
self.logit_capping_method = logit_capping_method
self.xai_temperature_len = -1
def forward(
self,
q,
......
......@@ -16,7 +16,6 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
import functools
import json
import logging
import math
import os
......@@ -35,9 +34,16 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.elementwise import (
experts_combine_triton,
fused_dual_residual_rmsnorm,
fused_rmsnorm,
gelu_and_mul_triton,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
......@@ -49,7 +55,12 @@ from sglang.srt.layers.moe.router import fused_moe_router_shim
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.rotary_embedding import (
RotaryEmbedding,
_yarn_find_correction_range,
_yarn_get_mscale,
get_rope,
)
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......@@ -58,13 +69,60 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import dump_to_file
from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
logger = logging.getLogger(__name__)
# Dump tensors for debugging
debug_tensor_dump_output_folder = None
debug_tensor_dump_prefill_only = False
# Skip all the other tensor dumps, only dump the target logits
debug_tensor_dump_only_target_logprobs = False
debug_tensor_dump_inject = False
debug_tensor_dump_layers = None
debug_tensor_dump_test = False
class Grok1MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
reduce_results=True,
use_presharded_weights: bool = False,
split_gate_up: bool = False,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
use_presharded_weights=use_presharded_weights,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
reduce_results=reduce_results,
use_presharded_weights=use_presharded_weights,
)
self.act_fn = GeluAndMul(approximate="tanh")
self.layer_id = layer_id
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x, _ = gelu_and_mul_triton(gate_up)
x, _ = self.down_proj(x)
return x
class Grok1MoE(nn.Module):
......@@ -87,10 +145,11 @@ class Grok1MoE(nn.Module):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
reduce_results=True,
reduce_results: bool = True,
use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
......@@ -145,6 +204,135 @@ class Grok1MoE(nn.Module):
return self.experts(hidden_states, topk_output)
def _yarn_linear_ramp_mask(
low: float, high: float, dim: int, dtype: torch.dtype
) -> torch.Tensor:
if low == high:
low -= 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def get_rope_scaling(config):
rope_type = getattr(config, "rope_type", None)
if rope_type:
original_max_position_embeddings = getattr(
config, "original_max_position_embeddings", None
)
scaling_factor = getattr(config, "scaling_factor", None)
extrapolation_factor = getattr(config, "extrapolation_factor", 1.0)
attn_factor = getattr(config, "attn_factor", 1.0)
beta_fast = getattr(config, "beta_fast", 32)
beta_slow = getattr(config, "beta_slow", 1)
rope_scaling = {
"extra_method": rope_type,
"max_position_embeddings": original_max_position_embeddings,
"scaling_factor": scaling_factor,
"extrapolation_factor": extrapolation_factor,
"attn_factor": attn_factor,
"beta_fast": beta_fast,
"beta_slow": beta_slow,
"dtype": torch.float,
}
return rope_scaling
else:
return None
class ScalingRotaryEmbedding(RotaryEmbedding):
"""Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extra_method: str = "yarn_log",
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extra_method = extra_method
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = _yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.rotary_dim,
self.base,
self.max_position_embeddings,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (
1
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
) * self.extrapolation_factor
if self.extra_method in ["original"]:
inv_freq = inv_freq_extrapolation
elif self.extra_method in ["yarn", "yarn_linear"]:
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
elif self.extra_method == "yarn_log":
inv_freq = torch.exp(
torch.log(inv_freq_extrapolation) * inv_freq_mask
+ torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask)
)
elif self.extra_method == "theta_scale":
exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
theta_scale_exponent = self.base ** (
math.log(
self.max_position_embeddings * self.scaling_factor / (2 * math.pi)
)
/ math.log(self.max_position_embeddings / (2 * math.pi))
)
inv_freq = torch.tensor(
1.0 / (theta_scale_exponent ** (exponents / self.rotary_dim)),
dtype=torch.float32,
)
else:
raise ValueError(f"Unknown extrapolation method: {self.extra_method}")
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(
self.max_position_embeddings * self.scaling_factor, dtype=torch.float32
)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
# cos = freqs.cos() * self.mscale
# sin = freqs.sin() * self.mscale
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
class Grok1Attention(nn.Module):
def __init__(
self,
......@@ -157,7 +345,9 @@ class Grok1Attention(nn.Module):
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
alt_stream: Optional[torch.cuda.Stream] = None,
load_presharded_attn: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
......@@ -183,7 +373,9 @@ class Grok1Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
rope_scaling = get_rope_scaling(config)
self.load_presharded_attn = load_presharded_attn
self.alt_stream = alt_stream or torch.cuda.Stream()
self.qkv_proj = QKVParallelLinear(
hidden_size,
......@@ -195,6 +387,7 @@ class Grok1Attention(nn.Module):
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
load_presharded_attn=self.load_presharded_attn,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
......@@ -205,6 +398,7 @@ class Grok1Attention(nn.Module):
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
use_presharded_weights=self.load_presharded_attn,
prefix=add_prefix("o_proj", prefix),
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -214,7 +408,37 @@ class Grok1Attention(nn.Module):
is_neox_style=True,
)
self.rope_rotate_half_dims = getattr(config, "rope_rotate_half_dims", False)
if rope_scaling is not None:
self.rotary_emb = ScalingRotaryEmbedding(
self.head_dim,
rotary_dim=(
self.head_dim
if not self.rope_rotate_half_dims
else self.head_dim // 2
),
base=int(self.rope_theta),
is_neox_style=True,
**rope_scaling,
)
pos_encoding_mode = "NONE"
else:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=(
self.head_dim
if not self.rope_rotate_half_dims
else self.head_dim // 2
),
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
pos_encoding_mode = "NONE"
logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
logit_capping_method = getattr(config, "attn_logit_softcapping_method", "tanh")
self.attn = RadixAttention(
self.num_heads,
......@@ -224,7 +448,11 @@ class Grok1Attention(nn.Module):
layer_id=layer_id,
logit_cap=logit_cap,
quant_config=quant_config,
pos_encoding_mode=pos_encoding_mode,
logit_capping_method=logit_capping_method,
prefix=add_prefix("attn", prefix),
)
self.attn.xai_temperature_len = getattr(self.config, "attn_temperature_len", -1)
def forward(
self,
......@@ -256,6 +484,8 @@ class Grok1Attention(nn.Module):
)
qkv, _ = self.qkv_proj(hidden_states)
dispose_tensor(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
......@@ -288,6 +518,7 @@ class Grok1Attention(nn.Module):
)
attn_output = self.attn(q, k, v, forward_batch)
del q, k, v, qkv
if debug_tensor_dump_output_folder:
dump_to_file(
......@@ -312,49 +543,89 @@ class Grok1DecoderLayer(nn.Module):
load_presharded_moe: bool = False,
load_presharded_attn: bool = False,
load_presharded_mlp: bool = False,
alt_stream: Optional[torch.cuda.Stream] = None,
skip_moe: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.residual_moe = getattr(config, "residual_moe", False)
self.layer_id = layer_id
self.alt_stream = alt_stream or torch.cuda.Stream()
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = Grok1Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
max_position=(
config.context_len
if hasattr(config, "context_len")
else config.max_position_embeddings
),
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
quant_config=quant_config,
reduce_results=False,
alt_stream=self.alt_stream,
load_presharded_attn=load_presharded_attn,
prefix=add_prefix("attn", prefix),
)
self.block_sparse_moe = Grok1MoE(
config=config,
layer_id=layer_id,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=getattr(
config,
"moe_intermediate_size",
getattr(config, "intermediate_size", None),
),
quant_config=quant_config,
reduce_results=True,
use_presharded_weights=load_presharded_moe,
inplace=True,
no_combine=False, # just a suggestion to not combine topk
)
split_gate_up = not getattr(config, "merge_gate_up", True)
if self.num_experts > 0:
self.block_sparse_moe = Grok1MoE(
config=config,
layer_id=layer_id,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=getattr(
config,
"moe_intermediate_size",
getattr(config, "intermediate_size", None),
),
quant_config=quant_config,
reduce_results=not self.residual_moe,
use_presharded_weights=load_presharded_moe,
inplace=False, # not self.residual_moe,
no_combine=False, # self.residual_moe, # just a suggestion to not combine topk
prefix=add_prefix("block_sparse_moe", prefix),
)
if self.residual_moe:
self.mlp = Grok1MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
reduce_results=False,
use_presharded_weights=load_presharded_mlp,
layer_id=layer_id,
split_gate_up=split_gate_up,
)
else:
raise NotImplementedError()
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.ffn = self.block_sparse_moe
if self.num_experts > 0:
if self.residual_moe:
# NOTE: self.block_sparse_moe modifies the input in-place,
# so we have to call it later. Be aware of any possible related errors.
if get_tensor_model_parallel_world_size() > 1:
self.ffn = lambda x: tensor_model_parallel_all_reduce(
self.moe_with_rmoe(x)
)
else:
self.ffn = self.moe_with_rmoe
else:
self.ffn = self.block_sparse_moe
else:
raise NotImplementedError()
def forward(
self,
......@@ -364,6 +635,10 @@ class Grok1DecoderLayer(nn.Module):
residual: Optional[torch.Tensor] = None,
deferred_norm: Optional[RMSNorm] = None,
) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
hidden_states_original = hidden_states
residual_original = residual
# Self Attention
if deferred_norm is not None:
assert residual is not None
......@@ -386,6 +661,14 @@ class Grok1DecoderLayer(nn.Module):
hidden_states,
)
if residual_original is not None:
dispose_tensor(residual_original)
dispose_flag = False
if residual is not hidden_states_original:
dispose_flag = True
dispose_tensor(hidden_states_original)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
......@@ -403,10 +686,23 @@ class Grok1DecoderLayer(nn.Module):
self.post_attn_norm.variance_epsilon,
)
if not dispose_flag:
dispose_tensor(hidden_states_original)
# Fully Connected
hidden_states = self.ffn(hidden_states)
return hidden_states, residual, self.post_moe_norm # defer layernorm
def moe_with_rmoe(self, x):
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
mlp_result = self.mlp(x)
with torch.cuda.stream(self.alt_stream):
# moe should not be inplace because of stream race condition
moe_result = self.block_sparse_moe(x)
current_stream.wait_stream(self.alt_stream)
return (mlp_result + moe_result) / 1.4142135623730951
class Grok1Model(nn.Module):
def __init__(
......@@ -417,6 +713,8 @@ class Grok1Model(nn.Module):
load_presharded_embedding: bool = False,
load_presharded_attn: bool = False,
load_presharded_mlp: bool = False,
replicate_embedding: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
......@@ -427,7 +725,11 @@ class Grok1Model(nn.Module):
config.vocab_size,
config.hidden_size,
use_presharded_weights=load_presharded_embedding,
enable_tp=not replicate_embedding,
prefix=add_prefix("embed_tokens", prefix),
)
self.alt_stream = torch.cuda.Stream()
self.layers = nn.ModuleList(
[
Grok1DecoderLayer(
......@@ -437,6 +739,7 @@ class Grok1Model(nn.Module):
load_presharded_moe=load_presharded_moe,
load_presharded_attn=load_presharded_attn,
load_presharded_mlp=load_presharded_mlp,
alt_stream=self.alt_stream,
)
for i in range(config.num_hidden_layers)
]
......@@ -506,6 +809,7 @@ class Grok1ForCausalLM(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
......@@ -514,7 +818,8 @@ class Grok1ForCausalLM(nn.Module):
# Get presharded weights.
self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
self.load_presharded_moe = (
self.config.num_local_experts > 0
getattr(config, "load_presharded_moe", True)
and self.config.num_local_experts > 0
and get_tensor_model_parallel_world_size() > 1
)
self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
......@@ -529,6 +834,11 @@ class Grok1ForCausalLM(nn.Module):
or self.load_presharded_embedding
)
default_replicate_lm_head = False
self.replicate_lm_head = getattr(
config, "replicate_lm_head", default_replicate_lm_head
)
if self.is_weights_presharded:
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
......@@ -536,6 +846,7 @@ class Grok1ForCausalLM(nn.Module):
self.replicate_lm_head = getattr(
config, "replicate_lm_head", default_replicate_lm_head
)
self.replicate_embedding = getattr(config, "replicate_embedding", False)
self.model = Grok1Model(
config,
......@@ -544,6 +855,8 @@ class Grok1ForCausalLM(nn.Module):
load_presharded_embedding=self.load_presharded_embedding,
load_presharded_attn=self.load_presharded_attn,
load_presharded_mlp=self.load_presharded_mlp,
replicate_embedding=self.replicate_embedding,
prefix=add_prefix("model", prefix),
)
lm_head_params_dtype = None
......@@ -553,6 +866,7 @@ class Grok1ForCausalLM(nn.Module):
config.vocab_size,
bias=False,
params_dtype=lm_head_params_dtype,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else:
......@@ -561,6 +875,7 @@ class Grok1ForCausalLM(nn.Module):
config.hidden_size,
use_presharded_weights=self.load_presharded_embedding,
params_dtype=lm_head_params_dtype,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
......@@ -577,6 +892,7 @@ class Grok1ForCausalLM(nn.Module):
f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
)
self.loaded_param_names = set()
def forward(
self,
......@@ -596,11 +912,13 @@ class Grok1ForCausalLM(nn.Module):
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
num_experts: Optional[int] = None,
ignore_parent_name: bool = False,
check_hit_names: bool = True,
model_config: PretrainedConfig | None = None,
) -> dict[str, torch.Tensor]:
if num_experts is None:
num_experts = self.config.num_local_experts
if model_config is None:
model_config = self.config
stacked_params_mapping = []
stacked_params_mapping += [
# (param_name, shard_name, shard_id)
......@@ -616,6 +934,7 @@ class Grok1ForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
num_experts = model_config.num_local_experts
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
......@@ -630,23 +949,26 @@ class Grok1ForCausalLM(nn.Module):
def load_weight_wrapper(
name: str, loaded_weight: torch.Tensor, *args, **kwargs
):
if ignore_parent_name:
name = name.split(".")[-1]
if name not in params_dict:
return
# Fuse constant multipliers into the weights
if "lm_head" in name:
loaded_weight = (
loaded_weight.to(torch.float32)
* self.config.output_multiplier_scale
* model_config.output_multiplier_scale
)
original_name = name
if ignore_parent_name:
name = name.split(".")[-1]
if name not in params_dict:
logger.info(f"Skipping {name=} in load_weights_wrapper")
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, *args, **kwargs)
hit_names.add(name)
self.loaded_param_names.add(original_name)
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
......@@ -685,19 +1007,22 @@ class Grok1ForCausalLM(nn.Module):
load_weight_wrapper(name=name, loaded_weight=loaded_weight)
if len(hit_names) > 5:
missing = all_names - hit_names
missing_exclude_scales = {x for x in missing if "scale" not in x}
logger.info(
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
)
if len(missing_exclude_scales) > 0:
raise ValueError(
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
if check_hit_names:
if len(hit_names) > 5:
missing = all_names - hit_names
missing_exclude_scales = {x for x in missing if "scale" not in x}
logger.info(
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
)
if len(missing_exclude_scales) > 0:
raise ValueError(
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
)
elif len(hit_names) == 0:
raise ValueError("load_weights failed because it did not hit any names.")
elif len(hit_names) == 0:
raise ValueError(
f"load_weights failed because it did not hit any names. {all_names=} {hit_names=}"
)
return hit_names
......@@ -708,7 +1033,11 @@ class Grok1ForCausalLM(nn.Module):
"moe_intermediate_size",
getattr(cfg, "intermediate_size", None),
)
num_experts = cfg.num_local_experts
residual_moe = getattr(cfg, "residual_moe", False)
if cfg.num_local_experts > 0:
num_experts = cfg.num_local_experts + (1 if residual_moe else 0)
else:
num_experts = 1
wq = (
cfg.num_hidden_layers
......
import functools
import json
from typing import AbstractSet, Collection, List, Literal, Union
class TiktokenProcessor:
def __init__(self, name: str):
self.tokenizer = TiktokenTokenizer(name)
def image_processor(self, image):
return {"pixel_values": [image]}
RESERVED_TOKEN_TEXTS = [f"<|reserved_{i}|>" for i in range(3, 128)]
CONTROL_TOKEN_TEXTS = [f"<|control{i}|>" for i in range(1, 705)]
PAD = "<|pad|>"
EOS = "<|eos|>"
SEP = "<|separator|>"
DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS]
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
# default + separate each single digit
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import tiktoken
from jinja2 import Template
# Read the JSON
with open(tokenizer_path, "rb") as fin:
xtok_dict = json.load(fin)
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::from_xtok_dict
mergeable_ranks = {
bytes(item["bytes"]): item["token"] for item in xtok_dict["regular_tokens"]
}
special_tokens = {
bytes(item["bytes"]).decode(): item["token"]
for item in xtok_dict["special_tokens"]
}
if xtok_dict["word_split"] == "V1":
pad_str = PAT_STR_B
else:
assert False, f"Unknown word_split: {xtok_dict['word_split']}"
pad_str = xtok_dict.get("pat_str", pad_str)
kwargs = {
"name": tokenizer_path,
"pat_str": pad_str,
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
if "default_allowed_special" in xtok_dict:
default_allowed_special = set(
[
bytes(bytes_list).decode()
for bytes_list in xtok_dict["default_allowed_special"]
]
)
if "vocab_size" in xtok_dict:
kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"]
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::__init__
default_allowed_special = None
control_tokens = DEFAULT_CONTROL_TOKENS
tokenizer = tiktoken.Encoding(**kwargs)
tokenizer._default_allowed_special = default_allowed_special or set()
tokenizer._control_tokens = control_tokens
def encode_patched(
self,
text: str,
*,
allowed_special: Union[
Literal["all"], AbstractSet[str]
] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> List[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(
self,
text,
allowed_special=allowed_special,
disallowed_special=(),
)
tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Allow more tokens to prevent crash
tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values())
tokenizer._default_allowed_special |= set(
CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS
)
# Convert to HF interface
self.tokenizer = tokenizer
self.bos_token_id = None
self.eos_token_id = tokenizer._special_tokens[EOS]
self.vocab_size = tokenizer.n_vocab
self.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
self.chat_template_jinja = Template(self.chat_template)
self.additional_stop_token_ids = None
def encode(self, x, add_special_tokens=False):
return self.tokenizer.encode(x)
def decode(self, x, *args, **kwargs):
return self.tokenizer.decode(x)
def batch_decode(
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
):
if len(batch) > 0 and isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch)
def apply_chat_template(
self, messages, tokenize, add_generation_prompt, tools=None
):
ret = self.chat_template_jinja.render(
messages=messages, add_generation_prompt=add_generation_prompt
)
return self.encode(ret) if tokenize else ret
def __call__(self, text, **kwargs):
return {
"input_ids": self.encode(text),
}
def init_xgrammar(self):
from xgrammar import TokenizerInfo
XGRAMMAR_SPECIAL_TOKEN_TEMPLATE = "<|xg_special_token_{}|>"
enc = self.tokenizer
encoded_vocab = {**enc._mergeable_ranks, **enc._special_tokens}
encoded_vocab = [
token for token, _ in sorted(encoded_vocab.items(), key=lambda x: x[1])
]
override_stop_tokens = [2] # eos
# These are treated as special tokens in xgrammar; we want to avoid them
# For now, xgrammar treats anything starting with b'\x00' as a special token
xgrammar_special_token_ids = []
for i, token in enumerate(encoded_vocab):
if isinstance(token, bytes) and token.startswith(b"\x00"):
xgrammar_special_token_ids.append(i)
for i, id in enumerate(xgrammar_special_token_ids):
encoded_vocab[id] = XGRAMMAR_SPECIAL_TOKEN_TEMPLATE.format(i)
tokenizer_info = TokenizerInfo(
encoded_vocab, stop_token_ids=override_stop_tokens
)
assert len(tokenizer_info.special_token_ids) == 0
return tokenizer_info, override_stop_tokens
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment