Unverified Commit ebbb75e9 authored by blzheng's avatar blzheng Committed by GitHub
Browse files

[CPU] Fix TP padding issue on Phi-4 (#8289)

parent b341b7db
...@@ -49,14 +49,25 @@ def get_num_heads_padding_size(tp_size, weight_block_size): ...@@ -49,14 +49,25 @@ def get_num_heads_padding_size(tp_size, weight_block_size):
def update_intermediate_size(model_config, attr_name, intermediate_padding_size): def update_intermediate_size(model_config, attr_name, intermediate_padding_size):
if hasattr(model_config.hf_config, attr_name): attr_value = intermediate_padding_size
if hasattr(model_config, "hf_config") and hasattr(
model_config.hf_config, attr_name
):
attr_value = getattr(model_config.hf_config, attr_name) attr_value = getattr(model_config.hf_config, attr_name)
if attr_value % intermediate_padding_size != 0: elif hasattr(model_config, attr_name):
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size attr_value = getattr(model_config, attr_name)
if attr_value % intermediate_padding_size != 0:
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
attr_value = pad_vocab_size(attr_value, intermediate_padding_size) attr_value = pad_vocab_size(attr_value, intermediate_padding_size)
if hasattr(model_config, "hf_config"):
setattr(model_config.hf_config, attr_name, attr_value) setattr(model_config.hf_config, attr_name, attr_value)
setattr(model_config.hf_text_config, attr_name, attr_value) if hasattr(model_config, "hf_text_config"):
setattr(model_config.hf_text_config, attr_name, attr_value)
else:
setattr(model_config, attr_name, attr_value)
return model_config return model_config
...@@ -118,4 +129,28 @@ def adjust_config_with_unaligned_cpu_tp( ...@@ -118,4 +129,28 @@ def adjust_config_with_unaligned_cpu_tp(
model_config = update_intermediate_size( model_config = update_intermediate_size(
model_config, "intermediate_size_mlp", intermediate_padding_size model_config, "intermediate_size_mlp", intermediate_padding_size
) )
if (
hasattr(model_config.hf_config, "vision_config")
and model_config.hf_config.vision_config.model_type == "siglip_vision_model"
):
model_config.hf_config.vision_config.original_num_attention_heads = (
model_config.num_attention_heads
)
if model_config.hf_config.vision_config.num_attention_heads % tp_size != 0:
model_config.hf_config.vision_config.head_dim = (
model_config.hf_config.vision_config.hidden_size
// model_config.hf_config.vision_config.num_attention_heads
)
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
model_config.hf_config.vision_config.num_attention_heads = pad_vocab_size(
model_config.hf_config.vision_config.num_attention_heads, pad_size
)
model_config.hf_config.vision_config = update_intermediate_size(
model_config.hf_config.vision_config,
"intermediate_size",
intermediate_padding_size,
)
return model_config return model_config
...@@ -129,6 +129,25 @@ def get_config( ...@@ -129,6 +129,25 @@ def get_config(
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
) )
if (
config.architectures is not None
and config.architectures[0] == "Phi4MMForCausalLM"
):
# Phi4MMForCausalLM uses a hard-coded vision_config. See:
# https://github.com/vllm-project/vllm/blob/6071e989df1531b59ef35568f83f7351afb0b51e/vllm/model_executor/models/phi4mm.py#L71
# We set it here to support cases where num_attention_heads is not divisible by the TP size.
from transformers import SiglipVisionConfig
vision_config = {
"hidden_size": 1152,
"image_size": 448,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction.
"patch_size": 14,
}
config.vision_config = SiglipVisionConfig(**vision_config)
text_config = get_hf_text_config(config=config) text_config = get_hf_text_config(config=config)
if isinstance(model, str) and text_config is not None: if isinstance(model, str) and text_config is not None:
......
...@@ -110,6 +110,20 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): ...@@ -110,6 +110,20 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
return param[shard_id], loaded_weight return param[shard_id], loaded_weight
def adjust_shard_offsets(shard_offsets, loaded_weight, dim):
actual_weight_size = loaded_weight.size(dim)
target_weight_size = shard_offsets[-1][-1] + shard_offsets[-1][-2]
if actual_weight_size != target_weight_size:
new_shard_offsets = []
new_offset = 0
for shard_id, shard_offset, shard_size in shard_offsets:
actual_shard_size = actual_weight_size * shard_size // target_weight_size
new_shard_offsets.append((shard_id, new_offset, actual_shard_size))
new_offset += actual_shard_size
return new_shard_offsets
return shard_offsets
class LinearBase(torch.nn.Module): class LinearBase(torch.nn.Module):
"""Base linear layer. """Base linear layer.
...@@ -535,6 +549,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -535,6 +549,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
packed_dim = getattr(param, "packed_dim", None) packed_dim = getattr(param, "packed_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
if _is_cpu:
shard_offsets = adjust_shard_offsets(
shard_offsets, loaded_weight, output_dim
)
for shard_id, shard_offset, shard_size in shard_offsets: for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization. # Special case for Quantization.
# If quantized, we need to adjust the offset and size to account # If quantized, we need to adjust the offset and size to account
...@@ -977,6 +996,11 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -977,6 +996,11 @@ class QKVParallelLinear(ColumnParallelLinear):
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
packed_dim = getattr(param, "packed_dim", None) packed_dim = getattr(param, "packed_dim", None)
if _is_cpu:
shard_offsets = adjust_shard_offsets(
shard_offsets, loaded_weight, output_dim
)
for shard_id, shard_offset, shard_size in shard_offsets: for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights. # Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account # If quantized, we need to adjust the offset and size to account
......
...@@ -116,9 +116,15 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -116,9 +116,15 @@ class UnquantizedLinearMethod(LinearMethodBase):
) -> torch.Tensor: ) -> torch.Tensor:
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.weight_packed_linear( x_shapes = x.shape
if len(x_shapes) == 3:
x = x.view(-1, x.shape[-1])
output = torch.ops.sgl_kernel.weight_packed_linear(
x, layer.weight, bias, True # is_vnni x, layer.weight, bias, True # is_vnni
) )
if len(x_shapes) == 3:
output = output.view(x_shapes[0], x_shapes[1], -1)
return output
return F.linear(x, layer.weight, bias) return F.linear(x, layer.weight, bias)
......
...@@ -54,25 +54,6 @@ VISION_ENCODER_TO_PROCESSING_CONFIG = { ...@@ -54,25 +54,6 @@ VISION_ENCODER_TO_PROCESSING_CONFIG = {
} }
def get_navit_vision_model():
vision_config = {
"hidden_size": 1152,
"image_size": 448,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction.
"patch_size": 14,
}
model_config = SiglipVisionConfig(**vision_config)
vision_model = Idefics2VisionTransformer(
config=model_config, require_post_norm=False
)
return vision_model
class Phi4MMImageEncoder(nn.Module): class Phi4MMImageEncoder(nn.Module):
"""Image embedding.""" """Image embedding."""
...@@ -88,8 +69,9 @@ class Phi4MMImageEncoder(nn.Module): ...@@ -88,8 +69,9 @@ class Phi4MMImageEncoder(nn.Module):
# n_embed or hidden_size # n_embed or hidden_size
hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
self.type_feature = "patch" self.type_feature = "patch"
self.img_processor = Idefics2VisionTransformer(
self.img_processor = get_navit_vision_model() config=config.vision_config, require_post_norm=False
)
pe_weight = self.img_processor.embeddings.position_embedding.weight pe_weight = self.img_processor.embeddings.position_embedding.weight
L, D = pe_weight.size() L, D = pe_weight.size()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment