Unverified Commit da2df84a authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Make `MMDistributedDataParallel` compatible with Pytorch1.12 (#2107)

* make  compatible with pytorch 1.12

* override _run_ddp_forward

* over write _run_ddp_forward

* refactor docstring
parent 6a03918f
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
from typing import Any, List, Tuple
import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
......@@ -140,3 +140,28 @@ class MMDistributedDataParallel(DistributedDataParallel):
and digit_version(TORCH_VERSION) > digit_version('1.2')):
self.require_forward_param_sync = False
return output
def _run_ddp_forward(self, *inputs, **kwargs) -> Any:
"""Processes inputs and runs ``self.module.forward``.
Pytorch 1.12.0 performs ``self.module.forward`` in ``_run_ddp_forward``
and deprecates using ``DistributedDataParallel.to_kwargs`` to
process inputs, which leads to inputs cannot be processed by
:meth:`MMDistributedDataParallel.to_kwargs` anymore. Therefore,
``MMDistributedDataParallel`` overrides this method to call
:meth:`to_kwargs` explicitly.
See more information in `<https://github.com/open-mmlab/mmsegmentation/issues/1742>`_. # noqa: E501
Returns:
Any: Forward result of :attr:`module`.
"""
module_to_run = self._replicated_tensor_module if \
self._use_replicated_tensor_module else self.module
if self.device_ids:
inputs, kwargs = self.to_kwargs( # type: ignore
inputs, kwargs, self.device_ids[0])
return module_to_run(*inputs[0], **kwargs[0]) # type: ignore
else:
return module_to_run(*inputs, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment