"graphbolt/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "9632ab1d7c27a3aa63f4c2470ecb99ad85edc70a"
Unverified Commit 760426e4 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Add pin_memory_() support for TorrchBasedFeatureStore. (#6452)

parent 02443df1
......@@ -151,6 +151,10 @@ class TorchBasedFeature(Feature):
)
self._tensor[ids] = value
def pin_memory_(self):
"""In-place operation to copy the feature to pinned memory."""
self._tensor = self._tensor.pin_memory()
class TorchBasedFeatureStore(BasicFeatureStore):
r"""A store to manage multiple pytorch based feature for access.
......@@ -205,3 +209,8 @@ class TorchBasedFeatureStore(BasicFeatureStore):
else:
raise ValueError(f"Unknown feature format {spec.format}")
super().__init__(features)
def pin_memory_(self):
"""In-place operation to copy the feature store to pinned memory."""
for feature in self._features.values():
feature.pin_memory_()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment