Unverified Commit 4c068936 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Fix nn.DataParallel compatibility in PyTorch 1.5 (#4300)

* Test case for #3936

* multigpu tests pass on pytorch 1.4.0

* Fixup

* multigpu tests pass on pytorch 1.5.0

* Update src/transformers/modeling_utils.py

* Update src/transformers/modeling_utils.py

* rename multigpu to require_multigpu

* mode doc
parent 9de4afa8
......@@ -550,7 +550,7 @@ class AlbertModel(AlbertPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
......
......@@ -703,9 +703,7 @@ class BertModel(BertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, self.device
)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
......
......@@ -704,7 +704,7 @@ class T5Stack(T5PreTrainedModel):
past_key_value_states = [None] * len(self.block)
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is not None:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
......
......@@ -17,7 +17,7 @@
import inspect
import logging
import os
from typing import Callable, Tuple
from typing import Callable, List, Tuple
import torch
from torch import Tensor, device, dtype, nn
......@@ -110,11 +110,33 @@ class ModuleUtilsMixin:
@property
def device(self) -> device:
try:
return next(self.parameters()).device
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
@property
def dtype(self) -> dtype:
try:
return next(self.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
"""type: torch.Tensor -> torch.Tensor"""
......
......@@ -623,7 +623,7 @@ class XLNetModel(XLNetPreTrainedModel):
mask_lo = torch.tril(attn_mask, diagonal=-1)
ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
ret = ret.to(next(self.parameters()))
ret = ret.to(self.device)
return ret
def cache_mem(self, curr_out, prev_mem):
......@@ -685,7 +685,7 @@ class XLNetModel(XLNetPreTrainedModel):
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
pos_emb = pos_emb.to(next(self.parameters()))
pos_emb = pos_emb.to(self.device)
return pos_emb
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
......@@ -761,8 +761,8 @@ class XLNetModel(XLNetPreTrainedModel):
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
klen = mlen + qlen
dtype_float = next(self.parameters()).dtype
device = next(self.parameters()).device
dtype_float = self.dtype
device = self.device
# Attention mask
# causal attention mask
......
......@@ -23,7 +23,7 @@ from typing import List
from transformers import is_torch_available
from .utils import require_torch, slow, torch_device
from .utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available():
......@@ -758,6 +758,31 @@ class ModelTesterMixin:
return True
return False
@require_multigpu
def test_multigpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present.
blacklist_non_batched_params = ["head_mask"]
for k in blacklist_non_batched_params:
inputs_dict.pop(k, None)
# move input tensors to cuda:O
for k, v in inputs_dict.items():
if torch.is_tensor(v):
inputs_dict[k] = v.to(0)
for model_class in self.all_model_classes:
model = model_class(config=config)
model.to(0)
model.eval()
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
with torch.no_grad():
_ = model(**inputs_dict)
global_rng = random.Random()
......
......@@ -41,7 +41,7 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
def __init__(
self,
parent,
batch_size=13,
batch_size=14,
seq_length=7,
is_training=True,
use_token_type_ids=True,
......
......@@ -46,7 +46,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
def __init__(
self,
parent,
batch_size=13,
batch_size=14,
seq_length=7,
is_training=True,
use_token_type_ids=True,
......
......@@ -19,7 +19,7 @@ from transformers import is_torch_available
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from .utils import require_torch, slow, torch_device
from .utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available():
......@@ -448,9 +448,14 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)
@require_multigpu
def test_multigpu_data_parallel_forward(self):
# Opt-out of this test.
pass
@require_torch
class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest.TestCase):
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False
......@@ -504,7 +509,7 @@ class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest
@require_torch
class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase, ReformerTesterMixin):
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False
......
......@@ -21,7 +21,7 @@ from transformers import is_torch_available
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
from .utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available():
......@@ -43,7 +43,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
def __init__(
self,
parent,
batch_size=13,
batch_size=14,
seq_length=7,
mem_len=30,
clamp_len=15,
......@@ -207,6 +207,11 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
self.model_tester.check_transfo_xl_lm_head_output(output_result)
@require_multigpu
def test_multigpu_data_parallel_forward(self):
# Opt-out of this test.
pass
@slow
def test_model_from_pretrained(self):
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
......@@ -61,7 +61,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
def __init__(
self,
parent,
batch_size=13,
batch_size=14,
seq_length=7,
mem_len=10,
clamp_len=-1,
......
......@@ -94,6 +94,25 @@ def require_tf(test_case):
return test_case
def require_multigpu(test_case):
"""
Decorator marking a test that requires a multi-GPU setup (in PyTorch).
These tests are skipped on a machine without multiple GPUs.
To run *only* the multigpu tests, assuming all test names contain multigpu:
$ pytest -sv ./tests -k "multigpu"
"""
if not _torch_available:
return unittest.skip("test requires PyTorch")(test_case)
import torch
if torch.cuda.device_count() < 2:
return unittest.skip("test requires multiple GPUs")(test_case)
return test_case
if _torch_available:
# Set the USE_CUDA environment variable to select a GPU.
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment