Commit 3a716536 authored by Li Xiaohui's avatar Li Xiaohui
Browse files

Add VL model adaptations and version update

parent 259605da
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import glob
requires_files = glob.glob('requirements/*.txt')
requires_files += ["pyproject.toml"]
for file in requires_files:
print(f">>> cleaning {file}")
with open(file) as f:
lines = f.readlines()
if "torch" in "".join(lines).lower():
print("removed:")
with open(file, 'w') as f:
for line in lines:
if 'torch' not in line.lower():
f.write(line)
else:
print(line.strip())
print(f"<<< done cleaning {file}")
print()
\ No newline at end of file
......@@ -280,6 +280,9 @@ class Glm4vVisionAttention(nn.Module):
f"GLM-4V does not support {self.attn_backend} backend now.")
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
if qkv.dim() == 2:
qkv = qkv.unsqueeze(1) # dim加上batch维度
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
......@@ -424,16 +427,56 @@ class Glm4vVisionBlock(nn.Module):
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:
# -------------------------
# 1) Attention
# -------------------------
normed_x = self.norm1(x)
x_attn = self.attn(
self.norm1(x),
normed_x,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
# 保证 attn 输出为 3D tensor
if x_attn.dim() == 2:
x_attn = x_attn.unsqueeze(1)
elif x_attn.dim() == 1:
x_attn = x_attn.unsqueeze(1).unsqueeze(2)
assert x_attn.dim() == 3, f"x_attn must be 3D, got {x_attn.shape}"
# -------------------------
# 2) norm2 + residual
# -------------------------
x_fused_norm, residual = self.norm2(x, residual=x_attn)
# -------------------------
# 3) MLP 前形状检查(核心)
# ------------------------
if x_fused_norm.dim() == 3 and x_fused_norm.shape[1] == 1:
mlp_in = x_fused_norm.squeeze(1)
restore_3d = True
elif x_fused_norm.dim() == 2:
mlp_in = x_fused_norm
restore_3d = False
else:
raise RuntimeError(f"Unexpected x_fused_norm shape {x_fused_norm.shape}, expect (N,D) or (N,1,D)")
# -------------------------
# 4) MLP
# ------------------------
out = self.mlp(mlp_in)
# MLP 可能返回 (N,D),恢复回三维
if restore_3d:
out = out.unsqueeze(1)
# -------------------------
# 5) residual + mlp_out
# -------------------------
assert out.shape == residual.shape, \
f"residual {residual.shape} vs mlp_out {out.shape} mismatch"
x = residual + out
return x
......
......@@ -874,10 +874,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
hidden_states = hidden_states[reverse_indices, :]
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("attn.qkv.", "attn.q.", "q"),
("attn.qkv.", "attn.k.", "k"),
("attn.qkv.", "attn.v.", "v"),
......@@ -886,65 +885,43 @@ class Qwen2_5_VisionTransformer(nn.Module):
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
original_name = name
if original_name.endswith(".bias"):
if original_name not in params_dict:
continue
param = params_dict[original_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(original_name)
continue
if name.startswith("vision_tower."):
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
name = name.replace(weight_name, param_name)
if name not in params_dict:
break
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name not in params_dict:
fixed_name = name.replace("model.", "")
if fixed_name in params_dict:
name = fixed_name
if name not in params_dict:
continue
param = params_dict[name]
if hasattr(loaded_weight, "ndim") and loaded_weight.ndim == 2:
if tuple(param.data.shape) == tuple(loaded_weight.t().shape):
loaded_weight = loaded_weight.t().contiguous()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"attn.qkv.weight",
"attn.proj.weight",
"mlp.0.weight",
"mlp.2.weight",
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
"lm_head.weight",
]
combined_words = "|".join(lay_key_words)
# lay_qkv_words = ["attn.qkv.weight"]
# qkv_words = "|".join(lay_qkv_words)
# lay_qkv_bias_words = ["attn.qkv.bias"]
# qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername in loaded_params:
weight = params_dict[layername]
# if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
# weight.data = pad_weight(weight.data, 32)
matches = re.findall(combined_words, layername)
if matches:
# if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
# weight.data = pad_weight(weight.data, 32)
# if self.use_fa_pad and (re.findall(qkv_words, layername)):
# if not gemm_bank_conf(weight.data.shape[0]):
# weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
return loaded_params
......@@ -1541,4 +1518,4 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
language_model="language_model",
connector="visual.merger.",
tower_model="visual.",
)
\ No newline at end of file
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try:
from ._version import __version__, __version_tuple__
__version__ = "0.11.0"
__version_tuple__ = (0, 11, 0)
__hcu_version__ = f'0.11.0+das.opt1.alpha.dtk25042'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e:
import warnings
warnings.warn(f"Failed to read commit hash:\n{e}",
warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning,
stacklevel=2)
__version__ = "dev"
__version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str):
"""Check whether a given version matches the previous minor version.
'''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
......@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
"""
'''
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
assert __version_tuple__[0] == 0
# assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version():
"""For the purpose of testing, return a previous minor version number."""
'''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment