"git@developer.sourcefind.cn:change/sglang.git" did not exist on "bb418ced802c6dbb6b0ae0d65218327129148769"
Unverified Commit 86d10d22 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Update grok.py and tiktoken tokenizer (#9532)

parent 83871aa1
...@@ -162,12 +162,16 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -162,12 +162,16 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
): ):
super().__init__() super().__init__()
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens if hasattr(tokenizer, "init_xgrammar"):
# This ensures consistency between what the model considers EOS and what XGrammar uses # For special tokenizer
tokenizer_info = TokenizerInfo.from_huggingface( tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar()
tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids else:
) # Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
override_stop_tokens = None # This ensures consistency between what the model considers EOS and what XGrammar uses
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids
)
override_stop_tokens = None
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info) self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size self.vocab_size = vocab_size
......
...@@ -263,6 +263,11 @@ def get_tokenizer( ...@@ -263,6 +263,11 @@ def get_tokenizer(
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface.""" """Gets a tokenizer for the given model name via Huggingface."""
if tokenizer_name.endswith(".json"):
from sglang.srt.tokenizer.tiktoken_tokenizer import TiktokenTokenizer
return TiktokenTokenizer(tokenizer_name)
if tokenizer_mode == "slow": if tokenizer_mode == "slow":
if kwargs.get("use_fast", False): if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
......
...@@ -20,6 +20,14 @@ if TYPE_CHECKING: ...@@ -20,6 +20,14 @@ if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
def logit_capping_mod(logit_capping_method, logit_cap):
# positive logit_cap -> tanh cap
if logit_capping_method == "tanh":
return logit_cap
else:
raise ValueError()
@dataclass @dataclass
class ForwardMetadata: class ForwardMetadata:
attn_logits: torch.Tensor attn_logits: torch.Tensor
...@@ -718,6 +726,8 @@ class TritonAttnBackend(AttentionBackend): ...@@ -718,6 +726,8 @@ class TritonAttnBackend(AttentionBackend):
layer, forward_batch.out_cache_loc, k, v layer, forward_batch.out_cache_loc, k, v
) )
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
causal = True causal = True
if layer.attn_type == AttentionType.ENCODER_ONLY: if layer.attn_type == AttentionType.ENCODER_ONLY:
causal = False causal = False
...@@ -750,10 +760,11 @@ class TritonAttnBackend(AttentionBackend): ...@@ -750,10 +760,11 @@ class TritonAttnBackend(AttentionBackend):
self.forward_metadata.mask_indptr, self.forward_metadata.mask_indptr,
self.forward_metadata.max_extend_len, self.forward_metadata.max_extend_len,
layer.scaling, layer.scaling,
layer.logit_cap, logit_cap=logits_soft_cap,
sliding_window_size=sliding_window_size, sliding_window_size=sliding_window_size,
sinks=sinks, sinks=sinks,
window_kv_offsets=window_kv_offsets, window_kv_offsets=window_kv_offsets,
xai_temperature_len=layer.xai_temperature_len,
) )
return o return o
...@@ -777,6 +788,8 @@ class TritonAttnBackend(AttentionBackend): ...@@ -777,6 +788,8 @@ class TritonAttnBackend(AttentionBackend):
else: else:
o = torch.empty_like(q) o = torch.empty_like(q)
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v layer, forward_batch.out_cache_loc, k, v
...@@ -801,8 +814,9 @@ class TritonAttnBackend(AttentionBackend): ...@@ -801,8 +814,9 @@ class TritonAttnBackend(AttentionBackend):
self.forward_metadata.num_kv_splits, self.forward_metadata.num_kv_splits,
self.max_kv_splits, self.max_kv_splits,
layer.scaling, layer.scaling,
layer.logit_cap, logit_cap=logits_soft_cap,
sinks=sinks, sinks=sinks,
xai_temperature_len=layer.xai_temperature_len,
) )
return o return o
......
...@@ -69,6 +69,7 @@ def _fwd_kernel_stage1( ...@@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
logit_cap: tl.constexpr, logit_cap: tl.constexpr,
Lk: tl.constexpr, Lk: tl.constexpr,
Lv: tl.constexpr, Lv: tl.constexpr,
xai_temperature_len: tl.constexpr,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
...@@ -85,6 +86,12 @@ def _fwd_kernel_stage1( ...@@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
kv_splits = tl.load(num_kv_splits + cur_batch) kv_splits = tl.load(num_kv_splits + cur_batch)
if xai_temperature_len > 0:
offs_qidx = cur_batch_seq_len - 1
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
kv_len_per_split = ( kv_len_per_split = (
...@@ -122,6 +129,9 @@ def _fwd_kernel_stage1( ...@@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
if logit_cap > 0: if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap) qk = logit_cap * tanh(qk / logit_cap)
if xai_temperature_len > 0:
qk *= xai_temperature_reg
qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
offs_buf_v = ( offs_buf_v = (
...@@ -181,6 +191,7 @@ def _decode_att_m_fwd( ...@@ -181,6 +191,7 @@ def _decode_att_m_fwd(
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
logit_cap, logit_cap,
xai_temperature_len=-1,
): ):
BLOCK = 64 BLOCK = 64
# [TODO] work around SGPR limit on MI3xx # [TODO] work around SGPR limit on MI3xx
...@@ -230,6 +241,7 @@ def _decode_att_m_fwd( ...@@ -230,6 +241,7 @@ def _decode_att_m_fwd(
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
MIN_BLOCK_KV=_MIN_BLOCK_KV, MIN_BLOCK_KV=_MIN_BLOCK_KV,
logit_cap=logit_cap, logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
num_warps=num_warps, num_warps=num_warps,
num_stages=2, num_stages=2,
Lk=Lk, Lk=Lk,
...@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1( ...@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
BLOCK_H: tl.constexpr, BLOCK_H: tl.constexpr,
MIN_BLOCK_KV: tl.constexpr, MIN_BLOCK_KV: tl.constexpr,
logit_cap: tl.constexpr, logit_cap: tl.constexpr,
xai_temperature_len: tl.constexpr,
Lk: tl.constexpr, Lk: tl.constexpr,
Lv: tl.constexpr, Lv: tl.constexpr,
): ):
...@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1( ...@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
kv_splits = tl.load(num_kv_splits + cur_batch) kv_splits = tl.load(num_kv_splits + cur_batch)
if xai_temperature_len > 0:
offs_qidx = cur_batch_seq_len - 1
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
...@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1( ...@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
if logit_cap > 0: if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap) qk = logit_cap * tanh(qk / logit_cap)
if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]
qk = tl.where( qk = tl.where(
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
) )
...@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd( ...@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
logit_cap, logit_cap,
xai_temperature_len=-1,
): ):
BLOCK = 32 BLOCK = 32
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
...@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd( ...@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_H=BLOCK_H, BLOCK_H=BLOCK_H,
MIN_BLOCK_KV=_MIN_BLOCK_KV, MIN_BLOCK_KV=_MIN_BLOCK_KV,
logit_cap=logit_cap, logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
num_warps=4, num_warps=4,
num_stages=num_stages, num_stages=num_stages,
Lk=Lk, Lk=Lk,
...@@ -620,6 +644,7 @@ def decode_attention_fwd_normal( ...@@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
sm_scale, sm_scale,
logit_cap=0.0, logit_cap=0.0,
sinks=None, sinks=None,
xai_temperature_len=-1,
): ):
_decode_att_m_fwd( _decode_att_m_fwd(
q, q,
...@@ -633,6 +658,7 @@ def decode_attention_fwd_normal( ...@@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
logit_cap, logit_cap,
xai_temperature_len,
) )
_decode_softmax_reducev_fwd( _decode_softmax_reducev_fwd(
attn_logits, attn_logits,
...@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped( ...@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
sm_scale, sm_scale,
logit_cap=0.0, logit_cap=0.0,
sinks=None, sinks=None,
xai_temperature_len=-1,
): ):
_decode_grouped_att_m_fwd( _decode_grouped_att_m_fwd(
q, q,
...@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped( ...@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
max_kv_splits, max_kv_splits,
sm_scale, sm_scale,
logit_cap, logit_cap,
xai_temperature_len,
) )
_decode_softmax_reducev_fwd( _decode_softmax_reducev_fwd(
attn_logits, attn_logits,
...@@ -702,6 +730,7 @@ def decode_attention_fwd( ...@@ -702,6 +730,7 @@ def decode_attention_fwd(
sm_scale, sm_scale,
logit_cap=0.0, logit_cap=0.0,
sinks=None, sinks=None,
xai_temperature_len=-1,
): ):
assert max_kv_splits == attn_logits.shape[2] assert max_kv_splits == attn_logits.shape[2]
assert q.shape[0] <= kv_indptr.shape[0] - 1 assert q.shape[0] <= kv_indptr.shape[0] - 1
...@@ -725,6 +754,7 @@ def decode_attention_fwd( ...@@ -725,6 +754,7 @@ def decode_attention_fwd(
sm_scale, sm_scale,
logit_cap=logit_cap, logit_cap=logit_cap,
sinks=sinks, sinks=sinks,
xai_temperature_len=xai_temperature_len,
) )
else: else:
# GQA/MQA/MLA # GQA/MQA/MLA
...@@ -742,4 +772,5 @@ def decode_attention_fwd( ...@@ -742,4 +772,5 @@ def decode_attention_fwd(
sm_scale, sm_scale,
logit_cap=logit_cap, logit_cap=logit_cap,
sinks=sinks, sinks=sinks,
xai_temperature_len=xai_temperature_len,
) )
...@@ -69,6 +69,7 @@ def _fwd_kernel( ...@@ -69,6 +69,7 @@ def _fwd_kernel(
stride_buf_vh, stride_buf_vh,
SLIDING_WINDOW_SIZE: tl.constexpr, SLIDING_WINDOW_SIZE: tl.constexpr,
logit_cap: tl.constexpr, logit_cap: tl.constexpr,
xai_temperature_len: tl.constexpr,
Lq: tl.constexpr, Lq: tl.constexpr,
Lv: tl.constexpr, Lv: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
...@@ -109,6 +110,15 @@ def _fwd_kernel( ...@@ -109,6 +110,15 @@ def _fwd_kernel(
mask_d = offs_d < Lq mask_d = offs_d < Lq
mask_dv = offs_dv < Lv mask_dv = offs_dv < Lv
if xai_temperature_len > 0:
offs_qidx = cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
xai_temperature_reg = tl.where(
offs_qidx > xai_temperature_len,
tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale,
1.0,
)
offs_q = ( offs_q = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs * stride_qbs
...@@ -203,6 +213,9 @@ def _fwd_kernel( ...@@ -203,6 +213,9 @@ def _fwd_kernel(
if logit_cap > 0: if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap) qk = logit_cap * tanh(qk / logit_cap)
if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]
qk = tl.where(final_mask, qk, float("-inf")) qk = tl.where(final_mask, qk, float("-inf"))
row_max = tl.max(qk, 1) row_max = tl.max(qk, 1)
...@@ -306,6 +319,9 @@ def _fwd_kernel( ...@@ -306,6 +319,9 @@ def _fwd_kernel(
if logit_cap > 0: if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap) qk = logit_cap * tanh(qk / logit_cap)
if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]
qk = tl.where(final_mask, qk, float("-inf")) qk = tl.where(final_mask, qk, float("-inf"))
row_max = tl.max(qk, 1) row_max = tl.max(qk, 1)
...@@ -373,6 +389,7 @@ def extend_attention_fwd( ...@@ -373,6 +389,7 @@ def extend_attention_fwd(
sliding_window_size=-1, sliding_window_size=-1,
sinks=None, sinks=None,
window_kv_offsets=None, window_kv_offsets=None,
xai_temperature_len=-1,
): ):
""" """
q_extend, k_extend, v_extend, o_extend: contiguous tensors q_extend, k_extend, v_extend, o_extend: contiguous tensors
...@@ -477,6 +494,7 @@ def extend_attention_fwd( ...@@ -477,6 +494,7 @@ def extend_attention_fwd(
v_buffer.stride(1), v_buffer.stride(1),
SLIDING_WINDOW_SIZE=sliding_window_size, SLIDING_WINDOW_SIZE=sliding_window_size,
logit_cap=logit_cap, logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE, BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV, BLOCK_DV=BLOCK_DV,
......
...@@ -486,3 +486,97 @@ def gelu_and_mul_triton( ...@@ -486,3 +486,97 @@ def gelu_and_mul_triton(
return out_hidden_states, out_scales return out_hidden_states, out_scales
else: else:
return out_hidden_states, None return out_hidden_states, None
# silu on first half of vector
@triton.jit
def silu_and_mul_kernel(
out_hidden_states_ptr, # (bs, hidden_dim)
out_scales_ptr, # (bs,)
hidden_states_ptr, # (bs, hidden_dim * 2)
quant_max: tl.constexpr,
static_scale: tl.constexpr,
hidden_dim: tl.constexpr, # the output hidden_dim
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
input_start = pid * hidden_dim * 2
output_start = pid * hidden_dim
input1_offs = tl.arange(0, BLOCK_SIZE)
mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
output_offs = tl.arange(0, BLOCK_SIZE)
x1 = tl.load(
hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
).to(tl.float32)
x3 = tl.load(
hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
).to(tl.float32)
# silu
# cast down before mul to better match training?
silu_x1 = x1 * tl.sigmoid(x1)
out = x3 * silu_x1.to(hidden_states_ptr.dtype.element_ty)
if quant_max is not None:
raise NotImplementedError()
tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
def silu_and_mul_triton(
hidden_states,
scales=None,
quantize=None, # dtype to quantize to
out=None,
):
bs, in_hidden_dim = hidden_states.shape
hidden_dim = in_hidden_dim // 2
if out is None:
out_hidden_states = torch.empty(
(bs, hidden_dim),
dtype=quantize or hidden_states.dtype,
device=hidden_states.device,
)
else:
assert out.shape == (bs, hidden_dim)
assert out.dtype == (quantize or hidden_states.dtype)
out_hidden_states = out
out_scales = None
static_scale = False
if quantize is not None:
if scales is None:
out_scales = torch.empty(
(bs,), dtype=torch.float32, device=hidden_states.device
)
else:
out_scales = scales
static_scale = True
max_warps = 16 if _is_hip else 32
config = {
# 8 ele per thread (not tuned)
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
),
}
silu_and_mul_kernel[(bs,)](
out_hidden_states,
out_scales,
hidden_states,
quant_max=torch.finfo(quantize).max if quantize is not None else None,
static_scale=static_scale,
hidden_dim=hidden_dim,
BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
**config,
)
if quantize is not None:
return out_hidden_states, out_scales
else:
return out_hidden_states, None
...@@ -45,11 +45,14 @@ def fused_moe_router_kernel( ...@@ -45,11 +45,14 @@ def fused_moe_router_kernel(
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1) logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
# logit softcap # logit softcap
logits_scaled = logits / moe_softcapping if moe_softcapping == 0:
exped = tl.exp(2 * logits_scaled) logits_softcapped = logits
top = exped - 1 else:
bottom = exped + 1 logits_scaled = logits / moe_softcapping
logits_softcapped = top / bottom * moe_softcapping exped = tl.exp(2 * logits_scaled)
top = exped - 1
bottom = exped + 1
logits_softcapped = top / bottom * moe_softcapping
# Add bias after softcapping # Add bias after softcapping
if is_correction_bias: if is_correction_bias:
...@@ -207,9 +210,12 @@ def fused_moe_router_large_bs_kernel( ...@@ -207,9 +210,12 @@ def fused_moe_router_large_bs_kernel(
b_ptrs += BLOCK_SIZE_K b_ptrs += BLOCK_SIZE_K
# 4. logit softcap # 4. logit softcap
logits_scaled = acc / moe_softcapping if moe_softcapping == 0:
exped = tl.exp(2 * logits_scaled) logits_softcapped = acc
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping else:
logits_scaled = acc / moe_softcapping
exped = tl.exp(2 * logits_scaled)
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
# 5. top1 # 5. top1
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :] arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
...@@ -234,7 +240,7 @@ def fused_moe_router_large_bs_kernel( ...@@ -234,7 +240,7 @@ def fused_moe_router_large_bs_kernel(
# 7. handle topk == 2 # 7. handle topk == 2
if topk == 2: if topk == 2:
cond_top2 = (arange_block_size_n < num_experts) and ( cond_top2 = (arange_block_size_n < num_experts) & (
arange_block_size_n != top1[:, None] arange_block_size_n != top1[:, None]
) )
top2 = tl.argmax( top2 = tl.argmax(
......
...@@ -52,6 +52,8 @@ class RadixAttention(nn.Module): ...@@ -52,6 +52,8 @@ class RadixAttention(nn.Module):
v_head_dim: int = -1, v_head_dim: int = -1,
sliding_window_size: int = -1, sliding_window_size: int = -1,
is_cross_attention: bool = False, is_cross_attention: bool = False,
pos_encoding_mode: str = "NONE",
logit_capping_method: str = "tanh",
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False, use_irope: bool = False,
...@@ -81,6 +83,10 @@ class RadixAttention(nn.Module): ...@@ -81,6 +83,10 @@ class RadixAttention(nn.Module):
self.quant_method.create_weights(self) self.quant_method.create_weights(self)
self.attn_type = attn_type self.attn_type = attn_type
self.pos_encoding_mode = pos_encoding_mode
self.logit_capping_method = logit_capping_method
self.xai_temperature_len = -1
def forward( def forward(
self, self,
q, q,
......
This diff is collapsed.
import functools
import json
from typing import AbstractSet, Collection, List, Literal, Union
class TiktokenProcessor:
def __init__(self, name: str):
self.tokenizer = TiktokenTokenizer(name)
def image_processor(self, image):
return {"pixel_values": [image]}
RESERVED_TOKEN_TEXTS = [f"<|reserved_{i}|>" for i in range(3, 128)]
CONTROL_TOKEN_TEXTS = [f"<|control{i}|>" for i in range(1, 705)]
PAD = "<|pad|>"
EOS = "<|eos|>"
SEP = "<|separator|>"
DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS]
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
# default + separate each single digit
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import tiktoken
from jinja2 import Template
# Read the JSON
with open(tokenizer_path, "rb") as fin:
xtok_dict = json.load(fin)
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::from_xtok_dict
mergeable_ranks = {
bytes(item["bytes"]): item["token"] for item in xtok_dict["regular_tokens"]
}
special_tokens = {
bytes(item["bytes"]).decode(): item["token"]
for item in xtok_dict["special_tokens"]
}
if xtok_dict["word_split"] == "V1":
pad_str = PAT_STR_B
else:
assert False, f"Unknown word_split: {xtok_dict['word_split']}"
pad_str = xtok_dict.get("pat_str", pad_str)
kwargs = {
"name": tokenizer_path,
"pat_str": pad_str,
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
if "default_allowed_special" in xtok_dict:
default_allowed_special = set(
[
bytes(bytes_list).decode()
for bytes_list in xtok_dict["default_allowed_special"]
]
)
if "vocab_size" in xtok_dict:
kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"]
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::__init__
default_allowed_special = None
control_tokens = DEFAULT_CONTROL_TOKENS
tokenizer = tiktoken.Encoding(**kwargs)
tokenizer._default_allowed_special = default_allowed_special or set()
tokenizer._control_tokens = control_tokens
def encode_patched(
self,
text: str,
*,
allowed_special: Union[
Literal["all"], AbstractSet[str]
] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> List[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(
self,
text,
allowed_special=allowed_special,
disallowed_special=(),
)
tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Allow more tokens to prevent crash
tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values())
tokenizer._default_allowed_special |= set(
CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS
)
# Convert to HF interface
self.tokenizer = tokenizer
self.bos_token_id = None
self.eos_token_id = tokenizer._special_tokens[EOS]
self.vocab_size = tokenizer.n_vocab
self.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
self.chat_template_jinja = Template(self.chat_template)
self.additional_stop_token_ids = None
def encode(self, x, add_special_tokens=False):
return self.tokenizer.encode(x)
def decode(self, x, *args, **kwargs):
return self.tokenizer.decode(x)
def batch_decode(
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
):
if len(batch) > 0 and isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch)
def apply_chat_template(
self, messages, tokenize, add_generation_prompt, tools=None
):
ret = self.chat_template_jinja.render(
messages=messages, add_generation_prompt=add_generation_prompt
)
return self.encode(ret) if tokenize else ret
def __call__(self, text, **kwargs):
return {
"input_ids": self.encode(text),
}
def init_xgrammar(self):
from xgrammar import TokenizerInfo
XGRAMMAR_SPECIAL_TOKEN_TEMPLATE = "<|xg_special_token_{}|>"
enc = self.tokenizer
encoded_vocab = {**enc._mergeable_ranks, **enc._special_tokens}
encoded_vocab = [
token for token, _ in sorted(encoded_vocab.items(), key=lambda x: x[1])
]
override_stop_tokens = [2] # eos
# These are treated as special tokens in xgrammar; we want to avoid them
# For now, xgrammar treats anything starting with b'\x00' as a special token
xgrammar_special_token_ids = []
for i, token in enumerate(encoded_vocab):
if isinstance(token, bytes) and token.startswith(b"\x00"):
xgrammar_special_token_ids.append(i)
for i, id in enumerate(xgrammar_special_token_ids):
encoded_vocab[id] = XGRAMMAR_SPECIAL_TOKEN_TEMPLATE.format(i)
tokenizer_info = TokenizerInfo(
encoded_vocab, stop_token_ids=override_stop_tokens
)
assert len(tokenizer_info.special_token_ids) == 0
return tokenizer_info, override_stop_tokens
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment