Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b3224ce8
Unverified
Commit
b3224ce8
authored
Jan 08, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Jan 08, 2024
Browse files
[GraphBolt] Switch to using `AT_DISPATCH_INDEX_TYPES` for graph (#6912)
parent
61504ec5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
26 deletions
+26
-26
graphbolt/src/cuda/neighbor_sampler.cu
graphbolt/src/cuda/neighbor_sampler.cu
+11
-11
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+15
-15
No files found.
graphbolt/src/cuda/neighbor_sampler.cu
View file @
b3224ce8
...
@@ -164,16 +164,16 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -164,16 +164,16 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto
max_in_degree
=
torch
::
empty
(
auto
max_in_degree
=
torch
::
empty
(
1
,
1
,
c10
::
TensorOptions
().
dtype
(
in_degree
.
scalar_type
()).
pinned_memory
(
true
));
c10
::
TensorOptions
().
dtype
(
in_degree
.
scalar_type
()).
pinned_memory
(
true
));
AT_DISPATCH_IN
TEGRAL
_TYPES
(
AT_DISPATCH_IN
DEX
_TYPES
(
indptr
.
scalar_type
(),
"SampleNeighborsInDegree"
,
([
&
]
{
indptr
.
scalar_type
(),
"SampleNeighborsInDegree"
,
([
&
]
{
size_t
tmp_storage_size
=
0
;
size_t
tmp_storage_size
=
0
;
cub
::
DeviceReduce
::
Max
(
cub
::
DeviceReduce
::
Max
(
nullptr
,
tmp_storage_size
,
in_degree
.
data_ptr
<
scalar
_t
>
(),
nullptr
,
tmp_storage_size
,
in_degree
.
data_ptr
<
index
_t
>
(),
max_in_degree
.
data_ptr
<
scalar
_t
>
(),
num_rows
,
stream
);
max_in_degree
.
data_ptr
<
index
_t
>
(),
num_rows
,
stream
);
auto
tmp_storage
=
allocator
.
AllocateStorage
<
char
>
(
tmp_storage_size
);
auto
tmp_storage
=
allocator
.
AllocateStorage
<
char
>
(
tmp_storage_size
);
cub
::
DeviceReduce
::
Max
(
cub
::
DeviceReduce
::
Max
(
tmp_storage
.
get
(),
tmp_storage_size
,
in_degree
.
data_ptr
<
scalar
_t
>
(),
tmp_storage
.
get
(),
tmp_storage_size
,
in_degree
.
data_ptr
<
index
_t
>
(),
max_in_degree
.
data_ptr
<
scalar
_t
>
(),
num_rows
,
stream
);
max_in_degree
.
data_ptr
<
index
_t
>
(),
num_rows
,
stream
);
}));
}));
auto
coo_rows
=
CSRToCOO
(
sub_indptr
,
indices
.
scalar_type
());
auto
coo_rows
=
CSRToCOO
(
sub_indptr
,
indices
.
scalar_type
());
const
auto
num_edges
=
coo_rows
.
size
(
0
);
const
auto
num_edges
=
coo_rows
.
size
(
0
);
...
@@ -184,9 +184,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -184,9 +184,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch
::
Tensor
output_indices
;
torch
::
Tensor
output_indices
;
torch
::
optional
<
torch
::
Tensor
>
output_type_per_edge
;
torch
::
optional
<
torch
::
Tensor
>
output_type_per_edge
;
AT_DISPATCH_IN
TEGRAL
_TYPES
(
AT_DISPATCH_IN
DEX
_TYPES
(
indptr
.
scalar_type
(),
"SampleNeighborsIndptr"
,
([
&
]
{
indptr
.
scalar_type
(),
"SampleNeighborsIndptr"
,
([
&
]
{
using
indptr_t
=
scalar
_t
;
using
indptr_t
=
index
_t
;
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
auto
sampled_degree
=
thrust
::
make_transform_iterator
(
auto
sampled_degree
=
thrust
::
make_transform_iterator
(
iota
,
MinInDegreeFanout
<
indptr_t
>
{
iota
,
MinInDegreeFanout
<
indptr_t
>
{
...
@@ -234,9 +234,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -234,9 +234,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
);
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
);
auto
sorted_edge_id_segments
=
auto
sorted_edge_id_segments
=
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
);
allocator
.
AllocateStorage
<
edge_id_t
>
(
num_edges
);
AT_DISPATCH_IN
TEGRAL
_TYPES
(
AT_DISPATCH_IN
DEX
_TYPES
(
indices
.
scalar_type
(),
"SampleNeighborsIndices"
,
([
&
]
{
indices
.
scalar_type
(),
"SampleNeighborsIndices"
,
([
&
]
{
using
indices_t
=
scalar
_t
;
using
indices_t
=
index
_t
;
auto
probs_or_mask_scalar_type
=
torch
::
kFloat32
;
auto
probs_or_mask_scalar_type
=
torch
::
kFloat32
;
if
(
probs_or_mask
.
has_value
())
{
if
(
probs_or_mask
.
has_value
())
{
probs_or_mask_scalar_type
=
probs_or_mask_scalar_type
=
...
@@ -347,9 +347,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -347,9 +347,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
picked_eids
.
options
().
dtype
(
indices
.
scalar_type
()));
picked_eids
.
options
().
dtype
(
indices
.
scalar_type
()));
// Compute: output_indices = indices.gather(0, picked_eids);
// Compute: output_indices = indices.gather(0, picked_eids);
AT_DISPATCH_IN
TEGRAL
_TYPES
(
AT_DISPATCH_IN
DEX
_TYPES
(
indices
.
scalar_type
(),
"SampleNeighborsOutputIndices"
,
([
&
]
{
indices
.
scalar_type
(),
"SampleNeighborsOutputIndices"
,
([
&
]
{
using
indices_t
=
scalar
_t
;
using
indices_t
=
index
_t
;
const
auto
exec_policy
=
const
auto
exec_policy
=
thrust
::
cuda
::
par_nosync
(
allocator
).
on
(
stream
);
thrust
::
cuda
::
par_nosync
(
allocator
).
on
(
stream
);
thrust
::
gather
(
thrust
::
gather
(
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
b3224ce8
...
@@ -293,14 +293,14 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
...
@@ -293,14 +293,14 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
std
::
vector
<
torch
::
Tensor
>
edge_ids_arr
(
num_seeds
);
std
::
vector
<
torch
::
Tensor
>
edge_ids_arr
(
num_seeds
);
std
::
vector
<
torch
::
Tensor
>
type_per_edge_arr
(
num_seeds
);
std
::
vector
<
torch
::
Tensor
>
type_per_edge_arr
(
num_seeds
);
AT_DISPATCH_IN
TEGRAL
_TYPES
(
AT_DISPATCH_IN
DEX
_TYPES
(
indptr_
.
scalar_type
(),
"InSubgraph"
,
([
&
]
{
indptr_
.
scalar_type
(),
"InSubgraph"
,
([
&
]
{
torch
::
parallel_for
(
torch
::
parallel_for
(
0
,
num_seeds
,
kDefaultGrainSize
,
[
&
](
size_t
start
,
size_t
end
)
{
0
,
num_seeds
,
kDefaultGrainSize
,
[
&
](
size_t
start
,
size_t
end
)
{
for
(
size_t
i
=
start
;
i
<
end
;
++
i
)
{
for
(
size_t
i
=
start
;
i
<
end
;
++
i
)
{
const
auto
node_id
=
nodes
[
i
].
item
<
scalar
_t
>
();
const
auto
node_id
=
nodes
[
i
].
item
<
index
_t
>
();
const
auto
start_idx
=
indptr_
[
node_id
].
item
<
scalar
_t
>
();
const
auto
start_idx
=
indptr_
[
node_id
].
item
<
index
_t
>
();
const
auto
end_idx
=
indptr_
[
node_id
+
1
].
item
<
scalar
_t
>
();
const
auto
end_idx
=
indptr_
[
node_id
+
1
].
item
<
index
_t
>
();
indptr
[
i
+
1
]
=
end_idx
-
start_idx
;
indptr
[
i
+
1
]
=
end_idx
-
start_idx
;
original_column_node_ids
[
i
]
=
node_id
;
original_column_node_ids
[
i
]
=
node_id
;
indices_arr
[
i
]
=
indices_
.
slice
(
0
,
start_idx
,
end_idx
);
indices_arr
[
i
]
=
indices_
.
slice
(
0
,
start_idx
,
end_idx
);
...
@@ -490,12 +490,12 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -490,12 +490,12 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch
::
Tensor
subgraph_indices
;
torch
::
Tensor
subgraph_indices
;
torch
::
optional
<
torch
::
Tensor
>
subgraph_type_per_edge
=
torch
::
nullopt
;
torch
::
optional
<
torch
::
Tensor
>
subgraph_type_per_edge
=
torch
::
nullopt
;
AT_DISPATCH_IN
TEGRAL
_TYPES
(
AT_DISPATCH_IN
DEX
_TYPES
(
indptr_
.
scalar_type
(),
"SampleNeighborsImplWrappedWithIndptr"
,
([
&
]
{
indptr_
.
scalar_type
(),
"SampleNeighborsImplWrappedWithIndptr"
,
([
&
]
{
using
indptr_t
=
scalar
_t
;
using
indptr_t
=
index
_t
;
AT_DISPATCH_IN
TEGRAL
_TYPES
(
AT_DISPATCH_IN
DEX
_TYPES
(
nodes
.
scalar_type
(),
"SampleNeighborsImplWrappedWithNodes"
,
([
&
]
{
nodes
.
scalar_type
(),
"SampleNeighborsImplWrappedWithNodes"
,
([
&
]
{
using
nodes_t
=
scalar
_t
;
using
nodes_t
=
index
_t
;
const
auto
indptr_data
=
indptr_
.
data_ptr
<
indptr_t
>
();
const
auto
indptr_data
=
indptr_
.
data_ptr
<
indptr_t
>
();
auto
num_picked_neighbors_data_ptr
=
auto
num_picked_neighbors_data_ptr
=
num_picked_neighbors_per_node
.
data_ptr
<
indptr_t
>
();
num_picked_neighbors_per_node
.
data_ptr
<
indptr_t
>
();
...
@@ -563,13 +563,13 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -563,13 +563,13 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
// Step 5. Calculate other attributes and return the
// Step 5. Calculate other attributes and return the
// subgraph.
// subgraph.
AT_DISPATCH_IN
TEGRAL
_TYPES
(
AT_DISPATCH_IN
DEX
_TYPES
(
subgraph_indices
.
scalar_type
(),
subgraph_indices
.
scalar_type
(),
"IndexSelectSubgraphIndices"
,
([
&
]
{
"IndexSelectSubgraphIndices"
,
([
&
]
{
auto
subgraph_indices_data_ptr
=
auto
subgraph_indices_data_ptr
=
subgraph_indices
.
data_ptr
<
scalar
_t
>
();
subgraph_indices
.
data_ptr
<
index
_t
>
();
auto
indices_data_ptr
=
auto
indices_data_ptr
=
indices_
.
data_ptr
<
scalar
_t
>
();
indices_
.
data_ptr
<
index
_t
>
();
for
(
auto
i
=
picked_offset
;
for
(
auto
i
=
picked_offset
;
i
<
picked_offset
+
picked_number
;
++
i
)
{
i
<
picked_offset
+
picked_number
;
++
i
)
{
subgraph_indices_data_ptr
[
i
]
=
subgraph_indices_data_ptr
[
i
]
=
...
@@ -1394,10 +1394,10 @@ inline int64_t LaborPick(
...
@@ -1394,10 +1394,10 @@ inline int64_t LaborPick(
if
(
NonUniform
&&
probs_or_mask
.
value
().
size
(
0
)
<=
num_neighbors
)
{
if
(
NonUniform
&&
probs_or_mask
.
value
().
size
(
0
)
<=
num_neighbors
)
{
local_probs_data
-=
offset
;
local_probs_data
-=
offset
;
}
}
AT_DISPATCH_IN
TEGRAL
_TYPES
(
AT_DISPATCH_IN
DEX
_TYPES
(
args
.
indices
.
scalar_type
(),
"LaborPickMain"
,
([
&
]
{
args
.
indices
.
scalar_type
(),
"LaborPickMain"
,
([
&
]
{
const
scalar
_t
*
local_indices_data
=
const
index
_t
*
local_indices_data
=
args
.
indices
.
data_ptr
<
scalar
_t
>
()
+
offset
;
args
.
indices
.
data_ptr
<
index
_t
>
()
+
offset
;
if
constexpr
(
Replace
)
{
if
constexpr
(
Replace
)
{
// [Algorithm] @mfbalin
// [Algorithm] @mfbalin
// Use a max-heap to get rid of the big random numbers and filter the
// Use a max-heap to get rid of the big random numbers and filter the
...
@@ -1431,7 +1431,7 @@ inline int64_t LaborPick(
...
@@ -1431,7 +1431,7 @@ inline int64_t LaborPick(
auto
heap_end
=
heap_data
;
auto
heap_end
=
heap_data
;
const
auto
init_count
=
(
num_neighbors
+
fanout
-
1
)
/
num_neighbors
;
const
auto
init_count
=
(
num_neighbors
+
fanout
-
1
)
/
num_neighbors
;
auto
sample_neighbor_i_with_index_t_jth_time
=
auto
sample_neighbor_i_with_index_t_jth_time
=
[
&
](
scalar
_t
t
,
int64_t
j
,
uint32_t
i
)
{
[
&
](
index
_t
t
,
int64_t
j
,
uint32_t
i
)
{
auto
rnd
=
labor
::
jth_sorted_uniform_random
(
auto
rnd
=
labor
::
jth_sorted_uniform_random
(
args
.
random_seed
,
t
,
args
.
num_nodes
,
j
,
remaining_data
[
i
],
args
.
random_seed
,
t
,
args
.
num_nodes
,
j
,
remaining_data
[
i
],
fanout
-
j
);
// r_t
fanout
-
j
);
// r_t
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment