Commit 9ae9e740 authored by Ziyue Jiang's avatar Ziyue Jiang
Browse files

fix diff device in some partition

parent 3a15b204
...@@ -789,6 +789,8 @@ class WorkerBase(ABC): ...@@ -789,6 +789,8 @@ class WorkerBase(ABC):
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(), args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device,
process_types=torch.device) # change devices from last stage to current device
args, kwargs = data_process_func(args_kwargs) args, kwargs = data_process_func(args_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment