"vscode:/vscode.git/clone" did not exist on "2fecde742db1b08e402eb6b11cfc3d80f2ec8a21"
Unverified Commit 82a1fc72 authored by Jacky Lee's avatar Jacky Lee Committed by GitHub
Browse files

Fix return_dict in encodec (#31646)

* fix: use return_dict parameter

* fix: type checks

* fix: unused imports

* update: one-line if else

* remove: recursive check
parent 5e89b335
...@@ -729,7 +729,7 @@ class EncodecModel(EncodecPreTrainedModel): ...@@ -729,7 +729,7 @@ class EncodecModel(EncodecPreTrainedModel):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
return_dict = return_dict or self.config.return_dict return_dict = return_dict if return_dict is not None else self.config.return_dict
chunk_length = self.config.chunk_length chunk_length = self.config.chunk_length
if chunk_length is None: if chunk_length is None:
...@@ -786,7 +786,7 @@ class EncodecModel(EncodecPreTrainedModel): ...@@ -786,7 +786,7 @@ class EncodecModel(EncodecPreTrainedModel):
>>> audio_codes = outputs.audio_codes >>> audio_codes = outputs.audio_codes
>>> audio_values = outputs.audio_values >>> audio_values = outputs.audio_values
```""" ```"""
return_dict = return_dict or self.config.return_dict return_dict = return_dict if return_dict is not None else self.config.return_dict
if padding_mask is None: if padding_mask is None:
padding_mask = torch.ones_like(input_values).bool() padding_mask = torch.ones_like(input_values).bool()
......
...@@ -19,7 +19,6 @@ import inspect ...@@ -19,7 +19,6 @@ import inspect
import os import os
import tempfile import tempfile
import unittest import unittest
from typing import Dict, List, Tuple
import numpy as np import numpy as np
from datasets import Audio, load_dataset from datasets import Audio, load_dataset
...@@ -385,32 +384,22 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) ...@@ -385,32 +384,22 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs) dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)
def recursive_check(tuple_object, dict_object): self.assertTrue(isinstance(tuple_output, tuple))
if isinstance(tuple_object, (List, Tuple)): self.assertTrue(isinstance(dict_output, dict))
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value) for tuple_value, dict_value in zip(tuple_output, dict_output.values()):
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5
), ),
msg=( msg=(
"Tuple and dict output are not equal. Difference:" "Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}."
), ),
) )
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment