Unverified Commit 65c08928 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Remove unnecessary weight initialization logic (#11736)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent ba214dff
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
Shared resampler perceiver network used in multimodal models and Shared resampler perceiver network used in multimodal models and
related helpers for sincos positional embeddings. related helpers for sincos positional embeddings.
Example models: Qwen (Qwen-VL), Minicpmv2.0 Example models: Qwen (Qwen-VL), MiniCPM-V 2.0
""" """
import math import math
from functools import partial from functools import partial
...@@ -37,7 +37,6 @@ import numpy as np ...@@ -37,7 +37,6 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn.init import trunc_normal_
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -169,8 +168,8 @@ class BaseResampler(nn.Module): ...@@ -169,8 +168,8 @@ class BaseResampler(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) self.query = nn.Parameter(torch.empty(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim: if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, self.kv_proj = ReplicatedLinear(kv_dim,
embed_dim, embed_dim,
...@@ -190,16 +189,7 @@ class BaseResampler(nn.Module): ...@@ -190,16 +189,7 @@ class BaseResampler(nn.Module):
self.ln_post = norm_layer(embed_dim) if do_post_projection else None self.ln_post = norm_layer(embed_dim) if do_post_projection else None
self.proj = nn.Parameter( self.proj = nn.Parameter(
(embed_dim**-0.5) * (embed_dim**-0.5) *
torch.randn(embed_dim, embed_dim)) if do_post_projection else None torch.empty(embed_dim, embed_dim)) if do_post_projection else None
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _repeat(self, query, N: int): def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1) return query.unsqueeze(1).repeat(1, N, 1)
...@@ -240,8 +230,6 @@ class Resampler2(BaseResampler): ...@@ -240,8 +230,6 @@ class Resampler2(BaseResampler):
self.pos_embed = nn.Parameter( self.pos_embed = nn.Parameter(
torch.from_numpy(pos_embed_arr).requires_grad_(False)) torch.from_numpy(pos_embed_arr).requires_grad_(False))
self.apply(self._init_weights)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -3,7 +3,6 @@ from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple, ...@@ -3,7 +3,6 @@ from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.init import trunc_normal_
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
...@@ -216,9 +215,7 @@ class AriaProjector(nn.Module): ...@@ -216,9 +215,7 @@ class AriaProjector(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.query = nn.Parameter( self.query = nn.Parameter(
torch.zeros(max(patch_to_query_dict.values()), self.embed_dim)) torch.empty(max(patch_to_query_dict.values()), self.embed_dim))
trunc_normal_(self.query, std=0.02)
self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads) self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads)
......
...@@ -141,8 +141,6 @@ class Resampler2_5(BaseResampler): ...@@ -141,8 +141,6 @@ class Resampler2_5(BaseResampler):
self.max_size = max_size self.max_size = max_size
self._set_2d_pos_cache(self.max_size) self._set_2d_pos_cache(self.max_size)
self.apply(self._init_weights)
def _set_2d_pos_cache(self, def _set_2d_pos_cache(self,
max_size: Tuple[int, int], max_size: Tuple[int, int],
device: torch.types.Device = "cpu") -> None: device: torch.types.Device = "cpu") -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment