"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "bda6a816beb9dbb32ddc57d1206af9537feb119a"
Unverified Commit 760426e4 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Add pin_memory_() support for TorrchBasedFeatureStore. (#6452)

parent 02443df1
...@@ -151,6 +151,10 @@ class TorchBasedFeature(Feature): ...@@ -151,6 +151,10 @@ class TorchBasedFeature(Feature):
) )
self._tensor[ids] = value self._tensor[ids] = value
def pin_memory_(self):
"""In-place operation to copy the feature to pinned memory."""
self._tensor = self._tensor.pin_memory()
class TorchBasedFeatureStore(BasicFeatureStore): class TorchBasedFeatureStore(BasicFeatureStore):
r"""A store to manage multiple pytorch based feature for access. r"""A store to manage multiple pytorch based feature for access.
...@@ -205,3 +209,8 @@ class TorchBasedFeatureStore(BasicFeatureStore): ...@@ -205,3 +209,8 @@ class TorchBasedFeatureStore(BasicFeatureStore):
else: else:
raise ValueError(f"Unknown feature format {spec.format}") raise ValueError(f"Unknown feature format {spec.format}")
super().__init__(features) super().__init__(features)
def pin_memory_(self):
"""In-place operation to copy the feature store to pinned memory."""
for feature in self._features.values():
feature.pin_memory_()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment