Unverified Commit 968c52dd authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Graphbolt] Gpu cache feature store (#6005)

parent cbfc8085
...@@ -6,4 +6,5 @@ from .ondisk_dataset import * ...@@ -6,4 +6,5 @@ from .ondisk_dataset import *
from .ondisk_metadata import * from .ondisk_metadata import *
from .sampled_subgraph_impl import * from .sampled_subgraph_impl import *
from .torch_based_feature_store import * from .torch_based_feature_store import *
from .gpu_cached_feature import *
from .uniform_negative_sampler import * from .uniform_negative_sampler import *
"""GPU cached feature for GraphBolt."""
import torch
from dgl.cuda import GPUCache
from ..feature_store import Feature
__all__ = ["GPUCachedFeature"]
class GPUCachedFeature(Feature):
r"""GPU cached feature wrapping a fallback feature."""
def __init__(self, fallback_feature: Feature, cache_size: int):
"""Initialize GPU cached feature with a given fallback.
Places the GPU cache to torch.cuda.current_device().
Parameters
----------
fallback_feature : Feature
The fallback feature.
cache_size : int
The capacity of the GPU cache, the number of features to store.
Examples
--------
>>> import torch
>>> torch_feat = torch.arange(0, 8)
>>> cache_size = 5
>>> fallback_feature = TorchBasedFeature(torch_feat)
>>> feature = GPUCachedFeature(fallback_feature, cache_size)
>>> feature.read()
tensor([0, 1, 2, 3, 4, 5, 6, 7])
>>> feature.read(torch.tensor([0, 1, 2]))
tensor([0, 1, 2])
>>> feature.update(torch.ones(3, dtype=torch.long),
... torch.tensor([0, 1, 2]))
>>> feature.read(torch.tensor([0, 1, 2, 3]))
tensor([1, 1, 1, 3])
"""
super(GPUCachedFeature, self).__init__()
assert isinstance(fallback_feature, Feature), (
f"The fallback_feature must be an instance of Feature, but got "
f"{type(fallback_feature)}."
)
self._fallback_feature = fallback_feature
self.cache_size = cache_size
# Fetching the feature dimension from the underlying feature.
feat0 = fallback_feature.read(torch.tensor([0]))
self.item_shape = (-1,) + feat0.shape[1:]
feat0 = torch.reshape(feat0, (1, -1))
self.flat_shape = (-1, feat0.shape[1])
self._feature = GPUCache(cache_size, feat0.shape[1])
def read(self, ids: torch.Tensor = None):
"""Read the feature by index.
The returned tensor is always in GPU memory, no matter whether the
fallback feature is in memory or on disk.
Parameters
----------
ids : torch.Tensor, optional
The index of the feature. If specified, only the specified indices
of the feature are read. If None, the entire feature is returned.
Returns
-------
torch.Tensor
The read feature.
"""
if ids is None:
return self._fallback_feature.read().to("cuda")
keys = ids.to("cuda")
values, missing_index, missing_keys = self._feature.query(keys)
missing_values = self._fallback_feature.read(missing_keys).to("cuda")
missing_values = missing_values.reshape(self.flat_shape)
values = values.to(missing_values.dtype)
values[missing_index] = missing_values
self._feature.replace(missing_keys, missing_values)
return torch.reshape(values, self.item_shape)
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
"""Update the feature.
Parameters
----------
value : torch.Tensor
The updated value of the feature.
ids : torch.Tensor, optional
The indices of the feature to update. If specified, only the
specified indices of the feature will be updated. For the feature,
the `ids[i]` row is updated to `value[i]`. So the indices and value
must have the same length. If None, the entire feature will be
updated.
"""
if ids is None:
self._fallback_feature.update(value)
size = min(self.cache_size, value.shape[0])
self._feature.replace(
torch.arange(0, size, device="cuda"),
value[:size].to("cuda").reshape(self.flat_shape),
)
else:
assert ids.shape[0] == value.shape[0], (
f"ids and value must have the same length, "
f"but got {ids.shape[0]} and {value.shape[0]}."
)
self._fallback_feature.update(value, ids)
self._feature.replace(
ids.to("cuda"), value.to("cuda").reshape(self.flat_shape)
)
import unittest
import backend as F
import torch
from dgl import graphbolt as gb
@unittest.skipIf(
F._default_context_str != "gpu",
reason="GPUCachedFeature requires a GPU.",
)
def test_gpu_cached_feature():
a = torch.tensor([1, 2, 3]).to("cuda").float()
b = torch.tensor([[1, 2, 3], [4, 5, 6]]).to("cuda").float()
feat_store_a = gb.GPUCachedFeature(gb.TorchBasedFeature(a), 2)
feat_store_b = gb.GPUCachedFeature(gb.TorchBasedFeature(b), 1)
# Test read the entire feature.
assert torch.equal(feat_store_a.read(), a)
assert torch.equal(feat_store_b.read(), b)
# Test read with ids.
assert torch.equal(
feat_store_a.read(torch.tensor([0, 2]).to("cuda")),
torch.tensor([1.0, 3.0]).to("cuda"),
)
assert torch.equal(
feat_store_a.read(torch.tensor([1, 1]).to("cuda")),
torch.tensor([2.0, 2.0]).to("cuda"),
)
assert torch.equal(
feat_store_b.read(torch.tensor([1]).to("cuda")),
torch.tensor([[4.0, 5.0, 6.0]]).to("cuda"),
)
# Test update the entire feature.
feat_store_a.update(torch.tensor([0.0, 1.0, 2.0]).to("cuda"))
assert torch.equal(
feat_store_a.read(), torch.tensor([0.0, 1.0, 2.0]).to("cuda")
)
# Test update with ids.
feat_store_a.update(
torch.tensor([2.0, 0.0]).to("cuda"), torch.tensor([0, 2]).to("cuda")
)
assert torch.equal(
feat_store_a.read(), torch.tensor([2.0, 1.0, 0.0]).to("cuda")
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment