"src/nni_manager/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "c1a5b1eedec9cbfa3966c36ce62b4553afb10dd7"
Unverified Commit 82a1fc72 authored by Jacky Lee's avatar Jacky Lee Committed by GitHub
Browse files

Fix return_dict in encodec (#31646)

* fix: use return_dict parameter

* fix: type checks

* fix: unused imports

* update: one-line if else

* remove: recursive check
parent 5e89b335
......@@ -729,7 +729,7 @@ class EncodecModel(EncodecPreTrainedModel):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
return_dict = return_dict or self.config.return_dict
return_dict = return_dict if return_dict is not None else self.config.return_dict
chunk_length = self.config.chunk_length
if chunk_length is None:
......@@ -786,7 +786,7 @@ class EncodecModel(EncodecPreTrainedModel):
>>> audio_codes = outputs.audio_codes
>>> audio_values = outputs.audio_values
```"""
return_dict = return_dict or self.config.return_dict
return_dict = return_dict if return_dict is not None else self.config.return_dict
if padding_mask is None:
padding_mask = torch.ones_like(input_values).bool()
......
......@@ -19,7 +19,6 @@ import inspect
import os
import tempfile
import unittest
from typing import Dict, List, Tuple
import numpy as np
from datasets import Audio, load_dataset
......@@ -385,31 +384,21 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
recursive_check(tuple_output, dict_output)
self.assertTrue(isinstance(tuple_output, tuple))
self.assertTrue(isinstance(dict_output, dict))
for tuple_value, dict_value in zip(tuple_output, dict_output.values()):
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:"
f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has"
f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}."
),
)
for model_class in self.all_model_classes:
model = model_class(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment