"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "9611c2d0aae7a1a667a3eecaa92756fea1073f20"
Unverified Commit 35cb101e authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

DataParallel fixes (#5733)

* DataParallel fixes:

1. switched to a more precise check
-        if self.args.n_gpu > 1:
+        if isinstance(model, nn.DataParallel):

2. fix tests - require the same fixup under DataParallel as the training module

* another fix
parent 290b6e18
...@@ -199,6 +199,9 @@ def train(args, train_dataset, model, tokenizer): ...@@ -199,6 +199,9 @@ def train(args, train_dataset, model, tokenizer):
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)} {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
) )
if isinstance(model, torch.nn.DataParallel):
inputs["return_tuple"] = True
outputs = model(**inputs) outputs = model(**inputs)
# model outputs are always tuple in transformers (see doc) # model outputs are always tuple in transformers (see doc)
loss = outputs[0] loss = outputs[0]
......
...@@ -623,7 +623,7 @@ class Trainer: ...@@ -623,7 +623,7 @@ class Trainer:
if self.args.past_index >= 0 and self._past is not None: if self.args.past_index >= 0 and self._past is not None:
inputs["mems"] = self._past inputs["mems"] = self._past
# Our model outputs do not work with DataParallel, so forcing return tuple. # Our model outputs do not work with DataParallel, so forcing return tuple.
if self.args.n_gpu > 1: if isinstance(model, nn.DataParallel):
inputs["return_tuple"] = True inputs["return_tuple"] = True
outputs = model(**inputs) outputs = model(**inputs)
...@@ -826,7 +826,7 @@ class Trainer: ...@@ -826,7 +826,7 @@ class Trainer:
if self.args.past_index >= 0: if self.args.past_index >= 0:
inputs["mems"] = past inputs["mems"] = past
# Our model outputs do not work with DataParallel, so forcing return tuple. # Our model outputs do not work with DataParallel, so forcing return tuple.
if self.args.n_gpu > 1: if isinstance(model, nn.DataParallel):
inputs["return_tuple"] = True inputs["return_tuple"] = True
with torch.no_grad(): with torch.no_grad():
......
...@@ -803,6 +803,8 @@ class ModelTesterMixin: ...@@ -803,6 +803,8 @@ class ModelTesterMixin:
# Wrap model in nn.DataParallel # Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Our model outputs do not work with DataParallel, so forcing return tuple.
inputs_dict["return_tuple"] = True
with torch.no_grad(): with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class)) _ = model(**self._prepare_for_class(inputs_dict, model_class))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment