"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a26ce4dee116a1d5d9099c8a94e22d1e31ad631c"
Unverified Commit 2af87d01 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[VITS] Fix nightly tests (#25986)

* fix tokenizer

* make bs even

* fix multi gpu test

* style

* model forward

* fix torch import

* revert tok pin
parent 3744126c
...@@ -27,6 +27,7 @@ from transformers.testing_utils import ( ...@@ -27,6 +27,7 @@ from transformers.testing_utils import (
is_flaky, is_flaky,
is_torch_available, is_torch_available,
require_torch, require_torch,
require_torch_multi_gpu,
slow, slow,
torch_device, torch_device,
) )
...@@ -177,6 +178,30 @@ class VitsModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -177,6 +178,30 @@ class VitsModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs) self.model_tester.create_and_check_model_forward(*config_and_inputs)
@require_torch_multi_gpu
# override to force all elements of the batch to have the same sequence length across GPUs
def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_stochastic_duration_prediction = False
# move input tensors to cuda:O
for key, value in inputs_dict.items():
if torch.is_tensor(value):
# make all elements of the batch the same -> ensures the output seq lengths are the same for DP
value[1:] = value[0]
inputs_dict[key] = value.to(0)
for model_class in self.all_model_classes:
model = model_class(config=config)
model.to(0)
model.eval()
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
set_seed(555)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class)).waveform
@unittest.skip("VITS is not deterministic") @unittest.skip("VITS is not deterministic")
def test_determinism(self): def test_determinism(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment