Unverified Commit f051a5cc authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Add `is_pinned` method to `TorchBasedFeatureStore`. (#7041)

parent e7c3454f
......@@ -198,6 +198,10 @@ class TorchBasedFeature(Feature):
self._is_inplace_pinned.add(x)
def is_pinned(self):
"""Returns True if the stored feature is pinned."""
return self._tensor.is_pinned()
def to(self, device): # pylint: disable=invalid-name
"""Copy `TorchBasedFeature` to the specified device."""
# copy.copy is a shallow copy so it does not copy tensor memory.
......@@ -293,6 +297,10 @@ class TorchBasedFeatureStore(BasicFeatureStore):
for feature in self._features.values():
feature.pin_memory_()
def is_pinned(self):
"""Returns True if all the stored features are pinned."""
return all(feature.is_pinned() for feature in self._features.values())
def to(self, device): # pylint: disable=invalid-name
"""Copy `TorchBasedFeatureStore` to the specified device."""
# copy.copy is a shallow copy so it does not copy tensor memory.
......
......@@ -136,11 +136,6 @@ def test_torch_based_feature(in_memory):
feature_a = feature_b = None
def is_feature_store_pinned(store):
for feature in store._features.values():
assert feature._tensor.is_pinned()
def is_feature_store_on_cuda(store):
for feature in store._features.values():
assert feature._tensor.is_cuda
......@@ -181,7 +176,7 @@ def test_feature_store_to_device(device):
feature_store = gb.TorchBasedFeatureStore(feature_data)
feature_store2 = feature_store.to(device)
if device == "pinned":
is_feature_store_pinned(feature_store2)
assert feature_store2.is_pinned()
elif device == "cuda":
is_feature_store_on_cuda(feature_store2)
......@@ -228,6 +223,8 @@ def test_torch_based_pinned_feature(dtype, idtype, shape, in_place):
else:
feature = feature.to("pinned")
assert feature.is_pinned()
# Test read entire pinned feature, the result should be on cuda.
assert torch.equal(feature.read(), test_tensor_cuda)
assert feature.read().is_cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment