data_parallel.py 249 Bytes
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
7
8
9
from torch.nn.parallel import DataParallel

from .scatter_gather import scatter_kwargs


class MMDataParallel(DataParallel):

    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)