Unverified Commit 2af87d01 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[VITS] Fix nightly tests (#25986)

* fix tokenizer

* make bs even

* fix multi gpu test

* style

* model forward

* fix torch import

* revert tok pin
parent 3744126c
......@@ -27,6 +27,7 @@ from transformers.testing_utils import (
is_flaky,
is_torch_available,
require_torch,
require_torch_multi_gpu,
slow,
torch_device,
)
......@@ -177,6 +178,30 @@ class VitsModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
@require_torch_multi_gpu
# override to force all elements of the batch to have the same sequence length across GPUs
def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_stochastic_duration_prediction = False
# move input tensors to cuda:O
for key, value in inputs_dict.items():
if torch.is_tensor(value):
# make all elements of the batch the same -> ensures the output seq lengths are the same for DP
value[1:] = value[0]
inputs_dict[key] = value.to(0)
for model_class in self.all_model_classes:
model = model_class(config=config)
model.to(0)
model.eval()
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
set_seed(555)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class)).waveform
@unittest.skip("VITS is not deterministic")
def test_determinism(self):
pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment