Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
557a8f81
Unverified
Commit
557a8f81
authored
Jan 05, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Jan 05, 2024
Browse files
[GraphBolt][CUDA] Multiple fanout hetero sampling support. (#6895)
parent
9286621c
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
280 additions
and
123 deletions
+280
-123
graphbolt/include/graphbolt/cuda_ops.h
graphbolt/include/graphbolt/cuda_ops.h
+16
-0
graphbolt/src/cuda/index_select_csc_impl.cu
graphbolt/src/cuda/index_select_csc_impl.cu
+0
-49
graphbolt/src/cuda/neighbor_sampler.cu
graphbolt/src/cuda/neighbor_sampler.cu
+33
-11
graphbolt/src/cuda/sampling_utils.cu
graphbolt/src/cuda/sampling_utils.cu
+135
-0
graphbolt/src/cuda/utils.h
graphbolt/src/cuda/utils.h
+25
-0
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+3
-1
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+67
-58
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+1
-4
No files found.
graphbolt/include/graphbolt/cuda_ops.h
View file @
557a8f81
...
@@ -98,6 +98,22 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
...
@@ -98,6 +98,22 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
SliceCSCIndptr
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
SliceCSCIndptr
(
torch
::
Tensor
indptr
,
torch
::
Tensor
nodes
);
torch
::
Tensor
indptr
,
torch
::
Tensor
nodes
);
/**
* @brief Given the compacted sub_indptr tensor, edge type tensor and
* sliced_indptr tensor of the original graph, returns the heterogenous
* versions of sub_indptr, indegrees and sliced_indptr.
*
* @param sub_indptr The compacted indptr tensor.
* @param etypes The compacted type_per_edge tensor.
* @param sliced_indptr The sliced_indptr tensor of original graph.
* @param num_fanouts The number of fanout values.
*
* @return Tuple of tensors (new_sub_indptr, new_indegrees, new_sliced_indptr):
*/
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
SliceCSCIndptrHetero
(
torch
::
Tensor
sub_indptr
,
torch
::
Tensor
etypes
,
torch
::
Tensor
sliced_indptr
,
int64_t
num_fanouts
);
/**
/**
* @brief Computes the exclusive prefix sum of the given input.
* @brief Computes the exclusive prefix sum of the given input.
*
*
...
...
graphbolt/src/cuda/index_select_csc_impl.cu
View file @
557a8f81
...
@@ -72,24 +72,6 @@ __global__ void _CopyIndicesAlignedKernel(
...
@@ -72,24 +72,6 @@ __global__ void _CopyIndicesAlignedKernel(
}
}
}
}
// Given rows and indptr, computes:
// inrow_indptr[i] = indptr[rows[i]];
// in_degree[i] = indptr[rows[i] + 1] - indptr[rows[i]];
template
<
typename
indptr_t
,
typename
nodes_t
>
struct
SliceFunc
{
const
nodes_t
*
rows
;
const
indptr_t
*
indptr
;
indptr_t
*
in_degree
;
indptr_t
*
inrow_indptr
;
__host__
__device__
auto
operator
()(
int64_t
tIdx
)
{
const
auto
out_row
=
rows
[
tIdx
];
const
auto
indptr_val
=
indptr
[
out_row
];
const
auto
degree
=
indptr
[
out_row
+
1
]
-
indptr_val
;
in_degree
[
tIdx
]
=
degree
;
inrow_indptr
[
tIdx
]
=
indptr_val
;
}
};
struct
PairSum
{
struct
PairSum
{
template
<
typename
indptr_t
>
template
<
typename
indptr_t
>
__host__
__device__
auto
operator
()(
__host__
__device__
auto
operator
()(
...
@@ -101,37 +83,6 @@ struct PairSum {
...
@@ -101,37 +83,6 @@ struct PairSum {
};
};
};
};
// Returns (indptr[nodes + 1] - indptr[nodes], indptr[nodes])
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
SliceCSCIndptr
(
torch
::
Tensor
indptr
,
torch
::
Tensor
nodes
)
{
auto
allocator
=
cuda
::
GetAllocator
();
const
auto
exec_policy
=
thrust
::
cuda
::
par_nosync
(
allocator
).
on
(
cuda
::
GetCurrentStream
());
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
// Read indptr only once in case it is pinned and access is slow.
auto
sliced_indptr
=
torch
::
empty
(
num_nodes
,
nodes
.
options
().
dtype
(
indptr
.
scalar_type
()));
// compute in-degrees
auto
in_degree
=
torch
::
empty
(
num_nodes
+
1
,
nodes
.
options
().
dtype
(
indptr
.
scalar_type
()));
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
AT_DISPATCH_INTEGRAL_TYPES
(
indptr
.
scalar_type
(),
"IndexSelectCSCIndptr"
,
([
&
]
{
using
indptr_t
=
scalar_t
;
AT_DISPATCH_INDEX_TYPES
(
nodes
.
scalar_type
(),
"IndexSelectCSCNodes"
,
([
&
]
{
using
nodes_t
=
index_t
;
thrust
::
for_each
(
exec_policy
,
iota
,
iota
+
num_nodes
,
SliceFunc
<
indptr_t
,
nodes_t
>
{
nodes
.
data_ptr
<
nodes_t
>
(),
indptr
.
data_ptr
<
indptr_t
>
(),
in_degree
.
data_ptr
<
indptr_t
>
(),
sliced_indptr
.
data_ptr
<
indptr_t
>
()});
}));
}));
return
{
in_degree
,
sliced_indptr
};
}
template
<
typename
indptr_t
,
typename
indices_t
>
template
<
typename
indptr_t
,
typename
indices_t
>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
UVAIndexSelectCSCCopyIndices
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
UVAIndexSelectCSCCopyIndices
(
torch
::
Tensor
indices
,
const
int64_t
num_nodes
,
torch
::
Tensor
indices
,
const
int64_t
num_nodes
,
...
...
graphbolt/src/cuda/neighbor_sampler.cu
View file @
557a8f81
...
@@ -80,10 +80,11 @@ __global__ void _ComputeRandoms(
...
@@ -80,10 +80,11 @@ __global__ void _ComputeRandoms(
template
<
typename
indptr_t
>
template
<
typename
indptr_t
>
struct
MinInDegreeFanout
{
struct
MinInDegreeFanout
{
const
indptr_t
*
in_degree
;
const
indptr_t
*
in_degree
;
int64_t
fanout
;
const
int64_t
*
fanouts
;
size_t
num_fanouts
;
__host__
__device__
auto
operator
()(
int64_t
i
)
{
__host__
__device__
auto
operator
()(
int64_t
i
)
{
return
static_cast
<
indptr_t
>
(
return
static_cast
<
indptr_t
>
(
min
(
static_cast
<
int64_t
>
(
in_degree
[
i
]),
fanout
));
min
(
static_cast
<
int64_t
>
(
in_degree
[
i
]),
fanout
s
[
i
%
num_fanouts
]
));
}
}
};
};
...
@@ -128,19 +129,38 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -128,19 +129,38 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
,
bool
return_eids
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
,
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
)
{
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
)
{
TORCH_CHECK
(
fanouts
.
size
()
==
1
,
"Heterogenous sampling is not supported yet!"
);
TORCH_CHECK
(
!
replace
,
"Sampling with replacement is not supported yet!"
);
TORCH_CHECK
(
!
replace
,
"Sampling with replacement is not supported yet!"
);
// Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
// Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
// are all resident on the GPU. If not, it is better to first extract them
// are all resident on the GPU. If not, it is better to first extract them
// before calling this function.
// before calling this function.
auto
allocator
=
cuda
::
GetAllocator
();
auto
allocator
=
cuda
::
GetAllocator
();
const
auto
stream
=
cuda
::
GetCurrentStream
();
const
auto
stream
=
cuda
::
GetCurrentStream
();
const
auto
num_rows
=
nodes
.
size
(
0
);
auto
num_rows
=
nodes
.
size
(
0
);
const
auto
fanout
=
auto
fanouts_pinned
=
torch
::
empty
(
fanouts
[
0
]
>=
0
?
fanouts
[
0
]
:
std
::
numeric_limits
<
int64_t
>::
max
();
fanouts
.
size
(),
c10
::
TensorOptions
().
dtype
(
torch
::
kLong
).
pinned_memory
(
true
));
auto
fanouts_pinned_ptr
=
fanouts_pinned
.
data_ptr
<
int64_t
>
();
for
(
size_t
i
=
0
;
i
<
fanouts
.
size
();
i
++
)
{
fanouts_pinned_ptr
[
i
]
=
fanouts
[
i
]
>=
0
?
fanouts
[
i
]
:
std
::
numeric_limits
<
int64_t
>::
max
();
}
// Finally, copy the adjusted fanout values to the device memory.
auto
fanouts_device
=
allocator
.
AllocateStorage
<
int64_t
>
(
fanouts
.
size
());
CUDA_CALL
(
cudaMemcpyAsync
(
fanouts_device
.
get
(),
fanouts_pinned_ptr
,
sizeof
(
int64_t
)
*
fanouts
.
size
(),
cudaMemcpyHostToDevice
,
stream
));
auto
in_degree_and_sliced_indptr
=
SliceCSCIndptr
(
indptr
,
nodes
);
auto
in_degree_and_sliced_indptr
=
SliceCSCIndptr
(
indptr
,
nodes
);
auto
in_degree
=
std
::
get
<
0
>
(
in_degree_and_sliced_indptr
);
auto
in_degree
=
std
::
get
<
0
>
(
in_degree_and_sliced_indptr
);
auto
sliced_indptr
=
std
::
get
<
1
>
(
in_degree_and_sliced_indptr
);
auto
sub_indptr
=
ExclusiveCumSum
(
in_degree
);
if
(
fanouts
.
size
()
>
1
)
{
torch
::
Tensor
sliced_type_per_edge
;
std
::
tie
(
sub_indptr
,
sliced_type_per_edge
)
=
IndexSelectCSCImpl
(
indptr
,
type_per_edge
.
value
(),
nodes
);
std
::
tie
(
sub_indptr
,
in_degree
,
sliced_indptr
)
=
SliceCSCIndptrHetero
(
sub_indptr
,
sliced_type_per_edge
,
sliced_indptr
,
fanouts
.
size
());
num_rows
=
sliced_indptr
.
size
(
0
);
}
auto
max_in_degree
=
torch
::
empty
(
auto
max_in_degree
=
torch
::
empty
(
1
,
1
,
c10
::
TensorOptions
().
dtype
(
in_degree
.
scalar_type
()).
pinned_memory
(
true
));
c10
::
TensorOptions
().
dtype
(
in_degree
.
scalar_type
()).
pinned_memory
(
true
));
...
@@ -155,13 +175,11 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -155,13 +175,11 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
tmp_storage
.
get
(),
tmp_storage_size
,
in_degree
.
data_ptr
<
scalar_t
>
(),
tmp_storage
.
get
(),
tmp_storage_size
,
in_degree
.
data_ptr
<
scalar_t
>
(),
max_in_degree
.
data_ptr
<
scalar_t
>
(),
num_rows
,
stream
);
max_in_degree
.
data_ptr
<
scalar_t
>
(),
num_rows
,
stream
);
}));
}));
auto
sliced_indptr
=
std
::
get
<
1
>
(
in_degree_and_sliced_indptr
);
auto
sub_indptr
=
ExclusiveCumSum
(
in_degree
);
auto
output_indptr
=
torch
::
empty_like
(
sub_indptr
);
auto
coo_rows
=
CSRToCOO
(
sub_indptr
,
indices
.
scalar_type
());
auto
coo_rows
=
CSRToCOO
(
sub_indptr
,
indices
.
scalar_type
());
const
auto
num_edges
=
coo_rows
.
size
(
0
);
const
auto
num_edges
=
coo_rows
.
size
(
0
);
const
auto
random_seed
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
const
auto
random_seed
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
());
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
());
auto
output_indptr
=
torch
::
empty_like
(
sub_indptr
);
torch
::
Tensor
picked_eids
;
torch
::
Tensor
picked_eids
;
torch
::
Tensor
output_indices
;
torch
::
Tensor
output_indices
;
torch
::
optional
<
torch
::
Tensor
>
output_type_per_edge
;
torch
::
optional
<
torch
::
Tensor
>
output_type_per_edge
;
...
@@ -172,7 +190,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -172,7 +190,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
auto
sampled_degree
=
thrust
::
make_transform_iterator
(
auto
sampled_degree
=
thrust
::
make_transform_iterator
(
iota
,
MinInDegreeFanout
<
indptr_t
>
{
iota
,
MinInDegreeFanout
<
indptr_t
>
{
in_degree
.
data_ptr
<
indptr_t
>
(),
fanout
});
in_degree
.
data_ptr
<
indptr_t
>
(),
fanouts_device
.
get
(),
fanouts
.
size
()});
{
// Compute output_indptr.
{
// Compute output_indptr.
size_t
tmp_storage_size
=
0
;
size_t
tmp_storage_size
=
0
;
...
@@ -362,6 +381,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -362,6 +381,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
}
}
}));
}));
// Convert output_indptr back to homo by discarding intermediate offsets.
output_indptr
=
output_indptr
.
slice
(
0
,
0
,
output_indptr
.
size
(
0
),
fanouts
.
size
());
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
if
(
return_eids
)
subgraph_reverse_edge_ids
=
std
::
move
(
picked_eids
);
if
(
return_eids
)
subgraph_reverse_edge_ids
=
std
::
move
(
picked_eids
);
...
...
graphbolt/src/cuda/sampling_utils.cu
0 → 100644
View file @
557a8f81
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/sampling_utils.cu
* @brief Sampling utility function implementations on CUDA.
*/
#include <thrust/execution_policy.h>
#include <thrust/iterator/counting_iterator.h>
#include <cub/cub.cuh>
#include "./common.h"
#include "./utils.h"
namespace
graphbolt
{
namespace
ops
{
// Given rows and indptr, computes:
// inrow_indptr[i] = indptr[rows[i]];
// in_degree[i] = indptr[rows[i] + 1] - indptr[rows[i]];
template
<
typename
indptr_t
,
typename
nodes_t
>
struct
SliceFunc
{
const
nodes_t
*
rows
;
const
indptr_t
*
indptr
;
indptr_t
*
in_degree
;
indptr_t
*
inrow_indptr
;
__host__
__device__
auto
operator
()(
int64_t
tIdx
)
{
const
auto
out_row
=
rows
[
tIdx
];
const
auto
indptr_val
=
indptr
[
out_row
];
const
auto
degree
=
indptr
[
out_row
+
1
]
-
indptr_val
;
in_degree
[
tIdx
]
=
degree
;
inrow_indptr
[
tIdx
]
=
indptr_val
;
}
};
// Returns (indptr[nodes + 1] - indptr[nodes], indptr[nodes])
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
SliceCSCIndptr
(
torch
::
Tensor
indptr
,
torch
::
Tensor
nodes
)
{
auto
allocator
=
cuda
::
GetAllocator
();
const
auto
exec_policy
=
thrust
::
cuda
::
par_nosync
(
allocator
).
on
(
cuda
::
GetCurrentStream
());
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
// Read indptr only once in case it is pinned and access is slow.
auto
sliced_indptr
=
torch
::
empty
(
num_nodes
,
nodes
.
options
().
dtype
(
indptr
.
scalar_type
()));
// compute in-degrees
auto
in_degree
=
torch
::
empty
(
num_nodes
+
1
,
nodes
.
options
().
dtype
(
indptr
.
scalar_type
()));
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
AT_DISPATCH_INTEGRAL_TYPES
(
indptr
.
scalar_type
(),
"IndexSelectCSCIndptr"
,
([
&
]
{
using
indptr_t
=
scalar_t
;
AT_DISPATCH_INDEX_TYPES
(
nodes
.
scalar_type
(),
"IndexSelectCSCNodes"
,
([
&
]
{
using
nodes_t
=
index_t
;
thrust
::
for_each
(
exec_policy
,
iota
,
iota
+
num_nodes
,
SliceFunc
<
indptr_t
,
nodes_t
>
{
nodes
.
data_ptr
<
nodes_t
>
(),
indptr
.
data_ptr
<
indptr_t
>
(),
in_degree
.
data_ptr
<
indptr_t
>
(),
sliced_indptr
.
data_ptr
<
indptr_t
>
()});
}));
}));
return
{
in_degree
,
sliced_indptr
};
}
template
<
typename
indptr_t
,
typename
etype_t
>
struct
EdgeTypeSearch
{
const
indptr_t
*
sub_indptr
;
const
indptr_t
*
sliced_indptr
;
const
etype_t
*
etypes
;
int64_t
num_fanouts
;
int64_t
num_rows
;
indptr_t
*
new_sub_indptr
;
indptr_t
*
new_sliced_indptr
;
__host__
__device__
auto
operator
()(
int64_t
i
)
{
const
auto
homo_i
=
i
/
num_fanouts
;
const
auto
indptr_i
=
sub_indptr
[
homo_i
];
const
auto
degree
=
sub_indptr
[
homo_i
+
1
]
-
indptr_i
;
const
etype_t
etype
=
i
%
num_fanouts
;
auto
offset
=
cuda
::
LowerBound
(
etypes
+
indptr_i
,
degree
,
etype
);
new_sub_indptr
[
i
]
=
indptr_i
+
offset
;
new_sliced_indptr
[
i
]
=
sliced_indptr
[
homo_i
]
+
offset
;
if
(
i
==
num_rows
-
1
)
new_sub_indptr
[
num_rows
]
=
indptr_i
+
degree
;
}
};
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
SliceCSCIndptrHetero
(
torch
::
Tensor
sub_indptr
,
torch
::
Tensor
etypes
,
torch
::
Tensor
sliced_indptr
,
int64_t
num_fanouts
)
{
auto
num_rows
=
(
sub_indptr
.
size
(
0
)
-
1
)
*
num_fanouts
;
auto
new_sub_indptr
=
torch
::
empty
(
num_rows
+
1
,
sub_indptr
.
options
());
auto
new_indegree
=
torch
::
empty
(
num_rows
+
2
,
sub_indptr
.
options
());
auto
new_sliced_indptr
=
torch
::
empty
(
num_rows
,
sliced_indptr
.
options
());
auto
allocator
=
cuda
::
GetAllocator
();
auto
stream
=
cuda
::
GetCurrentStream
();
const
auto
exec_policy
=
thrust
::
cuda
::
par_nosync
(
allocator
).
on
(
stream
);
thrust
::
counting_iterator
<
int64_t
>
iota
(
0
);
AT_DISPATCH_INTEGRAL_TYPES
(
sub_indptr
.
scalar_type
(),
"SliceCSCIndptrHeteroIndptr"
,
([
&
]
{
using
indptr_t
=
scalar_t
;
AT_DISPATCH_INTEGRAL_TYPES
(
etypes
.
scalar_type
(),
"SliceCSCIndptrHeteroTypePerEdge"
,
([
&
]
{
using
etype_t
=
scalar_t
;
thrust
::
for_each
(
exec_policy
,
iota
,
iota
+
num_rows
,
EdgeTypeSearch
<
indptr_t
,
etype_t
>
{
sub_indptr
.
data_ptr
<
indptr_t
>
(),
sliced_indptr
.
data_ptr
<
indptr_t
>
(),
etypes
.
data_ptr
<
etype_t
>
(),
num_fanouts
,
num_rows
,
new_sub_indptr
.
data_ptr
<
indptr_t
>
(),
new_sliced_indptr
.
data_ptr
<
indptr_t
>
()});
}));
size_t
tmp_storage_size
=
0
;
cub
::
DeviceAdjacentDifference
::
SubtractLeftCopy
(
nullptr
,
tmp_storage_size
,
new_sub_indptr
.
data_ptr
<
indptr_t
>
(),
new_indegree
.
data_ptr
<
indptr_t
>
(),
num_rows
+
1
,
cub
::
Difference
{},
stream
);
auto
tmp_storage
=
allocator
.
AllocateStorage
<
char
>
(
tmp_storage_size
);
cub
::
DeviceAdjacentDifference
::
SubtractLeftCopy
(
tmp_storage
.
get
(),
tmp_storage_size
,
new_sub_indptr
.
data_ptr
<
indptr_t
>
(),
new_indegree
.
data_ptr
<
indptr_t
>
(),
num_rows
+
1
,
cub
::
Difference
{},
stream
);
}));
// Discard the first element of the SubtractLeftCopy result and ensure that
// new_indegree tensor has size num_rows + 1 so that its ExclusiveCumSum is
// directly equivalent to new_sub_indptr.
// Equivalent to new_indegree = new_indegree[1:] in Python.
new_indegree
=
new_indegree
.
slice
(
0
,
1
);
return
{
new_sub_indptr
,
new_indegree
,
new_sliced_indptr
};
}
}
// namespace ops
}
// namespace graphbolt
graphbolt/src/cuda/utils.h
View file @
557a8f81
...
@@ -51,6 +51,31 @@ int NumberOfBits(const T& range) {
...
@@ -51,6 +51,31 @@ int NumberOfBits(const T& range) {
return
bits
;
return
bits
;
}
}
/**
* @brief Given a sorted array and a value this function returns the index
* of the first element which compares greater than or equal to value.
*
* This function assumes 0-based index
* @param A: ascending sorted array
* @param n: size of the A
* @param x: value to search in A
* @return index, i, of the first element st. A[i]>=x. If x>A[n-1] returns n.
* if x<A[0] then it returns 0.
*/
template
<
typename
indptr_t
,
typename
indices_t
>
__device__
indices_t
LowerBound
(
const
indptr_t
*
A
,
indices_t
n
,
indptr_t
x
)
{
indices_t
l
=
0
,
r
=
n
;
while
(
l
<
r
)
{
const
auto
m
=
l
+
(
r
-
l
)
/
2
;
if
(
x
>
A
[
m
])
{
l
=
m
+
1
;
}
else
{
r
=
m
;
}
}
return
l
;
}
/**
/**
* @brief Given a sorted array and a value this function returns the index
* @brief Given a sorted array and a value this function returns the index
* of the first element which compares greater than value.
* of the first element which compares greater than value.
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
557a8f81
...
@@ -618,7 +618,9 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -618,7 +618,9 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
utils
::
is_accessible_from_gpu
(
indices_
)
&&
utils
::
is_accessible_from_gpu
(
indices_
)
&&
utils
::
is_accessible_from_gpu
(
nodes
)
&&
utils
::
is_accessible_from_gpu
(
nodes
)
&&
(
!
probs_or_mask
.
has_value
()
||
(
!
probs_or_mask
.
has_value
()
||
utils
::
is_accessible_from_gpu
(
probs_or_mask
.
value
())))
{
utils
::
is_accessible_from_gpu
(
probs_or_mask
.
value
()))
&&
(
!
type_per_edge_
.
has_value
()
||
utils
::
is_accessible_from_gpu
(
type_per_edge_
.
value
())))
{
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE
(
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE
(
c10
::
DeviceType
::
CUDA
,
"SampleNeighbors"
,
{
c10
::
DeviceType
::
CUDA
,
"SampleNeighbors"
,
{
return
ops
::
SampleNeighbors
(
return
ops
::
SampleNeighbors
(
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
557a8f81
...
@@ -683,10 +683,6 @@ def test_multiprocessing():
...
@@ -683,10 +683,6 @@ def test_multiprocessing():
p
.
join
()
p
.
join
()
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
)
def
test_in_subgraph_homo
():
def
test_in_subgraph_homo
():
"""Original graph in COO:
"""Original graph in COO:
1 0 1 0 1
1 0 1 0 1
...
@@ -704,30 +700,29 @@ def test_in_subgraph_homo():
...
@@ -704,30 +700,29 @@ def test_in_subgraph_homo():
assert
indptr
[
-
1
]
==
len
(
indices
)
assert
indptr
[
-
1
]
==
len
(
indices
)
# Construct FusedCSCSamplingGraph.
# Construct FusedCSCSamplingGraph.
graph
=
gb
.
fused_csc_sampling_graph
(
indptr
,
indices
)
graph
=
gb
.
fused_csc_sampling_graph
(
indptr
,
indices
)
.
to
(
F
.
ctx
())
# Extract in subgraph.
# Extract in subgraph.
nodes
=
torch
.
LongT
ensor
([
4
,
1
,
3
])
nodes
=
torch
.
t
ensor
([
4
,
1
,
3
]
,
device
=
F
.
ctx
()
)
in_subgraph
=
graph
.
in_subgraph
(
nodes
)
in_subgraph
=
graph
.
in_subgraph
(
nodes
)
# Verify in subgraph.
# Verify in subgraph.
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
sampled_csc
.
indices
,
torch
.
LongTensor
([
0
,
3
,
4
,
2
,
3
,
1
,
2
])
in_subgraph
.
sampled_csc
.
indices
,
torch
.
tensor
([
0
,
3
,
4
,
2
,
3
,
1
,
2
],
device
=
F
.
ctx
()),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
sampled_csc
.
indptr
,
torch
.
LongTensor
([
0
,
3
,
5
,
7
])
in_subgraph
.
sampled_csc
.
indptr
,
torch
.
tensor
([
0
,
3
,
5
,
7
],
device
=
F
.
ctx
()),
)
)
assert
in_subgraph
.
original_column_node_ids
is
None
assert
in_subgraph
.
original_column_node_ids
is
None
assert
in_subgraph
.
original_row_node_ids
is
None
assert
in_subgraph
.
original_row_node_ids
is
None
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
original_edge_ids
,
torch
.
LongTensor
([
9
,
10
,
11
,
3
,
4
,
7
,
8
])
in_subgraph
.
original_edge_ids
,
torch
.
tensor
([
9
,
10
,
11
,
3
,
4
,
7
,
8
],
device
=
F
.
ctx
()),
)
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
)
def
test_in_subgraph_hetero
():
def
test_in_subgraph_hetero
():
"""Original graph in COO:
"""Original graph in COO:
1 0 1 0 1
1 0 1 0 1
...
@@ -773,44 +768,53 @@ def test_in_subgraph_hetero():
...
@@ -773,44 +768,53 @@ def test_in_subgraph_hetero():
type_per_edge
=
type_per_edge
,
type_per_edge
=
type_per_edge
,
node_type_to_id
=
ntypes
,
node_type_to_id
=
ntypes
,
edge_type_to_id
=
etypes
,
edge_type_to_id
=
etypes
,
)
)
.
to
(
F
.
ctx
())
# Extract in subgraph.
# Extract in subgraph.
nodes
=
{
nodes
=
{
"N0"
:
torch
.
LongT
ensor
([
1
]),
"N0"
:
torch
.
t
ensor
([
1
]
,
device
=
F
.
ctx
()
),
"N1"
:
torch
.
LongT
ensor
([
2
,
1
]),
"N1"
:
torch
.
t
ensor
([
2
,
1
]
,
device
=
F
.
ctx
()
),
}
}
in_subgraph
=
graph
.
in_subgraph
(
nodes
)
in_subgraph
=
graph
.
in_subgraph
(
nodes
)
# Verify in subgraph.
# Verify in subgraph.
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
sampled_csc
[
"N0:R0:N0"
].
indices
,
torch
.
LongTensor
([])
in_subgraph
.
sampled_csc
[
"N0:R0:N0"
].
indices
,
torch
.
tensor
([],
device
=
F
.
ctx
()),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
sampled_csc
[
"N0:R0:N0"
].
indptr
,
torch
.
LongTensor
([
0
,
0
])
in_subgraph
.
sampled_csc
[
"N0:R0:N0"
].
indptr
,
torch
.
tensor
([
0
,
0
],
device
=
F
.
ctx
()),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
sampled_csc
[
"N0:R1:N1"
].
indices
,
torch
.
LongTensor
([
0
,
1
])
in_subgraph
.
sampled_csc
[
"N0:R1:N1"
].
indices
,
torch
.
tensor
([
0
,
1
],
device
=
F
.
ctx
()),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
sampled_csc
[
"N0:R1:N1"
].
indptr
,
torch
.
LongTensor
([
0
,
1
,
2
])
in_subgraph
.
sampled_csc
[
"N0:R1:N1"
].
indptr
,
torch
.
tensor
([
0
,
1
,
2
],
device
=
F
.
ctx
()),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
sampled_csc
[
"N1:R2:N0"
].
indices
,
torch
.
LongTensor
([
0
,
1
])
in_subgraph
.
sampled_csc
[
"N1:R2:N0"
].
indices
,
torch
.
tensor
([
0
,
1
],
device
=
F
.
ctx
()),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
sampled_csc
[
"N1:R2:N0"
].
indptr
,
torch
.
LongTensor
([
0
,
2
])
in_subgraph
.
sampled_csc
[
"N1:R2:N0"
].
indptr
,
torch
.
tensor
([
0
,
2
],
device
=
F
.
ctx
()),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
sampled_csc
[
"N1:R3:N1"
].
indices
,
torch
.
LongTensor
([
1
,
2
,
0
])
in_subgraph
.
sampled_csc
[
"N1:R3:N1"
].
indices
,
torch
.
tensor
([
1
,
2
,
0
],
device
=
F
.
ctx
()),
)
)
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
sampled_csc
[
"N1:R3:N1"
].
indptr
,
torch
.
LongTensor
([
0
,
2
,
3
])
in_subgraph
.
sampled_csc
[
"N1:R3:N1"
].
indptr
,
torch
.
tensor
([
0
,
2
,
3
],
device
=
F
.
ctx
()),
)
)
assert
in_subgraph
.
original_column_node_ids
is
None
assert
in_subgraph
.
original_column_node_ids
is
None
assert
in_subgraph
.
original_row_node_ids
is
None
assert
in_subgraph
.
original_row_node_ids
is
None
assert
torch
.
equal
(
assert
torch
.
equal
(
in_subgraph
.
original_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
9
,
10
,
11
,
7
,
8
])
in_subgraph
.
original_edge_ids
,
torch
.
tensor
([
3
,
4
,
9
,
10
,
11
,
7
,
8
],
device
=
F
.
ctx
()),
)
)
...
@@ -1630,10 +1634,6 @@ def test_sample_neighbors_homo(labor, is_pinned):
...
@@ -1630,10 +1634,6 @@ def test_sample_neighbors_homo(labor, is_pinned):
assert
subgraph
.
original_edge_ids
is
None
assert
subgraph
.
original_edge_ids
is
None
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Heterogenous sampling on gpu is not supported yet."
,
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_sample_neighbors_hetero
(
labor
):
def
test_sample_neighbors_hetero
(
labor
):
"""Original graph in COO:
"""Original graph in COO:
...
@@ -1664,10 +1664,13 @@ def test_sample_neighbors_hetero(labor):
...
@@ -1664,10 +1664,13 @@ def test_sample_neighbors_hetero(labor):
type_per_edge
=
type_per_edge
,
type_per_edge
=
type_per_edge
,
node_type_to_id
=
ntypes
,
node_type_to_id
=
ntypes
,
edge_type_to_id
=
etypes
,
edge_type_to_id
=
etypes
,
)
)
.
to
(
F
.
ctx
())
# Sample on both node types.
# Sample on both node types.
nodes
=
{
"n1"
:
torch
.
LongTensor
([
0
]),
"n2"
:
torch
.
LongTensor
([
0
])}
nodes
=
{
"n1"
:
torch
.
tensor
([
0
],
device
=
F
.
ctx
()),
"n2"
:
torch
.
tensor
([
0
],
device
=
F
.
ctx
()),
}
fanouts
=
torch
.
tensor
([
-
1
,
-
1
])
fanouts
=
torch
.
tensor
([
-
1
,
-
1
])
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
subgraph
=
sampler
(
nodes
,
fanouts
)
subgraph
=
sampler
(
nodes
,
fanouts
)
...
@@ -1675,24 +1678,26 @@ def test_sample_neighbors_hetero(labor):
...
@@ -1675,24 +1678,26 @@ def test_sample_neighbors_hetero(labor):
# Verify in subgraph.
# Verify in subgraph.
expected_sampled_csc
=
{
expected_sampled_csc
=
{
"n1:e1:n2"
:
gb
.
CSCFormatBase
(
"n1:e1:n2"
:
gb
.
CSCFormatBase
(
indptr
=
torch
.
LongT
ensor
([
0
,
2
]),
indptr
=
torch
.
t
ensor
([
0
,
2
]
,
device
=
F
.
ctx
()
),
indices
=
torch
.
LongT
ensor
([
0
,
1
]),
indices
=
torch
.
t
ensor
([
0
,
1
]
,
device
=
F
.
ctx
()
),
),
),
"n2:e2:n1"
:
gb
.
CSCFormatBase
(
"n2:e2:n1"
:
gb
.
CSCFormatBase
(
indptr
=
torch
.
LongT
ensor
([
0
,
2
]),
indptr
=
torch
.
t
ensor
([
0
,
2
]
,
device
=
F
.
ctx
()
),
indices
=
torch
.
LongT
ensor
([
0
,
2
]),
indices
=
torch
.
t
ensor
([
0
,
2
]
,
device
=
F
.
ctx
()
),
),
),
}
}
assert
len
(
subgraph
.
sampled_csc
)
==
2
assert
len
(
subgraph
.
sampled_csc
)
==
2
for
etype
,
pairs
in
expected_sampled_csc
.
items
():
for
etype
,
pairs
in
expected_sampled_csc
.
items
():
assert
torch
.
equal
(
subgraph
.
sampled_csc
[
etype
].
indptr
,
pairs
.
indptr
)
assert
torch
.
equal
(
subgraph
.
sampled_csc
[
etype
].
indptr
,
pairs
.
indptr
)
assert
torch
.
equal
(
subgraph
.
sampled_csc
[
etype
].
indices
,
pairs
.
indices
)
assert
torch
.
equal
(
subgraph
.
sampled_csc
[
etype
].
indices
.
sort
()[
0
],
pairs
.
indices
)
assert
subgraph
.
original_column_node_ids
is
None
assert
subgraph
.
original_column_node_ids
is
None
assert
subgraph
.
original_row_node_ids
is
None
assert
subgraph
.
original_row_node_ids
is
None
assert
subgraph
.
original_edge_ids
is
None
assert
subgraph
.
original_edge_ids
is
None
# Sample on single node type.
# Sample on single node type.
nodes
=
{
"n1"
:
torch
.
LongT
ensor
([
0
])}
nodes
=
{
"n1"
:
torch
.
t
ensor
([
0
]
,
device
=
F
.
ctx
()
)}
fanouts
=
torch
.
tensor
([
-
1
,
-
1
])
fanouts
=
torch
.
tensor
([
-
1
,
-
1
])
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
subgraph
=
sampler
(
nodes
,
fanouts
)
subgraph
=
sampler
(
nodes
,
fanouts
)
...
@@ -1700,27 +1705,25 @@ def test_sample_neighbors_hetero(labor):
...
@@ -1700,27 +1705,25 @@ def test_sample_neighbors_hetero(labor):
# Verify in subgraph.
# Verify in subgraph.
expected_sampled_csc
=
{
expected_sampled_csc
=
{
"n1:e1:n2"
:
gb
.
CSCFormatBase
(
"n1:e1:n2"
:
gb
.
CSCFormatBase
(
indptr
=
torch
.
LongT
ensor
([
0
]),
indptr
=
torch
.
t
ensor
([
0
]
,
device
=
F
.
ctx
()
),
indices
=
torch
.
LongT
ensor
([]),
indices
=
torch
.
t
ensor
([]
,
device
=
F
.
ctx
()
),
),
),
"n2:e2:n1"
:
gb
.
CSCFormatBase
(
"n2:e2:n1"
:
gb
.
CSCFormatBase
(
indptr
=
torch
.
LongT
ensor
([
0
,
2
]),
indptr
=
torch
.
t
ensor
([
0
,
2
]
,
device
=
F
.
ctx
()
),
indices
=
torch
.
LongT
ensor
([
0
,
2
]),
indices
=
torch
.
t
ensor
([
0
,
2
]
,
device
=
F
.
ctx
()
),
),
),
}
}
assert
len
(
subgraph
.
sampled_csc
)
==
2
assert
len
(
subgraph
.
sampled_csc
)
==
2
for
etype
,
pairs
in
expected_sampled_csc
.
items
():
for
etype
,
pairs
in
expected_sampled_csc
.
items
():
assert
torch
.
equal
(
subgraph
.
sampled_csc
[
etype
].
indptr
,
pairs
.
indptr
)
assert
torch
.
equal
(
subgraph
.
sampled_csc
[
etype
].
indptr
,
pairs
.
indptr
)
assert
torch
.
equal
(
subgraph
.
sampled_csc
[
etype
].
indices
,
pairs
.
indices
)
assert
torch
.
equal
(
subgraph
.
sampled_csc
[
etype
].
indices
.
sort
()[
0
],
pairs
.
indices
)
assert
subgraph
.
original_column_node_ids
is
None
assert
subgraph
.
original_column_node_ids
is
None
assert
subgraph
.
original_row_node_ids
is
None
assert
subgraph
.
original_row_node_ids
is
None
assert
subgraph
.
original_edge_ids
is
None
assert
subgraph
.
original_edge_ids
is
None
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Heterogenous sampling on gpu is not supported yet."
,
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"fanouts, expected_sampled_num1, expected_sampled_num2"
,
"fanouts, expected_sampled_num1, expected_sampled_num2"
,
[
[
...
@@ -1769,9 +1772,12 @@ def test_sample_neighbors_fanouts(
...
@@ -1769,9 +1772,12 @@ def test_sample_neighbors_fanouts(
type_per_edge
=
type_per_edge
,
type_per_edge
=
type_per_edge
,
node_type_to_id
=
ntypes
,
node_type_to_id
=
ntypes
,
edge_type_to_id
=
etypes
,
edge_type_to_id
=
etypes
,
)
)
.
to
(
F
.
ctx
())
nodes
=
{
"n1"
:
torch
.
LongTensor
([
0
]),
"n2"
:
torch
.
LongTensor
([
0
])}
nodes
=
{
"n1"
:
torch
.
tensor
([
0
],
device
=
F
.
ctx
()),
"n2"
:
torch
.
tensor
([
0
],
device
=
F
.
ctx
()),
}
fanouts
=
torch
.
LongTensor
(
fanouts
)
fanouts
=
torch
.
LongTensor
(
fanouts
)
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
subgraph
=
sampler
(
nodes
,
fanouts
)
subgraph
=
sampler
(
nodes
,
fanouts
)
...
@@ -1897,10 +1903,6 @@ def test_sample_neighbors_return_eids_homo(labor, is_pinned):
...
@@ -1897,10 +1903,6 @@ def test_sample_neighbors_return_eids_homo(labor, is_pinned):
assert
subgraph
.
original_row_node_ids
is
None
assert
subgraph
.
original_row_node_ids
is
None
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Heterogenous sampling on gpu is not supported yet."
,
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_sample_neighbors_return_eids_hetero
(
labor
):
def
test_sample_neighbors_return_eids_hetero
(
labor
):
"""
"""
...
@@ -1936,24 +1938,31 @@ def test_sample_neighbors_return_eids_hetero(labor):
...
@@ -1936,24 +1938,31 @@ def test_sample_neighbors_return_eids_hetero(labor):
edge_attributes
=
edge_attributes
,
edge_attributes
=
edge_attributes
,
node_type_to_id
=
ntypes
,
node_type_to_id
=
ntypes
,
edge_type_to_id
=
etypes
,
edge_type_to_id
=
etypes
,
)
)
.
to
(
F
.
ctx
())
# Sample on both node types.
# Sample on both node types.
nodes
=
{
"n1"
:
torch
.
LongTensor
([
0
]),
"n2"
:
torch
.
LongTensor
([
0
])}
nodes
=
{
"n1"
:
torch
.
LongTensor
([
0
]).
to
(
F
.
ctx
()),
"n2"
:
torch
.
LongTensor
([
0
]).
to
(
F
.
ctx
()),
}
fanouts
=
torch
.
tensor
([
-
1
,
-
1
])
fanouts
=
torch
.
tensor
([
-
1
,
-
1
])
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
subgraph
=
sampler
(
nodes
,
fanouts
)
subgraph
=
sampler
(
nodes
,
fanouts
)
# Verify in subgraph.
expected_reverse_edge_ids
=
{
expected_reverse_edge_ids
=
{
"n2:e2:n1"
:
edge_attributes
[
gb
.
ORIGINAL_EDGE_ID
][
torch
.
tensor
([
0
,
1
])],
"n2:e2:n1"
:
graph
.
edge_attributes
[
gb
.
ORIGINAL_EDGE_ID
][
"n1:e1:n2"
:
edge_attributes
[
gb
.
ORIGINAL_EDGE_ID
][
torch
.
tensor
([
4
,
5
])],
torch
.
tensor
([
0
,
1
],
device
=
F
.
ctx
())
],
"n1:e1:n2"
:
graph
.
edge_attributes
[
gb
.
ORIGINAL_EDGE_ID
][
torch
.
tensor
([
4
,
5
],
device
=
F
.
ctx
())
],
}
}
assert
subgraph
.
original_column_node_ids
is
None
assert
subgraph
.
original_column_node_ids
is
None
assert
subgraph
.
original_row_node_ids
is
None
assert
subgraph
.
original_row_node_ids
is
None
for
etype
in
etypes
.
keys
():
for
etype
in
etypes
.
keys
():
assert
torch
.
equal
(
assert
torch
.
equal
(
subgraph
.
original_edge_ids
[
etype
],
expected_reverse_edge_ids
[
etype
]
subgraph
.
original_edge_ids
[
etype
].
sort
()[
0
],
expected_reverse_edge_ids
[
etype
].
sort
()[
0
],
)
)
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
557a8f81
...
@@ -874,15 +874,12 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor):
...
@@ -874,15 +874,12 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor):
)
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Heterogenous sampling is not supported on GPU yet."
,
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"sampler_type"
,
"sampler_type"
,
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
[
SamplerType
.
Normal
,
SamplerType
.
Layer
,
SamplerType
.
Temporal
],
)
)
def
test_SubgraphSampler_Hetero_multifanout_per_layer
(
sampler_type
):
def
test_SubgraphSampler_Hetero_multifanout_per_layer
(
sampler_type
):
_check_sampler_type
(
sampler_type
)
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
items_n1
=
torch
.
tensor
([
0
])
items_n1
=
torch
.
tensor
([
0
])
items_n2
=
torch
.
tensor
([
1
])
items_n2
=
torch
.
tensor
([
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment