Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
317e70b4
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c104482b9c0ef18624fc4de4d2f0093122abc3c1"
Unverified
Commit
317e70b4
authored
May 17, 2023
by
peizhou001
Committed by
GitHub
May 17, 2023
Browse files
[GraphBolt] Add sampling graph C side code (#5697)
parent
29df6ec4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
215 additions
and
0 deletions
+215
-0
graphbolt/include/csc_sampling_graph.h
graphbolt/include/csc_sampling_graph.h
+163
-0
graphbolt/include/graph.h
graphbolt/include/graph.h
+0
-0
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+52
-0
graphbolt/src/graph.cc
graphbolt/src/graph.cc
+0
-0
No files found.
graphbolt/include/csc_sampling_graph.h
0 → 100644
View file @
317e70b4
/**
* Copyright (c) 2023 by Contributors
* @file graphbolt/include/csc_sampling_graph.h
* @brief Header file of csc sampling graph.
*/
#include <torch/custom_class.h>
#include <torch/torch.h>
#include <string>
#include <vector>
namespace
graphbolt
{
namespace
sampling
{
using
StringList
=
std
::
vector
<
std
::
string
>
;
/**
* @brief Structure representing heterogeneous information about a graph.
*
* Example usage:
*
* Suppose the graph has 6 edges
* node_offset = [0, 2, 4]
* edge_types = [0, 1, 0, 2, 1, 2]
* HeteroInfo info({"A", "B", "C"}, {"X", "Y", "Z"}, node_offset, edge_types);
*
* This example creates a `HeteroInfo` object with three node types ("A", "B",
* "C") and three edge types ("X", "Y", "Z"). The `node_offset` tensor
* represents the offset array of node type, the given array indicates that node
* [0, 2) has type "A", [2, 4) has type "B", and [4, 6) has type "C". And the
* `edge_types` tensor represents the type id of each edge.
*/
struct
HeteroInfo
{
/**
* @brief Constructs a new `HeteroInfo` object.
* @param ntypes List of node types in the graph.
* @param etypes List of edge types in the graph.
* @param node_type_offset Offset array of node type. It is assumed that nodes
* of same type have consecutive ids.
* @param type_per_edge Type id of each edge, where type id is the
* corresponding index of `edge_types`.
*/
HeteroInfo
(
const
StringList
&
ntypes
,
const
StringList
&
etypes
,
torch
::
Tensor
&
node_type_offset
,
torch
::
Tensor
&
type_per_edge
)
:
node_types
(
ntypes
),
edge_types
(
etypes
),
node_type_offset
(
node_type_offset
),
type_per_edge
(
type_per_edge
)
{}
/** @brief List of node types in the graph.*/
StringList
node_types
;
/** @brief List of edge types in the graph. */
StringList
edge_types
;
/**
* @brief Offset array of node type. The length of it is equal to number of
* node types.
*/
torch
::
Tensor
node_type_offset
;
/**
* @brief Type id of each edge, where type id is the corresponding index of
* edge_types. The length of it is equal to the number of edges.
*/
torch
::
Tensor
type_per_edge
;
};
/**
* @brief A sampling oriented csc format graph.
*/
class
CSCSamplingGraph
:
public
torch
::
CustomClassHolder
{
public:
/**
* @brief Constructor for CSC with data.
* @param num_rows The number of nodes in the graph.
* @param indptr The CSC format index pointer array.
* @param indices The CSC format index array.
* @param hetero_info Heterogeneous graph information, if present. Nullptr
* means it is a homogeneous graph.
*/
CSCSamplingGraph
(
int64_t
num_nodes
,
torch
::
Tensor
&
indptr
,
torch
::
Tensor
&
indices
,
const
std
::
shared_ptr
<
HeteroInfo
>&
hetero_info
);
/**
* @brief Create a homogeneous CSC graph from tensors of CSC format.
* @param num_nodes The number of nodes in the graph.
* @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC.
*
* @return CSCSamplingGraph
*/
static
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
FromCSC
(
int64_t
num_nodes
,
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
);
/**
* @brief Create a heterogeneous CSC graph from tensors of CSC format.
* @param num_rows The number of nodes in the graph.
* @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC.
* @param ntypes A list of node types, if present.
* @param etypes A list of edge types, if present.
* @param node_type_offset_ A tensor representing the offset of node types, if
* present.
* @param type_per_edge_ A tensor representing the type of each edge, if
* present.
*
* @return CSCSamplingGraph
*/
static
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
FromCSCWithHeteroInfo
(
int64_t
num_nodes
,
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
const
StringList
&
ntypes
,
const
StringList
&
etypes
,
torch
::
Tensor
node_type_offset
,
torch
::
Tensor
type_per_edge
);
/** @brief Get the number of nodes. */
int64_t
NumNodes
()
const
{
return
num_nodes_
;
}
/** @brief Get the number of edges. */
int64_t
NumEdges
()
const
{
return
indices_
.
size
(
0
);
}
/** @brief Get the csc index pointer tensor. */
const
torch
::
Tensor
CSCIndptr
()
const
{
return
indptr_
;
}
/** @brief Get the index tensor. */
const
torch
::
Tensor
Indices
()
const
{
return
indices_
;
}
/** @brief Check if the graph is heterogeneous. */
inline
bool
IsHeterogeneous
()
const
{
return
hetero_info_
!=
nullptr
;
}
/** @brief Get the node type offset tensor for a heterogeneous graph. */
inline
const
torch
::
Tensor
NodeTypeOffset
()
const
{
return
hetero_info_
->
node_type_offset
;
}
/** @brief Get the list of node types for a heterogeneous graph. */
inline
StringList
&
NodeTypes
()
const
{
return
hetero_info_
->
node_types
;
}
/** @brief Get the list of edge types for a heterogeneous graph. */
inline
const
StringList
&
EdgeTypes
()
const
{
return
hetero_info_
->
edge_types
;
}
/** @brief Get the edge type tensor for a heterogeneous graph. */
inline
const
torch
::
Tensor
TypePerEdge
()
const
{
return
hetero_info_
->
type_per_edge
;
}
private:
/** @brief The number of nodes of the graph. */
int64_t
num_nodes_
=
0
;
/** @brief CSC format index pointer array. */
torch
::
Tensor
indptr_
;
/** @brief CSC format index array. */
torch
::
Tensor
indices_
;
/** @brief Heterogeneous graph information, if present. */
std
::
shared_ptr
<
HeteroInfo
>
hetero_info_
;
};
}
// namespace sampling
}
// namespace graphbolt
graphbolt/include/graph.h
deleted
100644 → 0
View file @
29df6ec4
graphbolt/src/csc_sampling_graph.cc
0 → 100644
View file @
317e70b4
/**
* Copyright (c) 2023 by Contributors
* @file graphbolt/include/csc_sampling_graph.cc
* @brief Source file of sampling graph.
*/
#include "csc_sampling_graph.h"
namespace
graphbolt
{
namespace
sampling
{
CSCSamplingGraph
::
CSCSamplingGraph
(
int64_t
num_nodes
,
torch
::
Tensor
&
indptr
,
torch
::
Tensor
&
indices
,
const
std
::
shared_ptr
<
HeteroInfo
>&
hetero_info
)
:
num_nodes_
(
num_nodes
),
indptr_
(
indptr
),
indices_
(
indices
),
hetero_info_
(
hetero_info
)
{
TORCH_CHECK
(
num_nodes
>=
0
);
TORCH_CHECK
(
indptr
.
dim
()
==
1
);
TORCH_CHECK
(
indices
.
dim
()
==
1
);
TORCH_CHECK
(
indptr
.
size
(
0
)
==
num_nodes
+
1
);
TORCH_CHECK
(
indptr
.
device
()
==
indices
.
device
());
}
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
CSCSamplingGraph
::
FromCSC
(
int64_t
num_nodes
,
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
)
{
return
c10
::
make_intrusive
<
CSCSamplingGraph
>
(
num_nodes
,
indptr
,
indices
,
nullptr
);
}
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
CSCSamplingGraph
::
FromCSCWithHeteroInfo
(
int64_t
num_nodes
,
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
const
StringList
&
ntypes
,
const
StringList
&
etypes
,
torch
::
Tensor
node_type_offset
,
torch
::
Tensor
type_per_edge
)
{
TORCH_CHECK
(
node_type_offset
.
size
(
0
)
>
0
);
TORCH_CHECK
(
node_type_offset
.
dim
()
==
1
);
TORCH_CHECK
(
type_per_edge
.
size
(
0
)
>
0
);
TORCH_CHECK
(
type_per_edge
.
dim
()
==
1
);
TORCH_CHECK
(
node_type_offset
.
device
()
==
type_per_edge
.
device
());
TORCH_CHECK
(
type_per_edge
.
device
()
==
indices
.
device
());
TORCH_CHECK
(
!
ntypes
.
empty
());
TORCH_CHECK
(
!
etypes
.
empty
());
auto
hetero_info
=
std
::
make_shared
<
HeteroInfo
>
(
ntypes
,
etypes
,
node_type_offset
,
type_per_edge
);
return
c10
::
make_intrusive
<
CSCSamplingGraph
>
(
num_nodes
,
indptr
,
indices
,
hetero_info
);
}
}
// namespace sampling
}
// namespace graphbolt
graphbolt/src/graph.cc
deleted
100644 → 0
View file @
29df6ec4
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment