Unverified Commit b868526d authored by Brayden Zhong's avatar Brayden Zhong Committed by GitHub
Browse files

Fix one more issue reported by torchfix (#4859)

parent 502524e2
......@@ -25,7 +25,7 @@ import torch.nn.functional as F
import torch.nn.utils.parametrize as P
import torch.types
from torch import nn
from torch.nn.utils import weight_norm
from torch.nn.utils import parametrizations
from tqdm import tqdm
from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
......@@ -585,7 +585,7 @@ class ConditionalChatTTS(PreTrainedModel):
self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
self.head_code = nn.ModuleList(
[
weight_norm(
parametrizations.weight_norm(
nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
name="weight",
)
......@@ -1859,11 +1859,22 @@ class MiniCPMO(MiniCPMBaseModel):
# the checkpoint. Skip them.
continue
# adapt to parametrization
# For weight_norm parametrization, handle both old and new formats
if self.config.init_tts and "tts" in name:
name = name.replace(".parametrizations", "")
name = name.replace(".weight.original0", ".weight_g")
name = name.replace(".weight.original1", ".weight_v")
# Handle loading from older checkpoints with weight_g/weight_v format
if ".weight_g" in name or ".weight_v" in name:
name = name.replace(
".weight_g", ".parametrizations.weight.original0"
)
name = name.replace(
".weight_v", ".parametrizations.weight.original1"
)
elif ".weight" in name and name not in params_dict:
param_name = name.replace(
".weight", ".parametrizations.weight.original0"
)
if param_name in params_dict:
name = param_name
# adapt to VisionAttention
if "vpm" in name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment