"src/vscode:/vscode.git/clone" did not exist on "3fe026e06c895049fe6f072fc2b394b2c8e85551"
Unverified Commit f8f9244a authored by Adarsh Shirawalmath's avatar Adarsh Shirawalmath Committed by GitHub
Browse files

[Bug Fix] Add partial rotary factor support for Phi-4 and upgrade to transformers v4.50.0 (#3984)


Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
parent ecbfe58b
...@@ -35,7 +35,7 @@ runtime_common = [ ...@@ -35,7 +35,7 @@ runtime_common = [
"python-multipart", "python-multipart",
"pyzmq>=25.1.2", "pyzmq>=25.1.2",
"torchao>=0.7.0", "torchao>=0.7.0",
"transformers==4.48.3", "transformers==4.50.0",
"uvicorn", "uvicorn",
"uvloop", "uvloop",
"xgrammar==0.1.16", "xgrammar==0.1.16",
......
...@@ -2,21 +2,12 @@ from sglang.srt.configs.chatglm import ChatGLMConfig ...@@ -2,21 +2,12 @@ from sglang.srt.configs.chatglm import ChatGLMConfig
from sglang.srt.configs.dbrx import DbrxConfig from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.qwen2_5_vl_config import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig,
)
__all__ = [ __all__ = [
"ExaoneConfig", "ExaoneConfig",
"ChatGLMConfig", "ChatGLMConfig",
"DbrxConfig", "DbrxConfig",
"DeepseekVL2Config", "DeepseekVL2Config",
"Qwen2_5_VLConfig",
"Qwen2_5_VLVisionConfig",
"MultiModalityConfig", "MultiModalityConfig",
"Gemma3Config",
"Gemma3TextConfig",
] ]
This diff is collapsed.
This diff is collapsed.
...@@ -35,10 +35,7 @@ from sglang.srt.configs import ( ...@@ -35,10 +35,7 @@ from sglang.srt.configs import (
DbrxConfig, DbrxConfig,
DeepseekVL2Config, DeepseekVL2Config,
ExaoneConfig, ExaoneConfig,
Gemma3Config,
Gemma3TextConfig,
MultiModalityConfig, MultiModalityConfig,
Qwen2_5_VLConfig,
) )
from sglang.srt.connector import create_remote_connector from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url from sglang.srt.utils import is_remote_url
...@@ -47,11 +44,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ...@@ -47,11 +44,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig, ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig, DbrxConfig.model_type: DbrxConfig,
ExaoneConfig.model_type: ExaoneConfig, ExaoneConfig.model_type: ExaoneConfig,
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
DeepseekVL2Config.model_type: DeepseekVL2Config, DeepseekVL2Config.model_type: DeepseekVL2Config,
MultiModalityConfig.model_type: MultiModalityConfig, MultiModalityConfig.model_type: MultiModalityConfig,
Gemma3Config.model_type: Gemma3Config,
Gemma3TextConfig.model_type: Gemma3TextConfig,
} }
for name, cls in _CONFIG_REGISTRY.items(): for name, cls in _CONFIG_REGISTRY.items():
...@@ -223,11 +217,26 @@ def get_processor( ...@@ -223,11 +217,26 @@ def get_processor(
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
**kwargs, **kwargs,
): ):
# pop 'revision' from kwargs if present.
revision = kwargs.pop("revision", tokenizer_revision)
config = AutoConfig.from_pretrained(
tokenizer_name,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
# fix: for Qwen2-VL model, inject default 'size' if not provided.
if config.model_type in {"qwen2_vl"}:
if "size" not in kwargs:
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
tokenizer_name, tokenizer_name,
*args, *args,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision, revision=revision,
**kwargs, **kwargs,
) )
......
...@@ -441,16 +441,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -441,16 +441,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
): ):
super().__init__() super().__init__()
if rotary_dim != head_size:
raise ValueError(
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
rotary_dim != head_size ({rotary_dim}!={head_size})."
)
if is_neox_style is False: if is_neox_style is False:
raise ValueError( raise ValueError(
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style." "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
) )
self.rotary_dim = rotary_dim
self.head_size = head_size self.head_size = head_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings self.original_max_position_embeddings = original_max_position_embeddings
...@@ -499,8 +495,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -499,8 +495,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
* ( * (
self.base self.base
** ( ** (
torch.arange(0, self.head_size, 2, dtype=torch.float) torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
/ self.head_size / self.rotary_dim
) )
) )
) )
...@@ -549,8 +545,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -549,8 +545,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
cos = cos.repeat(1, 2).unsqueeze(-2) cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2) sin = sin.repeat(1, 2).unsqueeze(-2)
query = query * cos + _rotate_neox(query) * sin query_rot = query[..., : self.rotary_dim]
key = key * cos + _rotate_neox(key) * sin query_pass = query[..., self.rotary_dim :]
query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
query = torch.cat((query_rot, query_pass), dim=-1)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
key = torch.cat((key_rot, key_pass), dim=-1)
return query.flatten(-2), key.flatten(-2) return query.flatten(-2), key.flatten(-2)
......
...@@ -21,11 +21,11 @@ from torch import nn ...@@ -21,11 +21,11 @@ from torch import nn
from transformers import ( from transformers import (
ROPE_INIT_FUNCTIONS, ROPE_INIT_FUNCTIONS,
AutoModel, AutoModel,
Gemma3TextConfig,
PretrainedConfig, PretrainedConfig,
PreTrainedModel, PreTrainedModel,
) )
from sglang.srt.configs.gemma3 import Gemma3TextConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import Gemma3RMSNorm from sglang.srt.layers.layernorm import Gemma3RMSNorm
......
...@@ -21,9 +21,15 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict ...@@ -21,9 +21,15 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import torch import torch
from torch import nn from torch import nn
from transformers import AutoModel, PreTrainedModel from transformers import (
AutoModel,
BatchFeature,
Gemma3Config,
Gemma3Processor,
PreTrainedModel,
)
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
from sglang.srt.configs import Gemma3Config
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm from sglang.srt.layers.layernorm import Gemma3RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
......
...@@ -129,6 +129,8 @@ class LlamaAttention(nn.Module): ...@@ -129,6 +129,8 @@ class LlamaAttention(nn.Module):
self.head_dim = getattr( self.head_dim = getattr(
config, "head_dim", self.hidden_size // self.total_num_heads config, "head_dim", self.hidden_size // self.total_num_heads
) )
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
...@@ -154,7 +156,7 @@ class LlamaAttention(nn.Module): ...@@ -154,7 +156,7 @@ class LlamaAttention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.rotary_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
......
...@@ -34,8 +34,15 @@ from einops import rearrange ...@@ -34,8 +34,15 @@ from einops import rearrange
from transformers import AutoModel, Qwen2VLConfig from transformers import AutoModel, Qwen2VLConfig
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from sglang.srt.configs import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -714,4 +721,3 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -714,4 +721,3 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
EntryClass = [Qwen2_5_VLForConditionalGeneration] EntryClass = [Qwen2_5_VLForConditionalGeneration]
AutoModel.register(Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration)
...@@ -20,7 +20,7 @@ pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-rei ...@@ -20,7 +20,7 @@ pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-rei
pip install torch_memory_saver --force-reinstall pip install torch_memory_saver --force-reinstall
pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pandas datasets pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets
# For compling xgrammar kernels # For compling xgrammar kernels
pip install cuda-python nvidia-cuda-nvrtc-cu12 pip install cuda-python nvidia-cuda-nvrtc-cu12
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment