Unverified Commit f051a5cc authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Add `is_pinned` method to `TorchBasedFeatureStore`. (#7041)

parent e7c3454f
...@@ -198,6 +198,10 @@ class TorchBasedFeature(Feature): ...@@ -198,6 +198,10 @@ class TorchBasedFeature(Feature):
self._is_inplace_pinned.add(x) self._is_inplace_pinned.add(x)
def is_pinned(self):
"""Returns True if the stored feature is pinned."""
return self._tensor.is_pinned()
def to(self, device): # pylint: disable=invalid-name def to(self, device): # pylint: disable=invalid-name
"""Copy `TorchBasedFeature` to the specified device.""" """Copy `TorchBasedFeature` to the specified device."""
# copy.copy is a shallow copy so it does not copy tensor memory. # copy.copy is a shallow copy so it does not copy tensor memory.
...@@ -293,6 +297,10 @@ class TorchBasedFeatureStore(BasicFeatureStore): ...@@ -293,6 +297,10 @@ class TorchBasedFeatureStore(BasicFeatureStore):
for feature in self._features.values(): for feature in self._features.values():
feature.pin_memory_() feature.pin_memory_()
def is_pinned(self):
"""Returns True if all the stored features are pinned."""
return all(feature.is_pinned() for feature in self._features.values())
def to(self, device): # pylint: disable=invalid-name def to(self, device): # pylint: disable=invalid-name
"""Copy `TorchBasedFeatureStore` to the specified device.""" """Copy `TorchBasedFeatureStore` to the specified device."""
# copy.copy is a shallow copy so it does not copy tensor memory. # copy.copy is a shallow copy so it does not copy tensor memory.
......
...@@ -136,11 +136,6 @@ def test_torch_based_feature(in_memory): ...@@ -136,11 +136,6 @@ def test_torch_based_feature(in_memory):
feature_a = feature_b = None feature_a = feature_b = None
def is_feature_store_pinned(store):
for feature in store._features.values():
assert feature._tensor.is_pinned()
def is_feature_store_on_cuda(store): def is_feature_store_on_cuda(store):
for feature in store._features.values(): for feature in store._features.values():
assert feature._tensor.is_cuda assert feature._tensor.is_cuda
...@@ -181,7 +176,7 @@ def test_feature_store_to_device(device): ...@@ -181,7 +176,7 @@ def test_feature_store_to_device(device):
feature_store = gb.TorchBasedFeatureStore(feature_data) feature_store = gb.TorchBasedFeatureStore(feature_data)
feature_store2 = feature_store.to(device) feature_store2 = feature_store.to(device)
if device == "pinned": if device == "pinned":
is_feature_store_pinned(feature_store2) assert feature_store2.is_pinned()
elif device == "cuda": elif device == "cuda":
is_feature_store_on_cuda(feature_store2) is_feature_store_on_cuda(feature_store2)
...@@ -228,6 +223,8 @@ def test_torch_based_pinned_feature(dtype, idtype, shape, in_place): ...@@ -228,6 +223,8 @@ def test_torch_based_pinned_feature(dtype, idtype, shape, in_place):
else: else:
feature = feature.to("pinned") feature = feature.to("pinned")
assert feature.is_pinned()
# Test read entire pinned feature, the result should be on cuda. # Test read entire pinned feature, the result should be on cuda.
assert torch.equal(feature.read(), test_tensor_cuda) assert torch.equal(feature.read(), test_tensor_cuda)
assert feature.read().is_cuda assert feature.read().is_cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment