"docs/source/vscode:/vscode.git/clone" did not exist on "ec92f983af5295fc92414a37b988d8384785988a"
Unverified Commit 4124a09f authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[EnCodec] Changes for 32kHz ckpt (#24296)

* [EnCodec] Changes for 32kHz ckpt

* Update src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py

* Update src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py
parent 01b55779
......@@ -90,6 +90,9 @@ class EncodecConfig(PretrainedConfig):
Number of discret codes that make up VQVAE.
codebook_dim (`int`, *optional*):
Dimension of the codebook vectors. If not defined, uses `hidden_size`.
use_conv_shortcut (`bool`, *optional*, defaults to `True`):
Whether to use a convolutional layer as the 'skip' connection in the `EncodecResnetBlock` block. If False,
an identity function will be used, giving a generic residual connection.
Example:
......@@ -131,6 +134,7 @@ class EncodecConfig(PretrainedConfig):
trim_right_ratio=1.0,
codebook_size=1024,
codebook_dim=None,
use_conv_shortcut=True,
**kwargs,
):
self.target_bandwidths = target_bandwidths
......@@ -155,6 +159,7 @@ class EncodecConfig(PretrainedConfig):
self.trim_right_ratio = trim_right_ratio
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
self.use_conv_shortcut = use_conv_shortcut
if self.norm_type not in ["weight_norm", "time_group_norm"]:
raise ValueError(
......
......@@ -28,6 +28,7 @@ from transformers import (
# checkpoints downloaded from:
# https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th
# https://huggingface.co/facebook/musicgen-small/resolve/main/compression_state_dict.bin
# https://dl.fbaipublicfiles.com/encodec/v0/encodec_48khz-7e698e3e.th
......@@ -206,7 +207,7 @@ def should_ignore(name, ignore_keys):
def recursively_load_weights(orig_dict, hf_model, model_name):
unused_weights = []
if model_name == "encodec_24khz":
if model_name == "encodec_24khz" or "encodec_32khz":
MAPPING = MAPPING_24K
elif model_name == "encodec_48khz":
MAPPING = MAPPING_48K
......@@ -292,6 +293,15 @@ def convert_checkpoint(
if model_name == "encodec_24khz":
pass # config is already correct
elif model_name == "encodec_32khz":
config.upsampling_ratios = [8, 5, 4, 4]
config.target_bandwidths = [2.2]
config.num_filters = 64
config.sampling_rate = 32_000
config.codebook_size = 2048
config.use_causal_conv = False
config.normalize = False
config.use_conv_shortcut = False
elif model_name == "encodec_48khz":
config.upsampling_ratios = [8, 5, 4, 2]
config.target_bandwidths = [3.0, 6.0, 12.0, 24.0]
......@@ -316,6 +326,9 @@ def convert_checkpoint(
feature_extractor.save_pretrained(pytorch_dump_folder_path)
original_checkpoint = torch.load(checkpoint_path)
if "best_state" in original_checkpoint:
# we might have a training state saved, in which case discard the yaml results and just retain the weights
original_checkpoint = original_checkpoint["best_state"]
recursively_load_weights(original_checkpoint, model, model_name)
model.save_pretrained(pytorch_dump_folder_path)
......@@ -331,7 +344,7 @@ if __name__ == "__main__":
"--model",
default="encodec_24khz",
type=str,
help="The model to convert. Should be one of 'encodec_24khz', 'encodec_48khz'.",
help="The model to convert. Should be one of 'encodec_24khz', 'encodec_32khz', 'encodec_48khz'.",
)
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
......
......@@ -259,7 +259,10 @@ class EncodecResnetBlock(nn.Module):
block += [EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)]
self.block = nn.ModuleList(block)
if config.use_conv_shortcut:
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
else:
self.shortcut = nn.Identity()
def forward(self, hidden_states):
residual = hidden_states
......
......@@ -385,6 +385,11 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_identity_shortcut(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
config.use_conv_shortcut = False
self.model_tester.create_and_check_model_forward(config, inputs_dict)
def normalize(arr):
norm = np.linalg.norm(arr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment