"src/geometry/vscode:/vscode.git/clone" did not exist on "aaaecbc9030acbab35ec55db4becde9ca8b765b4"
Unverified Commit 57f3a614 authored by ZhangShilong's avatar ZhangShilong Committed by GitHub
Browse files

fix scatter in pytorch18 (#882)

* fix scatter in pytorch18

* remove blanks
parent 6fc6f75a
......@@ -18,6 +18,11 @@ class MMDistributedDataParallel(DistributedDataParallel):
- It implement two APIs ``train_step()`` and ``val_step()``.
"""
def to_kwargs(self, inputs, kwargs, device_id):
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment