Unverified Commit 324fd976 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] `Minibatch.to()` patch. (#7330)

parent 7de2e51b
...@@ -129,6 +129,21 @@ def copy_or_convert_data( ...@@ -129,6 +129,21 @@ def copy_or_convert_data(
save_data(data, output_path, output_format) save_data(data, output_path, output_format)
def get_nonproperty_attributes(_obj) -> list:
"""Get attributes of the class except for the properties."""
attributes = [
attribute
for attribute in dir(_obj)
if not attribute.startswith("__")
and (
not hasattr(type(_obj), attribute)
or not isinstance(getattr(type(_obj), attribute), property)
)
and not callable(getattr(_obj, attribute))
]
return attributes
def get_attributes(_obj) -> list: def get_attributes(_obj) -> list:
"""Get attributes of the class.""" """Get attributes of the class."""
attributes = [ attributes = [
......
...@@ -9,7 +9,7 @@ import dgl ...@@ -9,7 +9,7 @@ import dgl
from dgl.utils import recursive_apply from dgl.utils import recursive_apply
from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr
from .internal import get_attributes from .internal import get_attributes, get_nonproperty_attributes
from .sampled_subgraph import SampledSubgraph from .sampled_subgraph import SampledSubgraph
__all__ = ["MiniBatch"] __all__ = ["MiniBatch"]
...@@ -556,23 +556,14 @@ class MiniBatch: ...@@ -556,23 +556,14 @@ class MiniBatch:
def to(self, device: torch.device): # pylint: disable=invalid-name def to(self, device: torch.device): # pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection.""" """Copy `MiniBatch` to the specified device using reflection."""
def _to(x, device): def _to(x):
return x.to(device) if hasattr(x, "to") else x return x.to(device) if hasattr(x, "to") else x
def apply_to(x, device): transfer_attrs = get_nonproperty_attributes(self)
return recursive_apply(x, lambda x: _to(x, device))
transfer_attrs = get_attributes(self)
for attr in transfer_attrs: for attr in transfer_attrs:
# Only copy member variables. # Only copy member variables.
try: setattr(self, attr, recursive_apply(getattr(self, attr), _to))
# For read-only attributes such as blocks , setattr will throw
# an AttributeError. We catch these exceptions and skip those
# attributes.
setattr(self, attr, apply_to(getattr(self, attr), device))
except AttributeError:
continue
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment