"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "072d8b2280569a2d13b91d3ed51546d201a57366"
Unverified Commit 91fe0c90 authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[GraphBolt] Refactor CSCSamplingGraph and update the corresponding method. (#6515)

parent 3f3652e0
...@@ -56,7 +56,7 @@ Standard Implementations ...@@ -56,7 +56,7 @@ Standard Implementations
OnDiskDataset OnDiskDataset
BuiltinDataset BuiltinDataset
CSCSamplingGraph FusedCSCSamplingGraph
UniformNegativeSampler UniformNegativeSampler
NeighborSampler NeighborSampler
LayerNeighborSampler LayerNeighborSampler
......
...@@ -106,7 +106,7 @@ def create_dataloader( ...@@ -106,7 +106,7 @@ def create_dataloader(
# Sample neighbors for each seed node in the mini-batch. # Sample neighbors for each seed node in the mini-batch.
# `graph`: # `graph`:
# The graph(CSCSamplingGraph) from which to sample neighbors. # The graph(FusedCSCSamplingGraph) from which to sample neighbors.
# `fanouts`: # `fanouts`:
# The number of neighbors to sample for each node in each layer. # The number of neighbors to sample for each node in each layer.
datapipe = datapipe.sample_neighbor(graph, fanouts=fanouts) datapipe = datapipe.sample_neighbor(graph, fanouts=fanouts)
...@@ -166,7 +166,7 @@ def rel_graph_embed(graph, embed_size): ...@@ -166,7 +166,7 @@ def rel_graph_embed(graph, embed_size):
Parameters Parameters
---------- ----------
graph : CSCSamplingGraph graph : FusedCSCSamplingGraph
The graph for which to create the heterogenous embedding layer. The graph for which to create the heterogenous embedding layer.
embed_size : int embed_size : int
The size of the embedding vectors. The size of the embedding vectors.
......
/** /**
* Copyright (c) 2023 by Contributors * Copyright (c) 2023 by Contributors
* @file graphbolt/csc_sampling_graph.h * @file graphbolt/fused_csc_sampling_graph.h
* @brief Header file of csc sampling graph. * @brief Header file of csc sampling graph.
*/ */
#ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_ #ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
...@@ -39,18 +39,18 @@ struct SamplerArgs<SamplerType::LABOR> { ...@@ -39,18 +39,18 @@ struct SamplerArgs<SamplerType::LABOR> {
* Suppose the graph has 3 node types, 3 edge types and 6 edges * Suppose the graph has 3 node types, 3 edge types and 6 edges
* auto node_type_offset = {0, 2, 4, 6} * auto node_type_offset = {0, 2, 4, 6}
* auto type_per_edge = {0, 1, 0, 2, 1, 2} * auto type_per_edge = {0, 1, 0, 2, 1, 2}
* auto graph = CSCSamplingGraph(..., ..., node_type_offset, type_per_edge) * auto graph = FusedCSCSamplingGraph(..., ..., node_type_offset, type_per_edge)
* *
* The `node_type_offset` tensor represents the offset array of node type, the * The `node_type_offset` tensor represents the offset array of node type, the
* given array indicates that node [0, 2) has type id 0, [2, 4) has type id 1, * given array indicates that node [0, 2) has type id 0, [2, 4) has type id 1,
* and [4, 6) has type id 2. And the `type_per_edge` tensor represents the type * and [4, 6) has type id 2. And the `type_per_edge` tensor represents the type
* id of each edge. * id of each edge.
*/ */
class CSCSamplingGraph : public torch::CustomClassHolder { class FusedCSCSamplingGraph : public torch::CustomClassHolder {
public: public:
using EdgeAttrMap = torch::Dict<std::string, torch::Tensor>; using EdgeAttrMap = torch::Dict<std::string, torch::Tensor>;
/** @brief Default constructor. */ /** @brief Default constructor. */
CSCSamplingGraph() = default; FusedCSCSamplingGraph() = default;
/** /**
* @brief Constructor for CSC with data. * @brief Constructor for CSC with data.
...@@ -61,7 +61,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -61,7 +61,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* @param type_per_edge A tensor representing the type of each edge, if * @param type_per_edge A tensor representing the type of each edge, if
* present. * present.
*/ */
CSCSamplingGraph( FusedCSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset, const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
...@@ -76,9 +76,9 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -76,9 +76,9 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* @param type_per_edge A tensor representing the type of each edge, if * @param type_per_edge A tensor representing the type of each edge, if
* present. * present.
* *
* @return CSCSamplingGraph * @return FusedCSCSamplingGraph
*/ */
static c10::intrusive_ptr<CSCSamplingGraph> FromCSC( static c10::intrusive_ptr<FusedCSCSamplingGraph> FromCSC(
const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset, const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
...@@ -155,7 +155,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -155,7 +155,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
/** /**
* @brief Pickle method for deserializing. * @brief Pickle method for deserializing.
* @param state The state of serialized CSCSamplingGraph. * @param state The state of serialized FusedCSCSamplingGraph.
*/ */
void SetState( void SetState(
const torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>& const torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>&
...@@ -163,7 +163,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -163,7 +163,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
/** /**
* @brief Pickle method for serializing. * @brief Pickle method for serializing.
* @returns The state of this CSCSamplingGraph. * @returns The state of this FusedCSCSamplingGraph.
*/ */
torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>> GetState() torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>> GetState()
const; const;
...@@ -246,18 +246,18 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -246,18 +246,18 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* @brief Copy the graph to shared memory. * @brief Copy the graph to shared memory.
* @param shared_memory_name The name of the shared memory. * @param shared_memory_name The name of the shared memory.
* *
* @return A new CSCSamplingGraph object on shared memory. * @return A new FusedCSCSamplingGraph object on shared memory.
*/ */
c10::intrusive_ptr<CSCSamplingGraph> CopyToSharedMemory( c10::intrusive_ptr<FusedCSCSamplingGraph> CopyToSharedMemory(
const std::string& shared_memory_name); const std::string& shared_memory_name);
/** /**
* @brief Load the graph from shared memory. * @brief Load the graph from shared memory.
* @param shared_memory_name The name of the shared memory. * @param shared_memory_name The name of the shared memory.
* *
* @return A new CSCSamplingGraph object on shared memory. * @return A new FusedCSCSamplingGraph object on shared memory.
*/ */
static c10::intrusive_ptr<CSCSamplingGraph> LoadFromSharedMemory( static c10::intrusive_ptr<FusedCSCSamplingGraph> LoadFromSharedMemory(
const std::string& shared_memory_name); const std::string& shared_memory_name);
/** /**
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#ifndef GRAPHBOLT_SERIALIZE_H_ #ifndef GRAPHBOLT_SERIALIZE_H_
#define GRAPHBOLT_SERIALIZE_H_ #define GRAPHBOLT_SERIALIZE_H_
#include <graphbolt/csc_sampling_graph.h> #include <graphbolt/fused_csc_sampling_graph.h>
#include <torch/torch.h> #include <torch/torch.h>
#include <string> #include <string>
...@@ -15,53 +15,55 @@ ...@@ -15,53 +15,55 @@
/** /**
* @brief Overload stream operator to enable `torch::save()` and `torch.load()` * @brief Overload stream operator to enable `torch::save()` and `torch.load()`
* for CSCSamplingGraph. * for FusedCSCSamplingGraph.
*/ */
namespace torch { namespace torch {
/** /**
* @brief Overload input stream operator for CSCSamplingGraph deserialization. * @brief Overload input stream operator for FusedCSCSamplingGraph
* deserialization.
* @param archive Input stream for deserializing. * @param archive Input stream for deserializing.
* @param graph CSCSamplingGraph. * @param graph FusedCSCSamplingGraph.
* *
* @return archive * @return archive
*/ */
inline serialize::InputArchive& operator>>( inline serialize::InputArchive& operator>>(
serialize::InputArchive& archive, serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph); graphbolt::sampling::FusedCSCSamplingGraph& graph);
/** /**
* @brief Overload output stream operator for CSCSamplingGraph serialization. * @brief Overload output stream operator for FusedCSCSamplingGraph
* serialization.
* @param archive Output stream for serializing. * @param archive Output stream for serializing.
* @param graph CSCSamplingGraph. * @param graph FusedCSCSamplingGraph.
* *
* @return archive * @return archive
*/ */
inline serialize::OutputArchive& operator<<( inline serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive, serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph); const graphbolt::sampling::FusedCSCSamplingGraph& graph);
} // namespace torch } // namespace torch
namespace graphbolt { namespace graphbolt {
/** /**
* @brief Load CSCSamplingGraph from file. * @brief Load FusedCSCSamplingGraph from file.
* @param filename File name to read. * @param filename File name to read.
* *
* @return CSCSamplingGraph. * @return FusedCSCSamplingGraph.
*/ */
c10::intrusive_ptr<sampling::CSCSamplingGraph> LoadCSCSamplingGraph( c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> LoadFusedCSCSamplingGraph(
const std::string& filename); const std::string& filename);
/** /**
* @brief Save CSCSamplingGraph to file. * @brief Save FusedCSCSamplingGraph to file.
* @param graph CSCSamplingGraph to save. * @param graph FusedCSCSamplingGraph to save.
* @param filename File name to save. * @param filename File name to save.
* *
*/ */
void SaveCSCSamplingGraph( void SaveFusedCSCSamplingGraph(
c10::intrusive_ptr<sampling::CSCSamplingGraph> graph, c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> graph,
const std::string& filename); const std::string& filename);
/** /**
......
/** /**
* Copyright (c) 2023 by Contributors * Copyright (c) 2023 by Contributors
* @file csc_sampling_graph.cc * @file fused_csc_sampling_graph.cc
* @brief Source file of sampling graph. * @brief Source file of sampling graph.
*/ */
#include <graphbolt/csc_sampling_graph.h> #include <graphbolt/fused_csc_sampling_graph.h>
#include <graphbolt/serialize.h> #include <graphbolt/serialize.h>
#include <torch/torch.h> #include <torch/torch.h>
...@@ -24,7 +24,7 @@ namespace sampling { ...@@ -24,7 +24,7 @@ namespace sampling {
static const int kPickleVersion = 6199; static const int kPickleVersion = 6199;
CSCSamplingGraph::CSCSamplingGraph( FusedCSCSamplingGraph::FusedCSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset, const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
...@@ -39,7 +39,7 @@ CSCSamplingGraph::CSCSamplingGraph( ...@@ -39,7 +39,7 @@ CSCSamplingGraph::CSCSamplingGraph(
TORCH_CHECK(indptr.device() == indices.device()); TORCH_CHECK(indptr.device() == indices.device());
} }
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC( c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::FromCSC(
const torch::Tensor& indptr, const torch::Tensor& indices, const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset, const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
...@@ -57,37 +57,40 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC( ...@@ -57,37 +57,40 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
TORCH_CHECK(pair.value().size(0) == indices.size(0)); TORCH_CHECK(pair.value().size(0) == indices.size(0));
} }
} }
return c10::make_intrusive<CSCSamplingGraph>( return c10::make_intrusive<FusedCSCSamplingGraph>(
indptr, indices, node_type_offset, type_per_edge, edge_attributes); indptr, indices, node_type_offset, type_per_edge, edge_attributes);
} }
void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
const int64_t magic_num = const int64_t magic_num =
read_from_archive(archive, "CSCSamplingGraph/magic_num").toInt(); read_from_archive(archive, "FusedCSCSamplingGraph/magic_num").toInt();
TORCH_CHECK( TORCH_CHECK(
magic_num == kCSCSamplingGraphSerializeMagic, magic_num == kCSCSamplingGraphSerializeMagic,
"Magic numbers mismatch when loading CSCSamplingGraph."); "Magic numbers mismatch when loading FusedCSCSamplingGraph.");
indptr_ = read_from_archive(archive, "CSCSamplingGraph/indptr").toTensor(); indptr_ =
indices_ = read_from_archive(archive, "CSCSamplingGraph/indices").toTensor(); read_from_archive(archive, "FusedCSCSamplingGraph/indptr").toTensor();
if (read_from_archive(archive, "CSCSamplingGraph/has_node_type_offset") indices_ =
read_from_archive(archive, "FusedCSCSamplingGraph/indices").toTensor();
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_node_type_offset")
.toBool()) { .toBool()) {
node_type_offset_ = node_type_offset_ =
read_from_archive(archive, "CSCSamplingGraph/node_type_offset") read_from_archive(archive, "FusedCSCSamplingGraph/node_type_offset")
.toTensor(); .toTensor();
} }
if (read_from_archive(archive, "CSCSamplingGraph/has_type_per_edge") if (read_from_archive(archive, "FusedCSCSamplingGraph/has_type_per_edge")
.toBool()) { .toBool()) {
type_per_edge_ = type_per_edge_ =
read_from_archive(archive, "CSCSamplingGraph/type_per_edge").toTensor(); read_from_archive(archive, "FusedCSCSamplingGraph/type_per_edge")
.toTensor();
} }
// Optional edge attributes. // Optional edge attributes.
torch::IValue has_edge_attributes; torch::IValue has_edge_attributes;
if (archive.try_read( if (archive.try_read(
"CSCSamplingGraph/has_edge_attributes", has_edge_attributes) && "FusedCSCSamplingGraph/has_edge_attributes", has_edge_attributes) &&
has_edge_attributes.toBool()) { has_edge_attributes.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict = torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "CSCSamplingGraph/edge_attributes") read_from_archive(archive, "FusedCSCSamplingGraph/edge_attributes")
.toGenericDict(); .toGenericDict();
EdgeAttrMap target_dict; EdgeAttrMap target_dict;
for (const auto& pair : generic_dict) { for (const auto& pair : generic_dict) {
...@@ -101,29 +104,35 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { ...@@ -101,29 +104,35 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
} }
} }
void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const { void FusedCSCSamplingGraph::Save(
archive.write("CSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic); torch::serialize::OutputArchive& archive) const {
archive.write("CSCSamplingGraph/indptr", indptr_); archive.write(
archive.write("CSCSamplingGraph/indices", indices_); "FusedCSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic);
archive.write("FusedCSCSamplingGraph/indptr", indptr_);
archive.write("FusedCSCSamplingGraph/indices", indices_);
archive.write( archive.write(
"CSCSamplingGraph/has_node_type_offset", node_type_offset_.has_value()); "FusedCSCSamplingGraph/has_node_type_offset",
node_type_offset_.has_value());
if (node_type_offset_) { if (node_type_offset_) {
archive.write( archive.write(
"CSCSamplingGraph/node_type_offset", node_type_offset_.value()); "FusedCSCSamplingGraph/node_type_offset", node_type_offset_.value());
} }
archive.write( archive.write(
"CSCSamplingGraph/has_type_per_edge", type_per_edge_.has_value()); "FusedCSCSamplingGraph/has_type_per_edge", type_per_edge_.has_value());
if (type_per_edge_) { if (type_per_edge_) {
archive.write("CSCSamplingGraph/type_per_edge", type_per_edge_.value()); archive.write(
"FusedCSCSamplingGraph/type_per_edge", type_per_edge_.value());
} }
archive.write( archive.write(
"CSCSamplingGraph/has_edge_attributes", edge_attributes_.has_value()); "FusedCSCSamplingGraph/has_edge_attributes",
edge_attributes_.has_value());
if (edge_attributes_) { if (edge_attributes_) {
archive.write("CSCSamplingGraph/edge_attributes", edge_attributes_.value()); archive.write(
"FusedCSCSamplingGraph/edge_attributes", edge_attributes_.value());
} }
} }
void CSCSamplingGraph::SetState( void FusedCSCSamplingGraph::SetState(
const torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>& const torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>&
state) { state) {
// State is a dict of dicts. The tensor-type attributes are stored in the dict // State is a dict of dicts. The tensor-type attributes are stored in the dict
...@@ -133,7 +142,7 @@ void CSCSamplingGraph::SetState( ...@@ -133,7 +142,7 @@ void CSCSamplingGraph::SetState(
TORCH_CHECK( TORCH_CHECK(
independent_tensors.at("version_number") independent_tensors.at("version_number")
.equal(torch::tensor({kPickleVersion})), .equal(torch::tensor({kPickleVersion})),
"Version number mismatches when loading pickled CSCSamplingGraph.") "Version number mismatches when loading pickled FusedCSCSamplingGraph.")
indptr_ = independent_tensors.at("indptr"); indptr_ = independent_tensors.at("indptr");
indices_ = independent_tensors.at("indices"); indices_ = independent_tensors.at("indices");
if (independent_tensors.find("node_type_offset") != if (independent_tensors.find("node_type_offset") !=
...@@ -149,7 +158,7 @@ void CSCSamplingGraph::SetState( ...@@ -149,7 +158,7 @@ void CSCSamplingGraph::SetState(
} }
torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>> torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>
CSCSamplingGraph::GetState() const { FusedCSCSamplingGraph::GetState() const {
// State is a dict of dicts. The tensor-type attributes are stored in the dict // State is a dict of dicts. The tensor-type attributes are stored in the dict
// with key "independent_tensors". The dict-type attributes (edge_attributes) // with key "independent_tensors". The dict-type attributes (edge_attributes)
// are stored directly with the their name as the key. // are stored directly with the their name as the key.
...@@ -173,7 +182,7 @@ CSCSamplingGraph::GetState() const { ...@@ -173,7 +182,7 @@ CSCSamplingGraph::GetState() const {
return state; return state;
} }
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph( c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const { const torch::Tensor& nodes) const {
using namespace torch::indexing; using namespace torch::indexing;
const int32_t kDefaultGrainSize = 100; const int32_t kDefaultGrainSize = 100;
...@@ -296,7 +305,7 @@ auto GetPickFn( ...@@ -296,7 +305,7 @@ auto GetPickFn(
} }
template <typename NumPickFn, typename PickFn> template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl( c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::SampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn, const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const { PickFn pick_fn) const {
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
...@@ -413,7 +422,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl( ...@@ -413,7 +422,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
subgraph_reverse_edge_ids, subgraph_type_per_edge); subgraph_reverse_edge_ids, subgraph_type_per_edge);
} }
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts, const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids, bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const { torch::optional<std::string> probs_name) const {
...@@ -451,7 +460,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( ...@@ -451,7 +460,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
} }
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
CSCSamplingGraph::SampleNegativeEdgesUniform( FusedCSCSamplingGraph::SampleNegativeEdgesUniform(
const std::tuple<torch::Tensor, torch::Tensor>& node_pairs, const std::tuple<torch::Tensor, torch::Tensor>& node_pairs,
int64_t negative_ratio, int64_t max_node_id) const { int64_t negative_ratio, int64_t max_node_id) const {
torch::Tensor pos_src; torch::Tensor pos_src;
...@@ -462,15 +471,15 @@ CSCSamplingGraph::SampleNegativeEdgesUniform( ...@@ -462,15 +471,15 @@ CSCSamplingGraph::SampleNegativeEdgesUniform(
return std::make_tuple(neg_src, neg_dst); return std::make_tuple(neg_src, neg_dst);
} }
static c10::intrusive_ptr<CSCSamplingGraph> BuildGraphFromSharedMemoryHelper( static c10::intrusive_ptr<FusedCSCSamplingGraph>
SharedMemoryHelper&& helper) { BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
helper.InitializeRead(); helper.InitializeRead();
auto indptr = helper.ReadTorchTensor(); auto indptr = helper.ReadTorchTensor();
auto indices = helper.ReadTorchTensor(); auto indices = helper.ReadTorchTensor();
auto node_type_offset = helper.ReadTorchTensor(); auto node_type_offset = helper.ReadTorchTensor();
auto type_per_edge = helper.ReadTorchTensor(); auto type_per_edge = helper.ReadTorchTensor();
auto edge_attributes = helper.ReadTorchTensorDict(); auto edge_attributes = helper.ReadTorchTensorDict();
auto graph = c10::make_intrusive<CSCSamplingGraph>( auto graph = c10::make_intrusive<FusedCSCSamplingGraph>(
indptr.value(), indices.value(), node_type_offset, type_per_edge, indptr.value(), indices.value(), node_type_offset, type_per_edge,
edge_attributes); edge_attributes);
auto shared_memory = helper.ReleaseSharedMemory(); auto shared_memory = helper.ReleaseSharedMemory();
...@@ -479,7 +488,8 @@ static c10::intrusive_ptr<CSCSamplingGraph> BuildGraphFromSharedMemoryHelper( ...@@ -479,7 +488,8 @@ static c10::intrusive_ptr<CSCSamplingGraph> BuildGraphFromSharedMemoryHelper(
return graph; return graph;
} }
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::CopyToSharedMemory( c10::intrusive_ptr<FusedCSCSamplingGraph>
FusedCSCSamplingGraph::CopyToSharedMemory(
const std::string& shared_memory_name) { const std::string& shared_memory_name) {
SharedMemoryHelper helper(shared_memory_name, SERIALIZED_METAINFO_SIZE_MAX); SharedMemoryHelper helper(shared_memory_name, SERIALIZED_METAINFO_SIZE_MAX);
helper.WriteTorchTensor(indptr_); helper.WriteTorchTensor(indptr_);
...@@ -491,13 +501,14 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::CopyToSharedMemory( ...@@ -491,13 +501,14 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::CopyToSharedMemory(
return BuildGraphFromSharedMemoryHelper(std::move(helper)); return BuildGraphFromSharedMemoryHelper(std::move(helper));
} }
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory( c10::intrusive_ptr<FusedCSCSamplingGraph>
FusedCSCSamplingGraph::LoadFromSharedMemory(
const std::string& shared_memory_name) { const std::string& shared_memory_name) {
SharedMemoryHelper helper(shared_memory_name, SERIALIZED_METAINFO_SIZE_MAX); SharedMemoryHelper helper(shared_memory_name, SERIALIZED_METAINFO_SIZE_MAX);
return BuildGraphFromSharedMemoryHelper(std::move(helper)); return BuildGraphFromSharedMemoryHelper(std::move(helper));
} }
void CSCSamplingGraph::HoldSharedMemoryObject( void FusedCSCSamplingGraph::HoldSharedMemoryObject(
SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm) { SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm) {
tensor_metadata_shm_ = std::move(tensor_metadata_shm); tensor_metadata_shm_ = std::move(tensor_metadata_shm);
tensor_data_shm_ = std::move(tensor_data_shm); tensor_data_shm_ = std::move(tensor_data_shm);
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* @brief Graph bolt library Python binding. * @brief Graph bolt library Python binding.
*/ */
#include <graphbolt/csc_sampling_graph.h> #include <graphbolt/fused_csc_sampling_graph.h>
#include <graphbolt/isin.h> #include <graphbolt/isin.h>
#include <graphbolt/serialize.h> #include <graphbolt/serialize.h>
#include <graphbolt/unique_and_compact.h> #include <graphbolt/unique_and_compact.h>
...@@ -26,43 +26,44 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -26,43 +26,44 @@ TORCH_LIBRARY(graphbolt, m) {
&SampledSubgraph::original_column_node_ids) &SampledSubgraph::original_column_node_ids)
.def_readwrite("original_edge_ids", &SampledSubgraph::original_edge_ids) .def_readwrite("original_edge_ids", &SampledSubgraph::original_edge_ids)
.def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge); .def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge);
m.class_<CSCSamplingGraph>("CSCSamplingGraph") m.class_<FusedCSCSamplingGraph>("FusedCSCSamplingGraph")
.def("num_nodes", &CSCSamplingGraph::NumNodes) .def("num_nodes", &FusedCSCSamplingGraph::NumNodes)
.def("num_edges", &CSCSamplingGraph::NumEdges) .def("num_edges", &FusedCSCSamplingGraph::NumEdges)
.def("csc_indptr", &CSCSamplingGraph::CSCIndptr) .def("csc_indptr", &FusedCSCSamplingGraph::CSCIndptr)
.def("indices", &CSCSamplingGraph::Indices) .def("indices", &FusedCSCSamplingGraph::Indices)
.def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset) .def("node_type_offset", &FusedCSCSamplingGraph::NodeTypeOffset)
.def("type_per_edge", &CSCSamplingGraph::TypePerEdge) .def("type_per_edge", &FusedCSCSamplingGraph::TypePerEdge)
.def("edge_attributes", &CSCSamplingGraph::EdgeAttributes) .def("edge_attributes", &FusedCSCSamplingGraph::EdgeAttributes)
.def("set_csc_indptr", &CSCSamplingGraph::SetCSCIndptr) .def("set_csc_indptr", &FusedCSCSamplingGraph::SetCSCIndptr)
.def("set_indices", &CSCSamplingGraph::SetIndices) .def("set_indices", &FusedCSCSamplingGraph::SetIndices)
.def("set_node_type_offset", &CSCSamplingGraph::SetNodeTypeOffset) .def("set_node_type_offset", &FusedCSCSamplingGraph::SetNodeTypeOffset)
.def("set_type_per_edge", &CSCSamplingGraph::SetTypePerEdge) .def("set_type_per_edge", &FusedCSCSamplingGraph::SetTypePerEdge)
.def("set_edge_attributes", &CSCSamplingGraph::SetEdgeAttributes) .def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes)
.def("in_subgraph", &CSCSamplingGraph::InSubgraph) .def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &CSCSamplingGraph::SampleNeighbors) .def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors)
.def( .def(
"sample_negative_edges_uniform", "sample_negative_edges_uniform",
&CSCSamplingGraph::SampleNegativeEdgesUniform) &FusedCSCSamplingGraph::SampleNegativeEdgesUniform)
.def("copy_to_shared_memory", &CSCSamplingGraph::CopyToSharedMemory) .def("copy_to_shared_memory", &FusedCSCSamplingGraph::CopyToSharedMemory)
.def_pickle( .def_pickle(
// __getstate__ // __getstate__
[](const c10::intrusive_ptr<CSCSamplingGraph>& self) [](const c10::intrusive_ptr<FusedCSCSamplingGraph>& self)
-> torch::Dict< -> torch::Dict<
std::string, torch::Dict<std::string, torch::Tensor>> { std::string, torch::Dict<std::string, torch::Tensor>> {
return self->GetState(); return self->GetState();
}, },
// __setstate__ // __setstate__
[](torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>> [](torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>
state) -> c10::intrusive_ptr<CSCSamplingGraph> { state) -> c10::intrusive_ptr<FusedCSCSamplingGraph> {
auto g = c10::make_intrusive<CSCSamplingGraph>(); auto g = c10::make_intrusive<FusedCSCSamplingGraph>();
g->SetState(state); g->SetState(state);
return g; return g;
}); });
m.def("from_csc", &CSCSamplingGraph::FromCSC); m.def("from_fused_csc", &FusedCSCSamplingGraph::FromCSC);
m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph); m.def("load_fused_csc_sampling_graph", &LoadFusedCSCSamplingGraph);
m.def("save_csc_sampling_graph", &SaveCSCSamplingGraph); m.def("save_fused_csc_sampling_graph", &SaveFusedCSCSamplingGraph);
m.def("load_from_shared_memory", &CSCSamplingGraph::LoadFromSharedMemory); m.def(
"load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory);
m.def("unique_and_compact", &UniqueAndCompact); m.def("unique_and_compact", &UniqueAndCompact);
m.def("isin", &IsIn); m.def("isin", &IsIn);
m.def("index_select", &ops::IndexSelect); m.def("index_select", &ops::IndexSelect);
......
...@@ -11,14 +11,14 @@ namespace torch { ...@@ -11,14 +11,14 @@ namespace torch {
serialize::InputArchive& operator>>( serialize::InputArchive& operator>>(
serialize::InputArchive& archive, serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph) { graphbolt::sampling::FusedCSCSamplingGraph& graph) {
graph.Load(archive); graph.Load(archive);
return archive; return archive;
} }
serialize::OutputArchive& operator<<( serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive, serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph) { const graphbolt::sampling::FusedCSCSamplingGraph& graph) {
graph.Save(archive); graph.Save(archive);
return archive; return archive;
} }
...@@ -27,15 +27,15 @@ serialize::OutputArchive& operator<<( ...@@ -27,15 +27,15 @@ serialize::OutputArchive& operator<<(
namespace graphbolt { namespace graphbolt {
c10::intrusive_ptr<sampling::CSCSamplingGraph> LoadCSCSamplingGraph( c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> LoadFusedCSCSamplingGraph(
const std::string& filename) { const std::string& filename) {
auto&& graph = c10::make_intrusive<sampling::CSCSamplingGraph>(); auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();
torch::load(*graph, filename); torch::load(*graph, filename);
return graph; return graph;
} }
void SaveCSCSamplingGraph( void SaveFusedCSCSamplingGraph(
c10::intrusive_ptr<sampling::CSCSamplingGraph> graph, c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> graph,
const std::string& filename) { const std::string& filename) {
torch::save(*graph, filename); torch::save(*graph, filename);
} }
......
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"BjkAK37xopp1"
],
"gpuType": "T4",
"private_outputs": true,
"authorship_tag": "ABX9TyOCdFtYQweXnIR1/5oWDSGq"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"id": "e1qfiZMOJYYv"
},
"source": [ "source": [
"# Graphbolt Quick Walkthrough\n", "# Graphbolt Quick Walkthrough\n",
"\n", "\n",
"The tutorial provides a quick walkthrough of operators provided by the `dgl.graphbolt` package, and illustrates how to create a GNN datapipe with the package. To learn more details about Stochastic Training of GNNs, please read the [materials](https://docs.dgl.ai/tutorials/large/index.html) provided by DGL.\n", "The tutorial provides a quick walkthrough of operators provided by the `dgl.graphbolt` package, and illustrates how to create a GNN datapipe with the package. To learn more details about Stochastic Training of GNNs, please read the [materials](https://docs.dgl.ai/tutorials/large/index.html) provided by DGL.\n",
"\n", "\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/graphbolt/walkthrough.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/graphbolt/walkthrough.ipynb)" "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/graphbolt/walkthrough.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/graphbolt/walkthrough.ipynb)"
], ]
"metadata": {
"id": "e1qfiZMOJYYv"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
...@@ -64,19 +43,24 @@ ...@@ -64,19 +43,24 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"id": "8O7PfsY4sPoN"
},
"source": [ "source": [
"## Dataset\n", "## Dataset\n",
"\n", "\n",
"The dataset has three primary components. *1*. An itemset, which can be iterated over as the training target. *2*. A sampling graph, which is used by the subgraph sampling algorithm to generate a subgraph. *3*. A feature store, which stores node, edge, and graph features.\n", "The dataset has three primary components. *1*. An itemset, which can be iterated over as the training target. *2*. A sampling graph, which is used by the subgraph sampling algorithm to generate a subgraph. *3*. A feature store, which stores node, edge, and graph features.\n",
"\n", "\n",
"* The **Itemset** is created from iterable data or tuple of iterable data." "* The **Itemset** is created from iterable data or tuple of iterable data."
], ]
"metadata": {
"id": "8O7PfsY4sPoN"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "g73ZAbMQsSgV"
},
"outputs": [],
"source": [ "source": [
"node_pairs = torch.tensor(\n", "node_pairs = torch.tensor(\n",
" [[7, 0], [6, 0], [1, 3], [3, 3], [2, 4], [8, 4], [1, 4], [2, 4], [1, 5],\n", " [[7, 0], [6, 0], [1, 3], [3, 3], [2, 4], [8, 4], [1, 4], [2, 4], [1, 5],\n",
...@@ -85,24 +69,24 @@ ...@@ -85,24 +69,24 @@
")\n", ")\n",
"item_set = gb.ItemSet(node_pairs, names=\"node_pairs\")\n", "item_set = gb.ItemSet(node_pairs, names=\"node_pairs\")\n",
"print(list(item_set))" "print(list(item_set))"
], ]
"metadata": {
"id": "g73ZAbMQsSgV"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"* The **SamplingGraph** is used by the subgraph sampling algorithm to generate a subgraph. In graphbolt, we provide a canonical solution, the CSCSamplingGraph, which achieves state-of-the-art time and space efficiency on CPU sampling. However, this requires enough CPU memory to host all CSCSamplingGraph objects in memory."
],
"metadata": { "metadata": {
"id": "Lqty9p4cs0OR" "id": "Lqty9p4cs0OR"
} },
"source": [
"* The **SamplingGraph** is used by the subgraph sampling algorithm to generate a subgraph. In graphbolt, we provide a canonical solution, the FusedCSCSamplingGraph, which achieves state-of-the-art time and space efficiency on CPU sampling. However, this requires enough CPU memory to host all FusedCSCSamplingGraph objects in memory."
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jDjY149xs3PI"
},
"outputs": [],
"source": [ "source": [
"indptr = torch.tensor([0, 2, 2, 2, 4, 8, 9, 12, 15, 17, 25])\n", "indptr = torch.tensor([0, 2, 2, 2, 4, 8, 9, 12, 15, 17, 25])\n",
"indices = torch.tensor(\n", "indices = torch.tensor(\n",
...@@ -111,26 +95,26 @@ ...@@ -111,26 +95,26 @@
"num_edges = 25\n", "num_edges = 25\n",
"eid = torch.arange(num_edges)\n", "eid = torch.arange(num_edges)\n",
"edge_attributes = {gb.ORIGINAL_EDGE_ID: eid}\n", "edge_attributes = {gb.ORIGINAL_EDGE_ID: eid}\n",
"graph = gb.from_csc(indptr, indices, None, None, edge_attributes, None)\n", "graph = gb.from_fused_csc(indptr, indices, None, None, edge_attributes, None)\n",
"print(graph)" "print(graph)"
], ]
"metadata": {
"id": "jDjY149xs3PI"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"* The **FeatureStore** is used to store node, edge, and graph features. In graphbolt, we provide the TorchBasedFeature and related optimizations, such as the GPUCachedFeature, for different use cases."
],
"metadata": { "metadata": {
"id": "mNp2S2_Vs8af" "id": "mNp2S2_Vs8af"
} },
"source": [
"* The **FeatureStore** is used to store node, edge, and graph features. In graphbolt, we provide the TorchBasedFeature and related optimizations, such as the GPUCachedFeature, for different use cases."
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zIU6KWe1Sm2g"
},
"outputs": [],
"source": [ "source": [
"num_nodes = 10\n", "num_nodes = 10\n",
"num_edges = 25\n", "num_edges = 25\n",
...@@ -144,159 +128,159 @@ ...@@ -144,159 +128,159 @@
"}\n", "}\n",
"feature_store = gb.BasicFeatureStore(features)\n", "feature_store = gb.BasicFeatureStore(features)\n",
"print(feature_store)" "print(feature_store)"
], ]
"metadata": {
"id": "zIU6KWe1Sm2g"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"id": "Oh2ockWWoXQ0"
},
"source": [ "source": [
"## DataPipe\n", "## DataPipe\n",
"\n", "\n",
"The DataPipe in Graphbolt is an extension of the PyTorch DataPipe, but it is specifically designed to address the challenges of training graph neural networks (GNNs). Each stage of the data pipeline loads data from different sources and can be combined with other stages to create more complex data pipelines. The intermediate data will be stored in **MiniBatch** data packs.\n", "The DataPipe in Graphbolt is an extension of the PyTorch DataPipe, but it is specifically designed to address the challenges of training graph neural networks (GNNs). Each stage of the data pipeline loads data from different sources and can be combined with other stages to create more complex data pipelines. The intermediate data will be stored in **MiniBatch** data packs.\n",
"\n", "\n",
"* **ItemSampler** iterates over input **Itemset** and create subsets." "* **ItemSampler** iterates over input **Itemset** and create subsets."
], ]
"metadata": {
"id": "Oh2ockWWoXQ0"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "execution_count": null,
"datapipe = gb.ItemSampler(item_set, batch_size=3, shuffle=False)\n",
"print(next(iter(datapipe)))"
],
"metadata": { "metadata": {
"id": "XtqPDprrogR7" "id": "XtqPDprrogR7"
}, },
"execution_count": null, "outputs": [],
"outputs": [] "source": [
"datapipe = gb.ItemSampler(item_set, batch_size=3, shuffle=False)\n",
"print(next(iter(datapipe)))"
]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"* **NegativeSampler** generate negative samples and return a mix of positive and negative samples."
],
"metadata": { "metadata": {
"id": "BjkAK37xopp1" "id": "BjkAK37xopp1"
} },
"source": [
"* **NegativeSampler** generate negative samples and return a mix of positive and negative samples."
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "execution_count": null,
"datapipe = datapipe.sample_uniform_negative(graph, 1)\n",
"print(next(iter(datapipe)))"
],
"metadata": { "metadata": {
"id": "PrFpGoOGopJy" "id": "PrFpGoOGopJy"
}, },
"execution_count": null, "outputs": [],
"outputs": [] "source": [
"datapipe = datapipe.sample_uniform_negative(graph, 1)\n",
"print(next(iter(datapipe)))"
]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"* **SubgraphSampler** samples a subgraph from a given set of nodes from a larger graph."
],
"metadata": { "metadata": {
"id": "fYO_oIwkpmb3" "id": "fYO_oIwkpmb3"
} },
"source": [
"* **SubgraphSampler** samples a subgraph from a given set of nodes from a larger graph."
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4UsY3PL3ppYV"
},
"outputs": [],
"source": [ "source": [
"fanouts = torch.tensor([1])\n", "fanouts = torch.tensor([1])\n",
"datapipe = datapipe.sample_neighbor(graph, [fanouts])\n", "datapipe = datapipe.sample_neighbor(graph, [fanouts])\n",
"print(next(iter(datapipe)))" "print(next(iter(datapipe)))"
], ]
"metadata": {
"id": "4UsY3PL3ppYV"
},
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"* **FeatureFetcher** fetchs features for node/edge in graphbolt."
],
"metadata": { "metadata": {
"id": "0uIydsjUqMA0" "id": "0uIydsjUqMA0"
} },
"source": [
"* **FeatureFetcher** fetchs features for node/edge in graphbolt."
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "execution_count": null,
"datapipe = datapipe.fetch_feature(feature_store, node_feature_keys=[\"feat\"], edge_feature_keys=[\"feat\"])\n",
"print(next(iter(datapipe)))"
],
"metadata": { "metadata": {
"id": "YAj8G7YBqO6G" "id": "YAj8G7YBqO6G"
}, },
"execution_count": null, "outputs": [],
"outputs": [] "source": [
"datapipe = datapipe.fetch_feature(feature_store, node_feature_keys=[\"feat\"], edge_feature_keys=[\"feat\"])\n",
"print(next(iter(datapipe)))"
]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"id": "Gt059n1xrmj-"
},
"source": [ "source": [
"After retrieving the required data, Graphbolt provides helper methods to convert it to the output format needed for subsequent GNN training.\n", "After retrieving the required data, Graphbolt provides helper methods to convert it to the output format needed for subsequent GNN training.\n",
"\n", "\n",
"* Convert to **DGLMiniBatch** format for training with DGL." "* Convert to **DGLMiniBatch** format for training with DGL."
], ]
"metadata": {
"id": "Gt059n1xrmj-"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "execution_count": null,
"datapipe = datapipe.to_dgl()\n",
"print(next(iter(datapipe)))"
],
"metadata": { "metadata": {
"id": "o8Yoi8BeqSdu" "id": "o8Yoi8BeqSdu"
}, },
"execution_count": null, "outputs": [],
"outputs": [] "source": [
"datapipe = datapipe.to_dgl()\n",
"print(next(iter(datapipe)))"
]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"* Copy the data to the GPU for training on the GPU."
],
"metadata": { "metadata": {
"id": "hjBSLPRPrsD2" "id": "hjBSLPRPrsD2"
} },
"source": [
"* Copy the data to the GPU for training on the GPU."
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "execution_count": null,
"datapipe = datapipe.copy_to(device=\"cuda\")\n",
"print(next(iter(datapipe)))"
],
"metadata": { "metadata": {
"id": "RofiZOUMqt_u" "id": "RofiZOUMqt_u"
}, },
"execution_count": null, "outputs": [],
"outputs": [] "source": [
"datapipe = datapipe.copy_to(device=\"cuda\")\n",
"print(next(iter(datapipe)))"
]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {
"id": "xm9HnyHRvxXj"
},
"source": [ "source": [
"## Exercise: Node classification\n", "## Exercise: Node classification\n",
"\n", "\n",
"Similarly, the following Dataset is created for node classification, can you implement the data pipeline for the dataset?" "Similarly, the following Dataset is created for node classification, can you implement the data pipeline for the dataset?"
], ]
"metadata": {
"id": "xm9HnyHRvxXj"
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YV-mk-xAv78v"
},
"outputs": [],
"source": [ "source": [
"# Dataset for node classification.\n", "# Dataset for node classification.\n",
"num_nodes = 10\n", "num_nodes = 10\n",
...@@ -311,7 +295,7 @@ ...@@ -311,7 +295,7 @@
"eid = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n", "eid = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n",
" 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])\n", " 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])\n",
"edge_attributes = {gb.ORIGINAL_EDGE_ID: eid}\n", "edge_attributes = {gb.ORIGINAL_EDGE_ID: eid}\n",
"graph = gb.from_csc(indptr, indices, None, None, edge_attributes, None)\n", "graph = gb.from_fused_csc(indptr, indices, None, None, edge_attributes, None)\n",
"\n", "\n",
"num_nodes = 10\n", "num_nodes = 10\n",
"num_edges = 25\n", "num_edges = 25\n",
...@@ -328,12 +312,28 @@ ...@@ -328,12 +312,28 @@
"# Datapipe.\n", "# Datapipe.\n",
"...\n", "...\n",
"print(next(iter(datapipe)))" "print(next(iter(datapipe)))"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyOCdFtYQweXnIR1/5oWDSGq",
"collapsed_sections": [
"BjkAK37xopp1"
], ],
"metadata": { "gpuType": "T4",
"id": "YV-mk-xAv78v" "private_outputs": true,
}, "provenance": []
"execution_count": null, },
"outputs": [] "kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
} }
] },
} "nbformat": 4,
\ No newline at end of file "nbformat_minor": 0
}
...@@ -1222,14 +1222,14 @@ def partition_graph( ...@@ -1222,14 +1222,14 @@ def partition_graph(
def convert_dgl_partition_to_csc_sampling_graph(part_config): def convert_dgl_partition_to_csc_sampling_graph(part_config):
"""Convert partitions of dgl to CSCSamplingGraph of GraphBolt. """Convert partitions of dgl to FusedCSCSamplingGraph of GraphBolt.
This API converts `DGLGraph` partitions to `CSCSamplingGraph` which is This API converts `DGLGraph` partitions to `FusedCSCSamplingGraph` which is
dedicated for sampling in `GraphBolt`. New graphs will be stored alongside dedicated for sampling in `GraphBolt`. New graphs will be stored alongside
original graph as `csc_sampling_graph.tar`. original graph as `fused_csc_sampling_graph.tar`.
In the near future, partitions are supposed to be saved as In the near future, partitions are supposed to be saved as
`CSCSamplingGraph` directly. At that time, this API should be deprecated. `FusedCSCSamplingGraph` directly. At that time, this API should be deprecated.
Parameters Parameters
---------- ----------
...@@ -1262,7 +1262,7 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config): ...@@ -1262,7 +1262,7 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
type_per_edge = type_per_edge.to(RESERVED_FIELD_DTYPE[ETYPE]) type_per_edge = type_per_edge.to(RESERVED_FIELD_DTYPE[ETYPE])
# Sanity check. # Sanity check.
assert len(type_per_edge) == graph.num_edges() assert len(type_per_edge) == graph.num_edges()
csc_graph = graphbolt.from_csc( csc_graph = graphbolt.from_fused_csc(
indptr, indices, None, type_per_edge, metadata=metadata indptr, indices, None, type_per_edge, metadata=metadata
) )
orig_graph_path = os.path.join( orig_graph_path = os.path.join(
...@@ -1270,6 +1270,6 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config): ...@@ -1270,6 +1270,6 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
part_meta[f"part-{part_id}"]["part_graph"], part_meta[f"part-{part_id}"]["part_graph"],
) )
csc_graph_path = os.path.join( csc_graph_path = os.path.join(
os.path.dirname(orig_graph_path), "csc_sampling_graph.tar" os.path.dirname(orig_graph_path), "fused_csc_sampling_graph.tar"
) )
graphbolt.save_csc_sampling_graph(csc_graph, csc_graph_path) graphbolt.save_fused_csc_sampling_graph(csc_graph, csc_graph_path)
"""Implementation of GraphBolt.""" """Implementation of GraphBolt."""
from .basic_feature_store import * from .basic_feature_store import *
from .csc_sampling_graph import * from .fused_csc_sampling_graph import *
from .neighbor_sampler import * from .neighbor_sampler import *
from .ondisk_dataset import * from .ondisk_dataset import *
from .ondisk_metadata import * from .ondisk_metadata import *
......
...@@ -20,11 +20,11 @@ from .sampled_subgraph_impl import SampledSubgraphImpl ...@@ -20,11 +20,11 @@ from .sampled_subgraph_impl import SampledSubgraphImpl
__all__ = [ __all__ = [
"GraphMetadata", "GraphMetadata",
"CSCSamplingGraph", "FusedCSCSamplingGraph",
"from_csc", "from_fused_csc",
"load_from_shared_memory", "load_from_shared_memory",
"load_csc_sampling_graph", "load_fused_csc_sampling_graph",
"save_csc_sampling_graph", "save_fused_csc_sampling_graph",
"from_dglgraph", "from_dglgraph",
] ]
...@@ -88,7 +88,7 @@ class GraphMetadata: ...@@ -88,7 +88,7 @@ class GraphMetadata:
self.edge_type_to_id = edge_type_to_id self.edge_type_to_id = edge_type_to_id
class CSCSamplingGraph(SamplingGraph): class FusedCSCSamplingGraph(SamplingGraph):
r"""A sampling graph in CSC format.""" r"""A sampling graph in CSC format."""
def __repr__(self): def __repr__(self):
...@@ -150,7 +150,7 @@ class CSCSamplingGraph(SamplingGraph): ...@@ -150,7 +150,7 @@ class CSCSamplingGraph(SamplingGraph):
>>> type_per_edge = torch.LongTensor( >>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3]) ... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> metadata = gb.GraphMetadata(ntypes, etypes) >>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> graph = gb.from_csc(indptr, indices, node_type_offset, >>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata) ... type_per_edge, None, metadata)
>>> print(graph.num_nodes) >>> print(graph.num_nodes)
{'N0': 2, 'N1': 3} {'N0': 2, 'N1': 3}
...@@ -425,7 +425,7 @@ class CSCSamplingGraph(SamplingGraph): ...@@ -425,7 +425,7 @@ class CSCSamplingGraph(SamplingGraph):
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0]) >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_csc(indptr, indices, type_per_edge=type_per_edge, >>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge,
... node_type_offset=node_type_offset, metadata=metadata) ... node_type_offset=node_type_offset, metadata=metadata)
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])} >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1]) >>> fanouts = torch.tensor([1, 1])
...@@ -605,7 +605,7 @@ class CSCSamplingGraph(SamplingGraph): ...@@ -605,7 +605,7 @@ class CSCSamplingGraph(SamplingGraph):
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0]) >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_csc(indptr, indices, type_per_edge=type_per_edge, >>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge,
... node_type_offset=node_type_offset, metadata=metadata) ... node_type_offset=node_type_offset, metadata=metadata)
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])} >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1]) >>> fanouts = torch.tensor([1, 1])
...@@ -697,16 +697,16 @@ class CSCSamplingGraph(SamplingGraph): ...@@ -697,16 +697,16 @@ class CSCSamplingGraph(SamplingGraph):
Returns Returns
------- -------
CSCSamplingGraph FusedCSCSamplingGraph
The copied CSCSamplingGraph object on shared memory. The copied FusedCSCSamplingGraph object on shared memory.
""" """
return CSCSamplingGraph( return FusedCSCSamplingGraph(
self._c_csc_graph.copy_to_shared_memory(shared_memory_name), self._c_csc_graph.copy_to_shared_memory(shared_memory_name),
self._metadata, self._metadata,
) )
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `CSCSamplingGraph` to the specified device.""" """Copy `FusedCSCSamplingGraph` to the specified device."""
def _to(x, device): def _to(x, device):
return x.to(device) if hasattr(x, "to") else x return x.to(device) if hasattr(x, "to") else x
...@@ -728,15 +728,15 @@ class CSCSamplingGraph(SamplingGraph): ...@@ -728,15 +728,15 @@ class CSCSamplingGraph(SamplingGraph):
return self return self
def from_csc( def from_fused_csc(
csc_indptr: torch.Tensor, csc_indptr: torch.Tensor,
indices: torch.Tensor, indices: torch.Tensor,
node_type_offset: Optional[torch.tensor] = None, node_type_offset: Optional[torch.tensor] = None,
type_per_edge: Optional[torch.tensor] = None, type_per_edge: Optional[torch.tensor] = None,
edge_attributes: Optional[Dict[str, torch.tensor]] = None, edge_attributes: Optional[Dict[str, torch.tensor]] = None,
metadata: Optional[GraphMetadata] = None, metadata: Optional[GraphMetadata] = None,
) -> CSCSamplingGraph: ) -> FusedCSCSamplingGraph:
"""Create a CSCSamplingGraph object from a CSC representation. """Create a FusedCSCSamplingGraph object from a CSC representation.
Parameters Parameters
---------- ----------
...@@ -756,8 +756,8 @@ def from_csc( ...@@ -756,8 +756,8 @@ def from_csc(
Metadata of the graph, by default None. Metadata of the graph, by default None.
Returns Returns
------- -------
CSCSamplingGraph FusedCSCSamplingGraph
The created CSCSamplingGraph object. The created FusedCSCSamplingGraph object.
Examples Examples
-------- --------
...@@ -768,13 +768,13 @@ def from_csc( ...@@ -768,13 +768,13 @@ def from_csc(
>>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3]) >>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
>>> node_type_offset = torch.tensor([0, 1, 2, 3]) >>> node_type_offset = torch.tensor([0, 1, 2, 3])
>>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0]) >>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
>>> graph = graphbolt.from_csc(csc_indptr, indices, >>> graph = graphbolt.from_fused_csc(csc_indptr, indices,
... node_type_offset=node_type_offset, ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge, ... type_per_edge=type_per_edge,
... edge_attributes=None, metadata=metadata) ... edge_attributes=None, metadata=metadata)
None, metadata) None, metadata)
>>> print(graph) >>> print(graph)
CSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]), FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]), indices=tensor([1, 3, 0, 1, 2, 0, 3]),
total_num_nodes=3, total_num_edges=7) total_num_nodes=3, total_num_edges=7)
""" """
...@@ -782,8 +782,8 @@ def from_csc( ...@@ -782,8 +782,8 @@ def from_csc(
assert len(metadata.node_type_to_id) + 1 == node_type_offset.size( assert len(metadata.node_type_to_id) + 1 == node_type_offset.size(
0 0
), "node_type_offset length should be |ntypes| + 1." ), "node_type_offset length should be |ntypes| + 1."
return CSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_csc( torch.ops.graphbolt.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset, node_type_offset,
...@@ -797,8 +797,8 @@ def from_csc( ...@@ -797,8 +797,8 @@ def from_csc(
def load_from_shared_memory( def load_from_shared_memory(
shared_memory_name: str, shared_memory_name: str,
metadata: Optional[GraphMetadata] = None, metadata: Optional[GraphMetadata] = None,
) -> CSCSamplingGraph: ) -> FusedCSCSamplingGraph:
"""Load a CSCSamplingGraph object from shared memory. """Load a FusedCSCSamplingGraph object from shared memory.
Parameters Parameters
---------- ----------
...@@ -807,16 +807,16 @@ def load_from_shared_memory( ...@@ -807,16 +807,16 @@ def load_from_shared_memory(
Returns Returns
------- -------
CSCSamplingGraph FusedCSCSamplingGraph
The loaded CSCSamplingGraph object on shared memory. The loaded FusedCSCSamplingGraph object on shared memory.
""" """
return CSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.load_from_shared_memory(shared_memory_name), torch.ops.graphbolt.load_from_shared_memory(shared_memory_name),
metadata, metadata,
) )
def _csc_sampling_graph_str(graph: CSCSamplingGraph) -> str: def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
"""Internal function for converting a csc sampling graph to string """Internal function for converting a csc sampling graph to string
representation. representation.
""" """
...@@ -848,24 +848,24 @@ def _csc_sampling_graph_str(graph: CSCSamplingGraph) -> str: ...@@ -848,24 +848,24 @@ def _csc_sampling_graph_str(graph: CSCSamplingGraph) -> str:
return final_str return final_str
def load_csc_sampling_graph(filename): def load_fused_csc_sampling_graph(filename):
"""Load CSCSamplingGraph from tar file.""" """Load FusedCSCSamplingGraph from tar file."""
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
with tarfile.open(filename, "r") as archive: with tarfile.open(filename, "r") as archive:
archive.extractall(temp_dir) archive.extractall(temp_dir)
graph_filename = os.path.join(temp_dir, "csc_sampling_graph.pt") graph_filename = os.path.join(temp_dir, "fused_csc_sampling_graph.pt")
metadata_filename = os.path.join(temp_dir, "metadata.pt") metadata_filename = os.path.join(temp_dir, "metadata.pt")
return CSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.load_csc_sampling_graph(graph_filename), torch.ops.graphbolt.load_fused_csc_sampling_graph(graph_filename),
torch.load(metadata_filename), torch.load(metadata_filename),
) )
def save_csc_sampling_graph(graph, filename): def save_fused_csc_sampling_graph(graph, filename):
"""Save CSCSamplingGraph to tar file.""" """Save FusedCSCSamplingGraph to tar file."""
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
graph_filename = os.path.join(temp_dir, "csc_sampling_graph.pt") graph_filename = os.path.join(temp_dir, "fused_csc_sampling_graph.pt")
torch.ops.graphbolt.save_csc_sampling_graph( torch.ops.graphbolt.save_fused_csc_sampling_graph(
graph._c_csc_graph, graph_filename graph._c_csc_graph, graph_filename
) )
metadata_filename = os.path.join(temp_dir, "metadata.pt") metadata_filename = os.path.join(temp_dir, "metadata.pt")
...@@ -877,15 +877,15 @@ def save_csc_sampling_graph(graph, filename): ...@@ -877,15 +877,15 @@ def save_csc_sampling_graph(graph, filename):
archive.add( archive.add(
metadata_filename, arcname=os.path.basename(metadata_filename) metadata_filename, arcname=os.path.basename(metadata_filename)
) )
print(f"CSCSamplingGraph has been saved to {filename}.") print(f"FusedCSCSamplingGraph has been saved to {filename}.")
def from_dglgraph( def from_dglgraph(
g: DGLGraph, g: DGLGraph,
is_homogeneous: bool = False, is_homogeneous: bool = False,
include_original_edge_id: bool = False, include_original_edge_id: bool = False,
) -> CSCSamplingGraph: ) -> FusedCSCSamplingGraph:
"""Convert a DGLGraph to CSCSamplingGraph.""" """Convert a DGLGraph to FusedCSCSamplingGraph."""
homo_g, ntype_count, _ = to_homogeneous(g, return_count=True) homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
...@@ -913,8 +913,8 @@ def from_dglgraph( ...@@ -913,8 +913,8 @@ def from_dglgraph(
# Assign edge attributes according to the original eids mapping. # Assign edge attributes according to the original eids mapping.
edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids] edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
return CSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_csc( torch.ops.graphbolt.from_fused_csc(
indptr, indptr,
indices, indices,
node_type_offset, node_type_offset,
......
...@@ -28,7 +28,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -28,7 +28,7 @@ class NeighborSampler(SubgraphSampler):
---------- ----------
datapipe : DataPipe datapipe : DataPipe
The datapipe. The datapipe.
graph : CSCSamplingGraph graph : FusedCSCSamplingGraph
The graph on which to perform subgraph sampling. The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor] or list[int] fanouts: list[torch.Tensor] or list[int]
The number of edges to be sampled for each node with or without The number of edges to be sampled for each node with or without
...@@ -59,7 +59,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -59,7 +59,7 @@ class NeighborSampler(SubgraphSampler):
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5]) >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices) >>> graph = gb.from_fused_csc(indptr, indices)
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]]) >>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs") >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
...@@ -165,7 +165,7 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -165,7 +165,7 @@ class LayerNeighborSampler(NeighborSampler):
---------- ----------
datapipe : DataPipe datapipe : DataPipe
The datapipe. The datapipe.
graph : CSCSamplingGraph graph : FusedCSCSamplingGraph
The graph on which to perform subgraph sampling. The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor] fanouts: list[torch.Tensor]
The number of edges to be sampled for each node with or without The number of edges to be sampled for each node with or without
...@@ -192,7 +192,7 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -192,7 +192,7 @@ class LayerNeighborSampler(NeighborSampler):
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5]) >>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices) >>> graph = gb.from_fused_csc(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT >>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]]) >>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs") >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
......
...@@ -17,11 +17,11 @@ from ..dataset import Dataset, Task ...@@ -17,11 +17,11 @@ from ..dataset import Dataset, Task
from ..itemset import ItemSet, ItemSetDict from ..itemset import ItemSet, ItemSetDict
from ..sampling_graph import SamplingGraph from ..sampling_graph import SamplingGraph
from ..utils import copy_or_convert_data, read_data from ..utils import copy_or_convert_data, read_data
from .csc_sampling_graph import ( from .fused_csc_sampling_graph import (
CSCSamplingGraph,
from_dglgraph, from_dglgraph,
load_csc_sampling_graph, FusedCSCSamplingGraph,
save_csc_sampling_graph, load_fused_csc_sampling_graph,
save_fused_csc_sampling_graph,
) )
from .ondisk_metadata import ( from .ondisk_metadata import (
OnDiskGraphTopology, OnDiskGraphTopology,
...@@ -45,7 +45,7 @@ def preprocess_ondisk_dataset( ...@@ -45,7 +45,7 @@ def preprocess_ondisk_dataset(
dataset_dir : str dataset_dir : str
The path to the dataset directory. The path to the dataset directory.
include_original_edge_id : bool, optional include_original_edge_id : bool, optional
Whether to include the original edge id in the CSCSamplingGraph. Whether to include the original edge id in the FusedCSCSamplingGraph.
Returns Returns
------- -------
...@@ -138,20 +138,20 @@ def preprocess_ondisk_dataset( ...@@ -138,20 +138,20 @@ def preprocess_ondisk_dataset(
) )
g.edata[graph_feature["name"]] = edge_data g.edata[graph_feature["name"]] = edge_data
# 4. Convert the DGLGraph to a CSCSamplingGraph. # 4. Convert the DGLGraph to a FusedCSCSamplingGraph.
csc_sampling_graph = from_dglgraph( fused_csc_sampling_graph = from_dglgraph(
g, is_homogeneous, include_original_edge_id g, is_homogeneous, include_original_edge_id
) )
# 5. Save the CSCSamplingGraph and modify the output_config. # 5. Save the FusedCSCSamplingGraph and modify the output_config.
output_config["graph_topology"] = {} output_config["graph_topology"] = {}
output_config["graph_topology"]["type"] = "CSCSamplingGraph" output_config["graph_topology"]["type"] = "FusedCSCSamplingGraph"
output_config["graph_topology"]["path"] = os.path.join( output_config["graph_topology"]["path"] = os.path.join(
processed_dir_prefix, "csc_sampling_graph.tar" processed_dir_prefix, "fused_csc_sampling_graph.tar"
) )
save_csc_sampling_graph( save_fused_csc_sampling_graph(
csc_sampling_graph, fused_csc_sampling_graph,
os.path.join( os.path.join(
dataset_dir, dataset_dir,
output_config["graph_topology"]["path"], output_config["graph_topology"]["path"],
...@@ -283,8 +283,8 @@ class OnDiskDataset(Dataset): ...@@ -283,8 +283,8 @@ class OnDiskDataset(Dataset):
dataset_name: graphbolt_test dataset_name: graphbolt_test
graph_topology: graph_topology:
type: CSCSamplingGraph type: FusedCSCSamplingGraph
path: graph_topology/csc_sampling_graph.tar path: graph_topology/fused_csc_sampling_graph.tar
feature_data: feature_data:
- domain: node - domain: node
type: paper type: paper
...@@ -340,7 +340,7 @@ class OnDiskDataset(Dataset): ...@@ -340,7 +340,7 @@ class OnDiskDataset(Dataset):
path: str path: str
The YAML file path. The YAML file path.
include_original_edge_id: bool, optional include_original_edge_id: bool, optional
Whether to include the original edge id in the CSCSamplingGraph. Whether to include the original edge id in the FusedCSCSamplingGraph.
""" """
def __init__( def __init__(
...@@ -434,12 +434,12 @@ class OnDiskDataset(Dataset): ...@@ -434,12 +434,12 @@ class OnDiskDataset(Dataset):
def _load_graph( def _load_graph(
self, graph_topology: OnDiskGraphTopology self, graph_topology: OnDiskGraphTopology
) -> CSCSamplingGraph: ) -> FusedCSCSamplingGraph:
"""Load the graph topology.""" """Load the graph topology."""
if graph_topology is None: if graph_topology is None:
return None return None
if graph_topology.type == "CSCSamplingGraph": if graph_topology.type == "FusedCSCSamplingGraph":
return load_csc_sampling_graph(graph_topology.path) return load_fused_csc_sampling_graph(graph_topology.path)
raise NotImplementedError( raise NotImplementedError(
f"Graph topology type {graph_topology.type} is not supported." f"Graph topology type {graph_topology.type} is not supported."
) )
......
...@@ -64,7 +64,7 @@ class OnDiskFeatureData(pydantic.BaseModel): ...@@ -64,7 +64,7 @@ class OnDiskFeatureData(pydantic.BaseModel):
class OnDiskGraphTopologyType(str, Enum): class OnDiskGraphTopologyType(str, Enum):
"""Enum of graph topology type.""" """Enum of graph topology type."""
CSC_SAMPLING = "CSCSamplingGraph" FUSED_CSC_SAMPLING = "FusedCSCSamplingGraph"
class OnDiskGraphTopology(pydantic.BaseModel): class OnDiskGraphTopology(pydantic.BaseModel):
......
"""Sampled subgraph for CSCSamplingGraph.""" """Sampled subgraph for FusedCSCSamplingGraph."""
# pylint: disable= invalid-name # pylint: disable= invalid-name
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
...@@ -13,7 +13,7 @@ __all__ = ["SampledSubgraphImpl"] ...@@ -13,7 +13,7 @@ __all__ = ["SampledSubgraphImpl"]
@dataclass @dataclass
class SampledSubgraphImpl(SampledSubgraph): class SampledSubgraphImpl(SampledSubgraph):
r"""Sampled subgraph of CSCSamplingGraph. r"""Sampled subgraph of FusedCSCSamplingGraph.
Examples Examples
-------- --------
......
...@@ -22,7 +22,7 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -22,7 +22,7 @@ class UniformNegativeSampler(NegativeSampler):
---------- ----------
datapipe : DataPipe datapipe : DataPipe
The datapipe. The datapipe.
graph : CSCSamplingGraph graph : FusedCSCSamplingGraph
The graph on which to perform negative sampling. The graph on which to perform negative sampling.
negative_ratio : int negative_ratio : int
The proportion of negative samples to positive samples. The proportion of negative samples to positive samples.
...@@ -32,7 +32,7 @@ class UniformNegativeSampler(NegativeSampler): ...@@ -32,7 +32,7 @@ class UniformNegativeSampler(NegativeSampler):
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5]) >>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0]) >>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices) >>> graph = gb.from_fused_csc(indptr, indices)
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2])) >>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs") >>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
......
...@@ -694,8 +694,10 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo( ...@@ -694,8 +694,10 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
orig_g = dgl.load_graphs( orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl") os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0] )[0][0]
new_g = dgl.graphbolt.load_csc_sampling_graph( new_g = dgl.graphbolt.load_fused_csc_sampling_graph(
os.path.join(test_dir, f"part{part_id}/csc_sampling_graph.tar") os.path.join(
test_dir, f"part{part_id}/fused_csc_sampling_graph.tar"
)
) )
orig_indptr, orig_indices, _ = orig_g.adj().csc() orig_indptr, orig_indices, _ = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr) assert th.equal(orig_indptr, new_g.csc_indptr)
...@@ -725,8 +727,10 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero( ...@@ -725,8 +727,10 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
orig_g = dgl.load_graphs( orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl") os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0] )[0][0]
new_g = dgl.graphbolt.load_csc_sampling_graph( new_g = dgl.graphbolt.load_fused_csc_sampling_graph(
os.path.join(test_dir, f"part{part_id}/csc_sampling_graph.tar") os.path.join(
test_dir, f"part{part_id}/fused_csc_sampling_graph.tar"
)
) )
orig_indptr, orig_indices, _ = orig_g.adj().csc() orig_indptr, orig_indices, _ = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr) assert th.equal(orig_indptr, new_g.csc_indptr)
......
...@@ -17,7 +17,7 @@ def rand_csc_graph(N, density): ...@@ -17,7 +17,7 @@ def rand_csc_graph(N, density):
indptr = torch.LongTensor(adj.indptr) indptr = torch.LongTensor(adj.indptr)
indices = torch.LongTensor(adj.indices) indices = torch.LongTensor(adj.indices)
graph = gb.from_csc(indptr, indices) graph = gb.from_fused_csc(indptr, indices)
return graph return graph
......
...@@ -26,7 +26,7 @@ mp.set_sharing_strategy("file_system") ...@@ -26,7 +26,7 @@ mp.set_sharing_strategy("file_system")
def test_empty_graph(total_num_nodes): def test_empty_graph(total_num_nodes):
csc_indptr = torch.zeros((total_num_nodes + 1,), dtype=int) csc_indptr = torch.zeros((total_num_nodes + 1,), dtype=int)
indices = torch.tensor([]) indices = torch.tensor([])
graph = gb.from_csc(csc_indptr, indices) graph = gb.from_fused_csc(csc_indptr, indices)
assert graph.total_num_edges == 0 assert graph.total_num_edges == 0
assert graph.total_num_nodes == total_num_nodes assert graph.total_num_nodes == total_num_nodes
assert torch.equal(graph.csc_indptr, csc_indptr) assert torch.equal(graph.csc_indptr, csc_indptr)
...@@ -52,7 +52,7 @@ def test_hetero_empty_graph(total_num_nodes): ...@@ -52,7 +52,7 @@ def test_hetero_empty_graph(total_num_nodes):
node_type_offset[0] = 0 node_type_offset[0] = 0
node_type_offset[-1] = total_num_nodes node_type_offset[-1] = total_num_nodes
type_per_edge = torch.tensor([]) type_per_edge = torch.tensor([])
graph = gb.from_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset, node_type_offset,
...@@ -119,7 +119,9 @@ def test_homo_graph(total_num_nodes, total_num_edges): ...@@ -119,7 +119,9 @@ def test_homo_graph(total_num_nodes, total_num_edges):
"A1": torch.randn(total_num_edges), "A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges),
} }
graph = gb.from_csc(csc_indptr, indices, edge_attributes=edge_attributes) graph = gb.from_fused_csc(
csc_indptr, indices, edge_attributes=edge_attributes
)
assert graph.total_num_nodes == total_num_nodes assert graph.total_num_nodes == total_num_nodes
assert graph.total_num_edges == total_num_edges assert graph.total_num_edges == total_num_edges
...@@ -156,7 +158,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes): ...@@ -156,7 +158,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
"A1": torch.randn(total_num_edges), "A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges),
} }
graph = gb.from_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset, node_type_offset,
...@@ -193,7 +195,9 @@ def test_num_nodes_homo(total_num_nodes, total_num_edges): ...@@ -193,7 +195,9 @@ def test_num_nodes_homo(total_num_nodes, total_num_edges):
"A1": torch.randn(total_num_edges), "A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges), "A2": torch.randn(total_num_edges),
} }
graph = gb.from_csc(csc_indptr, indices, edge_attributes=edge_attributes) graph = gb.from_fused_csc(
csc_indptr, indices, edge_attributes=edge_attributes
)
assert graph.num_nodes == total_num_nodes assert graph.num_nodes == total_num_nodes
...@@ -239,9 +243,9 @@ def test_num_nodes_hetero(): ...@@ -239,9 +243,9 @@ def test_num_nodes_hetero():
assert node_type_offset[-1] == total_num_nodes assert node_type_offset[-1] == total_num_nodes
assert all(type_per_edge < len(etypes)) assert all(type_per_edge < len(etypes))
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
metadata = gb.GraphMetadata(ntypes, etypes) metadata = gb.GraphMetadata(ntypes, etypes)
graph = gb.from_csc( graph = gb.from_fused_csc(
indptr, indices, node_type_offset, type_per_edge, None, metadata indptr, indices, node_type_offset, type_per_edge, None, metadata
) )
...@@ -273,7 +277,7 @@ def test_node_type_offset_wrong_legnth(node_type_offset): ...@@ -273,7 +277,7 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
10, 50, num_ntypes, 5 10, 50, num_ntypes, 5
) )
with pytest.raises(Exception): with pytest.raises(Exception):
gb.from_csc( gb.from_fused_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
) )
...@@ -290,12 +294,12 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges): ...@@ -290,12 +294,12 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
csc_indptr, indices = gbt.random_homo_graph( csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges total_num_nodes, total_num_edges
) )
graph = gb.from_csc(csc_indptr, indices) graph = gb.from_fused_csc(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "csc_sampling_graph.tar") filename = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_csc_sampling_graph(graph, filename) gb.save_fused_csc_sampling_graph(graph, filename)
graph2 = gb.load_csc_sampling_graph(filename) graph2 = gb.load_fused_csc_sampling_graph(filename)
assert graph.total_num_nodes == graph2.total_num_nodes assert graph.total_num_nodes == graph2.total_num_nodes
assert graph.total_num_edges == graph2.total_num_edges assert graph.total_num_edges == graph2.total_num_edges
...@@ -329,14 +333,14 @@ def test_load_save_hetero_graph( ...@@ -329,14 +333,14 @@ def test_load_save_hetero_graph(
) = gbt.random_hetero_graph( ) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes, total_num_edges, num_ntypes, num_etypes
) )
graph = gb.from_csc( graph = gb.from_fused_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
) )
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "csc_sampling_graph.tar") filename = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_csc_sampling_graph(graph, filename) gb.save_fused_csc_sampling_graph(graph, filename)
graph2 = gb.load_csc_sampling_graph(filename) graph2 = gb.load_fused_csc_sampling_graph(filename)
assert graph.total_num_nodes == graph2.total_num_nodes assert graph.total_num_nodes == graph2.total_num_nodes
assert graph.total_num_edges == graph2.total_num_edges assert graph.total_num_edges == graph2.total_num_edges
...@@ -361,7 +365,7 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges): ...@@ -361,7 +365,7 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges):
csc_indptr, indices = gbt.random_homo_graph( csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges total_num_nodes, total_num_edges
) )
graph = gb.from_csc(csc_indptr, indices) graph = gb.from_fused_csc(csc_indptr, indices)
serialized = pickle.dumps(graph) serialized = pickle.dumps(graph)
graph2 = pickle.loads(serialized) graph2 = pickle.loads(serialized)
...@@ -402,7 +406,7 @@ def test_pickle_hetero_graph( ...@@ -402,7 +406,7 @@ def test_pickle_hetero_graph(
"a": torch.randn((total_num_edges,)), "a": torch.randn((total_num_edges,)),
"b": torch.randint(1, 10, (total_num_edges,)), "b": torch.randint(1, 10, (total_num_edges,)),
} }
graph = gb.from_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset, node_type_offset,
...@@ -453,7 +457,7 @@ def test_multiprocessing(): ...@@ -453,7 +457,7 @@ def test_multiprocessing():
edge_attributes = { edge_attributes = {
"a": torch.randn((total_num_edges,)), "a": torch.randn((total_num_edges,)),
} }
graph = gb.from_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset, node_type_offset,
...@@ -489,8 +493,8 @@ def test_in_subgraph_homogeneous(): ...@@ -489,8 +493,8 @@ def test_in_subgraph_homogeneous():
assert indptr[-1] == total_num_edges assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_csc(indptr, indices) graph = gb.from_fused_csc(indptr, indices)
# Extract in subgraph. # Extract in subgraph.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -552,9 +556,9 @@ def test_in_subgraph_heterogeneous(): ...@@ -552,9 +556,9 @@ def test_in_subgraph_heterogeneous():
assert node_type_offset[-1] == total_num_nodes assert node_type_offset[-1] == total_num_nodes
assert all(type_per_edge < len(etypes)) assert all(type_per_edge < len(etypes))
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
metadata = gb.GraphMetadata(ntypes, etypes) metadata = gb.GraphMetadata(ntypes, etypes)
graph = gb.from_csc( graph = gb.from_fused_csc(
indptr, indices, node_type_offset, type_per_edge, None, metadata indptr, indices, node_type_offset, type_per_edge, None, metadata
) )
...@@ -599,8 +603,8 @@ def test_sample_neighbors_homo(): ...@@ -599,8 +603,8 @@ def test_sample_neighbors_homo():
assert indptr[-1] == total_num_edges assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_csc(indptr, indices) graph = gb.from_fused_csc(indptr, indices)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -642,8 +646,8 @@ def test_sample_neighbors_hetero(labor): ...@@ -642,8 +646,8 @@ def test_sample_neighbors_hetero(labor):
assert indptr[-1] == total_num_edges assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_csc( graph = gb.from_fused_csc(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -748,8 +752,8 @@ def test_sample_neighbors_fanouts( ...@@ -748,8 +752,8 @@ def test_sample_neighbors_fanouts(
assert indptr[-1] == total_num_edges assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_csc( graph = gb.from_fused_csc(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -806,8 +810,8 @@ def test_sample_neighbors_replace( ...@@ -806,8 +810,8 @@ def test_sample_neighbors_replace(
assert indptr[-1] == total_num_edges assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_csc( graph = gb.from_fused_csc(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -849,8 +853,8 @@ def test_sample_neighbors_return_eids_homo(labor): ...@@ -849,8 +853,8 @@ def test_sample_neighbors_return_eids_homo(labor):
# Add edge id mapping from CSC graph -> original graph. # Add edge id mapping from CSC graph -> original graph.
edge_attributes = {gb.ORIGINAL_EDGE_ID: torch.randperm(total_num_edges)} edge_attributes = {gb.ORIGINAL_EDGE_ID: torch.randperm(total_num_edges)}
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes) graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -897,8 +901,8 @@ def test_sample_neighbors_return_eids_hetero(labor): ...@@ -897,8 +901,8 @@ def test_sample_neighbors_return_eids_hetero(labor):
assert indptr[-1] == total_num_edges assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_csc( graph = gb.from_fused_csc(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -956,8 +960,8 @@ def test_sample_neighbors_probs(replace, labor, probs_name): ...@@ -956,8 +960,8 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
"mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1]), "mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1]),
} }
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes) graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -1002,8 +1006,8 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask): ...@@ -1002,8 +1006,8 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
edge_attributes = {"probs_or_mask": probs_or_mask} edge_attributes = {"probs_or_mask": probs_or_mask}
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes) graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -1038,7 +1042,7 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor): ...@@ -1038,7 +1042,7 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
@unittest.skipIf( @unittest.skipIf(
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="CSCSamplingGraph is only supported on CPU.", reason="FusedCSCSamplingGraph is only supported on CPU.",
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"total_num_nodes, total_num_edges", "total_num_nodes, total_num_edges",
...@@ -1058,7 +1062,9 @@ def test_homo_graph_on_shared_memory( ...@@ -1058,7 +1062,9 @@ def test_homo_graph_on_shared_memory(
} }
else: else:
edge_attributes = None edge_attributes = None
graph = gb.from_csc(csc_indptr, indices, edge_attributes=edge_attributes) graph = gb.from_fused_csc(
csc_indptr, indices, edge_attributes=edge_attributes
)
shm_name = "test_homo_g" shm_name = "test_homo_g"
graph1 = graph.copy_to_shared_memory(shm_name) graph1 = graph.copy_to_shared_memory(shm_name)
...@@ -1099,7 +1105,7 @@ def test_homo_graph_on_shared_memory( ...@@ -1099,7 +1105,7 @@ def test_homo_graph_on_shared_memory(
@unittest.skipIf( @unittest.skipIf(
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="CSCSamplingGraph is only supported on CPU.", reason="FusedCSCSamplingGraph is only supported on CPU.",
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"total_num_nodes, total_num_edges", "total_num_nodes, total_num_edges",
...@@ -1127,7 +1133,7 @@ def test_hetero_graph_on_shared_memory( ...@@ -1127,7 +1133,7 @@ def test_hetero_graph_on_shared_memory(
} }
else: else:
edge_attributes = None edge_attributes = None
graph = gb.from_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -1250,7 +1256,7 @@ def test_multiprocessing_with_shared_memory(): ...@@ -1250,7 +1256,7 @@ def test_multiprocessing_with_shared_memory():
node_type_offset.share_memory_() node_type_offset.share_memory_()
type_per_edge.share_memory_() type_per_edge.share_memory_()
graph = gb.from_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
...@@ -1308,7 +1314,7 @@ def test_from_dglgraph_homogeneous(): ...@@ -1308,7 +1314,7 @@ def test_from_dglgraph_homogeneous():
gb_g = gb.from_dglgraph( gb_g = gb.from_dglgraph(
dgl_g, is_homogeneous=True, include_original_edge_id=True dgl_g, is_homogeneous=True, include_original_edge_id=True
) )
# Get the COO representation of the CSCSamplingGraph. # Get the COO representation of the FusedCSCSamplingGraph.
num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1] num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1]
rows = gb_g.indices rows = gb_g.indices
columns = torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns) columns = torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns)
...@@ -1360,21 +1366,21 @@ def test_from_dglgraph_heterogeneous(): ...@@ -1360,21 +1366,21 @@ def test_from_dglgraph_heterogeneous():
dgl_g, is_homogeneous=False, include_original_edge_id=True dgl_g, is_homogeneous=False, include_original_edge_id=True
) )
# `reverse_node_id` is used to map the node id in CSCSamplingGraph to the # `reverse_node_id` is used to map the node id in FusedCSCSamplingGraph to the
# node id in Hetero-DGLGraph. # node id in Hetero-DGLGraph.
num_ntypes = gb_g.node_type_offset[1:] - gb_g.node_type_offset[:-1] num_ntypes = gb_g.node_type_offset[1:] - gb_g.node_type_offset[:-1]
reverse_node_id = torch.cat([torch.arange(num) for num in num_ntypes]) reverse_node_id = torch.cat([torch.arange(num) for num in num_ntypes])
# Get the COO representation of the CSCSamplingGraph. # Get the COO representation of the FusedCSCSamplingGraph.
num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1] num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1]
rows = reverse_node_id[gb_g.indices] rows = reverse_node_id[gb_g.indices]
columns = reverse_node_id[ columns = reverse_node_id[
torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns) torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns)
] ]
# Check the order of etypes in DGLGraph is the same as CSCSamplingGraph. # Check the order of etypes in DGLGraph is the same as FusedCSCSamplingGraph.
assert ( assert (
# Since the etypes in CSCSamplingGraph is "srctype:etype:dsttype", # Since the etypes in FusedCSCSamplingGraph is "srctype:etype:dsttype",
# we need to split the string and get the middle part. # we need to split the string and get the middle part.
list( list(
map( map(
...@@ -1463,8 +1469,8 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name): ...@@ -1463,8 +1469,8 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0]), "zero": torch.BoolTensor([0, 0, 0, 0, 0, 0]),
} }
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes) graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1]) nodes = torch.LongTensor([0, 1])
...@@ -1547,8 +1553,8 @@ def test_sample_neighbors_hetero_pick_number( ...@@ -1547,8 +1553,8 @@ def test_sample_neighbors_hetero_pick_number(
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0, 0, 0, 0]), "zero": torch.BoolTensor([0, 0, 0, 0, 0, 0, 0, 0, 0]),
} }
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_csc( graph = gb.from_fused_csc(
indptr, indptr,
indices, indices,
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
...@@ -1636,8 +1642,8 @@ def test_csc_sampling_graph_to_device(): ...@@ -1636,8 +1642,8 @@ def test_csc_sampling_graph_to_device():
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0, 0, 0, 0]), "zero": torch.BoolTensor([0, 0, 0, 0, 0, 0, 0, 0, 0]),
} }
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_csc( graph = gb.from_fused_csc(
indptr, indptr,
indices, indices,
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
......
...@@ -68,7 +68,7 @@ def test_UniformNegativeSampler_invoke(): ...@@ -68,7 +68,7 @@ def test_UniformNegativeSampler_invoke():
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20]) @pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_Uniform_NegativeSampler(negative_ratio): def test_Uniform_NegativeSampler(negative_ratio):
# Construct CSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05) graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30 num_seeds = 30
item_set = gb.ItemSet( item_set = gb.ItemSet(
...@@ -110,7 +110,7 @@ def get_hetero_graph(): ...@@ -110,7 +110,7 @@ def get_hetero_graph():
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5]) node_type_offset = torch.LongTensor([0, 2, 5])
return gb.from_csc( return gb.from_fused_csc(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment