Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3cdc37cc
Unverified
Commit
3cdc37cc
authored
Dec 05, 2023
by
Rhett Ying
Committed by
GitHub
Dec 05, 2023
Browse files
[GraphBolt] add ntype/etype_to_id into graph and save/load() (#6687)
parent
93b39729
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
153 additions
and
12 deletions
+153
-12
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+65
-0
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+56
-2
python/dgl/distributed/partition.py
python/dgl/distributed/partition.py
+11
-2
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+15
-3
tests/distributed/test_partition.py
tests/distributed/test_partition.py
+3
-2
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+1
-1
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
+2
-2
No files found.
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
3cdc37cc
...
@@ -48,6 +48,8 @@ struct SamplerArgs<SamplerType::LABOR> {
...
@@ -48,6 +48,8 @@ struct SamplerArgs<SamplerType::LABOR> {
*/
*/
class
FusedCSCSamplingGraph
:
public
torch
::
CustomClassHolder
{
class
FusedCSCSamplingGraph
:
public
torch
::
CustomClassHolder
{
public:
public:
using
NodeTypeToIDMap
=
torch
::
Dict
<
std
::
string
,
int64_t
>
;
using
EdgeTypeToIDMap
=
torch
::
Dict
<
std
::
string
,
int64_t
>
;
using
EdgeAttrMap
=
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>
;
using
EdgeAttrMap
=
torch
::
Dict
<
std
::
string
,
torch
::
Tensor
>
;
/** @brief Default constructor. */
/** @brief Default constructor. */
FusedCSCSamplingGraph
()
=
default
;
FusedCSCSamplingGraph
()
=
default
;
...
@@ -60,11 +62,19 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -60,11 +62,19 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present.
* present.
* @param type_per_edge A tensor representing the type of each edge, if
* @param type_per_edge A tensor representing the type of each edge, if
* present.
* present.
* @param node_type_to_id A dictionary mapping node type names to type IDs, if
* present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present.
* @param edge_attributes A dictionary of edge attributes, if present.
*
*/
*/
FusedCSCSamplingGraph
(
FusedCSCSamplingGraph
(
const
torch
::
Tensor
&
indptr
,
const
torch
::
Tensor
&
indices
,
const
torch
::
Tensor
&
indptr
,
const
torch
::
Tensor
&
indices
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
NodeTypeToIDMap
>&
node_type_to_id
,
const
torch
::
optional
<
EdgeTypeToIDMap
>&
edge_type_to_id
,
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
);
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
);
/**
/**
...
@@ -75,6 +85,11 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -75,6 +85,11 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* present.
* present.
* @param type_per_edge A tensor representing the type of each edge, if
* @param type_per_edge A tensor representing the type of each edge, if
* present.
* present.
* @param node_type_to_id A dictionary mapping node type names to type IDs, if
* present.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs, if
* present.
* @param edge_attributes A dictionary of edge attributes, if present.
*
*
* @return FusedCSCSamplingGraph
* @return FusedCSCSamplingGraph
*/
*/
...
@@ -82,6 +97,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -82,6 +97,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
const
torch
::
Tensor
&
indptr
,
const
torch
::
Tensor
&
indices
,
const
torch
::
Tensor
&
indptr
,
const
torch
::
Tensor
&
indices
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
NodeTypeToIDMap
>&
node_type_to_id
,
const
torch
::
optional
<
EdgeTypeToIDMap
>&
edge_type_to_id
,
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
);
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
);
/** @brief Get the number of nodes. */
/** @brief Get the number of nodes. */
...
@@ -106,6 +123,22 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -106,6 +123,22 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
return
type_per_edge_
;
return
type_per_edge_
;
}
}
/**
* @brief Get the node type to id map for a heterogeneous graph.
* @note The map is a dictionary mapping node type names to type IDs.
*/
inline
const
torch
::
optional
<
NodeTypeToIDMap
>
NodeTypeToID
()
const
{
return
node_type_to_id_
;
}
/**
* @brief Get the edge type to id map for a heterogeneous graph.
* @note The map is a dictionary mapping edge type names to type IDs.
*/
inline
const
torch
::
optional
<
EdgeTypeToIDMap
>
EdgeTypeToID
()
const
{
return
edge_type_to_id_
;
}
/** @brief Get the edge attributes dictionary. */
/** @brief Get the edge attributes dictionary. */
inline
const
torch
::
optional
<
EdgeAttrMap
>
EdgeAttributes
()
const
{
inline
const
torch
::
optional
<
EdgeAttrMap
>
EdgeAttributes
()
const
{
return
edge_attributes_
;
return
edge_attributes_
;
...
@@ -129,6 +162,24 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -129,6 +162,24 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
type_per_edge_
=
type_per_edge
;
type_per_edge_
=
type_per_edge
;
}
}
/**
* @brief Set the node type to id map for a heterogeneous graph.
* @note The map is a dictionary mapping node type names to type IDs.
*/
inline
void
SetNodeTypeToID
(
const
torch
::
optional
<
NodeTypeToIDMap
>&
node_type_to_id
)
{
node_type_to_id_
=
node_type_to_id
;
}
/**
* @brief Set the edge type to id map for a heterogeneous graph.
* @note The map is a dictionary mapping edge type names to type IDs.
*/
inline
void
SetEdgeTypeToID
(
const
torch
::
optional
<
EdgeTypeToIDMap
>&
edge_type_to_id
)
{
edge_type_to_id_
=
edge_type_to_id
;
}
/** @brief Set the edge attributes dictionary. */
/** @brief Set the edge attributes dictionary. */
inline
void
SetEdgeAttributes
(
inline
void
SetEdgeAttributes
(
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
)
{
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
)
{
...
@@ -302,6 +353,20 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -302,6 +353,20 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
*/
*/
torch
::
optional
<
torch
::
Tensor
>
type_per_edge_
;
torch
::
optional
<
torch
::
Tensor
>
type_per_edge_
;
/**
* @brief A dictionary mapping node type names to type IDs. The length of it
* is equal to the number of node types. The key is the node type name, and
* the value is the corresponding type ID.
*/
torch
::
optional
<
NodeTypeToIDMap
>
node_type_to_id_
;
/**
* @brief A dictionary mapping edge type names to type IDs. The length of it
* is equal to the number of edge types. The key is the edge type name, and
* the value is the corresponding type ID.
*/
torch
::
optional
<
EdgeTypeToIDMap
>
edge_type_to_id_
;
/**
/**
* @brief A dictionary of edge attributes. Each key represents the attribute's
* @brief A dictionary of edge attributes. Each key represents the attribute's
* name, while the corresponding value holds the attribute's specific value.
* name, while the corresponding value holds the attribute's specific value.
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
3cdc37cc
...
@@ -28,11 +28,15 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph(
...
@@ -28,11 +28,15 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph(
const
torch
::
Tensor
&
indptr
,
const
torch
::
Tensor
&
indices
,
const
torch
::
Tensor
&
indptr
,
const
torch
::
Tensor
&
indices
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
NodeTypeToIDMap
>&
node_type_to_id
,
const
torch
::
optional
<
EdgeTypeToIDMap
>&
edge_type_to_id
,
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
)
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
)
:
indptr_
(
indptr
),
:
indptr_
(
indptr
),
indices_
(
indices
),
indices_
(
indices
),
node_type_offset_
(
node_type_offset
),
node_type_offset_
(
node_type_offset
),
type_per_edge_
(
type_per_edge
),
type_per_edge_
(
type_per_edge
),
node_type_to_id_
(
node_type_to_id
),
edge_type_to_id_
(
edge_type_to_id
),
edge_attributes_
(
edge_attributes
)
{
edge_attributes_
(
edge_attributes
)
{
TORCH_CHECK
(
indptr
.
dim
()
==
1
);
TORCH_CHECK
(
indptr
.
dim
()
==
1
);
TORCH_CHECK
(
indices
.
dim
()
==
1
);
TORCH_CHECK
(
indices
.
dim
()
==
1
);
...
@@ -43,14 +47,21 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::FromCSC(
...
@@ -43,14 +47,21 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::FromCSC(
const
torch
::
Tensor
&
indptr
,
const
torch
::
Tensor
&
indices
,
const
torch
::
Tensor
&
indptr
,
const
torch
::
Tensor
&
indices
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
node_type_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
NodeTypeToIDMap
>&
node_type_to_id
,
const
torch
::
optional
<
EdgeTypeToIDMap
>&
edge_type_to_id
,
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
)
{
const
torch
::
optional
<
EdgeAttrMap
>&
edge_attributes
)
{
if
(
node_type_offset
.
has_value
())
{
if
(
node_type_offset
.
has_value
())
{
auto
&
offset
=
node_type_offset
.
value
();
auto
&
offset
=
node_type_offset
.
value
();
TORCH_CHECK
(
offset
.
dim
()
==
1
);
TORCH_CHECK
(
offset
.
dim
()
==
1
);
TORCH_CHECK
(
node_type_to_id
.
has_value
());
TORCH_CHECK
(
offset
.
size
(
0
)
==
static_cast
<
int64_t
>
(
node_type_to_id
.
value
().
size
()
+
1
));
}
}
if
(
type_per_edge
.
has_value
())
{
if
(
type_per_edge
.
has_value
())
{
TORCH_CHECK
(
type_per_edge
.
value
().
dim
()
==
1
);
TORCH_CHECK
(
type_per_edge
.
value
().
dim
()
==
1
);
TORCH_CHECK
(
type_per_edge
.
value
().
size
(
0
)
==
indices
.
size
(
0
));
TORCH_CHECK
(
type_per_edge
.
value
().
size
(
0
)
==
indices
.
size
(
0
));
TORCH_CHECK
(
edge_type_to_id
.
has_value
());
}
}
if
(
edge_attributes
.
has_value
())
{
if
(
edge_attributes
.
has_value
())
{
for
(
const
auto
&
pair
:
edge_attributes
.
value
())
{
for
(
const
auto
&
pair
:
edge_attributes
.
value
())
{
...
@@ -58,7 +69,8 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::FromCSC(
...
@@ -58,7 +69,8 @@ c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::FromCSC(
}
}
}
}
return
c10
::
make_intrusive
<
FusedCSCSamplingGraph
>
(
return
c10
::
make_intrusive
<
FusedCSCSamplingGraph
>
(
indptr
,
indices
,
node_type_offset
,
type_per_edge
,
edge_attributes
);
indptr
,
indices
,
node_type_offset
,
type_per_edge
,
node_type_to_id
,
edge_type_to_id
,
edge_attributes
);
}
}
void
FusedCSCSamplingGraph
::
Load
(
torch
::
serialize
::
InputArchive
&
archive
)
{
void
FusedCSCSamplingGraph
::
Load
(
torch
::
serialize
::
InputArchive
&
archive
)
{
...
@@ -84,6 +96,34 @@ void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
...
@@ -84,6 +96,34 @@ void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
.
toTensor
();
.
toTensor
();
}
}
if
(
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/has_node_type_to_id"
)
.
toBool
())
{
torch
::
Dict
<
torch
::
IValue
,
torch
::
IValue
>
generic_dict
=
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/node_type_to_id"
)
.
toGenericDict
();
NodeTypeToIDMap
node_type_to_id
;
for
(
const
auto
&
pair
:
generic_dict
)
{
std
::
string
key
=
pair
.
key
().
toStringRef
();
int64_t
value
=
pair
.
value
().
toInt
();
node_type_to_id
.
insert
(
std
::
move
(
key
),
value
);
}
node_type_to_id_
=
std
::
move
(
node_type_to_id
);
}
if
(
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/has_edge_type_to_id"
)
.
toBool
())
{
torch
::
Dict
<
torch
::
IValue
,
torch
::
IValue
>
generic_dict
=
read_from_archive
(
archive
,
"FusedCSCSamplingGraph/edge_type_to_id"
)
.
toGenericDict
();
EdgeTypeToIDMap
edge_type_to_id
;
for
(
const
auto
&
pair
:
generic_dict
)
{
std
::
string
key
=
pair
.
key
().
toStringRef
();
int64_t
value
=
pair
.
value
().
toInt
();
edge_type_to_id
.
insert
(
std
::
move
(
key
),
value
);
}
edge_type_to_id_
=
std
::
move
(
edge_type_to_id
);
}
// Optional edge attributes.
// Optional edge attributes.
torch
::
IValue
has_edge_attributes
;
torch
::
IValue
has_edge_attributes
;
if
(
archive
.
try_read
(
if
(
archive
.
try_read
(
...
@@ -123,6 +163,20 @@ void FusedCSCSamplingGraph::Save(
...
@@ -123,6 +163,20 @@ void FusedCSCSamplingGraph::Save(
archive
.
write
(
archive
.
write
(
"FusedCSCSamplingGraph/type_per_edge"
,
type_per_edge_
.
value
());
"FusedCSCSamplingGraph/type_per_edge"
,
type_per_edge_
.
value
());
}
}
archive
.
write
(
"FusedCSCSamplingGraph/has_node_type_to_id"
,
node_type_to_id_
.
has_value
());
if
(
node_type_to_id_
)
{
archive
.
write
(
"FusedCSCSamplingGraph/node_type_to_id"
,
node_type_to_id_
.
value
());
}
archive
.
write
(
"FusedCSCSamplingGraph/has_edge_type_to_id"
,
edge_type_to_id_
.
has_value
());
if
(
edge_type_to_id_
)
{
archive
.
write
(
"FusedCSCSamplingGraph/edge_type_to_id"
,
edge_type_to_id_
.
value
());
}
archive
.
write
(
archive
.
write
(
"FusedCSCSamplingGraph/has_edge_attributes"
,
"FusedCSCSamplingGraph/has_edge_attributes"
,
edge_attributes_
.
has_value
());
edge_attributes_
.
has_value
());
...
@@ -505,7 +559,7 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
...
@@ -505,7 +559,7 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
auto
edge_attributes
=
helper
.
ReadTorchTensorDict
();
auto
edge_attributes
=
helper
.
ReadTorchTensorDict
();
auto
graph
=
c10
::
make_intrusive
<
FusedCSCSamplingGraph
>
(
auto
graph
=
c10
::
make_intrusive
<
FusedCSCSamplingGraph
>
(
indptr
.
value
(),
indices
.
value
(),
node_type_offset
,
type_per_edge
,
indptr
.
value
(),
indices
.
value
(),
node_type_offset
,
type_per_edge
,
edge_attributes
);
torch
::
nullopt
,
torch
::
nullopt
,
edge_attributes
);
auto
shared_memory
=
helper
.
ReleaseSharedMemory
();
auto
shared_memory
=
helper
.
ReleaseSharedMemory
();
graph
->
HoldSharedMemoryObject
(
graph
->
HoldSharedMemoryObject
(
std
::
move
(
shared_memory
.
first
),
std
::
move
(
shared_memory
.
second
));
std
::
move
(
shared_memory
.
first
),
std
::
move
(
shared_memory
.
second
));
...
...
python/dgl/distributed/partition.py
View file @
3cdc37cc
...
@@ -1254,7 +1254,12 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
...
@@ -1254,7 +1254,12 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
)
)
# Construct GraphMetadata.
# Construct GraphMetadata.
_
,
_
,
ntypes
,
etypes
=
load_partition_book
(
part_config
,
part_id
)
_
,
_
,
ntypes
,
etypes
=
load_partition_book
(
part_config
,
part_id
)
metadata
=
graphbolt
.
GraphMetadata
(
ntypes
,
etypes
)
node_type_to_id
=
{
ntype
:
ntid
for
ntid
,
ntype
in
enumerate
(
ntypes
)}
edge_type_to_id
=
{
_etype_tuple_to_str
(
etype
):
etid
for
etid
,
etype
in
enumerate
(
etypes
)
}
metadata
=
graphbolt
.
GraphMetadata
(
node_type_to_id
,
edge_type_to_id
)
# Obtain CSC indtpr and indices.
# Obtain CSC indtpr and indices.
indptr
,
indices
,
_
=
graph
.
adj
().
csc
()
indptr
,
indices
,
_
=
graph
.
adj
().
csc
()
# Initalize type per edge.
# Initalize type per edge.
...
@@ -1263,7 +1268,11 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
...
@@ -1263,7 +1268,11 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
# Sanity check.
# Sanity check.
assert
len
(
type_per_edge
)
==
graph
.
num_edges
()
assert
len
(
type_per_edge
)
==
graph
.
num_edges
()
csc_graph
=
graphbolt
.
from_fused_csc
(
csc_graph
=
graphbolt
.
from_fused_csc
(
indptr
,
indices
,
None
,
type_per_edge
,
metadata
=
metadata
indptr
,
indices
,
node_type_offset
=
None
,
type_per_edge
=
type_per_edge
,
metadata
=
metadata
,
)
)
orig_graph_path
=
os
.
path
.
join
(
orig_graph_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
part_config
),
os
.
path
.
dirname
(
part_config
),
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
3cdc37cc
...
@@ -456,7 +456,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -456,7 +456,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
dst_ntype_id
=
self
.
metadata
.
node_type_to_id
[
dst_ntype
]
dst_ntype_id
=
self
.
metadata
.
node_type_to_id
[
dst_ntype
]
node_edge_type
[
dst_ntype_id
].
append
((
etype
,
etype_id
))
node_edge_type
[
dst_ntype_id
].
append
((
etype
,
etype_id
))
# construct subgraphs
# construct subgraphs
for
(
i
,
seed
)
in
enumerate
(
column
):
for
i
,
seed
in
enumerate
(
column
):
l
=
indptr
[
i
].
item
()
l
=
indptr
[
i
].
item
()
r
=
indptr
[
i
+
1
].
item
()
r
=
indptr
[
i
+
1
].
item
()
node_type
=
(
node_type
=
(
...
@@ -465,7 +465,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -465,7 +465,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
).
item
()
).
item
()
-
1
-
1
)
)
for
(
etype
,
etype_id
)
in
node_edge_type
[
node_type
]:
for
etype
,
etype_id
in
node_edge_type
[
node_type
]:
src_ntype
,
_
,
_
=
etype_str_to_tuple
(
etype
)
src_ntype
,
_
,
_
=
etype_str_to_tuple
(
etype
)
src_ntype_id
=
self
.
metadata
.
node_type_to_id
[
src_ntype
]
src_ntype_id
=
self
.
metadata
.
node_type_to_id
[
src_ntype
]
num_edges
=
torch
.
searchsorted
(
num_edges
=
torch
.
searchsorted
(
...
@@ -925,12 +925,16 @@ def from_fused_csc(
...
@@ -925,12 +925,16 @@ def from_fused_csc(
assert
len
(
metadata
.
node_type_to_id
)
+
1
==
node_type_offset
.
size
(
assert
len
(
metadata
.
node_type_to_id
)
+
1
==
node_type_offset
.
size
(
0
0
),
"node_type_offset length should be |ntypes| + 1."
),
"node_type_offset length should be |ntypes| + 1."
node_type_to_id
=
metadata
.
node_type_to_id
if
metadata
else
None
edge_type_to_id
=
metadata
.
edge_type_to_id
if
metadata
else
None
return
FusedCSCSamplingGraph
(
return
FusedCSCSamplingGraph
(
torch
.
ops
.
graphbolt
.
from_fused_csc
(
torch
.
ops
.
graphbolt
.
from_fused_csc
(
csc_indptr
,
csc_indptr
,
indices
,
indices
,
node_type_offset
,
node_type_offset
,
type_per_edge
,
type_per_edge
,
node_type_to_id
,
edge_type_to_id
,
edge_attributes
,
edge_attributes
,
),
),
metadata
,
metadata
,
...
@@ -1046,7 +1050,11 @@ def from_dglgraph(
...
@@ -1046,7 +1050,11 @@ def from_dglgraph(
# Obtain CSC matrix.
# Obtain CSC matrix.
indptr
,
indices
,
edge_ids
=
homo_g
.
adj_tensors
(
"csc"
)
indptr
,
indices
,
edge_ids
=
homo_g
.
adj_tensors
(
"csc"
)
ntype_count
.
insert
(
0
,
0
)
ntype_count
.
insert
(
0
,
0
)
node_type_offset
=
torch
.
cumsum
(
torch
.
LongTensor
(
ntype_count
),
0
)
node_type_offset
=
(
None
if
is_homogeneous
else
torch
.
cumsum
(
torch
.
LongTensor
(
ntype_count
),
0
)
)
# Assign edge type according to the order of CSC matrix.
# Assign edge type according to the order of CSC matrix.
type_per_edge
=
None
if
is_homogeneous
else
homo_g
.
edata
[
ETYPE
][
edge_ids
]
type_per_edge
=
None
if
is_homogeneous
else
homo_g
.
edata
[
ETYPE
][
edge_ids
]
...
@@ -1056,12 +1064,16 @@ def from_dglgraph(
...
@@ -1056,12 +1064,16 @@ def from_dglgraph(
# Assign edge attributes according to the original eids mapping.
# Assign edge attributes according to the original eids mapping.
edge_attributes
[
ORIGINAL_EDGE_ID
]
=
homo_g
.
edata
[
EID
][
edge_ids
]
edge_attributes
[
ORIGINAL_EDGE_ID
]
=
homo_g
.
edata
[
EID
][
edge_ids
]
node_type_to_id
=
metadata
.
node_type_to_id
if
metadata
else
None
edge_type_to_id
=
metadata
.
edge_type_to_id
if
metadata
else
None
return
FusedCSCSamplingGraph
(
return
FusedCSCSamplingGraph
(
torch
.
ops
.
graphbolt
.
from_fused_csc
(
torch
.
ops
.
graphbolt
.
from_fused_csc
(
indptr
,
indptr
,
indices
,
indices
,
node_type_offset
,
node_type_offset
,
type_per_edge
,
type_per_edge
,
node_type_to_id
,
edge_type_to_id
,
edge_attributes
,
edge_attributes
,
),
),
metadata
,
metadata
,
...
...
tests/distributed/test_partition.py
View file @
3cdc37cc
...
@@ -16,6 +16,7 @@ from dgl.distributed import (
...
@@ -16,6 +16,7 @@ from dgl.distributed import (
partition_graph
,
partition_graph
,
)
)
from
dgl.distributed.graph_partition_book
import
(
from
dgl.distributed.graph_partition_book
import
(
_etype_str_to_tuple
,
_etype_tuple_to_str
,
_etype_tuple_to_str
,
DEFAULT_ETYPE
,
DEFAULT_ETYPE
,
DEFAULT_NTYPE
,
DEFAULT_NTYPE
,
...
@@ -707,7 +708,7 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
...
@@ -707,7 +708,7 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
for
node_type
,
type_id
in
new_g
.
metadata
.
node_type_to_id
.
items
():
for
node_type
,
type_id
in
new_g
.
metadata
.
node_type_to_id
.
items
():
assert
g
.
get_ntype_id
(
node_type
)
==
type_id
assert
g
.
get_ntype_id
(
node_type
)
==
type_id
for
edge_type
,
type_id
in
new_g
.
metadata
.
edge_type_to_id
.
items
():
for
edge_type
,
type_id
in
new_g
.
metadata
.
edge_type_to_id
.
items
():
assert
g
.
get_etype_id
(
edge_type
)
==
type_id
assert
g
.
get_etype_id
(
_etype_str_to_tuple
(
edge_type
)
)
==
type_id
@
pytest
.
mark
.
parametrize
(
"part_method"
,
[
"metis"
,
"random"
])
@
pytest
.
mark
.
parametrize
(
"part_method"
,
[
"metis"
,
"random"
])
...
@@ -738,7 +739,7 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
...
@@ -738,7 +739,7 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
for
node_type
,
type_id
in
new_g
.
metadata
.
node_type_to_id
.
items
():
for
node_type
,
type_id
in
new_g
.
metadata
.
node_type_to_id
.
items
():
assert
g
.
get_ntype_id
(
node_type
)
==
type_id
assert
g
.
get_ntype_id
(
node_type
)
==
type_id
for
edge_type
,
type_id
in
new_g
.
metadata
.
edge_type_to_id
.
items
():
for
edge_type
,
type_id
in
new_g
.
metadata
.
edge_type_to_id
.
items
():
assert
g
.
get_etype_id
(
edge_type
)
==
type_id
assert
g
.
get_etype_id
(
_etype_str_to_tuple
(
edge_type
)
)
==
type_id
assert
new_g
.
node_type_offset
is
None
assert
new_g
.
node_type_offset
is
None
assert
th
.
equal
(
orig_g
.
edata
[
dgl
.
ETYPE
],
new_g
.
type_per_edge
)
assert
th
.
equal
(
orig_g
.
edata
[
dgl
.
ETYPE
],
new_g
.
type_per_edge
)
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
3cdc37cc
...
@@ -1354,7 +1354,7 @@ def test_from_dglgraph_homogeneous():
...
@@ -1354,7 +1354,7 @@ def test_from_dglgraph_homogeneous():
assert
gb_g
.
total_num_nodes
==
dgl_g
.
num_nodes
()
assert
gb_g
.
total_num_nodes
==
dgl_g
.
num_nodes
()
assert
gb_g
.
total_num_edges
==
dgl_g
.
num_edges
()
assert
gb_g
.
total_num_edges
==
dgl_g
.
num_edges
()
assert
torch
.
equal
(
gb_g
.
node_type_offset
,
torch
.
tensor
([
0
,
1000
]))
assert
gb_g
.
node_type_offset
is
None
assert
gb_g
.
type_per_edge
is
None
assert
gb_g
.
type_per_edge
is
None
assert
gb_g
.
metadata
is
None
assert
gb_g
.
metadata
is
None
...
...
tests/python/pytorch/graphbolt/impl/test_ondisk_dataset.py
View file @
3cdc37cc
...
@@ -1999,8 +1999,8 @@ def test_BuiltinDataset():
...
@@ -1999,8 +1999,8 @@ def test_BuiltinDataset():
"""Test BuiltinDataset."""
"""Test BuiltinDataset."""
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
# Case 1: download from DGL S3 storage.
# Case 1: download from DGL S3 storage.
dataset_name
=
"test-
only
"
dataset_name
=
"test-
dataset-231204
"
# Add
test-only
dataset to the builtin dataset list for testing only.
# Add dataset to the builtin dataset list for testing only.
gb
.
BuiltinDataset
.
_all_datasets
.
append
(
dataset_name
)
gb
.
BuiltinDataset
.
_all_datasets
.
append
(
dataset_name
)
dataset
=
gb
.
BuiltinDataset
(
name
=
dataset_name
,
root
=
test_dir
).
load
()
dataset
=
gb
.
BuiltinDataset
(
name
=
dataset_name
,
root
=
test_dir
).
load
()
assert
dataset
.
graph
is
not
None
assert
dataset
.
graph
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment