Unverified Commit 4c068936 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Fix nn.DataParallel compatibility in PyTorch 1.5 (#4300)

* Test case for #3936

* multigpu tests pass on pytorch 1.4.0

* Fixup

* multigpu tests pass on pytorch 1.5.0

* Update src/transformers/modeling_utils.py

* Update src/transformers/modeling_utils.py

* rename multigpu to require_multigpu

* mode doc
parent 9de4afa8
...@@ -550,7 +550,7 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -550,7 +550,7 @@ class AlbertModel(AlbertPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
......
...@@ -703,9 +703,7 @@ class BertModel(BertPreTrainedModel): ...@@ -703,9 +703,7 @@ class BertModel(BertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
attention_mask, input_shape, self.device
)
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
......
...@@ -704,7 +704,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -704,7 +704,7 @@ class T5Stack(T5PreTrainedModel):
past_key_value_states = [None] * len(self.block) past_key_value_states = [None] * len(self.block)
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is not None: if self.is_decoder and encoder_attention_mask is not None:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import inspect import inspect
import logging import logging
import os import os
from typing import Callable, Tuple from typing import Callable, List, Tuple
import torch import torch
from torch import Tensor, device, dtype, nn from torch import Tensor, device, dtype, nn
...@@ -110,11 +110,33 @@ class ModuleUtilsMixin: ...@@ -110,11 +110,33 @@ class ModuleUtilsMixin:
@property @property
def device(self) -> device: def device(self) -> device:
return next(self.parameters()).device try:
return next(self.parameters()).device
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
@property @property
def dtype(self) -> dtype: def dtype(self) -> dtype:
return next(self.parameters()).dtype try:
return next(self.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
"""type: torch.Tensor -> torch.Tensor""" """type: torch.Tensor -> torch.Tensor"""
......
...@@ -623,7 +623,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -623,7 +623,7 @@ class XLNetModel(XLNetPreTrainedModel):
mask_lo = torch.tril(attn_mask, diagonal=-1) mask_lo = torch.tril(attn_mask, diagonal=-1)
ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1) ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
ret = ret.to(next(self.parameters())) ret = ret.to(self.device)
return ret return ret
def cache_mem(self, curr_out, prev_mem): def cache_mem(self, curr_out, prev_mem):
...@@ -685,7 +685,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -685,7 +685,7 @@ class XLNetModel(XLNetPreTrainedModel):
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz) pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
pos_emb = pos_emb.to(next(self.parameters())) pos_emb = pos_emb.to(self.device)
return pos_emb return pos_emb
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
...@@ -761,8 +761,8 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -761,8 +761,8 @@ class XLNetModel(XLNetPreTrainedModel):
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0 mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
klen = mlen + qlen klen = mlen + qlen
dtype_float = next(self.parameters()).dtype dtype_float = self.dtype
device = next(self.parameters()).device device = self.device
# Attention mask # Attention mask
# causal attention mask # causal attention mask
......
...@@ -23,7 +23,7 @@ from typing import List ...@@ -23,7 +23,7 @@ from typing import List
from transformers import is_torch_available from transformers import is_torch_available
from .utils import require_torch, slow, torch_device from .utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
...@@ -758,6 +758,31 @@ class ModelTesterMixin: ...@@ -758,6 +758,31 @@ class ModelTesterMixin:
return True return True
return False return False
@require_multigpu
def test_multigpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present.
blacklist_non_batched_params = ["head_mask"]
for k in blacklist_non_batched_params:
inputs_dict.pop(k, None)
# move input tensors to cuda:O
for k, v in inputs_dict.items():
if torch.is_tensor(v):
inputs_dict[k] = v.to(0)
for model_class in self.all_model_classes:
model = model_class(config=config)
model.to(0)
model.eval()
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
with torch.no_grad():
_ = model(**inputs_dict)
global_rng = random.Random() global_rng = random.Random()
......
...@@ -41,7 +41,7 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -41,7 +41,7 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
def __init__( def __init__(
self, self,
parent, parent,
batch_size=13, batch_size=14,
seq_length=7, seq_length=7,
is_training=True, is_training=True,
use_token_type_ids=True, use_token_type_ids=True,
......
...@@ -46,7 +46,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -46,7 +46,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
def __init__( def __init__(
self, self,
parent, parent,
batch_size=13, batch_size=14,
seq_length=7, seq_length=7,
is_training=True, is_training=True,
use_token_type_ids=True, use_token_type_ids=True,
......
...@@ -19,7 +19,7 @@ from transformers import is_torch_available ...@@ -19,7 +19,7 @@ from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from .utils import require_torch, slow, torch_device from .utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
...@@ -448,9 +448,14 @@ class ReformerTesterMixin: ...@@ -448,9 +448,14 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs) self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)
@require_multigpu
def test_multigpu_data_parallel_forward(self):
# Opt-out of this test.
pass
@require_torch @require_torch
class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest.TestCase): class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False test_pruning = False
...@@ -504,7 +509,7 @@ class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest ...@@ -504,7 +509,7 @@ class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest
@require_torch @require_torch
class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase, ReformerTesterMixin): class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False test_pruning = False
......
...@@ -21,7 +21,7 @@ from transformers import is_torch_available ...@@ -21,7 +21,7 @@ from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device from .utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
...@@ -43,7 +43,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -43,7 +43,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
def __init__( def __init__(
self, self,
parent, parent,
batch_size=13, batch_size=14,
seq_length=7, seq_length=7,
mem_len=30, mem_len=30,
clamp_len=15, clamp_len=15,
...@@ -207,6 +207,11 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -207,6 +207,11 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs) output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
self.model_tester.check_transfo_xl_lm_head_output(output_result) self.model_tester.check_transfo_xl_lm_head_output(output_result)
@require_multigpu
def test_multigpu_data_parallel_forward(self):
# Opt-out of this test.
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -61,7 +61,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -61,7 +61,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
def __init__( def __init__(
self, self,
parent, parent,
batch_size=13, batch_size=14,
seq_length=7, seq_length=7,
mem_len=10, mem_len=10,
clamp_len=-1, clamp_len=-1,
......
...@@ -94,6 +94,25 @@ def require_tf(test_case): ...@@ -94,6 +94,25 @@ def require_tf(test_case):
return test_case return test_case
def require_multigpu(test_case):
"""
Decorator marking a test that requires a multi-GPU setup (in PyTorch).
These tests are skipped on a machine without multiple GPUs.
To run *only* the multigpu tests, assuming all test names contain multigpu:
$ pytest -sv ./tests -k "multigpu"
"""
if not _torch_available:
return unittest.skip("test requires PyTorch")(test_case)
import torch
if torch.cuda.device_count() < 2:
return unittest.skip("test requires multiple GPUs")(test_case)
return test_case
if _torch_available: if _torch_available:
# Set the USE_CUDA environment variable to select a GPU. # Set the USE_CUDA environment variable to select a GPU.
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu" torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment