"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4e36bb0d23a0450079560ac12d2858e2eb3f7e24"
Unverified Commit f6db850d authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Add native `GPUCachedFeature` instead of using DGL (#6939)

parent 528b041c
...@@ -530,6 +530,10 @@ if(BUILD_GRAPHBOLT) ...@@ -530,6 +530,10 @@ if(BUILD_GRAPHBOLT)
string(REPLACE ";" "\\;" CUDA_ARCHITECTURES_ESCAPED "${CUDA_ARCHITECTURES}") string(REPLACE ";" "\\;" CUDA_ARCHITECTURES_ESCAPED "${CUDA_ARCHITECTURES}")
file(TO_NATIVE_PATH ${CMAKE_CURRENT_BINARY_DIR} BINDIR) file(TO_NATIVE_PATH ${CMAKE_CURRENT_BINARY_DIR} BINDIR)
file(TO_NATIVE_PATH ${CMAKE_COMMAND} CMAKE_CMD) file(TO_NATIVE_PATH ${CMAKE_COMMAND} CMAKE_CMD)
if(USE_CUDA)
get_target_property(GPU_CACHE_INCLUDE_DIRS gpu_cache INCLUDE_DIRECTORIES)
endif(USE_CUDA)
string(REPLACE ";" "\\;" GPU_CACHE_INCLUDE_DIRS_ESCAPED "${GPU_CACHE_INCLUDE_DIRS}")
if(MSVC) if(MSVC)
file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/graphbolt/build.bat BUILD_SCRIPT) file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/graphbolt/build.bat BUILD_SCRIPT)
add_custom_target( add_custom_target(
...@@ -540,6 +544,7 @@ if(BUILD_GRAPHBOLT) ...@@ -540,6 +544,7 @@ if(BUILD_GRAPHBOLT)
CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR} CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}
USE_CUDA=${USE_CUDA} USE_CUDA=${USE_CUDA}
BINDIR=${BINDIR} BINDIR=${BINDIR}
GPU_CACHE_INCLUDE_DIRS="${GPU_CACHE_INCLUDE_DIRS_ESCAPED}"
CFLAGS=${CMAKE_C_FLAGS} CFLAGS=${CMAKE_C_FLAGS}
CXXFLAGS=${CMAKE_CXX_FLAGS} CXXFLAGS=${CMAKE_CXX_FLAGS}
CUDAARCHS="${CUDA_ARCHITECTURES_ESCAPED}" CUDAARCHS="${CUDA_ARCHITECTURES_ESCAPED}"
...@@ -557,6 +562,7 @@ if(BUILD_GRAPHBOLT) ...@@ -557,6 +562,7 @@ if(BUILD_GRAPHBOLT)
CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR} CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}
USE_CUDA=${USE_CUDA} USE_CUDA=${USE_CUDA}
BINDIR=${CMAKE_CURRENT_BINARY_DIR} BINDIR=${CMAKE_CURRENT_BINARY_DIR}
GPU_CACHE_INCLUDE_DIRS="${GPU_CACHE_INCLUDE_DIRS_ESCAPED}"
CFLAGS=${CMAKE_C_FLAGS} CFLAGS=${CMAKE_C_FLAGS}
CXXFLAGS=${CMAKE_CXX_FLAGS} CXXFLAGS=${CMAKE_CXX_FLAGS}
CUDAARCHS="${CUDA_ARCHITECTURES_ESCAPED}" CUDAARCHS="${CUDA_ARCHITECTURES_ESCAPED}"
...@@ -565,4 +571,7 @@ if(BUILD_GRAPHBOLT) ...@@ -565,4 +571,7 @@ if(BUILD_GRAPHBOLT)
DEPENDS ${BUILD_SCRIPT} DEPENDS ${BUILD_SCRIPT}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/graphbolt) WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/graphbolt)
endif(MSVC) endif(MSVC)
if(USE_CUDA)
add_dependencies(graphbolt gpu_cache)
endif(USE_CUDA)
endif(BUILD_GRAPHBOLT) endif(BUILD_GRAPHBOLT)
...@@ -76,6 +76,11 @@ if(USE_CUDA) ...@@ -76,6 +76,11 @@ if(USE_CUDA)
"../third_party/cccl/thrust" "../third_party/cccl/thrust"
"../third_party/cccl/cub" "../third_party/cccl/cub"
"../third_party/cccl/libcudacxx/include") "../third_party/cccl/libcudacxx/include")
message(STATUS "Use HugeCTR gpu_cache for graphbolt with INCLUDE_DIRS $ENV{GPU_CACHE_INCLUDE_DIRS}.")
target_include_directories(${LIB_GRAPHBOLT_NAME} PRIVATE $ENV{GPU_CACHE_INCLUDE_DIRS})
target_link_directories(${LIB_GRAPHBOLT_NAME} PRIVATE ${GPU_CACHE_BUILD_DIR})
target_link_libraries(${LIB_GRAPHBOLT_NAME} gpu_cache)
get_property(archs TARGET ${LIB_GRAPHBOLT_NAME} PROPERTY CUDA_ARCHITECTURES) get_property(archs TARGET ${LIB_GRAPHBOLT_NAME} PROPERTY CUDA_ARCHITECTURES)
message(STATUS "CUDA_ARCHITECTURES for graphbolt: ${archs}") message(STATUS "CUDA_ARCHITECTURES for graphbolt: ${archs}")
......
...@@ -11,7 +11,7 @@ IF x%1x == xx GOTO single ...@@ -11,7 +11,7 @@ IF x%1x == xx GOTO single
FOR %%X IN (%*) DO ( FOR %%X IN (%*) DO (
DEL /S /Q * DEL /S /Q *
"%CMAKE_COMMAND%" -DCMAKE_CONFIGURATION_TYPES=Release -DPYTHON_INTERP=%%X .. -G "Visual Studio 16 2019" || EXIT /B 1 "%CMAKE_COMMAND%" -DGPU_CACHE_BUILD_DIR=%BINDIR% -DCMAKE_CONFIGURATION_TYPES=Release -DPYTHON_INTERP=%%X .. -G "Visual Studio 16 2019" || EXIT /B 1
msbuild graphbolt.sln /m /nr:false || EXIT /B 1 msbuild graphbolt.sln /m /nr:false || EXIT /B 1
COPY /Y Release\*.dll "%BINDIR%\graphbolt" || EXIT /B 1 COPY /Y Release\*.dll "%BINDIR%\graphbolt" || EXIT /B 1
) )
...@@ -21,7 +21,7 @@ GOTO end ...@@ -21,7 +21,7 @@ GOTO end
:single :single
DEL /S /Q * DEL /S /Q *
"%CMAKE_COMMAND%" -DCMAKE_CONFIGURATION_TYPES=Release .. -G "Visual Studio 16 2019" || EXIT /B 1 "%CMAKE_COMMAND%" -DGPU_CACHE_BUILD_DIR=%BINDIR% -DCMAKE_CONFIGURATION_TYPES=Release .. -G "Visual Studio 16 2019" || EXIT /B 1
msbuild graphbolt.sln /m /nr:false || EXIT /B 1 msbuild graphbolt.sln /m /nr:false || EXIT /B 1
COPY /Y Release\*.dll "%BINDIR%\graphbolt" || EXIT /B 1 COPY /Y Release\*.dll "%BINDIR%\graphbolt" || EXIT /B 1
......
...@@ -12,7 +12,7 @@ else ...@@ -12,7 +12,7 @@ else
CPSOURCE=*.so CPSOURCE=*.so
fi fi
CMAKE_FLAGS="-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DUSE_CUDA=$USE_CUDA" CMAKE_FLAGS="-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DUSE_CUDA=$USE_CUDA -DGPU_CACHE_BUILD_DIR=$BINDIR"
echo $CMAKE_FLAGS echo $CMAKE_FLAGS
if [ $# -eq 0 ]; then if [ $# -eq 0 ]; then
......
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/gpu_cache.cu
* @brief GPUCache implementation on CUDA.
*/
#include <numeric>
#include "./common.h"
#include "./gpu_cache.h"
namespace graphbolt {
namespace cuda {
GpuCache::GpuCache(const std::vector<int64_t> &shape, torch::ScalarType dtype) {
TORCH_CHECK(shape.size() >= 2, "Shape must at least have 2 dimensions.");
const auto num_items = shape[0];
const int64_t num_feats =
std::accumulate(shape.begin() + 1, shape.end(), 1ll, std::multiplies<>());
const int element_size =
torch::empty(1, torch::TensorOptions().dtype(dtype)).element_size();
num_bytes_ = num_feats * element_size;
num_float_feats_ = (num_bytes_ + sizeof(float) - 1) / sizeof(float);
cache_ = std::make_unique<gpu_cache_t>(
(num_items + bucket_size - 1) / bucket_size, num_float_feats_);
shape_ = shape;
shape_[0] = -1;
dtype_ = dtype;
device_id_ = cuda::GetCurrentStream().device_index();
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(
torch::Tensor keys) {
TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device.");
TORCH_CHECK(
keys.device().index() == device_id_,
"Keys should be on the correct CUDA device.");
TORCH_CHECK(keys.sizes().size() == 1, "Keys should be a 1D tensor.");
keys = keys.to(torch::kLong);
auto values = torch::empty(
{keys.size(0), num_float_feats_}, keys.options().dtype(torch::kFloat));
auto missing_index =
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
auto missing_keys =
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
cuda::CopyScalar<size_t> missing_len;
auto stream = cuda::GetCurrentStream();
cache_->Query(
reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),
values.data_ptr<float>(),
reinterpret_cast<uint64_t *>(missing_index.data_ptr()),
reinterpret_cast<key_t *>(missing_keys.data_ptr()), missing_len.get(),
stream);
values = values.view(torch::kByte)
.slice(1, 0, num_bytes_)
.view(dtype_)
.view(shape_);
// To safely read missing_len, we synchronize
stream.synchronize();
missing_index = missing_index.slice(0, 0, static_cast<size_t>(missing_len));
missing_keys = missing_keys.slice(0, 0, static_cast<size_t>(missing_len));
return std::make_tuple(values, missing_index, missing_keys);
}
void GpuCache::Replace(torch::Tensor keys, torch::Tensor values) {
TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device.");
TORCH_CHECK(
keys.device().index() == device_id_,
"Keys should be on the correct CUDA device.");
TORCH_CHECK(values.device().is_cuda(), "Keys should be on a CUDA device.");
TORCH_CHECK(
values.device().index() == device_id_,
"Values should be on the correct CUDA device.");
TORCH_CHECK(
keys.size(0) == values.size(0),
"The first dimensions of keys and values must match.");
TORCH_CHECK(
std::equal(shape_.begin() + 1, shape_.end(), values.sizes().begin() + 1),
"Values should have the correct dimensions.");
TORCH_CHECK(
values.scalar_type() == dtype_, "Values should have the correct dtype.");
keys = keys.to(torch::kLong);
torch::Tensor float_values;
if (num_bytes_ % sizeof(float) != 0) {
float_values = torch::empty(
{values.size(0), num_float_feats_},
values.options().dtype(torch::kFloat));
float_values.view(torch::kByte)
.slice(1, 0, num_bytes_)
.copy_(values.view(torch::kByte).view({values.size(0), -1}));
} else {
float_values = values.view(torch::kByte)
.view({values.size(0), -1})
.view(torch::kFloat)
.contiguous();
}
cache_->Replace(
reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),
float_values.data_ptr<float>(), cuda::GetCurrentStream());
}
c10::intrusive_ptr<GpuCache> GpuCache::Create(
const std::vector<int64_t> &shape, torch::ScalarType dtype) {
return c10::make_intrusive<GpuCache>(shape, dtype);
}
} // namespace cuda
} // namespace graphbolt
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/gpu_cache.h
* @brief Header file of HugeCTR gpu_cache wrapper.
*/
#ifndef GRAPHBOLT_GPU_CACHE_H_
#define GRAPHBOLT_GPU_CACHE_H_
#include <torch/custom_class.h>
#include <torch/torch.h>
#include <limits>
#include <nv_gpu_cache.hpp>
namespace graphbolt {
namespace cuda {
class GpuCache : public torch::CustomClassHolder {
using key_t = long long;
constexpr static int set_associativity = 2;
constexpr static int WARP_SIZE = 32;
constexpr static int bucket_size = WARP_SIZE * set_associativity;
using gpu_cache_t = ::gpu_cache::gpu_cache<
key_t, uint64_t, std::numeric_limits<key_t>::max(), set_associativity,
WARP_SIZE>;
public:
/**
* @brief Constructor for the GpuCache struct.
*
* @param shape The shape of the GPU cache.
* @param dtype The datatype of items to be stored.
*/
GpuCache(const std::vector<int64_t>& shape, torch::ScalarType dtype);
GpuCache() = default;
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> Query(
torch::Tensor keys);
void Replace(torch::Tensor keys, torch::Tensor values);
static c10::intrusive_ptr<GpuCache> Create(
const std::vector<int64_t>& shape, torch::ScalarType dtype);
private:
std::vector<int64_t> shape_;
torch::ScalarType dtype_;
std::unique_ptr<gpu_cache_t> cache_;
int64_t num_bytes_;
int64_t num_float_feats_;
torch::DeviceIndex device_id_;
};
// The cu file in HugeCTR gpu cache uses unsigned int and long long.
// Changing to int64_t results in a mismatch of template arguments.
static_assert(
sizeof(long long) == sizeof(int64_t),
"long long and int64_t needs to have the same size."); // NOLINT
} // namespace cuda
} // namespace graphbolt
#endif // GRAPHBOLT_GPU_CACHE_H_
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
#include "./index_select.h" #include "./index_select.h"
#include "./random.h" #include "./random.h"
#ifdef GRAPHBOLT_USE_CUDA
#include "./cuda/gpu_cache.h"
#endif
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
...@@ -70,6 +74,12 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -70,6 +74,12 @@ TORCH_LIBRARY(graphbolt, m) {
g->SetState(state); g->SetState(state);
return g; return g;
}); });
#ifdef GRAPHBOLT_USE_CUDA
m.class_<cuda::GpuCache>("GpuCache")
.def("query", &cuda::GpuCache::Query)
.def("replace", &cuda::GpuCache::Replace);
m.def("gpu_cache", &cuda::GpuCache::Create);
#endif
m.def("fused_csc_sampling_graph", &FusedCSCSamplingGraph::Create); m.def("fused_csc_sampling_graph", &FusedCSCSamplingGraph::Create);
m.def( m.def(
"load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory); "load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory);
......
"""Implementation of GraphBolt.""" """Implementation of GraphBolt."""
from .basic_feature_store import * from .basic_feature_store import *
from .fused_csc_sampling_graph import * from .fused_csc_sampling_graph import *
from .gpu_cache import *
from .gpu_cached_feature import * from .gpu_cached_feature import *
from .in_subgraph_sampler import * from .in_subgraph_sampler import *
from .legacy_dataset import * from .legacy_dataset import *
......
"""HugeCTR gpu_cache wrapper for graphbolt."""
import torch
class GPUCache(object):
"""High-level wrapper for GPU embedding cache"""
def __init__(self, cache_shape, dtype):
major, _ = torch.cuda.get_device_capability()
assert (
major >= 7
), "GPUCache is supported only on CUDA compute capability >= 70 (Volta)."
self._cache = torch.ops.graphbolt.gpu_cache(cache_shape, dtype)
self.total_miss = 0
self.total_queries = 0
def query(self, keys):
"""Queries the GPU cache.
Parameters
----------
keys : Tensor
The keys to query the GPU cache with.
Returns
-------
tuple(Tensor, Tensor, Tensor)
A tuple containing (values, missing_indices, missing_keys) where
values[missing_indices] corresponds to cache misses that should be
filled by quering another source with missing_keys.
"""
self.total_queries += keys.shape[0]
values, missing_index, missing_keys = self._cache.query(keys)
self.total_miss += missing_keys.shape[0]
return values, missing_index, missing_keys
def replace(self, keys, values):
"""Inserts key-value pairs into the GPU cache using the Least-Recently
Used (LRU) algorithm to remove old key-value pairs if it is full.
Parameters
----------
keys: Tensor
The keys to insert to the GPU cache.
values: Tensor
The values to insert to the GPU cache.
"""
self._cache.replace(keys, values)
@property
def miss_rate(self):
"""Returns the cache miss rate since creation."""
return self.total_miss / self.total_queries
"""GPU cached feature for GraphBolt.""" """GPU cached feature for GraphBolt."""
import torch import torch
from dgl.cuda import GPUCache
from ..feature_store import Feature from ..feature_store import Feature
from .gpu_cache import GPUCache
__all__ = ["GPUCachedFeature"] __all__ = ["GPUCachedFeature"]
...@@ -52,10 +52,7 @@ class GPUCachedFeature(Feature): ...@@ -52,10 +52,7 @@ class GPUCachedFeature(Feature):
self.cache_size = cache_size self.cache_size = cache_size
# Fetching the feature dimension from the underlying feature. # Fetching the feature dimension from the underlying feature.
feat0 = fallback_feature.read(torch.tensor([0])) feat0 = fallback_feature.read(torch.tensor([0]))
self.item_shape = (-1,) + feat0.shape[1:] self._feature = GPUCache((cache_size,) + feat0.shape[1:], feat0.dtype)
feat0 = torch.reshape(feat0, (1, -1))
self.flat_shape = (-1, feat0.shape[1])
self._feature = GPUCache(cache_size, feat0.shape[1])
def read(self, ids: torch.Tensor = None): def read(self, ids: torch.Tensor = None):
"""Read the feature by index. """Read the feature by index.
...@@ -75,15 +72,12 @@ class GPUCachedFeature(Feature): ...@@ -75,15 +72,12 @@ class GPUCachedFeature(Feature):
The read feature. The read feature.
""" """
if ids is None: if ids is None:
return self._fallback_feature.read().to("cuda") return self._fallback_feature.read()
keys = ids.to("cuda") values, missing_index, missing_keys = self._feature.query(ids)
values, missing_index, missing_keys = self._feature.query(keys)
missing_values = self._fallback_feature.read(missing_keys).to("cuda") missing_values = self._fallback_feature.read(missing_keys).to("cuda")
missing_values = missing_values.reshape(self.flat_shape)
values = values.to(missing_values.dtype)
values[missing_index] = missing_values values[missing_index] = missing_values
self._feature.replace(missing_keys, missing_values) self._feature.replace(missing_keys, missing_values)
return torch.reshape(values, self.item_shape) return values
def size(self): def size(self):
"""Get the size of the feature. """Get the size of the feature.
...@@ -114,10 +108,8 @@ class GPUCachedFeature(Feature): ...@@ -114,10 +108,8 @@ class GPUCachedFeature(Feature):
size = min(self.cache_size, value.shape[0]) size = min(self.cache_size, value.shape[0])
self._feature.replace( self._feature.replace(
torch.arange(0, size, device="cuda"), torch.arange(0, size, device="cuda"),
value[:size].to("cuda").reshape(self.flat_shape), value[:size].to("cuda"),
) )
else: else:
self._fallback_feature.update(value, ids) self._fallback_feature.update(value, ids)
self._feature.replace( self._feature.replace(ids, value)
ids.to("cuda"), value.to("cuda").reshape(self.flat_shape)
)
...@@ -2,34 +2,53 @@ import unittest ...@@ -2,34 +2,53 @@ import unittest
import backend as F import backend as F
import pytest
import torch import torch
from dgl import graphbolt as gb from dgl import graphbolt as gb
@unittest.skipIf( @unittest.skipIf(
F._default_context_str != "gpu", F._default_context_str != "gpu"
reason="GPUCachedFeature requires a GPU.", or torch.cuda.get_device_capability()[0] < 7,
reason="GPUCachedFeature requires a Volta or later generation NVIDIA GPU.",
) )
def test_gpu_cached_feature(): @pytest.mark.parametrize(
a = torch.tensor([[1, 2, 3], [4, 5, 6]]).to("cuda").float() "dtype",
b = torch.tensor([[[1, 2], [3, 4]], [[4, 5], [6, 7]]]).to("cuda").float() [
torch.bool,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
],
)
def test_gpu_cached_feature(dtype):
a = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dtype, pin_memory=True)
b = torch.tensor(
[[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype, pin_memory=True
)
feat_store_a = gb.GPUCachedFeature(gb.TorchBasedFeature(a), 2) feat_store_a = gb.GPUCachedFeature(gb.TorchBasedFeature(a), 2)
feat_store_b = gb.GPUCachedFeature(gb.TorchBasedFeature(b), 1) feat_store_b = gb.GPUCachedFeature(gb.TorchBasedFeature(b), 1)
# Test read the entire feature. # Test read the entire feature.
assert torch.equal(feat_store_a.read(), a) assert torch.equal(feat_store_a.read(), a.to("cuda"))
assert torch.equal(feat_store_b.read(), b) assert torch.equal(feat_store_b.read(), b.to("cuda"))
# Test read with ids. # Test read with ids.
assert torch.equal( assert torch.equal(
feat_store_a.read(torch.tensor([0]).to("cuda")), feat_store_a.read(torch.tensor([0]).to("cuda")),
torch.tensor([[1.0, 2.0, 3.0]]).to("cuda"), torch.tensor([[1, 2, 3]], dtype=dtype).to("cuda"),
) )
assert torch.equal( assert torch.equal(
feat_store_b.read(torch.tensor([1, 1]).to("cuda")), feat_store_b.read(torch.tensor([1, 1]).to("cuda")),
torch.tensor([[[4.0, 5.0], [6.0, 7.0]], [[4.0, 5.0], [6.0, 7.0]]]).to( torch.tensor([[[4, 5], [6, 7]], [[4, 5], [6, 7]]], dtype=dtype).to(
"cuda" "cuda"
), ),
) )
...@@ -40,18 +59,19 @@ def test_gpu_cached_feature(): ...@@ -40,18 +59,19 @@ def test_gpu_cached_feature():
# Test update the entire feature. # Test update the entire feature.
feat_store_a.update( feat_store_a.update(
torch.tensor([[0.0, 1.0, 2.0], [3.0, 5.0, 2.0]]).to("cuda") torch.tensor([[0, 1, 2], [3, 5, 2]], dtype=dtype).to("cuda")
) )
assert torch.equal( assert torch.equal(
feat_store_a.read(), feat_store_a.read(),
torch.tensor([[0.0, 1.0, 2.0], [3.0, 5.0, 2.0]]).to("cuda"), torch.tensor([[0, 1, 2], [3, 5, 2]], dtype=dtype).to("cuda"),
) )
# Test update with ids. # Test update with ids.
feat_store_a.update( feat_store_a.update(
torch.tensor([[2.0, 0.0, 1.0]]).to("cuda"), torch.tensor([0]).to("cuda") torch.tensor([[2, 0, 1]], dtype=dtype).to("cuda"),
torch.tensor([0]).to("cuda"),
) )
assert torch.equal( assert torch.equal(
feat_store_a.read(), feat_store_a.read(),
torch.tensor([[2.0, 0.0, 1.0], [3.0, 5.0, 2.0]]).to("cuda"), torch.tensor([[2, 0, 1], [3, 5, 2]], dtype=dtype).to("cuda"),
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment