"docs/vscode:/vscode.git/clone" did not exist on "70fbdb26e99d7d0b0299acc108f83bd63626e589"
Unverified Commit 59d260f5 authored by Bijaya Dangol's avatar Bijaya Dangol Committed by GitHub
Browse files

[Model] Add Grok-2 (#31847)


Signed-off-by: default avatardangoldbj <dangoldbj23@gmail.com>
parent 18d4e481
......@@ -399,6 +399,7 @@ th {
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ |
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ |
| `Grok1ForCausalLM` | Grok2 | `xai-org/grok-2` | ✅︎ | ✅︎ |
| `HunYuanDenseV1ForCausalLM` | Hunyuan Dense | `tencent/Hunyuan-7B-Instruct` | ✅︎ | ✅︎ |
| `HunYuanMoEV1ForCausalLM` | Hunyuan-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ |
| `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | |
......@@ -459,6 +460,9 @@ th {
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | |
| `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ | ✅︎ |
!!! note
Grok2 requires `tokenizer.tok.json` with `tiktoken` installed. You can optionally override MoE router renormalization with `moe_router_renormalize`.
Some models are supported only via the [Transformers modeling backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers modeling backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it!
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from ...utils import dummy_hf_overrides
MODELS = ["xai-org/grok-2"]
def _grok2_dummy_overrides(hf_config):
hf_config = dummy_hf_overrides(hf_config, model_arch="Grok1ForCausalLM")
text_config = hf_config.get_text_config()
text_config.update(
{
"hidden_size": 256,
"intermediate_size": 512,
"moe_intermediate_size": 256,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"head_dim": 64,
}
)
return hf_config
@pytest.mark.parametrize("model", MODELS)
def test_dummy_generate(vllm_runner, monkeypatch, model: str) -> None:
with monkeypatch.context() as m:
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(
model,
load_format="dummy",
max_model_len=128,
hf_overrides=_grok2_dummy_overrides,
enforce_eager=True,
) as llm:
prompt = "Hello from Grok-2"
tokenizer = llm.get_llm().get_tokenizer()
prompt_len = len(tokenizer.encode(prompt))
outputs = llm.generate_greedy([prompt], max_tokens=1)
output_ids, output_str = outputs[0]
assert len(output_ids) > prompt_len
assert output_str is not None
......@@ -289,6 +289,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Grok1ModelForCausalLM": _HfExamplesInfo(
"hpcai-tech/grok-1", trust_remote_code=True
),
"Grok1ForCausalLM": _HfExamplesInfo("xai-org/grok-2", trust_remote_code=True),
"HunYuanDenseV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-7B-Instruct"),
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo(
"tencent/Hunyuan-A13B-Instruct", trust_remote_code=True
......
......@@ -10,6 +10,7 @@ from transformers import (
)
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
......@@ -37,6 +38,10 @@ def test_tokenizer_like_protocol():
assert isinstance(tokenizer, MistralTokenizer)
_assert_tokenizer_like(tokenizer)
tokenizer = get_tokenizer("xai-org/grok-2", tokenizer_mode="grok2")
assert isinstance(tokenizer, Grok2Tokenizer)
_assert_tokenizer_like(tokenizer)
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])
def test_tokenizer_revision(tokenizer_name: str):
......
......@@ -21,8 +21,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Grok1 model."""
"""Inference-only Grok (Grok1/Grok2) model."""
import math
from collections.abc import Iterable
from itertools import islice
from typing import Any
......@@ -35,9 +36,12 @@ from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
......@@ -68,6 +72,100 @@ from .utils import (
DEFAULT_ATTN_OUTPUT_MULTIPLIER = 0.08838834764831845
DEFAULT_OUTPUT_MULTIPLIER_SCALE = 0.5773502691896257
DEFAULT_EMBEDDING_MULTIPLIER_SCALE = 78.38367176906169
DEFAULT_ROUTER_LOGIT_SOFTCAP = 30.0
logger = init_logger(__name__)
def _get_num_experts(config) -> int:
return getattr(config, "num_experts", getattr(config, "num_local_experts", 8))
def _get_moe_intermediate_size(config) -> int:
return getattr(config, "moe_intermediate_size", config.intermediate_size)
def _get_grok_version(config) -> str:
"""Detect Grok version from HF config using multiple heuristics."""
# Check for Grok2-specific attributes (both for robust detection)
has_residual_moe = getattr(config, "residual_moe", False)
has_moe_intermediate_size = hasattr(config, "moe_intermediate_size")
if has_residual_moe or has_moe_intermediate_size:
return "grok2"
return "grok1" # Default to Grok1
def _get_rope_parameters(config) -> dict[str, Any] | None:
rope_parameters = getattr(config, "rope_parameters", None)
if rope_parameters is None:
rope_type = getattr(config, "rope_type", None)
if rope_type is None:
return None
rope_parameters = {"rope_type": rope_type}
rope_theta = getattr(config, "rope_theta", None)
if rope_theta is not None:
rope_parameters["rope_theta"] = rope_theta
scaling_factor = getattr(config, "scaling_factor", None)
if scaling_factor is not None:
rope_parameters["factor"] = scaling_factor
for name in (
"original_max_position_embeddings",
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
):
value = getattr(config, name, None)
if value is not None:
rope_parameters[name] = value
if rope_parameters.get("rope_type") == "original":
rope_parameters = dict(rope_parameters)
rope_parameters["rope_type"] = "default"
return rope_parameters
def _get_moe_renormalize(config) -> bool:
explicit_value = getattr(
config, "moe_router_renormalize", getattr(config, "moe_renormalize", None)
)
if explicit_value is not None:
return bool(explicit_value)
return not getattr(config, "residual_moe", False)
class Grok1MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = GeluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class Grok1MoE(nn.Module):
......@@ -85,9 +183,11 @@ class Grok1MoE(nn.Module):
top_k: int,
hidden_size: int,
intermediate_size: int,
router_logit_soft_cap: float = 0.0,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
tp_size: int | None = None,
renormalize: bool = False,
prefix: str = "",
):
super().__init__()
......@@ -110,12 +210,13 @@ class Grok1MoE(nn.Module):
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
renormalize=renormalize,
quant_config=quant_config,
tp_size=tp_size,
activation="gelu",
prefix=f"{prefix}.experts",
)
self.router_logit_soft_cap = router_logit_soft_cap
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
......@@ -123,7 +224,10 @@ class Grok1MoE(nn.Module):
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
router_logits = 30.0 * F.tanh(router_logits / 30.0)
if self.router_logit_soft_cap > 0:
router_logits = self.router_logit_soft_cap * F.tanh(
router_logits / self.router_logit_soft_cap
)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)
......@@ -187,6 +291,15 @@ class Grok1Attention(nn.Module):
)
attn_logits_soft_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
attn_logit_softcapping_method = getattr(
config, "attn_logit_softcapping_method", None
)
if attn_logit_softcapping_method not in (None, "tanh"):
logger.warning_once(
"Grok attention logit softcapping method '%s' is not "
"supported; falling back to default behavior.",
attn_logit_softcapping_method,
)
self.attn = Attention(
self.num_heads,
......@@ -238,30 +351,50 @@ class Grok1DecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_parameters=getattr(config, "rope_parameters", None),
rope_parameters=_get_rope_parameters(config),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
config=config,
) # Pass config to Grok1Attention
# Grok1 uses "num_experts" in its config
num_experts = getattr(config, "num_experts", 8)
num_experts = _get_num_experts(config)
num_experts_per_tok = getattr(config, "num_experts_per_tok", 2)
moe_intermediate_size = _get_moe_intermediate_size(config)
moe_renormalize = _get_moe_renormalize(config)
self.moe_block = Grok1MoE(
num_experts=num_experts,
top_k=num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
intermediate_size=moe_intermediate_size,
router_logit_soft_cap=max(
getattr(
config,
"router_logit_softcapping",
DEFAULT_ROUTER_LOGIT_SOFTCAP,
),
0.0,
),
quant_config=quant_config,
renormalize=moe_renormalize,
prefix=f"{prefix}.moe_block",
)
self.residual_moe = getattr(config, "residual_moe", False)
self.residual_moe_scale = 1.0 / math.sqrt(2.0)
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = None
if self.residual_moe:
self.mlp = Grok1MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def forward(
self,
......@@ -286,7 +419,13 @@ class Grok1DecoderLayer(nn.Module):
# MoE block with normalization
hidden_states, residual = self.pre_moe_norm(hidden_states, residual)
hidden_states = self.moe_block(hidden_states)
if self.residual_moe:
assert self.mlp is not None
hidden_states = (
self.moe_block(hidden_states) + self.mlp(hidden_states)
) * self.residual_moe_scale
else:
hidden_states = self.moe_block(hidden_states)
hidden_states = self.post_moe_norm(hidden_states)
return hidden_states, residual
......@@ -294,7 +433,16 @@ class Grok1DecoderLayer(nn.Module):
@support_torch_compile
class Grok1Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
ckpt_gate_proj_name: str = "linear",
ckpt_down_proj_name: str = "linear_1",
ckpt_up_proj_name: str = "linear_v",
weight_name_remapping: dict[str, str] | None = None,
):
super().__init__()
config = vllm_config.model_config.hf_config
......@@ -305,6 +453,12 @@ class Grok1Model(nn.Module):
self.quant_config = quant_config
self.padding_idx = config.pad_token_id
# Store expert naming for weight loading
self.ckpt_gate_proj_name = ckpt_gate_proj_name
self.ckpt_down_proj_name = ckpt_down_proj_name
self.ckpt_up_proj_name = ckpt_up_proj_name
self.weight_name_remapping = weight_name_remapping or {}
self.vocab_size = config.vocab_size
self.embedding_multiplier_scale = getattr(
......@@ -365,14 +519,13 @@ class Grok1Model(nn.Module):
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Map Grok1's unique expert parameter names to standard names
# Grok1 uses "num_experts" in its config
num_experts = getattr(self.config, "num_experts", 8)
# Map expert parameter names to standard names
num_experts = _get_num_experts(self.config)
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="linear", # Grok1 specific
ckpt_down_proj_name="linear_1", # Grok1 specific
ckpt_up_proj_name="linear_v", # Grok1 specific
ckpt_gate_proj_name=self.ckpt_gate_proj_name,
ckpt_down_proj_name=self.ckpt_down_proj_name,
ckpt_up_proj_name=self.ckpt_up_proj_name,
num_experts=num_experts,
)
......@@ -382,12 +535,18 @@ class Grok1Model(nn.Module):
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("mlp.gate_up_proj", "mlp.gate_proj", 0),
("mlp.gate_up_proj", "mlp.up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
# Apply version-specific weight name remapping
for old_pattern, new_pattern in self.weight_name_remapping.items():
if old_pattern in name:
name = name.replace(old_pattern, new_pattern)
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
......@@ -418,6 +577,8 @@ class Grok1Model(nn.Module):
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -464,6 +625,8 @@ class Grok1Model(nn.Module):
if "norm.scale" in name:
name = name.replace("scale", "weight")
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
......@@ -473,9 +636,12 @@ class Grok1Model(nn.Module):
return loaded_params
class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class GrokBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"""Base class for Grok models with shared logic."""
fall_back_to_pt_during_load = False
# Subclasses should override these
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -484,6 +650,15 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
],
}
# Expert weight naming - subclasses override these
ckpt_gate_proj_name: str = "linear"
ckpt_down_proj_name: str = "linear_1"
ckpt_up_proj_name: str = "linear_v"
def get_weight_name_remapping(self) -> dict[str, str]:
"""Return weight name remapping for this version. Override in subclasses."""
return {}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......@@ -491,11 +666,15 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = Grok1Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
ckpt_gate_proj_name=self.ckpt_gate_proj_name,
ckpt_down_proj_name=self.ckpt_down_proj_name,
ckpt_up_proj_name=self.ckpt_up_proj_name,
weight_name_remapping=self.get_weight_name_remapping(),
)
self.lm_head = ParallelLMHead(
......@@ -512,7 +691,9 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE
)
self.logits_processor = LogitsProcessor(
config.vocab_size, scale=self.output_multiplier_scale
config.vocab_size,
scale=self.output_multiplier_scale,
soft_cap=getattr(config, "final_logit_softcapping", None),
)
self.make_empty_intermediate_tensors = (
......@@ -553,3 +734,70 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
class Grok1ForCausalLM(GrokBaseForCausalLM):
"""Grok1-specific implementation."""
# Grok1 expert weight naming
ckpt_gate_proj_name = "linear"
ckpt_down_proj_name = "linear_1"
ckpt_up_proj_name = "linear_v"
def get_weight_name_remapping(self) -> dict[str, str]:
# Grok1 uses standard naming, no remapping needed
return {}
class Grok2ForCausalLM(GrokBaseForCausalLM):
"""Grok2-specific implementation."""
# Grok2 has additional packed modules for MLP
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# Grok2 expert weight naming
ckpt_gate_proj_name = "w1"
ckpt_down_proj_name = "w2"
ckpt_up_proj_name = "w3"
def get_weight_name_remapping(self) -> dict[str, str]:
# Grok2 checkpoint uses different naming conventions
return {
".self_attn.": ".attn.",
".block_sparse_moe.": ".moe_block.",
}
# Version dispatch mapping
_GROK_VERSIONS: dict[str, type[GrokBaseForCausalLM]] = {
"grok1": Grok1ForCausalLM,
"grok2": Grok2ForCausalLM,
}
class GrokForCausalLM(GrokBaseForCausalLM):
"""Factory class that dispatches to version-specific implementation."""
def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
version = _get_grok_version(config)
instance_cls = _GROK_VERSIONS.get(version)
if instance_cls is None:
raise ValueError(f"Unsupported Grok version: {version}")
# Merge class attributes for LoRA/quantization compatibility
cls.packed_modules_mapping = dict(cls.packed_modules_mapping)
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
return instance_cls(vllm_config=vllm_config, prefix=prefix)
......@@ -119,7 +119,8 @@ _TEXT_GENERATION_MODELS = {
"GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"), # noqa: E501
"GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501
"GritLM": ("gritlm", "GritLM"),
"Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
"Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"),
"Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
"HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
"HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
"HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tokenizer for Grok-2 .tok.json format."""
import functools
import json
from collections.abc import Collection, Set
from pathlib import Path
from typing import Any, Literal, overload
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
HfHubHTTPError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from transformers import BatchEncoding
from transformers.utils import chat_template_utils as hf_chat_utils
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.logger import init_logger
from .protocol import TokenizerLike
logger = init_logger(__name__)
PAD = "<|pad|>"
EOS = "<|eos|>"
SEP = "<|separator|>"
RESERVED_TOKEN_TEXTS = [f"<|reserved_{i}|>" for i in range(3, 128)]
CONTROL_TOKEN_TEXTS = [f"<|control{i}|>" for i in range(1, 705)]
DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS]
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": SEP, "eos": EOS}
DEFAULT_CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'user' %}"
"{{ 'Human: ' + message['content'].strip() + '<|separator|>\\n\\n' }}"
"{% elif message['role'] == 'system' %}"
"{{ 'System: ' + message['content'].strip() + '<|separator|>\\n\\n' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ 'Assistant: ' + message['content'] + '<|separator|>\\n\\n' }}"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ 'Assistant:' }}"
"{% endif %}"
)
# Default + separate each single digit.
PAT_STR_B = (
r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}|"""
r""" ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
)
def _maybe_load_tokenizer_config(
model_path: Path,
*,
repo_id: str | None,
revision: str | None,
download_dir: str | None,
) -> dict[str, Any]:
config_path = model_path / "tokenizer_config.json"
if config_path.is_file():
with config_path.open("r", encoding="utf-8") as f:
return json.load(f)
if repo_id is None:
return {}
try:
config_file = hf_hub_download(
repo_id=repo_id,
filename="tokenizer_config.json",
revision=revision,
cache_dir=download_dir,
)
except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError):
# If the repo, revision, or file does not exist, fall back silently.
return {}
except HfHubHTTPError as exc:
logger.warning(
"Failed to download tokenizer_config.json from %s. "
"This may be due to a network or authentication issue. "
"The default chat template will be used. Error: %s",
repo_id,
exc,
)
return {}
try:
with Path(config_file).open("r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError as exc:
logger.warning(
"Failed to parse tokenizer_config.json. "
"The default chat template will be used. Error: %s",
exc,
)
return {}
except OSError as exc:
logger.warning(
"Failed to open tokenizer_config.json. "
"The default chat template will be used. Error: %s",
exc,
)
return {}
def _load_tiktoken_encoding(
vocab_file: Path,
) -> tuple[Any, dict[str, int]]:
try:
import tiktoken
except ImportError as exc:
raise ImportError("Grok-2 tokenizer requires the `tiktoken` package.") from exc
with vocab_file.open("rb") as f:
xtok_dict = json.load(f)
mergeable_ranks = {
bytes(item["bytes"]): item["token"]
for item in xtok_dict.get("regular_tokens", [])
}
special_tokens = {
bytes(item["bytes"]).decode("utf-8", errors="replace"): item["token"]
for item in xtok_dict.get("special_tokens", [])
}
if xtok_dict.get("word_split") == "V1":
pat_str = PAT_STR_B
else:
raise ValueError(f"Unknown word_split: {xtok_dict.get('word_split')!r}")
pat_str = xtok_dict.get("pat_str", pat_str)
kwargs = {
"name": str(vocab_file),
"pat_str": pat_str,
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
if "vocab_size" in xtok_dict:
kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"]
tokenizer = tiktoken.Encoding(**kwargs)
default_allowed_special: set[str] | None = None
if "default_allowed_special" in xtok_dict:
default_allowed_special = {
bytes(bytes_list).decode("utf-8", errors="replace")
for bytes_list in xtok_dict["default_allowed_special"]
}
tokenizer._default_allowed_special = default_allowed_special or set()
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
def encode_patched(
self,
text: str,
*,
allowed_special: Literal["all"] | Set[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
) -> list[int]:
del disallowed_special
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(
self,
text,
allowed_special=allowed_special,
disallowed_special=(),
)
tokenizer.encode = functools.partial(encode_patched, tokenizer)
tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values())
tokenizer._default_allowed_special |= set(
CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS
)
return tokenizer, special_tokens
class Grok2Tokenizer(TokenizerLike):
@classmethod
def from_pretrained(
cls,
path_or_repo_id: str | Path,
*args,
trust_remote_code: bool = False,
revision: str | None = None,
download_dir: str | None = None,
**kwargs,
) -> "Grok2Tokenizer":
if args:
logger.debug_once("Ignoring extra positional args for Grok2Tokenizer.")
path = Path(path_or_repo_id)
if path.is_file():
vocab_file = path
model_path = path.parent
repo_id = None
elif path.is_dir():
vocab_file = path / "tokenizer.tok.json"
model_path = path
repo_id = None
else:
vocab_file = Path(
hf_hub_download(
repo_id=str(path_or_repo_id),
filename="tokenizer.tok.json",
revision=revision,
cache_dir=download_dir,
)
)
model_path = vocab_file.parent
repo_id = str(path_or_repo_id)
if not vocab_file.is_file():
raise FileNotFoundError(f"tokenizer.tok.json not found at {vocab_file}.")
config = _maybe_load_tokenizer_config(
model_path,
repo_id=repo_id,
revision=revision,
download_dir=download_dir,
)
return cls(
vocab_file=vocab_file,
name_or_path=str(path_or_repo_id),
truncation_side=kwargs.get("truncation_side", "left"),
chat_template=config.get("chat_template"),
init_kwargs=config,
)
def __init__(
self,
*,
vocab_file: Path,
name_or_path: str,
truncation_side: str,
chat_template: str | None,
init_kwargs: dict[str, Any] | None = None,
) -> None:
super().__init__()
self.name_or_path = name_or_path
self._truncation_side = truncation_side
self.init_kwargs = init_kwargs or {}
self._chat_template = chat_template or DEFAULT_CHAT_TEMPLATE
self._tokenizer, self._special_tokens = _load_tiktoken_encoding(vocab_file)
self._token_to_id: dict[str, int] = {}
self._id_to_token: dict[int, str] = {}
for token, token_id in self._tokenizer._mergeable_ranks.items():
token_str = token.decode("utf-8", errors="replace")
self._token_to_id[token_str] = token_id
self._id_to_token[token_id] = token_str
for token, token_id in self._special_tokens.items():
self._token_to_id[token] = token_id
self._id_to_token[token_id] = token
bos_token_id = self._special_tokens.get(SEP)
if bos_token_id is None:
bos_token_id = self._special_tokens.get(PAD)
if bos_token_id is None:
bos_token_id = self._special_tokens.get(EOS)
if bos_token_id is None:
bos_token_id = 0
self._bos_token_id = bos_token_id
self._eos_token_id = self._special_tokens.get(EOS, self._bos_token_id)
self._pad_token_id = self._special_tokens.get(PAD, self._eos_token_id)
self._unk_token_id = self._pad_token_id
def num_special_tokens_to_add(self) -> int:
return 0
@property
def all_special_tokens(self) -> list[str]:
return list(self._special_tokens.keys())
@property
def all_special_ids(self) -> list[int]:
return list(self._special_tokens.values())
@property
def bos_token_id(self) -> int:
return self._bos_token_id
@property
def eos_token_id(self) -> int:
return self._eos_token_id
@property
def pad_token_id(self) -> int:
return self._pad_token_id
@property
def is_fast(self) -> bool:
return False
@property
def vocab_size(self) -> int:
return self._tokenizer.n_vocab
@property
def max_token_id(self) -> int:
return self._tokenizer.n_vocab - 1
@property
def truncation_side(self) -> str:
return self._truncation_side
def get_vocab(self) -> dict[str, int]:
return dict(self._token_to_id)
def get_added_vocab(self) -> dict[str, int]:
return dict(self._special_tokens)
def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]:
if max_length is None or len(tokens) <= max_length:
return tokens
if self.truncation_side == "left":
return tokens[-max_length:]
return tokens[:max_length]
def encode(
self,
text: str,
truncation: bool | None = None,
max_length: int | None = None,
add_special_tokens: bool = True,
) -> list[int]:
del add_special_tokens
tokens = self._tokenizer.encode(text)
if truncation:
tokens = self._maybe_truncate(tokens, max_length)
return tokens
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
if isinstance(ids, int):
ids = [ids]
if skip_special_tokens:
ids = [
token_id
for token_id in ids
if token_id not in self._special_tokens.values()
]
return self._tokenizer.decode(ids)
@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
if isinstance(tokens, str):
return self._token_to_id.get(tokens, self._unk_token_id)
return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]
def convert_ids_to_tokens(
self, ids: list[int], skip_special_tokens: bool = False
) -> list[str]:
tokens = []
for token_id in ids:
if skip_special_tokens and token_id in self._special_tokens.values():
continue
tokens.append(self._id_to_token.get(token_id, "<|unk|>"))
return tokens
def convert_tokens_to_string(self, tokens: list[str]) -> str:
token_ids = self.convert_tokens_to_ids(tokens)
return self.decode(token_ids, skip_special_tokens=False)
def __call__(
self,
text: str | list[str],
text_pair: str | None = None,
add_special_tokens: bool = True,
truncation: bool = False,
max_length: int | None = None,
) -> BatchEncoding:
if text_pair is not None:
raise NotImplementedError("text_pair is not supported for Grok2Tokenizer.")
if isinstance(text, list):
input_ids_batch: list[list[int]] = [
self.encode(
item,
truncation=truncation,
max_length=max_length,
add_special_tokens=add_special_tokens,
)
for item in text
]
attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch]
return BatchEncoding(
{"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
)
input_ids = self.encode(
text,
truncation=truncation,
max_length=max_length,
add_special_tokens=add_special_tokens,
)
attention_mask = [1] * len(input_ids)
return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})
def get_chat_template(
self, chat_template: str | None, tools: list[dict[str, Any]] | None = None
) -> str | None:
del tools
return chat_template or self._chat_template
def apply_chat_template(
self,
messages: list[ChatCompletionMessageParam],
tools: list[dict[str, Any]] | None = None,
chat_template: str | None = None,
tokenize: bool = False,
**kwargs,
) -> str | list[int]:
template = self.get_chat_template(chat_template, tools=tools)
if template is None:
raise ValueError(
"No chat template available. Provide `chat_template` explicitly."
)
prompt = hf_chat_utils.apply_chat_template(
conversation=messages,
chat_template=template,
tools=tools,
**kwargs,
)
if tokenize:
return self.encode(prompt, add_special_tokens=False)
return prompt
......@@ -31,6 +31,7 @@ logger = init_logger(__name__)
_VLLM_TOKENIZERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
"grok2": ("grok2", "Grok2Tokenizer"),
"hf": ("hf", "CachedHfTokenizer"),
"mistral": ("mistral", "MistralTokenizer"),
}
......@@ -151,6 +152,17 @@ def resolve_tokenizer_args(
if len(files_list) > 0:
tokenizer_mode = "mistral"
# Try to use Grok2 tiktoken tokenizer if possible
if tokenizer_mode == "auto":
allow_patterns = ["tokenizer.tok.json"]
files_list = list_filtered_repo_files(
model_name_or_path=str(tokenizer_name),
allow_patterns=allow_patterns,
revision=revision,
)
if len(files_list) > 0:
tokenizer_mode = "grok2"
# Fallback to HF tokenizer
if tokenizer_mode == "auto":
tokenizer_mode = "hf"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment