Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d5b03bcb
Unverified
Commit
d5b03bcb
authored
Feb 03, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Feb 02, 2024
Browse files
[GraphBolt][CUDA] GPUCache performance fix. (#7073)
parent
85683869
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
6 deletions
+5
-6
graphbolt/src/cuda/gpu_cache.cu
graphbolt/src/cuda/gpu_cache.cu
+5
-6
No files found.
graphbolt/src/cuda/gpu_cache.cu
View file @
d5b03bcb
...
@@ -43,20 +43,19 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(
...
@@ -43,20 +43,19 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(
torch
::
empty
(
keys
.
size
(
0
),
keys
.
options
().
dtype
(
torch
::
kLong
));
torch
::
empty
(
keys
.
size
(
0
),
keys
.
options
().
dtype
(
torch
::
kLong
));
auto
missing_keys
=
auto
missing_keys
=
torch
::
empty
(
keys
.
size
(
0
),
keys
.
options
().
dtype
(
torch
::
kLong
));
torch
::
empty
(
keys
.
size
(
0
),
keys
.
options
().
dtype
(
torch
::
kLong
));
cuda
::
CopyScalar
<
size_t
>
missing_len
;
auto
allocator
=
cuda
::
GetAllocator
()
;
auto
stream
=
cuda
::
GetCurrentStream
(
);
auto
missing_len_device
=
allocator
.
AllocateStorage
<
size_t
>
(
1
);
cache_
->
Query
(
cache_
->
Query
(
reinterpret_cast
<
const
key_t
*>
(
keys
.
data_ptr
()),
keys
.
size
(
0
),
reinterpret_cast
<
const
key_t
*>
(
keys
.
data_ptr
()),
keys
.
size
(
0
),
values
.
data_ptr
<
float
>
(),
values
.
data_ptr
<
float
>
(),
reinterpret_cast
<
uint64_t
*>
(
missing_index
.
data_ptr
()),
reinterpret_cast
<
uint64_t
*>
(
missing_index
.
data_ptr
()),
reinterpret_cast
<
key_t
*>
(
missing_keys
.
data_ptr
()),
missing_len
.
get
(),
reinterpret_cast
<
key_t
*>
(
missing_keys
.
data_ptr
()),
s
tream
);
missing_len_device
.
get
(),
cuda
::
GetCurrentS
tream
()
);
values
=
values
.
view
(
torch
::
kByte
)
values
=
values
.
view
(
torch
::
kByte
)
.
slice
(
1
,
0
,
num_bytes_
)
.
slice
(
1
,
0
,
num_bytes_
)
.
view
(
dtype_
)
.
view
(
dtype_
)
.
view
(
shape_
);
.
view
(
shape_
);
// To safely read missing_len, we synchronize
cuda
::
CopyScalar
<
size_t
>
missing_len
(
missing_len_device
.
get
());
stream
.
synchronize
();
missing_index
=
missing_index
.
slice
(
0
,
0
,
static_cast
<
size_t
>
(
missing_len
));
missing_index
=
missing_index
.
slice
(
0
,
0
,
static_cast
<
size_t
>
(
missing_len
));
missing_keys
=
missing_keys
.
slice
(
0
,
0
,
static_cast
<
size_t
>
(
missing_len
));
missing_keys
=
missing_keys
.
slice
(
0
,
0
,
static_cast
<
size_t
>
(
missing_len
));
return
std
::
make_tuple
(
values
,
missing_index
,
missing_keys
);
return
std
::
make_tuple
(
values
,
missing_index
,
missing_keys
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment