Unverified Commit 90c3c761 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add `to` to CSCFormatBase. (#6746)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 109aed56
...@@ -152,6 +152,22 @@ class CSCFormatBase: ...@@ -152,6 +152,22 @@ class CSCFormatBase:
def __repr__(self) -> str: def __repr__(self) -> str:
return _csc_format_base_str(self) return _csc_format_base_str(self)
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `CSCFormatBase` to the specified device using reflection."""
for attr in dir(self):
# Only copy member variables.
if not callable(getattr(self, attr)) and not attr.startswith("__"):
setattr(
self,
attr,
recursive_apply(
getattr(self, attr), lambda x: apply_to(x, device)
),
)
return self
def _csc_format_base_str(csc_format_base: CSCFormatBase) -> str: def _csc_format_base_str(csc_format_base: CSCFormatBase) -> str:
final_str = "CSCFormatBase(" final_str = "CSCFormatBase("
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment