Unverified Commit 4124a09f authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[EnCodec] Changes for 32kHz ckpt (#24296)

* [EnCodec] Changes for 32kHz ckpt

* Update src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py

* Update src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py
parent 01b55779
...@@ -90,6 +90,9 @@ class EncodecConfig(PretrainedConfig): ...@@ -90,6 +90,9 @@ class EncodecConfig(PretrainedConfig):
Number of discret codes that make up VQVAE. Number of discret codes that make up VQVAE.
codebook_dim (`int`, *optional*): codebook_dim (`int`, *optional*):
Dimension of the codebook vectors. If not defined, uses `hidden_size`. Dimension of the codebook vectors. If not defined, uses `hidden_size`.
use_conv_shortcut (`bool`, *optional*, defaults to `True`):
Whether to use a convolutional layer as the 'skip' connection in the `EncodecResnetBlock` block. If False,
an identity function will be used, giving a generic residual connection.
Example: Example:
...@@ -131,6 +134,7 @@ class EncodecConfig(PretrainedConfig): ...@@ -131,6 +134,7 @@ class EncodecConfig(PretrainedConfig):
trim_right_ratio=1.0, trim_right_ratio=1.0,
codebook_size=1024, codebook_size=1024,
codebook_dim=None, codebook_dim=None,
use_conv_shortcut=True,
**kwargs, **kwargs,
): ):
self.target_bandwidths = target_bandwidths self.target_bandwidths = target_bandwidths
...@@ -155,6 +159,7 @@ class EncodecConfig(PretrainedConfig): ...@@ -155,6 +159,7 @@ class EncodecConfig(PretrainedConfig):
self.trim_right_ratio = trim_right_ratio self.trim_right_ratio = trim_right_ratio
self.codebook_size = codebook_size self.codebook_size = codebook_size
self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
self.use_conv_shortcut = use_conv_shortcut
if self.norm_type not in ["weight_norm", "time_group_norm"]: if self.norm_type not in ["weight_norm", "time_group_norm"]:
raise ValueError( raise ValueError(
......
...@@ -28,6 +28,7 @@ from transformers import ( ...@@ -28,6 +28,7 @@ from transformers import (
# checkpoints downloaded from: # checkpoints downloaded from:
# https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th # https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th
# https://huggingface.co/facebook/musicgen-small/resolve/main/compression_state_dict.bin
# https://dl.fbaipublicfiles.com/encodec/v0/encodec_48khz-7e698e3e.th # https://dl.fbaipublicfiles.com/encodec/v0/encodec_48khz-7e698e3e.th
...@@ -206,7 +207,7 @@ def should_ignore(name, ignore_keys): ...@@ -206,7 +207,7 @@ def should_ignore(name, ignore_keys):
def recursively_load_weights(orig_dict, hf_model, model_name): def recursively_load_weights(orig_dict, hf_model, model_name):
unused_weights = [] unused_weights = []
if model_name == "encodec_24khz": if model_name == "encodec_24khz" or "encodec_32khz":
MAPPING = MAPPING_24K MAPPING = MAPPING_24K
elif model_name == "encodec_48khz": elif model_name == "encodec_48khz":
MAPPING = MAPPING_48K MAPPING = MAPPING_48K
...@@ -292,6 +293,15 @@ def convert_checkpoint( ...@@ -292,6 +293,15 @@ def convert_checkpoint(
if model_name == "encodec_24khz": if model_name == "encodec_24khz":
pass # config is already correct pass # config is already correct
elif model_name == "encodec_32khz":
config.upsampling_ratios = [8, 5, 4, 4]
config.target_bandwidths = [2.2]
config.num_filters = 64
config.sampling_rate = 32_000
config.codebook_size = 2048
config.use_causal_conv = False
config.normalize = False
config.use_conv_shortcut = False
elif model_name == "encodec_48khz": elif model_name == "encodec_48khz":
config.upsampling_ratios = [8, 5, 4, 2] config.upsampling_ratios = [8, 5, 4, 2]
config.target_bandwidths = [3.0, 6.0, 12.0, 24.0] config.target_bandwidths = [3.0, 6.0, 12.0, 24.0]
...@@ -316,6 +326,9 @@ def convert_checkpoint( ...@@ -316,6 +326,9 @@ def convert_checkpoint(
feature_extractor.save_pretrained(pytorch_dump_folder_path) feature_extractor.save_pretrained(pytorch_dump_folder_path)
original_checkpoint = torch.load(checkpoint_path) original_checkpoint = torch.load(checkpoint_path)
if "best_state" in original_checkpoint:
# we might have a training state saved, in which case discard the yaml results and just retain the weights
original_checkpoint = original_checkpoint["best_state"]
recursively_load_weights(original_checkpoint, model, model_name) recursively_load_weights(original_checkpoint, model, model_name)
model.save_pretrained(pytorch_dump_folder_path) model.save_pretrained(pytorch_dump_folder_path)
...@@ -331,7 +344,7 @@ if __name__ == "__main__": ...@@ -331,7 +344,7 @@ if __name__ == "__main__":
"--model", "--model",
default="encodec_24khz", default="encodec_24khz",
type=str, type=str,
help="The model to convert. Should be one of 'encodec_24khz', 'encodec_48khz'.", help="The model to convert. Should be one of 'encodec_24khz', 'encodec_32khz', 'encodec_48khz'.",
) )
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint") parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
......
...@@ -259,7 +259,10 @@ class EncodecResnetBlock(nn.Module): ...@@ -259,7 +259,10 @@ class EncodecResnetBlock(nn.Module):
block += [EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)] block += [EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)]
self.block = nn.ModuleList(block) self.block = nn.ModuleList(block)
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1) if config.use_conv_shortcut:
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
else:
self.shortcut = nn.Identity()
def forward(self, hidden_states): def forward(self, hidden_states):
residual = hidden_states residual = hidden_states
......
...@@ -385,6 +385,11 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) ...@@ -385,6 +385,11 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
def test_identity_shortcut(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
config.use_conv_shortcut = False
self.model_tester.create_and_check_model_forward(config, inputs_dict)
def normalize(arr): def normalize(arr):
norm = np.linalg.norm(arr) norm = np.linalg.norm(arr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment