Unverified Commit cbf036f7 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add test (#14810)

parent c4a0fb51
...@@ -28,7 +28,7 @@ from datasets import load_dataset ...@@ -28,7 +28,7 @@ from datasets import load_dataset
from transformers import PerceiverConfig from transformers import PerceiverConfig
from transformers.file_utils import is_torch_available, is_vision_available from transformers.file_utils import is_torch_available, is_vision_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
...@@ -757,6 +757,31 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -757,6 +757,31 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase):
loss.backward() loss.backward()
@require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_model_class(model_class)
# some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present.
blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"]
for k in blacklist_non_batched_params:
inputs_dict.pop(k, None)
# move input tensors to cuda:O
for k, v in inputs_dict.items():
if torch.is_tensor(v):
inputs_dict[k] = v.to(0)
model = model_class(config=config)
model.to(0)
model.eval()
# Wrap model in nn.DataParallel
model = nn.DataParallel(model)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))
@unittest.skip(reason="Perceiver models don't have a typical head like is the case with BERT") @unittest.skip(reason="Perceiver models don't have a typical head like is the case with BERT")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment