Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
84a01a16
Unverified
Commit
84a01a16
authored
Nov 06, 2023
by
czkkkkkk
Committed by
GitHub
Nov 06, 2023
Browse files
[Graphbolt] Enable cuda optimizations for UVAIndexSelect. (#6507)
parent
91fe0c90
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
155 additions
and
42 deletions
+155
-42
graphbolt/src/cuda/index_select_impl.cu
graphbolt/src/cuda/index_select_impl.cu
+98
-11
graphbolt/src/cuda/utils.h
graphbolt/src/cuda/utils.h
+37
-0
tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py
.../pytorch/graphbolt/impl/test_torch_based_feature_store.py
+20
-31
No files found.
graphbolt/src/cuda/index_select_impl.cu
View file @
84a01a16
...
...
@@ -9,15 +9,37 @@
#include <numeric>
#include "../index_select.h"
#include "./utils.h"
namespace
graphbolt
{
namespace
ops
{
/** @brief Index select operator implementation for feature size 1. */
template
<
typename
DType
,
typename
IdType
>
__global__
void
IndexSelectSingleKernel
(
const
DType
*
input
,
const
int64_t
input_len
,
const
IdType
*
index
,
const
int64_t
output_len
,
DType
*
output
,
const
int64_t
*
permutation
=
nullptr
)
{
int64_t
out_row_index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
while
(
out_row_index
<
output_len
)
{
assert
(
index
[
out_row_index
]
>=
0
&&
index
[
out_row_index
]
<
input_len
);
const
auto
out_row
=
permutation
?
permutation
[
out_row_index
]
:
out_row_index
;
output
[
out_row
]
=
input
[
index
[
out_row_index
]];
out_row_index
+=
stride
;
}
}
/**
* @brief Index select operator implementation for feature size > 1.
*/
template
<
typename
DType
,
typename
IdType
>
__global__
void
IndexSelectMultiKernel
(
const
DType
*
const
input
,
const
int64_t
input_len
,
const
int64_t
feature_size
,
const
IdType
*
const
index
,
const
int64_t
output_len
,
DType
*
const
output
)
{
const
int64_t
output_len
,
DType
*
const
output
,
const
int64_t
*
permutation
=
nullptr
)
{
int64_t
out_row_index
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
const
int64_t
stride
=
blockDim
.
y
*
gridDim
.
x
;
...
...
@@ -26,8 +48,10 @@ __global__ void IndexSelectMultiKernel(
int64_t
column
=
threadIdx
.
x
;
const
int64_t
in_row
=
index
[
out_row_index
];
assert
(
in_row
>=
0
&&
in_row
<
input_len
);
const
auto
out_row
=
permutation
?
permutation
[
out_row_index
]
:
out_row_index
;
while
(
column
<
feature_size
)
{
output
[
out_row
_index
*
feature_size
+
column
]
=
output
[
out_row
*
feature_size
+
column
]
=
input
[
in_row
*
feature_size
+
column
];
column
+=
blockDim
.
x
;
}
...
...
@@ -35,6 +59,43 @@ __global__ void IndexSelectMultiKernel(
}
}
/**
* @brief Index select operator implementation for feature size > 1.
*
* @note This is a cross-device access version of IndexSelectMultiKernel. Since
* the memory access over PCIe is more sensitive to the data access aligment
* (cacheline), we need a separate version here.
*/
template
<
typename
DType
,
typename
IdType
>
__global__
void
IndexSelectMultiKernelAligned
(
const
DType
*
const
input
,
const
int64_t
input_len
,
const
int64_t
feature_size
,
const
IdType
*
const
index
,
const
int64_t
output_len
,
DType
*
const
output
,
const
int64_t
*
permutation
=
nullptr
)
{
int64_t
out_row_index
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
const
int64_t
stride
=
blockDim
.
y
*
gridDim
.
x
;
while
(
out_row_index
<
output_len
)
{
int64_t
col
=
threadIdx
.
x
;
const
int64_t
in_row
=
index
[
out_row_index
];
assert
(
in_row
>=
0
&&
in_row
<
input_len
);
const
int64_t
idx_offset
=
((
uint64_t
)(
&
input
[
in_row
*
feature_size
])
%
GPU_CACHE_LINE_SIZE
)
/
sizeof
(
DType
);
col
=
col
-
idx_offset
;
const
auto
out_row
=
permutation
?
permutation
[
out_row_index
]
:
out_row_index
;
while
(
col
<
feature_size
)
{
if
(
col
>=
0
)
output
[
out_row
*
feature_size
+
col
]
=
input
[
in_row
*
feature_size
+
col
];
col
+=
blockDim
.
x
;
}
out_row_index
+=
stride
;
}
}
template
<
typename
DType
,
typename
IdType
>
torch
::
Tensor
UVAIndexSelectImpl_
(
torch
::
Tensor
input
,
torch
::
Tensor
index
)
{
const
int64_t
input_len
=
input
.
size
(
0
);
...
...
@@ -46,19 +107,45 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
.
dtype
(
input
.
dtype
())
.
device
(
c10
::
DeviceType
::
CUDA
));
DType
*
input_ptr
=
input
.
data_ptr
<
DType
>
();
IdType
*
index_ptr
=
index
.
data_ptr
<
IdType
>
();
DType
*
ret_ptr
=
ret
.
data_ptr
<
DType
>
();
// Sort the index to improve the memory access pattern.
torch
::
Tensor
sorted_index
,
permutation
;
std
::
tie
(
sorted_index
,
permutation
)
=
torch
::
sort
(
index
);
const
IdType
*
index_sorted_ptr
=
sorted_index
.
data_ptr
<
IdType
>
();
const
int64_t
*
permutation_ptr
=
permutation
.
data_ptr
<
int64_t
>
();
cudaStream_t
stream
=
0
;
if
(
feature_size
==
1
)
{
// Use a single thread to process each output row to avoid wasting threads.
const
int
num_threads
=
cuda
::
FindNumThreads
(
return_len
);
const
int
num_blocks
=
(
return_len
+
num_threads
-
1
)
/
num_threads
;
IndexSelectSingleKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
input_ptr
,
input_len
,
index_sorted_ptr
,
return_len
,
ret_ptr
,
permutation_ptr
);
}
else
{
dim3
block
(
512
,
1
);
// Find the smallest block size that can fit the feature_size.
while
(
static_cast
<
int64_t
>
(
block
.
x
)
>=
2
*
feature_size
)
{
block
.
x
>>=
1
;
block
.
y
<<=
1
;
}
const
dim3
grid
((
return_len
+
block
.
y
-
1
)
/
block
.
y
);
if
(
feature_size
*
sizeof
(
DType
)
<=
GPU_CACHE_LINE_SIZE
)
{
// When feature size is smaller than GPU cache line size, use unaligned
// version for less SM usage, which is more resource efficient.
IndexSelectMultiKernel
<<<
grid
,
block
,
0
,
stream
>>>
(
input_ptr
,
input_len
,
feature_size
,
index_ptr
,
return_len
,
ret_ptr
);
input_ptr
,
input_len
,
feature_size
,
index_sorted_ptr
,
return_len
,
ret_ptr
,
permutation_ptr
);
}
else
{
// Use aligned version to improve the memory access pattern.
IndexSelectMultiKernelAligned
<<<
grid
,
block
,
0
,
stream
>>>
(
input_ptr
,
input_len
,
feature_size
,
index_sorted_ptr
,
return_len
,
ret_ptr
,
permutation_ptr
);
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
auto
return_shape
=
std
::
vector
<
int64_t
>
({
return_len
});
return_shape
.
insert
(
return_shape
.
end
(),
input
.
sizes
().
begin
()
+
1
,
input
.
sizes
().
end
());
...
...
graphbolt/src/cuda/utils.h
0 → 100644
View file @
84a01a16
/**
* Copyright (c) 2023 by Contributors
*
* @file utils.h
* @brief CUDA utilities.
*/
#ifndef GRAPHBOLT_CUDA_UTILS_H_
#define GRAPHBOLT_CUDA_UTILS_H_
namespace
graphbolt
{
namespace
cuda
{
// The cache line size of GPU.
#define GPU_CACHE_LINE_SIZE 128
// The max number of threads per block.
#define CUDA_MAX_NUM_THREADS 1024
/**
* @brief Calculate the number of threads needed given the size of the dimension
* to be processed.
*
* It finds the largest power of two that is less than or equal to the minimum
* of size and CUDA_MAX_NUM_THREADS.
*/
inline
int
FindNumThreads
(
int
size
)
{
int
ret
=
1
;
while
((
ret
<<
1
)
<=
std
::
min
(
size
,
CUDA_MAX_NUM_THREADS
))
{
ret
<<=
1
;
}
return
ret
;
}
}
// namespace cuda
}
// namespace graphbolt
#endif // GRAPHBOLT_CUDA_UTILS_H_
tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py
View file @
84a01a16
import
os
import
tempfile
import
unittest
from
functools
import
reduce
from
operator
import
mul
import
backend
as
F
...
...
@@ -136,44 +138,31 @@ def test_torch_based_feature(in_memory):
"dtype"
,
[
torch
.
float32
,
torch
.
float64
,
torch
.
int32
,
torch
.
int64
]
)
@
pytest
.
mark
.
parametrize
(
"idtype"
,
[
torch
.
int32
,
torch
.
int64
])
def
test_torch_based_pinned_feature
(
dtype
,
idtype
):
a
=
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
dtype
).
pin_memory
()
b
=
torch
.
tensor
(
[[[
1
,
2
],
[
3
,
4
]],
[[
4
,
5
],
[
6
,
7
]]],
dtype
=
dtype
).
pin_memory
()
c
=
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
dtype
).
pin_memory
()
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
2
,
1
),
(
2
,
3
),
(
2
,
2
,
2
)])
def
test_torch_based_pinned_feature
(
dtype
,
idtype
,
shape
):
tensor
=
torch
.
arange
(
0
,
reduce
(
mul
,
shape
),
dtype
=
dtype
).
reshape
(
shape
)
test_tensor
=
tensor
.
clone
().
detach
()
test_tensor_cuda
=
test_tensor
.
cuda
()
feature_a
=
gb
.
TorchBasedFeature
(
a
)
feature_b
=
gb
.
TorchBasedFeature
(
b
)
feature_c
=
gb
.
TorchBasedFeature
(
c
)
feature
=
gb
.
TorchBasedFeature
(
tensor
)
feature
.
pin_memory_
()
# Test read entire pinned feature, the result should be on cuda.
assert
torch
.
equal
(
feature
.
read
(),
test_tensor_cuda
)
assert
feature
.
read
().
is_cuda
assert
torch
.
equal
(
feature_a
.
read
(),
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
dtype
).
cuda
(),
)
assert
feature_a
.
read
().
is_cuda
assert
torch
.
equal
(
feature_b
.
read
(),
torch
.
tensor
([[[
1
,
2
],
[
3
,
4
]],
[[
4
,
5
],
[
6
,
7
]]],
dtype
=
dtype
).
cuda
(),
feature
.
read
(
torch
.
tensor
([
0
],
dtype
=
idtype
).
cuda
()),
test_tensor_cuda
[[
0
]],
)
assert
feature_b
.
read
().
is_cuda
assert
torch
.
equal
(
feature_a
.
read
(
torch
.
tensor
([
0
],
dtype
=
idtype
).
cuda
()),
torch
.
tensor
([[
1
,
2
,
3
]],
dtype
=
dtype
).
cuda
(),
)
assert
feature_a
.
read
(
torch
.
tensor
([
0
],
dtype
=
idtype
).
cuda
()).
is_cuda
assert
torch
.
equal
(
feature_b
.
read
(
torch
.
tensor
([
1
],
dtype
=
idtype
).
cuda
()),
torch
.
tensor
([[[
4
,
5
],
[
6
,
7
]]],
dtype
=
dtype
).
cuda
(),
)
assert
feature_b
.
read
(
torch
.
tensor
([
1
],
dtype
=
idtype
).
cuda
()).
is_cuda
assert
feature_c
.
read
().
is_cuda
# Test read pinned feature with idx on cuda, the result should be on cuda.
assert
feature
.
read
(
torch
.
tensor
([
0
],
dtype
=
idtype
).
cuda
()).
is_cuda
# Test read pinned feature with idx on cpu, the result should be on cpu.
assert
torch
.
equal
(
feature_c
.
read
(
torch
.
tensor
([
0
],
dtype
=
idtype
)),
torch
.
tensor
([[
1
,
2
,
3
]],
dtype
=
dtype
),
feature
.
read
(
torch
.
tensor
([
0
],
dtype
=
idtype
)),
test_tensor
[[
0
]]
)
assert
not
feature
_c
.
read
(
torch
.
tensor
([
0
],
dtype
=
idtype
)).
is_cuda
assert
not
feature
.
read
(
torch
.
tensor
([
0
],
dtype
=
idtype
)).
is_cuda
def
write_tensor_to_disk
(
dir
,
name
,
t
,
fmt
=
"torch"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment