"src/git@developer.sourcefind.cn:modelzoo/video_migraphx.git" did not exist on "d1d43032ed8fff53511ca9a33158c281b6265fdf"
Unverified Commit f8f9244a authored by Adarsh Shirawalmath's avatar Adarsh Shirawalmath Committed by GitHub
Browse files

[Bug Fix] Add partial rotary factor support for Phi-4 and upgrade to transformers v4.50.0 (#3984)


Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
parent ecbfe58b
......@@ -35,7 +35,7 @@ runtime_common = [
"python-multipart",
"pyzmq>=25.1.2",
"torchao>=0.7.0",
"transformers==4.48.3",
"transformers==4.50.0",
"uvicorn",
"uvloop",
"xgrammar==0.1.16",
......
......@@ -2,21 +2,12 @@ from sglang.srt.configs.chatglm import ChatGLMConfig
from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.qwen2_5_vl_config import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig,
)
__all__ = [
"ExaoneConfig",
"ChatGLMConfig",
"DbrxConfig",
"DeepseekVL2Config",
"Qwen2_5_VLConfig",
"Qwen2_5_VLVisionConfig",
"MultiModalityConfig",
"Gemma3Config",
"Gemma3TextConfig",
]
This diff is collapsed.
This diff is collapsed.
......@@ -35,10 +35,7 @@ from sglang.srt.configs import (
DbrxConfig,
DeepseekVL2Config,
ExaoneConfig,
Gemma3Config,
Gemma3TextConfig,
MultiModalityConfig,
Qwen2_5_VLConfig,
)
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url
......@@ -47,11 +44,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig,
ExaoneConfig.model_type: ExaoneConfig,
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
DeepseekVL2Config.model_type: DeepseekVL2Config,
MultiModalityConfig.model_type: MultiModalityConfig,
Gemma3Config.model_type: Gemma3Config,
Gemma3TextConfig.model_type: Gemma3TextConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
......@@ -223,11 +217,26 @@ def get_processor(
tokenizer_revision: Optional[str] = None,
**kwargs,
):
# pop 'revision' from kwargs if present.
revision = kwargs.pop("revision", tokenizer_revision)
config = AutoConfig.from_pretrained(
tokenizer_name,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
# fix: for Qwen2-VL model, inject default 'size' if not provided.
if config.model_type in {"qwen2_vl"}:
if "size" not in kwargs:
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
revision=revision,
**kwargs,
)
......
......@@ -441,16 +441,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
):
super().__init__()
if rotary_dim != head_size:
raise ValueError(
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
rotary_dim != head_size ({rotary_dim}!={head_size})."
)
if is_neox_style is False:
raise ValueError(
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)
self.rotary_dim = rotary_dim
self.head_size = head_size
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
......@@ -499,8 +495,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
* (
self.base
** (
torch.arange(0, self.head_size, 2, dtype=torch.float)
/ self.head_size
torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
/ self.rotary_dim
)
)
)
......@@ -549,8 +545,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2)
query = query * cos + _rotate_neox(query) * sin
key = key * cos + _rotate_neox(key) * sin
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
query = torch.cat((query_rot, query_pass), dim=-1)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
key = torch.cat((key_rot, key_pass), dim=-1)
return query.flatten(-2), key.flatten(-2)
......
......@@ -21,11 +21,11 @@ from torch import nn
from transformers import (
ROPE_INIT_FUNCTIONS,
AutoModel,
Gemma3TextConfig,
PretrainedConfig,
PreTrainedModel,
)
from sglang.srt.configs.gemma3 import Gemma3TextConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import Gemma3RMSNorm
......
......@@ -21,9 +21,15 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel
from transformers import (
AutoModel,
BatchFeature,
Gemma3Config,
Gemma3Processor,
PreTrainedModel,
)
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
from sglang.srt.configs import Gemma3Config
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
......
......@@ -129,6 +129,8 @@ class LlamaAttention(nn.Module):
self.head_dim = getattr(
config, "head_dim", self.hidden_size // self.total_num_heads
)
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
......@@ -154,7 +156,7 @@ class LlamaAttention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
......
......@@ -34,8 +34,15 @@ from einops import rearrange
from transformers import AutoModel, Qwen2VLConfig
from transformers.activations import ACT2FN
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from sglang.srt.configs import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
......@@ -714,4 +721,3 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
EntryClass = [Qwen2_5_VLForConditionalGeneration]
AutoModel.register(Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration)
......@@ -20,7 +20,7 @@ pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-rei
pip install torch_memory_saver --force-reinstall
pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pandas datasets
pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets
# For compling xgrammar kernels
pip install cuda-python nvidia-cuda-nvrtc-cu12
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment