Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f6db850d
Unverified
Commit
f6db850d
authored
Jan 18, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Jan 18, 2024
Browse files
[GraphBolt][CUDA] Add native `GPUCachedFeature` instead of using DGL (#6939)
parent
528b041c
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
296 additions
and
32 deletions
+296
-32
CMakeLists.txt
CMakeLists.txt
+9
-0
graphbolt/CMakeLists.txt
graphbolt/CMakeLists.txt
+5
-0
graphbolt/build.bat
graphbolt/build.bat
+2
-2
graphbolt/build.sh
graphbolt/build.sh
+1
-1
graphbolt/src/cuda/gpu_cache.cu
graphbolt/src/cuda/gpu_cache.cu
+108
-0
graphbolt/src/cuda/gpu_cache.h
graphbolt/src/cuda/gpu_cache.h
+66
-0
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+10
-0
python/dgl/graphbolt/impl/__init__.py
python/dgl/graphbolt/impl/__init__.py
+1
-0
python/dgl/graphbolt/impl/gpu_cache.py
python/dgl/graphbolt/impl/gpu_cache.py
+53
-0
python/dgl/graphbolt/impl/gpu_cached_feature.py
python/dgl/graphbolt/impl/gpu_cached_feature.py
+8
-16
tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py
.../python/pytorch/graphbolt/impl/test_gpu_cached_feature.py
+33
-13
No files found.
CMakeLists.txt
View file @
f6db850d
...
@@ -530,6 +530,10 @@ if(BUILD_GRAPHBOLT)
...
@@ -530,6 +530,10 @@ if(BUILD_GRAPHBOLT)
string
(
REPLACE
";"
"
\\
;"
CUDA_ARCHITECTURES_ESCAPED
"
${
CUDA_ARCHITECTURES
}
"
)
string
(
REPLACE
";"
"
\\
;"
CUDA_ARCHITECTURES_ESCAPED
"
${
CUDA_ARCHITECTURES
}
"
)
file
(
TO_NATIVE_PATH
${
CMAKE_CURRENT_BINARY_DIR
}
BINDIR
)
file
(
TO_NATIVE_PATH
${
CMAKE_CURRENT_BINARY_DIR
}
BINDIR
)
file
(
TO_NATIVE_PATH
${
CMAKE_COMMAND
}
CMAKE_CMD
)
file
(
TO_NATIVE_PATH
${
CMAKE_COMMAND
}
CMAKE_CMD
)
if
(
USE_CUDA
)
get_target_property
(
GPU_CACHE_INCLUDE_DIRS gpu_cache INCLUDE_DIRECTORIES
)
endif
(
USE_CUDA
)
string
(
REPLACE
";"
"
\\
;"
GPU_CACHE_INCLUDE_DIRS_ESCAPED
"
${
GPU_CACHE_INCLUDE_DIRS
}
"
)
if
(
MSVC
)
if
(
MSVC
)
file
(
TO_NATIVE_PATH
${
CMAKE_CURRENT_SOURCE_DIR
}
/graphbolt/build.bat BUILD_SCRIPT
)
file
(
TO_NATIVE_PATH
${
CMAKE_CURRENT_SOURCE_DIR
}
/graphbolt/build.bat BUILD_SCRIPT
)
add_custom_target
(
add_custom_target
(
...
@@ -540,6 +544,7 @@ if(BUILD_GRAPHBOLT)
...
@@ -540,6 +544,7 @@ if(BUILD_GRAPHBOLT)
CUDA_TOOLKIT_ROOT_DIR=
${
CUDA_TOOLKIT_ROOT_DIR
}
CUDA_TOOLKIT_ROOT_DIR=
${
CUDA_TOOLKIT_ROOT_DIR
}
USE_CUDA=
${
USE_CUDA
}
USE_CUDA=
${
USE_CUDA
}
BINDIR=
${
BINDIR
}
BINDIR=
${
BINDIR
}
GPU_CACHE_INCLUDE_DIRS=
"
${
GPU_CACHE_INCLUDE_DIRS_ESCAPED
}
"
CFLAGS=
${
CMAKE_C_FLAGS
}
CFLAGS=
${
CMAKE_C_FLAGS
}
CXXFLAGS=
${
CMAKE_CXX_FLAGS
}
CXXFLAGS=
${
CMAKE_CXX_FLAGS
}
CUDAARCHS=
"
${
CUDA_ARCHITECTURES_ESCAPED
}
"
CUDAARCHS=
"
${
CUDA_ARCHITECTURES_ESCAPED
}
"
...
@@ -557,6 +562,7 @@ if(BUILD_GRAPHBOLT)
...
@@ -557,6 +562,7 @@ if(BUILD_GRAPHBOLT)
CUDA_TOOLKIT_ROOT_DIR=
${
CUDA_TOOLKIT_ROOT_DIR
}
CUDA_TOOLKIT_ROOT_DIR=
${
CUDA_TOOLKIT_ROOT_DIR
}
USE_CUDA=
${
USE_CUDA
}
USE_CUDA=
${
USE_CUDA
}
BINDIR=
${
CMAKE_CURRENT_BINARY_DIR
}
BINDIR=
${
CMAKE_CURRENT_BINARY_DIR
}
GPU_CACHE_INCLUDE_DIRS=
"
${
GPU_CACHE_INCLUDE_DIRS_ESCAPED
}
"
CFLAGS=
${
CMAKE_C_FLAGS
}
CFLAGS=
${
CMAKE_C_FLAGS
}
CXXFLAGS=
${
CMAKE_CXX_FLAGS
}
CXXFLAGS=
${
CMAKE_CXX_FLAGS
}
CUDAARCHS=
"
${
CUDA_ARCHITECTURES_ESCAPED
}
"
CUDAARCHS=
"
${
CUDA_ARCHITECTURES_ESCAPED
}
"
...
@@ -565,4 +571,7 @@ if(BUILD_GRAPHBOLT)
...
@@ -565,4 +571,7 @@ if(BUILD_GRAPHBOLT)
DEPENDS
${
BUILD_SCRIPT
}
DEPENDS
${
BUILD_SCRIPT
}
WORKING_DIRECTORY
${
CMAKE_SOURCE_DIR
}
/graphbolt
)
WORKING_DIRECTORY
${
CMAKE_SOURCE_DIR
}
/graphbolt
)
endif
(
MSVC
)
endif
(
MSVC
)
if
(
USE_CUDA
)
add_dependencies
(
graphbolt gpu_cache
)
endif
(
USE_CUDA
)
endif
(
BUILD_GRAPHBOLT
)
endif
(
BUILD_GRAPHBOLT
)
graphbolt/CMakeLists.txt
View file @
f6db850d
...
@@ -77,6 +77,11 @@ if(USE_CUDA)
...
@@ -77,6 +77,11 @@ if(USE_CUDA)
"../third_party/cccl/cub"
"../third_party/cccl/cub"
"../third_party/cccl/libcudacxx/include"
)
"../third_party/cccl/libcudacxx/include"
)
message
(
STATUS
"Use HugeCTR gpu_cache for graphbolt with INCLUDE_DIRS $ENV{GPU_CACHE_INCLUDE_DIRS}."
)
target_include_directories
(
${
LIB_GRAPHBOLT_NAME
}
PRIVATE $ENV{GPU_CACHE_INCLUDE_DIRS}
)
target_link_directories
(
${
LIB_GRAPHBOLT_NAME
}
PRIVATE
${
GPU_CACHE_BUILD_DIR
}
)
target_link_libraries
(
${
LIB_GRAPHBOLT_NAME
}
gpu_cache
)
get_property
(
archs TARGET
${
LIB_GRAPHBOLT_NAME
}
PROPERTY CUDA_ARCHITECTURES
)
get_property
(
archs TARGET
${
LIB_GRAPHBOLT_NAME
}
PROPERTY CUDA_ARCHITECTURES
)
message
(
STATUS
"CUDA_ARCHITECTURES for graphbolt:
${
archs
}
"
)
message
(
STATUS
"CUDA_ARCHITECTURES for graphbolt:
${
archs
}
"
)
endif
()
endif
()
...
...
graphbolt/build.bat
View file @
f6db850d
...
@@ -11,7 +11,7 @@ IF x%1x == xx GOTO single
...
@@ -11,7 +11,7 @@ IF x%1x == xx GOTO single
FOR
%%X
IN
(
%
*)
DO
(
FOR
%%X
IN
(
%
*)
DO
(
DEL
/S /Q
*
DEL
/S /Q
*
"
%CMAKE_COMMAND%
"
-DCMAKE
_CONFIGURATION_TYPES
=
Release
-DPYTHON
_INTERP
=
%%X
..
-G
"Visual Studio 16 2019"
||
EXIT
/B
1
"
%CMAKE_COMMAND%
"
-DGPU
_CACHE_BUILD_DIR
=
%BINDIR%
-DCMAKE
_CONFIGURATION_TYPES
=
Release
-DPYTHON
_INTERP
=
%%X
..
-G
"Visual Studio 16 2019"
||
EXIT
/B
1
msbuild
graphbolt
.sln
/m /nr
:false
||
EXIT
/B
1
msbuild
graphbolt
.sln
/m /nr
:false
||
EXIT
/B
1
COPY
/Y
Release
\
*
.dll
"
%BINDIR%
\graphbolt"
||
EXIT
/B
1
COPY
/Y
Release
\
*
.dll
"
%BINDIR%
\graphbolt"
||
EXIT
/B
1
)
)
...
@@ -21,7 +21,7 @@ GOTO end
...
@@ -21,7 +21,7 @@ GOTO end
:single
:single
DEL
/S /Q
*
DEL
/S /Q
*
"
%CMAKE_COMMAND%
"
-DCMAKE
_CONFIGURATION_TYPES
=
Release
..
-G
"Visual Studio 16 2019"
||
EXIT
/B
1
"
%CMAKE_COMMAND%
"
-DGPU
_CACHE_BUILD_DIR
=
%BINDIR%
-DCMAKE
_CONFIGURATION_TYPES
=
Release
..
-G
"Visual Studio 16 2019"
||
EXIT
/B
1
msbuild
graphbolt
.sln
/m /nr
:false
||
EXIT
/B
1
msbuild
graphbolt
.sln
/m /nr
:false
||
EXIT
/B
1
COPY
/Y
Release
\
*
.dll
"
%BINDIR%
\graphbolt"
||
EXIT
/B
1
COPY
/Y
Release
\
*
.dll
"
%BINDIR%
\graphbolt"
||
EXIT
/B
1
...
...
graphbolt/build.sh
View file @
f6db850d
...
@@ -12,7 +12,7 @@ else
...
@@ -12,7 +12,7 @@ else
CPSOURCE
=
*
.so
CPSOURCE
=
*
.so
fi
fi
CMAKE_FLAGS
=
"-DCUDA_TOOLKIT_ROOT_DIR=
$CUDA_TOOLKIT_ROOT_DIR
-DUSE_CUDA=
$USE_CUDA
"
CMAKE_FLAGS
=
"-DCUDA_TOOLKIT_ROOT_DIR=
$CUDA_TOOLKIT_ROOT_DIR
-DUSE_CUDA=
$USE_CUDA
-DGPU_CACHE_BUILD_DIR=
$BINDIR
"
echo
$CMAKE_FLAGS
echo
$CMAKE_FLAGS
if
[
$#
-eq
0
]
;
then
if
[
$#
-eq
0
]
;
then
...
...
graphbolt/src/cuda/gpu_cache.cu
0 → 100644
View file @
f6db850d
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/gpu_cache.cu
* @brief GPUCache implementation on CUDA.
*/
#include <numeric>
#include "./common.h"
#include "./gpu_cache.h"
namespace
graphbolt
{
namespace
cuda
{
GpuCache
::
GpuCache
(
const
std
::
vector
<
int64_t
>
&
shape
,
torch
::
ScalarType
dtype
)
{
TORCH_CHECK
(
shape
.
size
()
>=
2
,
"Shape must at least have 2 dimensions."
);
const
auto
num_items
=
shape
[
0
];
const
int64_t
num_feats
=
std
::
accumulate
(
shape
.
begin
()
+
1
,
shape
.
end
(),
1ll
,
std
::
multiplies
<>
());
const
int
element_size
=
torch
::
empty
(
1
,
torch
::
TensorOptions
().
dtype
(
dtype
)).
element_size
();
num_bytes_
=
num_feats
*
element_size
;
num_float_feats_
=
(
num_bytes_
+
sizeof
(
float
)
-
1
)
/
sizeof
(
float
);
cache_
=
std
::
make_unique
<
gpu_cache_t
>
(
(
num_items
+
bucket_size
-
1
)
/
bucket_size
,
num_float_feats_
);
shape_
=
shape
;
shape_
[
0
]
=
-
1
;
dtype_
=
dtype
;
device_id_
=
cuda
::
GetCurrentStream
().
device_index
();
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
GpuCache
::
Query
(
torch
::
Tensor
keys
)
{
TORCH_CHECK
(
keys
.
device
().
is_cuda
(),
"Keys should be on a CUDA device."
);
TORCH_CHECK
(
keys
.
device
().
index
()
==
device_id_
,
"Keys should be on the correct CUDA device."
);
TORCH_CHECK
(
keys
.
sizes
().
size
()
==
1
,
"Keys should be a 1D tensor."
);
keys
=
keys
.
to
(
torch
::
kLong
);
auto
values
=
torch
::
empty
(
{
keys
.
size
(
0
),
num_float_feats_
},
keys
.
options
().
dtype
(
torch
::
kFloat
));
auto
missing_index
=
torch
::
empty
(
keys
.
size
(
0
),
keys
.
options
().
dtype
(
torch
::
kLong
));
auto
missing_keys
=
torch
::
empty
(
keys
.
size
(
0
),
keys
.
options
().
dtype
(
torch
::
kLong
));
cuda
::
CopyScalar
<
size_t
>
missing_len
;
auto
stream
=
cuda
::
GetCurrentStream
();
cache_
->
Query
(
reinterpret_cast
<
const
key_t
*>
(
keys
.
data_ptr
()),
keys
.
size
(
0
),
values
.
data_ptr
<
float
>
(),
reinterpret_cast
<
uint64_t
*>
(
missing_index
.
data_ptr
()),
reinterpret_cast
<
key_t
*>
(
missing_keys
.
data_ptr
()),
missing_len
.
get
(),
stream
);
values
=
values
.
view
(
torch
::
kByte
)
.
slice
(
1
,
0
,
num_bytes_
)
.
view
(
dtype_
)
.
view
(
shape_
);
// To safely read missing_len, we synchronize
stream
.
synchronize
();
missing_index
=
missing_index
.
slice
(
0
,
0
,
static_cast
<
size_t
>
(
missing_len
));
missing_keys
=
missing_keys
.
slice
(
0
,
0
,
static_cast
<
size_t
>
(
missing_len
));
return
std
::
make_tuple
(
values
,
missing_index
,
missing_keys
);
}
void
GpuCache
::
Replace
(
torch
::
Tensor
keys
,
torch
::
Tensor
values
)
{
TORCH_CHECK
(
keys
.
device
().
is_cuda
(),
"Keys should be on a CUDA device."
);
TORCH_CHECK
(
keys
.
device
().
index
()
==
device_id_
,
"Keys should be on the correct CUDA device."
);
TORCH_CHECK
(
values
.
device
().
is_cuda
(),
"Keys should be on a CUDA device."
);
TORCH_CHECK
(
values
.
device
().
index
()
==
device_id_
,
"Values should be on the correct CUDA device."
);
TORCH_CHECK
(
keys
.
size
(
0
)
==
values
.
size
(
0
),
"The first dimensions of keys and values must match."
);
TORCH_CHECK
(
std
::
equal
(
shape_
.
begin
()
+
1
,
shape_
.
end
(),
values
.
sizes
().
begin
()
+
1
),
"Values should have the correct dimensions."
);
TORCH_CHECK
(
values
.
scalar_type
()
==
dtype_
,
"Values should have the correct dtype."
);
keys
=
keys
.
to
(
torch
::
kLong
);
torch
::
Tensor
float_values
;
if
(
num_bytes_
%
sizeof
(
float
)
!=
0
)
{
float_values
=
torch
::
empty
(
{
values
.
size
(
0
),
num_float_feats_
},
values
.
options
().
dtype
(
torch
::
kFloat
));
float_values
.
view
(
torch
::
kByte
)
.
slice
(
1
,
0
,
num_bytes_
)
.
copy_
(
values
.
view
(
torch
::
kByte
).
view
({
values
.
size
(
0
),
-
1
}));
}
else
{
float_values
=
values
.
view
(
torch
::
kByte
)
.
view
({
values
.
size
(
0
),
-
1
})
.
view
(
torch
::
kFloat
)
.
contiguous
();
}
cache_
->
Replace
(
reinterpret_cast
<
const
key_t
*>
(
keys
.
data_ptr
()),
keys
.
size
(
0
),
float_values
.
data_ptr
<
float
>
(),
cuda
::
GetCurrentStream
());
}
c10
::
intrusive_ptr
<
GpuCache
>
GpuCache
::
Create
(
const
std
::
vector
<
int64_t
>
&
shape
,
torch
::
ScalarType
dtype
)
{
return
c10
::
make_intrusive
<
GpuCache
>
(
shape
,
dtype
);
}
}
// namespace cuda
}
// namespace graphbolt
graphbolt/src/cuda/gpu_cache.h
0 → 100644
View file @
f6db850d
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/gpu_cache.h
* @brief Header file of HugeCTR gpu_cache wrapper.
*/
#ifndef GRAPHBOLT_GPU_CACHE_H_
#define GRAPHBOLT_GPU_CACHE_H_
#include <torch/custom_class.h>
#include <torch/torch.h>
#include <limits>
#include <nv_gpu_cache.hpp>
namespace
graphbolt
{
namespace
cuda
{
class
GpuCache
:
public
torch
::
CustomClassHolder
{
using
key_t
=
long
long
;
constexpr
static
int
set_associativity
=
2
;
constexpr
static
int
WARP_SIZE
=
32
;
constexpr
static
int
bucket_size
=
WARP_SIZE
*
set_associativity
;
using
gpu_cache_t
=
::
gpu_cache
::
gpu_cache
<
key_t
,
uint64_t
,
std
::
numeric_limits
<
key_t
>::
max
(),
set_associativity
,
WARP_SIZE
>
;
public:
/**
* @brief Constructor for the GpuCache struct.
*
* @param shape The shape of the GPU cache.
* @param dtype The datatype of items to be stored.
*/
GpuCache
(
const
std
::
vector
<
int64_t
>&
shape
,
torch
::
ScalarType
dtype
);
GpuCache
()
=
default
;
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
Query
(
torch
::
Tensor
keys
);
void
Replace
(
torch
::
Tensor
keys
,
torch
::
Tensor
values
);
static
c10
::
intrusive_ptr
<
GpuCache
>
Create
(
const
std
::
vector
<
int64_t
>&
shape
,
torch
::
ScalarType
dtype
);
private:
std
::
vector
<
int64_t
>
shape_
;
torch
::
ScalarType
dtype_
;
std
::
unique_ptr
<
gpu_cache_t
>
cache_
;
int64_t
num_bytes_
;
int64_t
num_float_feats_
;
torch
::
DeviceIndex
device_id_
;
};
// The cu file in HugeCTR gpu cache uses unsigned int and long long.
// Changing to int64_t results in a mismatch of template arguments.
static_assert
(
sizeof
(
long
long
)
==
sizeof
(
int64_t
),
"long long and int64_t needs to have the same size."
);
// NOLINT
}
// namespace cuda
}
// namespace graphbolt
#endif // GRAPHBOLT_GPU_CACHE_H_
graphbolt/src/python_binding.cc
View file @
f6db850d
...
@@ -15,6 +15,10 @@
...
@@ -15,6 +15,10 @@
#include "./index_select.h"
#include "./index_select.h"
#include "./random.h"
#include "./random.h"
#ifdef GRAPHBOLT_USE_CUDA
#include "./cuda/gpu_cache.h"
#endif
namespace
graphbolt
{
namespace
graphbolt
{
namespace
sampling
{
namespace
sampling
{
...
@@ -70,6 +74,12 @@ TORCH_LIBRARY(graphbolt, m) {
...
@@ -70,6 +74,12 @@ TORCH_LIBRARY(graphbolt, m) {
g
->
SetState
(
state
);
g
->
SetState
(
state
);
return
g
;
return
g
;
});
});
#ifdef GRAPHBOLT_USE_CUDA
m
.
class_
<
cuda
::
GpuCache
>
(
"GpuCache"
)
.
def
(
"query"
,
&
cuda
::
GpuCache
::
Query
)
.
def
(
"replace"
,
&
cuda
::
GpuCache
::
Replace
);
m
.
def
(
"gpu_cache"
,
&
cuda
::
GpuCache
::
Create
);
#endif
m
.
def
(
"fused_csc_sampling_graph"
,
&
FusedCSCSamplingGraph
::
Create
);
m
.
def
(
"fused_csc_sampling_graph"
,
&
FusedCSCSamplingGraph
::
Create
);
m
.
def
(
m
.
def
(
"load_from_shared_memory"
,
&
FusedCSCSamplingGraph
::
LoadFromSharedMemory
);
"load_from_shared_memory"
,
&
FusedCSCSamplingGraph
::
LoadFromSharedMemory
);
...
...
python/dgl/graphbolt/impl/__init__.py
View file @
f6db850d
"""Implementation of GraphBolt."""
"""Implementation of GraphBolt."""
from
.basic_feature_store
import
*
from
.basic_feature_store
import
*
from
.fused_csc_sampling_graph
import
*
from
.fused_csc_sampling_graph
import
*
from
.gpu_cache
import
*
from
.gpu_cached_feature
import
*
from
.gpu_cached_feature
import
*
from
.in_subgraph_sampler
import
*
from
.in_subgraph_sampler
import
*
from
.legacy_dataset
import
*
from
.legacy_dataset
import
*
...
...
python/dgl/graphbolt/impl/gpu_cache.py
0 → 100644
View file @
f6db850d
"""HugeCTR gpu_cache wrapper for graphbolt."""
import
torch
class
GPUCache
(
object
):
"""High-level wrapper for GPU embedding cache"""
def
__init__
(
self
,
cache_shape
,
dtype
):
major
,
_
=
torch
.
cuda
.
get_device_capability
()
assert
(
major
>=
7
),
"GPUCache is supported only on CUDA compute capability >= 70 (Volta)."
self
.
_cache
=
torch
.
ops
.
graphbolt
.
gpu_cache
(
cache_shape
,
dtype
)
self
.
total_miss
=
0
self
.
total_queries
=
0
def
query
(
self
,
keys
):
"""Queries the GPU cache.
Parameters
----------
keys : Tensor
The keys to query the GPU cache with.
Returns
-------
tuple(Tensor, Tensor, Tensor)
A tuple containing (values, missing_indices, missing_keys) where
values[missing_indices] corresponds to cache misses that should be
filled by quering another source with missing_keys.
"""
self
.
total_queries
+=
keys
.
shape
[
0
]
values
,
missing_index
,
missing_keys
=
self
.
_cache
.
query
(
keys
)
self
.
total_miss
+=
missing_keys
.
shape
[
0
]
return
values
,
missing_index
,
missing_keys
def
replace
(
self
,
keys
,
values
):
"""Inserts key-value pairs into the GPU cache using the Least-Recently
Used (LRU) algorithm to remove old key-value pairs if it is full.
Parameters
----------
keys: Tensor
The keys to insert to the GPU cache.
values: Tensor
The values to insert to the GPU cache.
"""
self
.
_cache
.
replace
(
keys
,
values
)
@
property
def
miss_rate
(
self
):
"""Returns the cache miss rate since creation."""
return
self
.
total_miss
/
self
.
total_queries
python/dgl/graphbolt/impl/gpu_cached_feature.py
View file @
f6db850d
"""GPU cached feature for GraphBolt."""
"""GPU cached feature for GraphBolt."""
import
torch
import
torch
from
dgl.cuda
import
GPUCache
from
..feature_store
import
Feature
from
..feature_store
import
Feature
from
.gpu_cache
import
GPUCache
__all__
=
[
"GPUCachedFeature"
]
__all__
=
[
"GPUCachedFeature"
]
...
@@ -52,10 +52,7 @@ class GPUCachedFeature(Feature):
...
@@ -52,10 +52,7 @@ class GPUCachedFeature(Feature):
self
.
cache_size
=
cache_size
self
.
cache_size
=
cache_size
# Fetching the feature dimension from the underlying feature.
# Fetching the feature dimension from the underlying feature.
feat0
=
fallback_feature
.
read
(
torch
.
tensor
([
0
]))
feat0
=
fallback_feature
.
read
(
torch
.
tensor
([
0
]))
self
.
item_shape
=
(
-
1
,)
+
feat0
.
shape
[
1
:]
self
.
_feature
=
GPUCache
((
cache_size
,)
+
feat0
.
shape
[
1
:],
feat0
.
dtype
)
feat0
=
torch
.
reshape
(
feat0
,
(
1
,
-
1
))
self
.
flat_shape
=
(
-
1
,
feat0
.
shape
[
1
])
self
.
_feature
=
GPUCache
(
cache_size
,
feat0
.
shape
[
1
])
def
read
(
self
,
ids
:
torch
.
Tensor
=
None
):
def
read
(
self
,
ids
:
torch
.
Tensor
=
None
):
"""Read the feature by index.
"""Read the feature by index.
...
@@ -75,15 +72,12 @@ class GPUCachedFeature(Feature):
...
@@ -75,15 +72,12 @@ class GPUCachedFeature(Feature):
The read feature.
The read feature.
"""
"""
if
ids
is
None
:
if
ids
is
None
:
return
self
.
_fallback_feature
.
read
().
to
(
"cuda"
)
return
self
.
_fallback_feature
.
read
()
keys
=
ids
.
to
(
"cuda"
)
values
,
missing_index
,
missing_keys
=
self
.
_feature
.
query
(
ids
)
values
,
missing_index
,
missing_keys
=
self
.
_feature
.
query
(
keys
)
missing_values
=
self
.
_fallback_feature
.
read
(
missing_keys
).
to
(
"cuda"
)
missing_values
=
self
.
_fallback_feature
.
read
(
missing_keys
).
to
(
"cuda"
)
missing_values
=
missing_values
.
reshape
(
self
.
flat_shape
)
values
=
values
.
to
(
missing_values
.
dtype
)
values
[
missing_index
]
=
missing_values
values
[
missing_index
]
=
missing_values
self
.
_feature
.
replace
(
missing_keys
,
missing_values
)
self
.
_feature
.
replace
(
missing_keys
,
missing_values
)
return
torch
.
reshape
(
values
,
self
.
item_shape
)
return
values
def
size
(
self
):
def
size
(
self
):
"""Get the size of the feature.
"""Get the size of the feature.
...
@@ -114,10 +108,8 @@ class GPUCachedFeature(Feature):
...
@@ -114,10 +108,8 @@ class GPUCachedFeature(Feature):
size
=
min
(
self
.
cache_size
,
value
.
shape
[
0
])
size
=
min
(
self
.
cache_size
,
value
.
shape
[
0
])
self
.
_feature
.
replace
(
self
.
_feature
.
replace
(
torch
.
arange
(
0
,
size
,
device
=
"cuda"
),
torch
.
arange
(
0
,
size
,
device
=
"cuda"
),
value
[:
size
].
to
(
"cuda"
)
.
reshape
(
self
.
flat_shape
)
,
value
[:
size
].
to
(
"cuda"
),
)
)
else
:
else
:
self
.
_fallback_feature
.
update
(
value
,
ids
)
self
.
_fallback_feature
.
update
(
value
,
ids
)
self
.
_feature
.
replace
(
self
.
_feature
.
replace
(
ids
,
value
)
ids
.
to
(
"cuda"
),
value
.
to
(
"cuda"
).
reshape
(
self
.
flat_shape
)
)
tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py
View file @
f6db850d
...
@@ -2,34 +2,53 @@ import unittest
...
@@ -2,34 +2,53 @@ import unittest
import
backend
as
F
import
backend
as
F
import
pytest
import
torch
import
torch
from
dgl
import
graphbolt
as
gb
from
dgl
import
graphbolt
as
gb
@
unittest
.
skipIf
(
@
unittest
.
skipIf
(
F
.
_default_context_str
!=
"gpu"
,
F
.
_default_context_str
!=
"gpu"
reason
=
"GPUCachedFeature requires a GPU."
,
or
torch
.
cuda
.
get_device_capability
()[
0
]
<
7
,
reason
=
"GPUCachedFeature requires a Volta or later generation NVIDIA GPU."
,
)
)
def
test_gpu_cached_feature
():
@
pytest
.
mark
.
parametrize
(
a
=
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]).
to
(
"cuda"
).
float
()
"dtype"
,
b
=
torch
.
tensor
([[[
1
,
2
],
[
3
,
4
]],
[[
4
,
5
],
[
6
,
7
]]]).
to
(
"cuda"
).
float
()
[
torch
.
bool
,
torch
.
uint8
,
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
,
torch
.
float64
,
],
)
def
test_gpu_cached_feature
(
dtype
):
a
=
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
dtype
,
pin_memory
=
True
)
b
=
torch
.
tensor
(
[[[
1
,
2
],
[
3
,
4
]],
[[
4
,
5
],
[
6
,
7
]]],
dtype
=
dtype
,
pin_memory
=
True
)
feat_store_a
=
gb
.
GPUCachedFeature
(
gb
.
TorchBasedFeature
(
a
),
2
)
feat_store_a
=
gb
.
GPUCachedFeature
(
gb
.
TorchBasedFeature
(
a
),
2
)
feat_store_b
=
gb
.
GPUCachedFeature
(
gb
.
TorchBasedFeature
(
b
),
1
)
feat_store_b
=
gb
.
GPUCachedFeature
(
gb
.
TorchBasedFeature
(
b
),
1
)
# Test read the entire feature.
# Test read the entire feature.
assert
torch
.
equal
(
feat_store_a
.
read
(),
a
)
assert
torch
.
equal
(
feat_store_a
.
read
(),
a
.
to
(
"cuda"
)
)
assert
torch
.
equal
(
feat_store_b
.
read
(),
b
)
assert
torch
.
equal
(
feat_store_b
.
read
(),
b
.
to
(
"cuda"
)
)
# Test read with ids.
# Test read with ids.
assert
torch
.
equal
(
assert
torch
.
equal
(
feat_store_a
.
read
(
torch
.
tensor
([
0
]).
to
(
"cuda"
)),
feat_store_a
.
read
(
torch
.
tensor
([
0
]).
to
(
"cuda"
)),
torch
.
tensor
([[
1
.0
,
2
.0
,
3
.0
]]).
to
(
"cuda"
),
torch
.
tensor
([[
1
,
2
,
3
]]
,
dtype
=
dtype
).
to
(
"cuda"
),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
feat_store_b
.
read
(
torch
.
tensor
([
1
,
1
]).
to
(
"cuda"
)),
feat_store_b
.
read
(
torch
.
tensor
([
1
,
1
]).
to
(
"cuda"
)),
torch
.
tensor
([[[
4
.0
,
5
.0
],
[
6
.0
,
7
.0
]],
[[
4
.0
,
5
.0
],
[
6
.0
,
7
.0
]]]).
to
(
torch
.
tensor
([[[
4
,
5
],
[
6
,
7
]],
[[
4
,
5
],
[
6
,
7
]]]
,
dtype
=
dtype
).
to
(
"cuda"
"cuda"
),
),
)
)
...
@@ -40,18 +59,19 @@ def test_gpu_cached_feature():
...
@@ -40,18 +59,19 @@ def test_gpu_cached_feature():
# Test update the entire feature.
# Test update the entire feature.
feat_store_a
.
update
(
feat_store_a
.
update
(
torch
.
tensor
([[
0
.0
,
1
.0
,
2
.0
],
[
3
.0
,
5
.0
,
2
.0
]]).
to
(
"cuda"
)
torch
.
tensor
([[
0
,
1
,
2
],
[
3
,
5
,
2
]]
,
dtype
=
dtype
).
to
(
"cuda"
)
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
feat_store_a
.
read
(),
feat_store_a
.
read
(),
torch
.
tensor
([[
0
.0
,
1
.0
,
2
.0
],
[
3
.0
,
5
.0
,
2
.0
]]).
to
(
"cuda"
),
torch
.
tensor
([[
0
,
1
,
2
],
[
3
,
5
,
2
]]
,
dtype
=
dtype
).
to
(
"cuda"
),
)
)
# Test update with ids.
# Test update with ids.
feat_store_a
.
update
(
feat_store_a
.
update
(
torch
.
tensor
([[
2.0
,
0.0
,
1.0
]]).
to
(
"cuda"
),
torch
.
tensor
([
0
]).
to
(
"cuda"
)
torch
.
tensor
([[
2
,
0
,
1
]],
dtype
=
dtype
).
to
(
"cuda"
),
torch
.
tensor
([
0
]).
to
(
"cuda"
),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
feat_store_a
.
read
(),
feat_store_a
.
read
(),
torch
.
tensor
([[
2
.0
,
0.
0
,
1
.0
],
[
3
.0
,
5
.0
,
2
.0
]]).
to
(
"cuda"
),
torch
.
tensor
([[
2
,
0
,
1
],
[
3
,
5
,
2
]]
,
dtype
=
dtype
).
to
(
"cuda"
),
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment