Unverified Commit c3baf433 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix multi-GPU edge classification crashing with pure GPU sampling (#3946)



* fix

* fix
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 2eaa58e0
...@@ -513,8 +513,9 @@ class CollateWrapper(object): ...@@ -513,8 +513,9 @@ class CollateWrapper(object):
self.device = device self.device = device
def __call__(self, items): def __call__(self, items):
if self.use_uva: if self.use_uva or (self.g.device != torch.device('cpu')):
# Only copy the indices to the given device if in UVA mode. # Only copy the indices to the given device if in UVA mode or the graph is not on
# CPU.
items = recursive_apply(items, lambda x: x.to(self.device)) items = recursive_apply(items, lambda x: x.to(self.device))
batch = self.sample_func(self.g, items) batch = self.sample_func(self.g, items)
return recursive_apply(batch, remove_parent_storage_columns, self.g) return recursive_apply(batch, remove_parent_storage_columns, self.g)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment