Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0f3bfd7e
Unverified
Commit
0f3bfd7e
authored
Jan 12, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Jan 12, 2024
Browse files
[GraphBolt][CUDA] Refactor `IndexSelectCSC` and add `output_size` argument (#6927)
parent
3795a006
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
110 additions
and
58 deletions
+110
-58
graphbolt/include/graphbolt/cuda_ops.h
graphbolt/include/graphbolt/cuda_ops.h
+24
-1
graphbolt/src/cuda/index_select_csc_impl.cu
graphbolt/src/cuda/index_select_csc_impl.cu
+51
-37
graphbolt/src/cuda/insubgraph.cu
graphbolt/src/cuda/insubgraph.cu
+7
-5
graphbolt/src/cuda/neighbor_sampler.cu
graphbolt/src/cuda/neighbor_sampler.cu
+13
-8
graphbolt/src/index_select.cc
graphbolt/src/index_select.cc
+3
-2
graphbolt/src/index_select.h
graphbolt/src/index_select.h
+3
-1
tests/python/pytorch/graphbolt/impl/test_in_subgraph_sampler.py
...python/pytorch/graphbolt/impl/test_in_subgraph_sampler.py
+9
-4
No files found.
graphbolt/include/graphbolt/cuda_ops.h
View file @
0f3bfd7e
...
@@ -68,6 +68,27 @@ Sort(torch::Tensor input, int num_bits = 0);
...
@@ -68,6 +68,27 @@ Sort(torch::Tensor input, int num_bits = 0);
*/
*/
torch
::
Tensor
IsIn
(
torch
::
Tensor
elements
,
torch
::
Tensor
test_elements
);
torch
::
Tensor
IsIn
(
torch
::
Tensor
elements
,
torch
::
Tensor
test_elements
);
/**
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor.
*
* NOTE: The shape of all tensors must be 1-D.
*
* @param in_degree Indegree tensor containing degrees of nodes being copied.
* @param sliced_indptr Sliced_indptr tensor containing indptr values of nodes
* being copied.
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @param nodes_max An upperbound on `nodes.max()`.
* @param output_size The total number of edges being copied.
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
*/
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
IndexSelectCSCImpl
(
torch
::
Tensor
in_degree
,
torch
::
Tensor
sliced_indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
,
int64_t
nodes_max
,
torch
::
optional
<
int64_t
>
output_size
=
torch
::
nullopt
);
/**
/**
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor.
* tensor.
...
@@ -77,11 +98,13 @@ torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements);
...
@@ -77,11 +98,13 @@ torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements);
* @param indptr Indptr tensor containing offsets with shape (N,).
* @param indptr Indptr tensor containing offsets with shape (N,).
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @param nodes Nodes tensor with shape (M,).
* @param output_size The total number of edges being copied.
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
*/
*/
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
IndexSelectCSCImpl
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
IndexSelectCSCImpl
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
);
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
,
torch
::
optional
<
int64_t
>
output_size
=
torch
::
nullopt
);
/**
/**
* @brief Slices the indptr tensor with nodes and returns the indegrees of the
* @brief Slices the indptr tensor with nodes and returns the indegrees of the
...
...
graphbolt/src/cuda/index_select_csc_impl.cu
View file @
0f3bfd7e
...
@@ -86,14 +86,15 @@ template <typename indptr_t, typename indices_t>
...
@@ -86,14 +86,15 @@ template <typename indptr_t, typename indices_t>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
UVAIndexSelectCSCCopyIndices
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
UVAIndexSelectCSCCopyIndices
(
torch
::
Tensor
indices
,
const
int64_t
num_nodes
,
torch
::
Tensor
indices
,
const
int64_t
num_nodes
,
const
indptr_t
*
const
in_degree
,
const
indptr_t
*
const
sliced_indptr
,
const
indptr_t
*
const
in_degree
,
const
indptr_t
*
const
sliced_indptr
,
const
int64_t
*
const
perm
,
torch
::
TensorOptions
nodes_options
,
const
int64_t
*
const
perm
,
torch
::
TensorOptions
options
,
torch
::
ScalarType
indptr_scalar_type
)
{
torch
::
ScalarType
indptr_scalar_type
,
torch
::
optional
<
int64_t
>
output_size
)
{
auto
allocator
=
cuda
::
GetAllocator
();
auto
allocator
=
cuda
::
GetAllocator
();
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
// Output indptr for the slice indexed by nodes.
// Output indptr for the slice indexed by nodes.
auto
output_indptr
=
auto
output_indptr
=
torch
::
empty
(
num_nodes
+
1
,
nodes_
options
.
dtype
(
indptr_scalar_type
));
torch
::
empty
(
num_nodes
+
1
,
options
.
dtype
(
indptr_scalar_type
));
auto
output_indptr_aligned
=
auto
output_indptr_aligned
=
allocator
.
AllocateStorage
<
indptr_t
>
(
num_nodes
+
1
);
allocator
.
AllocateStorage
<
indptr_t
>
(
num_nodes
+
1
);
...
@@ -114,16 +115,18 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
...
@@ -114,16 +115,18 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
}
}
// Copy the actual total number of edges.
// Copy the actual total number of edges.
auto
edge_count
=
if
(
!
output_size
.
has_value
())
{
cuda
::
CopyScalar
{
output_indptr
.
data_ptr
<
indptr_t
>
()
+
num_nodes
};
auto
edge_count
=
cuda
::
CopyScalar
{
output_indptr
.
data_ptr
<
indptr_t
>
()
+
num_nodes
};
output_size
=
static_cast
<
indptr_t
>
(
edge_count
);
}
// Copy the modified number of edges.
// Copy the modified number of edges.
auto
edge_count_aligned
=
auto
edge_count_aligned
=
cuda
::
CopyScalar
{
output_indptr_aligned
.
get
()
+
num_nodes
};
cuda
::
CopyScalar
{
output_indptr_aligned
.
get
()
+
num_nodes
};
// Allocate output array with actual number of edges.
// Allocate output array with actual number of edges.
torch
::
Tensor
output_indices
=
torch
::
empty
(
torch
::
Tensor
output_indices
=
static_cast
<
indptr_t
>
(
edge_count
),
torch
::
empty
(
output_size
.
value
(),
options
.
dtype
(
indices
.
scalar_type
()));
nodes_options
.
dtype
(
indices
.
scalar_type
()));
const
dim3
block
(
BLOCK_SIZE
);
const
dim3
block
(
BLOCK_SIZE
);
const
dim3
grid
(
const
dim3
grid
(
(
static_cast
<
indptr_t
>
(
edge_count_aligned
)
+
BLOCK_SIZE
-
1
)
/
(
static_cast
<
indptr_t
>
(
edge_count_aligned
)
+
BLOCK_SIZE
-
1
)
/
...
@@ -141,26 +144,22 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
...
@@ -141,26 +144,22 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
UVAIndexSelectCSCImpl
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
UVAIndexSelectCSCImpl
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
)
{
torch
::
Tensor
in_degree
,
torch
::
Tensor
sliced_indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
,
int
num_bits
,
torch
::
optional
<
int64_t
>
output_size
)
{
// Sorting nodes so that accesses over PCI-e are more regular.
// Sorting nodes so that accesses over PCI-e are more regular.
const
auto
sorted_idx
=
const
auto
sorted_idx
=
Sort
(
nodes
,
num_bits
).
second
;
Sort
(
nodes
,
cuda
::
NumberOfBits
(
indptr
.
size
(
0
)
-
1
)).
second
;
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
auto
in_degree_and_sliced_indptr
=
SliceCSCIndptr
(
indptr
,
nodes
);
return
AT_DISPATCH_INTEGRAL_TYPES
(
return
AT_DISPATCH_INTEGRAL_TYPES
(
indptr
.
scalar_type
(),
"UVAIndexSelectCSCIndptr"
,
([
&
]
{
sliced_
indptr
.
scalar_type
(),
"UVAIndexSelectCSCIndptr"
,
([
&
]
{
using
indptr_t
=
scalar_t
;
using
indptr_t
=
scalar_t
;
auto
in_degree
=
std
::
get
<
0
>
(
in_degree_and_sliced_indptr
).
data_ptr
<
indptr_t
>
();
auto
sliced_indptr
=
std
::
get
<
1
>
(
in_degree_and_sliced_indptr
).
data_ptr
<
indptr_t
>
();
return
GRAPHBOLT_DISPATCH_ELEMENT_SIZES
(
return
GRAPHBOLT_DISPATCH_ELEMENT_SIZES
(
indices
.
element_size
(),
"UVAIndexSelectCSCCopyIndices"
,
([
&
]
{
indices
.
element_size
(),
"UVAIndexSelectCSCCopyIndices"
,
([
&
]
{
return
UVAIndexSelectCSCCopyIndices
<
indptr_t
,
element_size_t
>
(
return
UVAIndexSelectCSCCopyIndices
<
indptr_t
,
element_size_t
>
(
indices
,
num_nodes
,
in_degree
,
sliced_indptr
,
indices
,
num_nodes
,
in_degree
.
data_ptr
<
indptr_t
>
(),
sliced_indptr
.
data_ptr
<
indptr_t
>
(),
sorted_idx
.
data_ptr
<
int64_t
>
(),
nodes
.
options
(),
sorted_idx
.
data_ptr
<
int64_t
>
(),
nodes
.
options
(),
indptr
.
scalar_type
());
sliced_
indptr
.
scalar_type
()
,
output_size
);
}));
}));
}));
}));
}
}
...
@@ -204,38 +203,39 @@ void IndexSelectCSCCopyIndices(
...
@@ -204,38 +203,39 @@ void IndexSelectCSCCopyIndices(
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
DeviceIndexSelectCSCImpl
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
DeviceIndexSelectCSCImpl
(
torch
::
Tensor
in
dptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
)
{
torch
::
Tensor
in
_degree
,
torch
::
Tensor
sliced_indptr
,
torch
::
Tensor
indices
,
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
torch
::
TensorOptions
options
,
torch
::
optional
<
int64_t
>
output_size
)
{
auto
in_degree_and_sliced_indptr
=
SliceCSCIndptr
(
indptr
,
nodes
);
const
int64_t
num_nodes
=
sliced_indptr
.
size
(
0
);
return
AT_DISPATCH_INTEGRAL_TYPES
(
return
AT_DISPATCH_INTEGRAL_TYPES
(
indptr
.
scalar_type
(),
"IndexSelectCSCIndptr"
,
([
&
]
{
sliced_
indptr
.
scalar_type
(),
"IndexSelectCSCIndptr"
,
([
&
]
{
using
indptr_t
=
scalar_t
;
using
indptr_t
=
scalar_t
;
auto
in_degree
=
auto
in_degree_ptr
=
in_degree
.
data_ptr
<
indptr_t
>
();
std
::
get
<
0
>
(
in_degree_and_sliced_indptr
).
data_ptr
<
indptr_t
>
();
auto
sliced_indptr_ptr
=
sliced_indptr
.
data_ptr
<
indptr_t
>
();
auto
sliced_indptr
=
std
::
get
<
1
>
(
in_degree_and_sliced_indptr
).
data_ptr
<
indptr_t
>
();
// Output indptr for the slice indexed by nodes.
// Output indptr for the slice indexed by nodes.
torch
::
Tensor
output_indptr
=
torch
::
empty
(
torch
::
Tensor
output_indptr
=
torch
::
empty
(
num_nodes
+
1
,
nodes
.
options
()
.
dtype
(
indptr
.
scalar_type
()));
num_nodes
+
1
,
options
.
dtype
(
sliced_
indptr
.
scalar_type
()));
// Compute the output indptr, output_indptr.
// Compute the output indptr, output_indptr.
CUB_CALL
(
CUB_CALL
(
DeviceScan
::
ExclusiveSum
,
in_degree
,
DeviceScan
::
ExclusiveSum
,
in_degree
_ptr
,
output_indptr
.
data_ptr
<
indptr_t
>
(),
num_nodes
+
1
);
output_indptr
.
data_ptr
<
indptr_t
>
(),
num_nodes
+
1
);
// Number of edges being copied.
// Number of edges being copied.
auto
edge_count
=
if
(
!
output_size
.
has_value
())
{
cuda
::
CopyScalar
{
output_indptr
.
data_ptr
<
indptr_t
>
()
+
num_nodes
};
auto
edge_count
=
cuda
::
CopyScalar
{
output_indptr
.
data_ptr
<
indptr_t
>
()
+
num_nodes
};
output_size
=
static_cast
<
indptr_t
>
(
edge_count
);
}
// Allocate output array of size number of copied edges.
// Allocate output array of size number of copied edges.
torch
::
Tensor
output_indices
=
torch
::
empty
(
torch
::
Tensor
output_indices
=
torch
::
empty
(
static_cast
<
indptr_t
>
(
edge_count
),
output_size
.
value
(),
options
.
dtype
(
indices
.
scalar_type
()));
nodes
.
options
().
dtype
(
indices
.
scalar_type
()));
GRAPHBOLT_DISPATCH_ELEMENT_SIZES
(
GRAPHBOLT_DISPATCH_ELEMENT_SIZES
(
indices
.
element_size
(),
"IndexSelectCSCCopyIndices"
,
([
&
]
{
indices
.
element_size
(),
"IndexSelectCSCCopyIndices"
,
([
&
]
{
using
indices_t
=
element_size_t
;
using
indices_t
=
element_size_t
;
IndexSelectCSCCopyIndices
<
indptr_t
,
indices_t
>
(
IndexSelectCSCCopyIndices
<
indptr_t
,
indices_t
>
(
num_nodes
,
reinterpret_cast
<
indices_t
*>
(
indices
.
data_ptr
()),
num_nodes
,
reinterpret_cast
<
indices_t
*>
(
indices
.
data_ptr
()),
sliced_indptr
,
in_degree
,
output_indptr
.
data_ptr
<
indptr_t
>
(),
sliced_indptr_ptr
,
in_degree_ptr
,
output_indptr
.
data_ptr
<
indptr_t
>
(),
reinterpret_cast
<
indices_t
*>
(
output_indices
.
data_ptr
()));
reinterpret_cast
<
indices_t
*>
(
output_indices
.
data_ptr
()));
}));
}));
return
std
::
make_tuple
(
output_indptr
,
output_indices
);
return
std
::
make_tuple
(
output_indptr
,
output_indices
);
...
@@ -243,13 +243,27 @@ std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl(
...
@@ -243,13 +243,27 @@ std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl(
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
IndexSelectCSCImpl
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
IndexSelectCSCImpl
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
)
{
torch
::
Tensor
in_degree
,
torch
::
Tensor
sliced_indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
,
int64_t
nodes_max
,
torch
::
optional
<
int64_t
>
output_size
)
{
if
(
indices
.
is_pinned
())
{
if
(
indices
.
is_pinned
())
{
return
UVAIndexSelectCSCImpl
(
indptr
,
indices
,
nodes
);
int
num_bits
=
cuda
::
NumberOfBits
(
nodes_max
+
1
);
return
UVAIndexSelectCSCImpl
(
in_degree
,
sliced_indptr
,
indices
,
nodes
,
num_bits
,
output_size
);
}
else
{
}
else
{
return
DeviceIndexSelectCSCImpl
(
indptr
,
indices
,
nodes
);
return
DeviceIndexSelectCSCImpl
(
in_degree
,
sliced_indptr
,
indices
,
nodes
.
options
(),
output_size
);
}
}
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
IndexSelectCSCImpl
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
,
torch
::
optional
<
int64_t
>
output_size
)
{
auto
[
in_degree
,
sliced_indptr
]
=
SliceCSCIndptr
(
indptr
,
nodes
);
return
IndexSelectCSCImpl
(
in_degree
,
sliced_indptr
,
indices
,
nodes
,
indptr
.
size
(
0
)
-
2
,
output_size
);
}
}
// namespace ops
}
// namespace ops
}
// namespace graphbolt
}
// namespace graphbolt
graphbolt/src/cuda/insubgraph.cu
View file @
0f3bfd7e
...
@@ -16,15 +16,17 @@ namespace ops {
...
@@ -16,15 +16,17 @@ namespace ops {
c10
::
intrusive_ptr
<
sampling
::
FusedSampledSubgraph
>
InSubgraph
(
c10
::
intrusive_ptr
<
sampling
::
FusedSampledSubgraph
>
InSubgraph
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
,
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
)
{
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
)
{
auto
[
output_indptr
,
output_indices
]
=
auto
[
in_degree
,
sliced_indptr
]
=
SliceCSCIndptr
(
indptr
,
nodes
);
IndexSelectCSCImpl
(
indptr
,
indices
,
nodes
);
auto
[
output_indptr
,
output_indices
]
=
IndexSelectCSCImpl
(
in_degree
,
sliced_indptr
,
indices
,
nodes
,
indptr
.
size
(
0
)
-
2
);
const
int64_t
num_edges
=
output_indices
.
size
(
0
);
torch
::
optional
<
torch
::
Tensor
>
output_type_per_edge
;
torch
::
optional
<
torch
::
Tensor
>
output_type_per_edge
;
if
(
type_per_edge
)
{
if
(
type_per_edge
)
{
output_type_per_edge
=
output_type_per_edge
=
std
::
get
<
1
>
(
IndexSelectCSCImpl
(
std
::
get
<
1
>
(
IndexSelectCSCImpl
(
indptr
,
type_per_edge
.
value
(),
nodes
));
in_degree
,
sliced_indptr
,
type_per_edge
.
value
(),
nodes
,
indptr
.
size
(
0
)
-
2
,
num_edges
));
}
}
auto
rows
=
CSRToCOO
(
output_indptr
,
indices
.
scalar_type
());
auto
rows
=
CSRToCOO
(
output_indptr
,
indices
.
scalar_type
());
auto
[
in_degree
,
sliced_indptr
]
=
SliceCSCIndptr
(
indptr
,
nodes
);
auto
i
=
torch
::
arange
(
output_indices
.
size
(
0
),
output_indptr
.
options
());
auto
i
=
torch
::
arange
(
output_indices
.
size
(
0
),
output_indptr
.
options
());
auto
edge_ids
=
auto
edge_ids
=
i
-
output_indptr
.
gather
(
0
,
rows
)
+
sliced_indptr
.
gather
(
0
,
rows
);
i
-
output_indptr
.
gather
(
0
,
rows
)
+
sliced_indptr
.
gather
(
0
,
rows
);
...
...
graphbolt/src/cuda/neighbor_sampler.cu
View file @
0f3bfd7e
...
@@ -157,25 +157,30 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -157,25 +157,30 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto
in_degree_and_sliced_indptr
=
SliceCSCIndptr
(
indptr
,
nodes
);
auto
in_degree_and_sliced_indptr
=
SliceCSCIndptr
(
indptr
,
nodes
);
auto
in_degree
=
std
::
get
<
0
>
(
in_degree_and_sliced_indptr
);
auto
in_degree
=
std
::
get
<
0
>
(
in_degree_and_sliced_indptr
);
auto
sliced_indptr
=
std
::
get
<
1
>
(
in_degree_and_sliced_indptr
);
auto
sliced_indptr
=
std
::
get
<
1
>
(
in_degree_and_sliced_indptr
);
torch
::
optional
<
int64_t
>
num_edges_
;
torch
::
Tensor
sub_indptr
;
torch
::
Tensor
sub_indptr
;
// @todo mfbalin, refactor IndexSelectCSCImpl so that it does not have to take
// nodes as input
torch
::
optional
<
torch
::
Tensor
>
sliced_probs_or_mask
;
torch
::
optional
<
torch
::
Tensor
>
sliced_probs_or_mask
;
if
(
probs_or_mask
.
has_value
())
{
if
(
probs_or_mask
.
has_value
())
{
torch
::
Tensor
sliced_probs_or_mask_tensor
;
torch
::
Tensor
sliced_probs_or_mask_tensor
;
std
::
tie
(
sub_indptr
,
sliced_probs_or_mask_tensor
)
=
std
::
tie
(
sub_indptr
,
sliced_probs_or_mask_tensor
)
=
IndexSelectCSCImpl
(
IndexSelectCSCImpl
(
indptr
,
probs_or_mask
.
value
(),
nodes
);
in_degree
,
sliced_indptr
,
probs_or_mask
.
value
(),
nodes
,
indptr
.
size
(
0
)
-
2
,
num_edges_
);
sliced_probs_or_mask
=
sliced_probs_or_mask_tensor
;
sliced_probs_or_mask
=
sliced_probs_or_mask_tensor
;
}
else
{
num_edges_
=
sliced_probs_or_mask_tensor
.
size
(
0
);
sub_indptr
=
ExclusiveCumSum
(
in_degree
);
}
}
if
(
fanouts
.
size
()
>
1
)
{
if
(
fanouts
.
size
()
>
1
)
{
torch
::
Tensor
sliced_type_per_edge
;
torch
::
Tensor
sliced_type_per_edge
;
std
::
tie
(
sub_indptr
,
sliced_type_per_edge
)
=
std
::
tie
(
sub_indptr
,
sliced_type_per_edge
)
=
IndexSelectCSCImpl
(
IndexSelectCSCImpl
(
indptr
,
type_per_edge
.
value
(),
nodes
);
in_degree
,
sliced_indptr
,
type_per_edge
.
value
(),
nodes
,
indptr
.
size
(
0
)
-
2
,
num_edges_
);
std
::
tie
(
sub_indptr
,
in_degree
,
sliced_indptr
)
=
SliceCSCIndptrHetero
(
std
::
tie
(
sub_indptr
,
in_degree
,
sliced_indptr
)
=
SliceCSCIndptrHetero
(
sub_indptr
,
sliced_type_per_edge
,
sliced_indptr
,
fanouts
.
size
());
sub_indptr
,
sliced_type_per_edge
,
sliced_indptr
,
fanouts
.
size
());
num_rows
=
sliced_indptr
.
size
(
0
);
num_rows
=
sliced_indptr
.
size
(
0
);
num_edges_
=
sliced_type_per_edge
.
size
(
0
);
}
// If sub_indptr was not computed in the two code blocks above:
if
(
!
probs_or_mask
.
has_value
()
&&
fanouts
.
size
()
<=
1
)
{
sub_indptr
=
ExclusiveCumSum
(
in_degree
);
}
}
auto
max_in_degree
=
torch
::
empty
(
auto
max_in_degree
=
torch
::
empty
(
1
,
1
,
...
...
graphbolt/src/index_select.cc
View file @
0f3bfd7e
...
@@ -22,14 +22,15 @@ torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) {
...
@@ -22,14 +22,15 @@ torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) {
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
IndexSelectCSC
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
IndexSelectCSC
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
)
{
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
,
torch
::
optional
<
int64_t
>
output_size
)
{
TORCH_CHECK
(
TORCH_CHECK
(
indices
.
sizes
().
size
()
==
1
,
"IndexSelectCSC only supports 1d tensors"
);
indices
.
sizes
().
size
()
==
1
,
"IndexSelectCSC only supports 1d tensors"
);
if
(
utils
::
is_on_gpu
(
nodes
)
&&
utils
::
is_accessible_from_gpu
(
indptr
)
&&
if
(
utils
::
is_on_gpu
(
nodes
)
&&
utils
::
is_accessible_from_gpu
(
indptr
)
&&
utils
::
is_accessible_from_gpu
(
indices
))
{
utils
::
is_accessible_from_gpu
(
indices
))
{
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE
(
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE
(
c10
::
DeviceType
::
CUDA
,
"IndexSelectCSCImpl"
,
c10
::
DeviceType
::
CUDA
,
"IndexSelectCSCImpl"
,
{
return
IndexSelectCSCImpl
(
indptr
,
indices
,
nodes
);
});
{
return
IndexSelectCSCImpl
(
indptr
,
indices
,
nodes
,
output_size
);
});
}
}
// @todo: The CPU supports only integer dtypes for indices tensor.
// @todo: The CPU supports only integer dtypes for indices tensor.
TORCH_CHECK
(
TORCH_CHECK
(
...
...
graphbolt/src/index_select.h
View file @
0f3bfd7e
...
@@ -25,11 +25,13 @@ namespace ops {
...
@@ -25,11 +25,13 @@ namespace ops {
* @param indptr Indptr tensor containing offsets with shape (N,).
* @param indptr Indptr tensor containing offsets with shape (N,).
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param indices Indices tensor with edge information of shape (indptr[N],).
* @param nodes Nodes tensor with shape (M,).
* @param nodes Nodes tensor with shape (M,).
* @param output_size The total number of edges being copied.
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* @return (torch::Tensor, torch::Tensor) Output indptr and indices tensors of
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
* shapes (M + 1,) and ((indptr[nodes + 1] - indptr[nodes]).sum(),).
*/
*/
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
IndexSelectCSC
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
IndexSelectCSC
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
);
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
,
torch
::
optional
<
int64_t
>
output_size
=
torch
::
nullopt
);
/**
/**
* @brief Select rows from input tensor according to index tensor.
* @brief Select rows from input tensor according to index tensor.
...
...
tests/python/pytorch/graphbolt/impl/test_in_subgraph_sampler.py
View file @
0f3bfd7e
...
@@ -22,7 +22,10 @@ from .. import gb_test_utils
...
@@ -22,7 +22,10 @@ from .. import gb_test_utils
)
)
@
pytest
.
mark
.
parametrize
(
"idtype"
,
[
torch
.
int32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"idtype"
,
[
torch
.
int32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"is_pinned"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"is_pinned"
,
[
False
,
True
])
def
test_index_select_csc
(
indptr_dtype
,
indices_dtype
,
idtype
,
is_pinned
):
@
pytest
.
mark
.
parametrize
(
"output_size"
,
[
None
,
True
])
def
test_index_select_csc
(
indptr_dtype
,
indices_dtype
,
idtype
,
is_pinned
,
output_size
):
"""Original graph in COO:
"""Original graph in COO:
1 0 1 0 1 0
1 0 1 0 1 0
1 0 0 1 0 1
1 0 0 1 0 1
...
@@ -38,7 +41,7 @@ def test_index_select_csc(indptr_dtype, indices_dtype, idtype, is_pinned):
...
@@ -38,7 +41,7 @@ def test_index_select_csc(indptr_dtype, indices_dtype, idtype, is_pinned):
index
=
torch
.
tensor
([
0
,
5
,
3
],
dtype
=
idtype
)
index
=
torch
.
tensor
([
0
,
5
,
3
],
dtype
=
idtype
)
cpu_indptr
,
cpu_indices
=
torch
.
ops
.
graphbolt
.
index_select_csc
(
cpu_indptr
,
cpu_indices
=
torch
.
ops
.
graphbolt
.
index_select_csc
(
indptr
,
indices
,
index
indptr
,
indices
,
index
,
None
)
)
if
is_pinned
:
if
is_pinned
:
indptr
=
indptr
.
pin_memory
()
indptr
=
indptr
.
pin_memory
()
...
@@ -48,10 +51,12 @@ def test_index_select_csc(indptr_dtype, indices_dtype, idtype, is_pinned):
...
@@ -48,10 +51,12 @@ def test_index_select_csc(indptr_dtype, indices_dtype, idtype, is_pinned):
indices
=
indices
.
cuda
()
indices
=
indices
.
cuda
()
index
=
index
.
cuda
()
index
=
index
.
cuda
()
if
output_size
:
output_size
=
len
(
cpu_indices
)
gpu_indptr
,
gpu_indices
=
torch
.
ops
.
graphbolt
.
index_select_csc
(
gpu_indptr
,
gpu_indices
=
torch
.
ops
.
graphbolt
.
index_select_csc
(
indptr
,
indices
,
index
indptr
,
indices
,
index
,
output_size
)
)
assert
not
cpu_indptr
.
is_cuda
assert
not
cpu_indptr
.
is_cuda
assert
not
cpu_indices
.
is_cuda
assert
not
cpu_indices
.
is_cuda
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment