Unverified Commit b868526d authored by Brayden Zhong's avatar Brayden Zhong Committed by GitHub
Browse files

Fix one more issue reported by torchfix (#4859)

parent 502524e2
...@@ -25,7 +25,7 @@ import torch.nn.functional as F ...@@ -25,7 +25,7 @@ import torch.nn.functional as F
import torch.nn.utils.parametrize as P import torch.nn.utils.parametrize as P
import torch.types import torch.types
from torch import nn from torch import nn
from torch.nn.utils import weight_norm from torch.nn.utils import parametrizations
from tqdm import tqdm from tqdm import tqdm
from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
...@@ -585,7 +585,7 @@ class ConditionalChatTTS(PreTrainedModel): ...@@ -585,7 +585,7 @@ class ConditionalChatTTS(PreTrainedModel):
self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size) self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
self.head_code = nn.ModuleList( self.head_code = nn.ModuleList(
[ [
weight_norm( parametrizations.weight_norm(
nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False), nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
name="weight", name="weight",
) )
...@@ -1859,11 +1859,22 @@ class MiniCPMO(MiniCPMBaseModel): ...@@ -1859,11 +1859,22 @@ class MiniCPMO(MiniCPMBaseModel):
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
# adapt to parametrization # For weight_norm parametrization, handle both old and new formats
if self.config.init_tts and "tts" in name: if self.config.init_tts and "tts" in name:
name = name.replace(".parametrizations", "") # Handle loading from older checkpoints with weight_g/weight_v format
name = name.replace(".weight.original0", ".weight_g") if ".weight_g" in name or ".weight_v" in name:
name = name.replace(".weight.original1", ".weight_v") name = name.replace(
".weight_g", ".parametrizations.weight.original0"
)
name = name.replace(
".weight_v", ".parametrizations.weight.original1"
)
elif ".weight" in name and name not in params_dict:
param_name = name.replace(
".weight", ".parametrizations.weight.original0"
)
if param_name in params_dict:
name = param_name
# adapt to VisionAttention # adapt to VisionAttention
if "vpm" in name: if "vpm" in name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment