Unverified Commit 91fe0c90 authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[GraphBolt] Refactor CSCSamplingGraph and update the corresponding method. (#6515)

parent 3f3652e0
......@@ -56,7 +56,7 @@ Standard Implementations
OnDiskDataset
BuiltinDataset
CSCSamplingGraph
FusedCSCSamplingGraph
UniformNegativeSampler
NeighborSampler
LayerNeighborSampler
......
......@@ -106,7 +106,7 @@ def create_dataloader(
# Sample neighbors for each seed node in the mini-batch.
# `graph`:
# The graph(CSCSamplingGraph) from which to sample neighbors.
# The graph(FusedCSCSamplingGraph) from which to sample neighbors.
# `fanouts`:
# The number of neighbors to sample for each node in each layer.
datapipe = datapipe.sample_neighbor(graph, fanouts=fanouts)
......@@ -166,7 +166,7 @@ def rel_graph_embed(graph, embed_size):
Parameters
----------
graph : CSCSamplingGraph
graph : FusedCSCSamplingGraph
The graph for which to create the heterogenous embedding layer.
embed_size : int
The size of the embedding vectors.
......
/**
* Copyright (c) 2023 by Contributors
* @file graphbolt/csc_sampling_graph.h
* @file graphbolt/fused_csc_sampling_graph.h
* @brief Header file of csc sampling graph.
*/
#ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
......@@ -39,18 +39,18 @@ struct SamplerArgs<SamplerType::LABOR> {
* Suppose the graph has 3 node types, 3 edge types and 6 edges
* auto node_type_offset = {0, 2, 4, 6}
* auto type_per_edge = {0, 1, 0, 2, 1, 2}
* auto graph = CSCSamplingGraph(..., ..., node_type_offset, type_per_edge)
* auto graph = FusedCSCSamplingGraph(..., ..., node_type_offset, type_per_edge)
*
* The `node_type_offset` tensor represents the offset array of node type, the
* given array indicates that node [0, 2) has type id 0, [2, 4) has type id 1,
* and [4, 6) has type id 2. And the `type_per_edge` tensor represents the type
* id of each edge.
*/
class CSCSamplingGraph : public torch::CustomClassHolder {
class FusedCSCSamplingGraph : public torch::CustomClassHolder {
public:
using EdgeAttrMap = torch::Dict<std::string, torch::Tensor>;
/** @brief Default constructor. */
CSCSamplingGraph() = default;
FusedCSCSamplingGraph() = default;
/**
* @brief Constructor for CSC with data.
......@@ -61,7 +61,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* @param type_per_edge A tensor representing the type of each edge, if
* present.
*/
CSCSamplingGraph(
FusedCSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge,
......@@ -76,9 +76,9 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* @param type_per_edge A tensor representing the type of each edge, if
* present.
*
* @return CSCSamplingGraph
* @return FusedCSCSamplingGraph
*/
static c10::intrusive_ptr<CSCSamplingGraph> FromCSC(
static c10::intrusive_ptr<FusedCSCSamplingGraph> FromCSC(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge,
......@@ -155,7 +155,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
/**
* @brief Pickle method for deserializing.
* @param state The state of serialized CSCSamplingGraph.
* @param state The state of serialized FusedCSCSamplingGraph.
*/
void SetState(
const torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>&
......@@ -163,7 +163,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
/**
* @brief Pickle method for serializing.
* @returns The state of this CSCSamplingGraph.
* @returns The state of this FusedCSCSamplingGraph.
*/
torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>> GetState()
const;
......@@ -246,18 +246,18 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* @brief Copy the graph to shared memory.
* @param shared_memory_name The name of the shared memory.
*
* @return A new CSCSamplingGraph object on shared memory.
* @return A new FusedCSCSamplingGraph object on shared memory.
*/
c10::intrusive_ptr<CSCSamplingGraph> CopyToSharedMemory(
c10::intrusive_ptr<FusedCSCSamplingGraph> CopyToSharedMemory(
const std::string& shared_memory_name);
/**
* @brief Load the graph from shared memory.
* @param shared_memory_name The name of the shared memory.
*
* @return A new CSCSamplingGraph object on shared memory.
* @return A new FusedCSCSamplingGraph object on shared memory.
*/
static c10::intrusive_ptr<CSCSamplingGraph> LoadFromSharedMemory(
static c10::intrusive_ptr<FusedCSCSamplingGraph> LoadFromSharedMemory(
const std::string& shared_memory_name);
/**
......
......@@ -7,7 +7,7 @@
#ifndef GRAPHBOLT_SERIALIZE_H_
#define GRAPHBOLT_SERIALIZE_H_
#include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/fused_csc_sampling_graph.h>
#include <torch/torch.h>
#include <string>
......@@ -15,53 +15,55 @@
/**
* @brief Overload stream operator to enable `torch::save()` and `torch.load()`
* for CSCSamplingGraph.
* for FusedCSCSamplingGraph.
*/
namespace torch {
/**
* @brief Overload input stream operator for CSCSamplingGraph deserialization.
* @brief Overload input stream operator for FusedCSCSamplingGraph
* deserialization.
* @param archive Input stream for deserializing.
* @param graph CSCSamplingGraph.
* @param graph FusedCSCSamplingGraph.
*
* @return archive
*/
inline serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph);
graphbolt::sampling::FusedCSCSamplingGraph& graph);
/**
* @brief Overload output stream operator for CSCSamplingGraph serialization.
* @brief Overload output stream operator for FusedCSCSamplingGraph
* serialization.
* @param archive Output stream for serializing.
* @param graph CSCSamplingGraph.
* @param graph FusedCSCSamplingGraph.
*
* @return archive
*/
inline serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph);
const graphbolt::sampling::FusedCSCSamplingGraph& graph);
} // namespace torch
namespace graphbolt {
/**
* @brief Load CSCSamplingGraph from file.
* @brief Load FusedCSCSamplingGraph from file.
* @param filename File name to read.
*
* @return CSCSamplingGraph.
* @return FusedCSCSamplingGraph.
*/
c10::intrusive_ptr<sampling::CSCSamplingGraph> LoadCSCSamplingGraph(
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> LoadFusedCSCSamplingGraph(
const std::string& filename);
/**
* @brief Save CSCSamplingGraph to file.
* @param graph CSCSamplingGraph to save.
* @brief Save FusedCSCSamplingGraph to file.
* @param graph FusedCSCSamplingGraph to save.
* @param filename File name to save.
*
*/
void SaveCSCSamplingGraph(
c10::intrusive_ptr<sampling::CSCSamplingGraph> graph,
void SaveFusedCSCSamplingGraph(
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> graph,
const std::string& filename);
/**
......
/**
* Copyright (c) 2023 by Contributors
* @file csc_sampling_graph.cc
* @file fused_csc_sampling_graph.cc
* @brief Source file of sampling graph.
*/
#include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/fused_csc_sampling_graph.h>
#include <graphbolt/serialize.h>
#include <torch/torch.h>
......@@ -24,7 +24,7 @@ namespace sampling {
static const int kPickleVersion = 6199;
CSCSamplingGraph::CSCSamplingGraph(
FusedCSCSamplingGraph::FusedCSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge,
......@@ -39,7 +39,7 @@ CSCSamplingGraph::CSCSamplingGraph(
TORCH_CHECK(indptr.device() == indices.device());
}
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
c10::intrusive_ptr<FusedCSCSamplingGraph> FusedCSCSamplingGraph::FromCSC(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
const torch::optional<torch::Tensor>& type_per_edge,
......@@ -57,37 +57,40 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
TORCH_CHECK(pair.value().size(0) == indices.size(0));
}
}
return c10::make_intrusive<CSCSamplingGraph>(
return c10::make_intrusive<FusedCSCSamplingGraph>(
indptr, indices, node_type_offset, type_per_edge, edge_attributes);
}
void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
const int64_t magic_num =
read_from_archive(archive, "CSCSamplingGraph/magic_num").toInt();
read_from_archive(archive, "FusedCSCSamplingGraph/magic_num").toInt();
TORCH_CHECK(
magic_num == kCSCSamplingGraphSerializeMagic,
"Magic numbers mismatch when loading CSCSamplingGraph.");
indptr_ = read_from_archive(archive, "CSCSamplingGraph/indptr").toTensor();
indices_ = read_from_archive(archive, "CSCSamplingGraph/indices").toTensor();
if (read_from_archive(archive, "CSCSamplingGraph/has_node_type_offset")
"Magic numbers mismatch when loading FusedCSCSamplingGraph.");
indptr_ =
read_from_archive(archive, "FusedCSCSamplingGraph/indptr").toTensor();
indices_ =
read_from_archive(archive, "FusedCSCSamplingGraph/indices").toTensor();
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_node_type_offset")
.toBool()) {
node_type_offset_ =
read_from_archive(archive, "CSCSamplingGraph/node_type_offset")
read_from_archive(archive, "FusedCSCSamplingGraph/node_type_offset")
.toTensor();
}
if (read_from_archive(archive, "CSCSamplingGraph/has_type_per_edge")
if (read_from_archive(archive, "FusedCSCSamplingGraph/has_type_per_edge")
.toBool()) {
type_per_edge_ =
read_from_archive(archive, "CSCSamplingGraph/type_per_edge").toTensor();
read_from_archive(archive, "FusedCSCSamplingGraph/type_per_edge")
.toTensor();
}
// Optional edge attributes.
torch::IValue has_edge_attributes;
if (archive.try_read(
"CSCSamplingGraph/has_edge_attributes", has_edge_attributes) &&
"FusedCSCSamplingGraph/has_edge_attributes", has_edge_attributes) &&
has_edge_attributes.toBool()) {
torch::Dict<torch::IValue, torch::IValue> generic_dict =
read_from_archive(archive, "CSCSamplingGraph/edge_attributes")
read_from_archive(archive, "FusedCSCSamplingGraph/edge_attributes")
.toGenericDict();
EdgeAttrMap target_dict;
for (const auto& pair : generic_dict) {
......@@ -101,29 +104,35 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
}
}
void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
archive.write("CSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic);
archive.write("CSCSamplingGraph/indptr", indptr_);
archive.write("CSCSamplingGraph/indices", indices_);
void FusedCSCSamplingGraph::Save(
torch::serialize::OutputArchive& archive) const {
archive.write(
"FusedCSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic);
archive.write("FusedCSCSamplingGraph/indptr", indptr_);
archive.write("FusedCSCSamplingGraph/indices", indices_);
archive.write(
"CSCSamplingGraph/has_node_type_offset", node_type_offset_.has_value());
"FusedCSCSamplingGraph/has_node_type_offset",
node_type_offset_.has_value());
if (node_type_offset_) {
archive.write(
"CSCSamplingGraph/node_type_offset", node_type_offset_.value());
"FusedCSCSamplingGraph/node_type_offset", node_type_offset_.value());
}
archive.write(
"CSCSamplingGraph/has_type_per_edge", type_per_edge_.has_value());
"FusedCSCSamplingGraph/has_type_per_edge", type_per_edge_.has_value());
if (type_per_edge_) {
archive.write("CSCSamplingGraph/type_per_edge", type_per_edge_.value());
archive.write(
"FusedCSCSamplingGraph/type_per_edge", type_per_edge_.value());
}
archive.write(
"CSCSamplingGraph/has_edge_attributes", edge_attributes_.has_value());
"FusedCSCSamplingGraph/has_edge_attributes",
edge_attributes_.has_value());
if (edge_attributes_) {
archive.write("CSCSamplingGraph/edge_attributes", edge_attributes_.value());
archive.write(
"FusedCSCSamplingGraph/edge_attributes", edge_attributes_.value());
}
}
void CSCSamplingGraph::SetState(
void FusedCSCSamplingGraph::SetState(
const torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>&
state) {
// State is a dict of dicts. The tensor-type attributes are stored in the dict
......@@ -133,7 +142,7 @@ void CSCSamplingGraph::SetState(
TORCH_CHECK(
independent_tensors.at("version_number")
.equal(torch::tensor({kPickleVersion})),
"Version number mismatches when loading pickled CSCSamplingGraph.")
"Version number mismatches when loading pickled FusedCSCSamplingGraph.")
indptr_ = independent_tensors.at("indptr");
indices_ = independent_tensors.at("indices");
if (independent_tensors.find("node_type_offset") !=
......@@ -149,7 +158,7 @@ void CSCSamplingGraph::SetState(
}
torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>
CSCSamplingGraph::GetState() const {
FusedCSCSamplingGraph::GetState() const {
// State is a dict of dicts. The tensor-type attributes are stored in the dict
// with key "independent_tensors". The dict-type attributes (edge_attributes)
// are stored directly with the their name as the key.
......@@ -173,7 +182,7 @@ CSCSamplingGraph::GetState() const {
return state;
}
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const {
using namespace torch::indexing;
const int32_t kDefaultGrainSize = 100;
......@@ -296,7 +305,7 @@ auto GetPickFn(
}
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::SampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const {
const int64_t num_nodes = nodes.size(0);
......@@ -413,7 +422,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
subgraph_reverse_edge_ids, subgraph_type_per_edge);
}
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const {
......@@ -451,7 +460,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
}
std::tuple<torch::Tensor, torch::Tensor>
CSCSamplingGraph::SampleNegativeEdgesUniform(
FusedCSCSamplingGraph::SampleNegativeEdgesUniform(
const std::tuple<torch::Tensor, torch::Tensor>& node_pairs,
int64_t negative_ratio, int64_t max_node_id) const {
torch::Tensor pos_src;
......@@ -462,15 +471,15 @@ CSCSamplingGraph::SampleNegativeEdgesUniform(
return std::make_tuple(neg_src, neg_dst);
}
static c10::intrusive_ptr<CSCSamplingGraph> BuildGraphFromSharedMemoryHelper(
SharedMemoryHelper&& helper) {
static c10::intrusive_ptr<FusedCSCSamplingGraph>
BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
helper.InitializeRead();
auto indptr = helper.ReadTorchTensor();
auto indices = helper.ReadTorchTensor();
auto node_type_offset = helper.ReadTorchTensor();
auto type_per_edge = helper.ReadTorchTensor();
auto edge_attributes = helper.ReadTorchTensorDict();
auto graph = c10::make_intrusive<CSCSamplingGraph>(
auto graph = c10::make_intrusive<FusedCSCSamplingGraph>(
indptr.value(), indices.value(), node_type_offset, type_per_edge,
edge_attributes);
auto shared_memory = helper.ReleaseSharedMemory();
......@@ -479,7 +488,8 @@ static c10::intrusive_ptr<CSCSamplingGraph> BuildGraphFromSharedMemoryHelper(
return graph;
}
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::CopyToSharedMemory(
c10::intrusive_ptr<FusedCSCSamplingGraph>
FusedCSCSamplingGraph::CopyToSharedMemory(
const std::string& shared_memory_name) {
SharedMemoryHelper helper(shared_memory_name, SERIALIZED_METAINFO_SIZE_MAX);
helper.WriteTorchTensor(indptr_);
......@@ -491,13 +501,14 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::CopyToSharedMemory(
return BuildGraphFromSharedMemoryHelper(std::move(helper));
}
c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
c10::intrusive_ptr<FusedCSCSamplingGraph>
FusedCSCSamplingGraph::LoadFromSharedMemory(
const std::string& shared_memory_name) {
SharedMemoryHelper helper(shared_memory_name, SERIALIZED_METAINFO_SIZE_MAX);
return BuildGraphFromSharedMemoryHelper(std::move(helper));
}
void CSCSamplingGraph::HoldSharedMemoryObject(
void FusedCSCSamplingGraph::HoldSharedMemoryObject(
SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm) {
tensor_metadata_shm_ = std::move(tensor_metadata_shm);
tensor_data_shm_ = std::move(tensor_data_shm);
......
......@@ -4,7 +4,7 @@
* @brief Graph bolt library Python binding.
*/
#include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/fused_csc_sampling_graph.h>
#include <graphbolt/isin.h>
#include <graphbolt/serialize.h>
#include <graphbolt/unique_and_compact.h>
......@@ -26,43 +26,44 @@ TORCH_LIBRARY(graphbolt, m) {
&SampledSubgraph::original_column_node_ids)
.def_readwrite("original_edge_ids", &SampledSubgraph::original_edge_ids)
.def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge);
m.class_<CSCSamplingGraph>("CSCSamplingGraph")
.def("num_nodes", &CSCSamplingGraph::NumNodes)
.def("num_edges", &CSCSamplingGraph::NumEdges)
.def("csc_indptr", &CSCSamplingGraph::CSCIndptr)
.def("indices", &CSCSamplingGraph::Indices)
.def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset)
.def("type_per_edge", &CSCSamplingGraph::TypePerEdge)
.def("edge_attributes", &CSCSamplingGraph::EdgeAttributes)
.def("set_csc_indptr", &CSCSamplingGraph::SetCSCIndptr)
.def("set_indices", &CSCSamplingGraph::SetIndices)
.def("set_node_type_offset", &CSCSamplingGraph::SetNodeTypeOffset)
.def("set_type_per_edge", &CSCSamplingGraph::SetTypePerEdge)
.def("set_edge_attributes", &CSCSamplingGraph::SetEdgeAttributes)
.def("in_subgraph", &CSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &CSCSamplingGraph::SampleNeighbors)
m.class_<FusedCSCSamplingGraph>("FusedCSCSamplingGraph")
.def("num_nodes", &FusedCSCSamplingGraph::NumNodes)
.def("num_edges", &FusedCSCSamplingGraph::NumEdges)
.def("csc_indptr", &FusedCSCSamplingGraph::CSCIndptr)
.def("indices", &FusedCSCSamplingGraph::Indices)
.def("node_type_offset", &FusedCSCSamplingGraph::NodeTypeOffset)
.def("type_per_edge", &FusedCSCSamplingGraph::TypePerEdge)
.def("edge_attributes", &FusedCSCSamplingGraph::EdgeAttributes)
.def("set_csc_indptr", &FusedCSCSamplingGraph::SetCSCIndptr)
.def("set_indices", &FusedCSCSamplingGraph::SetIndices)
.def("set_node_type_offset", &FusedCSCSamplingGraph::SetNodeTypeOffset)
.def("set_type_per_edge", &FusedCSCSamplingGraph::SetTypePerEdge)
.def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes)
.def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors)
.def(
"sample_negative_edges_uniform",
&CSCSamplingGraph::SampleNegativeEdgesUniform)
.def("copy_to_shared_memory", &CSCSamplingGraph::CopyToSharedMemory)
&FusedCSCSamplingGraph::SampleNegativeEdgesUniform)
.def("copy_to_shared_memory", &FusedCSCSamplingGraph::CopyToSharedMemory)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<CSCSamplingGraph>& self)
[](const c10::intrusive_ptr<FusedCSCSamplingGraph>& self)
-> torch::Dict<
std::string, torch::Dict<std::string, torch::Tensor>> {
return self->GetState();
},
// __setstate__
[](torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>
state) -> c10::intrusive_ptr<CSCSamplingGraph> {
auto g = c10::make_intrusive<CSCSamplingGraph>();
state) -> c10::intrusive_ptr<FusedCSCSamplingGraph> {
auto g = c10::make_intrusive<FusedCSCSamplingGraph>();
g->SetState(state);
return g;
});
m.def("from_csc", &CSCSamplingGraph::FromCSC);
m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph);
m.def("save_csc_sampling_graph", &SaveCSCSamplingGraph);
m.def("load_from_shared_memory", &CSCSamplingGraph::LoadFromSharedMemory);
m.def("from_fused_csc", &FusedCSCSamplingGraph::FromCSC);
m.def("load_fused_csc_sampling_graph", &LoadFusedCSCSamplingGraph);
m.def("save_fused_csc_sampling_graph", &SaveFusedCSCSamplingGraph);
m.def(
"load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory);
m.def("unique_and_compact", &UniqueAndCompact);
m.def("isin", &IsIn);
m.def("index_select", &ops::IndexSelect);
......
......@@ -11,14 +11,14 @@ namespace torch {
serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::CSCSamplingGraph& graph) {
graphbolt::sampling::FusedCSCSamplingGraph& graph) {
graph.Load(archive);
return archive;
}
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const graphbolt::sampling::CSCSamplingGraph& graph) {
const graphbolt::sampling::FusedCSCSamplingGraph& graph) {
graph.Save(archive);
return archive;
}
......@@ -27,15 +27,15 @@ serialize::OutputArchive& operator<<(
namespace graphbolt {
c10::intrusive_ptr<sampling::CSCSamplingGraph> LoadCSCSamplingGraph(
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> LoadFusedCSCSamplingGraph(
const std::string& filename) {
auto&& graph = c10::make_intrusive<sampling::CSCSamplingGraph>();
auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();
torch::load(*graph, filename);
return graph;
}
void SaveCSCSamplingGraph(
c10::intrusive_ptr<sampling::CSCSamplingGraph> graph,
void SaveFusedCSCSamplingGraph(
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> graph,
const std::string& filename) {
torch::save(*graph, filename);
}
......
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"BjkAK37xopp1"
],
"gpuType": "T4",
"private_outputs": true,
"authorship_tag": "ABX9TyOCdFtYQweXnIR1/5oWDSGq"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "e1qfiZMOJYYv"
},
"source": [
"# Graphbolt Quick Walkthrough\n",
"\n",
"The tutorial provides a quick walkthrough of operators provided by the `dgl.graphbolt` package, and illustrates how to create a GNN datapipe with the package. To learn more details about Stochastic Training of GNNs, please read the [materials](https://docs.dgl.ai/tutorials/large/index.html) provided by DGL.\n",
"\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/graphbolt/walkthrough.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/dmlc/dgl/blob/master/notebooks/graphbolt/walkthrough.ipynb)"
],
"metadata": {
"id": "e1qfiZMOJYYv"
}
]
},
{
"cell_type": "code",
......@@ -64,19 +43,24 @@
},
{
"cell_type": "markdown",
"metadata": {
"id": "8O7PfsY4sPoN"
},
"source": [
"## Dataset\n",
"\n",
"The dataset has three primary components. *1*. An itemset, which can be iterated over as the training target. *2*. A sampling graph, which is used by the subgraph sampling algorithm to generate a subgraph. *3*. A feature store, which stores node, edge, and graph features.\n",
"\n",
"* The **Itemset** is created from iterable data or tuple of iterable data."
],
"metadata": {
"id": "8O7PfsY4sPoN"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "g73ZAbMQsSgV"
},
"outputs": [],
"source": [
"node_pairs = torch.tensor(\n",
" [[7, 0], [6, 0], [1, 3], [3, 3], [2, 4], [8, 4], [1, 4], [2, 4], [1, 5],\n",
......@@ -85,24 +69,24 @@
")\n",
"item_set = gb.ItemSet(node_pairs, names=\"node_pairs\")\n",
"print(list(item_set))"
],
"metadata": {
"id": "g73ZAbMQsSgV"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"source": [
"* The **SamplingGraph** is used by the subgraph sampling algorithm to generate a subgraph. In graphbolt, we provide a canonical solution, the CSCSamplingGraph, which achieves state-of-the-art time and space efficiency on CPU sampling. However, this requires enough CPU memory to host all CSCSamplingGraph objects in memory."
],
"metadata": {
"id": "Lqty9p4cs0OR"
}
},
"source": [
"* The **SamplingGraph** is used by the subgraph sampling algorithm to generate a subgraph. In graphbolt, we provide a canonical solution, the FusedCSCSamplingGraph, which achieves state-of-the-art time and space efficiency on CPU sampling. However, this requires enough CPU memory to host all FusedCSCSamplingGraph objects in memory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jDjY149xs3PI"
},
"outputs": [],
"source": [
"indptr = torch.tensor([0, 2, 2, 2, 4, 8, 9, 12, 15, 17, 25])\n",
"indices = torch.tensor(\n",
......@@ -111,26 +95,26 @@
"num_edges = 25\n",
"eid = torch.arange(num_edges)\n",
"edge_attributes = {gb.ORIGINAL_EDGE_ID: eid}\n",
"graph = gb.from_csc(indptr, indices, None, None, edge_attributes, None)\n",
"graph = gb.from_fused_csc(indptr, indices, None, None, edge_attributes, None)\n",
"print(graph)"
],
"metadata": {
"id": "jDjY149xs3PI"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"source": [
"* The **FeatureStore** is used to store node, edge, and graph features. In graphbolt, we provide the TorchBasedFeature and related optimizations, such as the GPUCachedFeature, for different use cases."
],
"metadata": {
"id": "mNp2S2_Vs8af"
}
},
"source": [
"* The **FeatureStore** is used to store node, edge, and graph features. In graphbolt, we provide the TorchBasedFeature and related optimizations, such as the GPUCachedFeature, for different use cases."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zIU6KWe1Sm2g"
},
"outputs": [],
"source": [
"num_nodes = 10\n",
"num_edges = 25\n",
......@@ -144,159 +128,159 @@
"}\n",
"feature_store = gb.BasicFeatureStore(features)\n",
"print(feature_store)"
],
"metadata": {
"id": "zIU6KWe1Sm2g"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Oh2ockWWoXQ0"
},
"source": [
"## DataPipe\n",
"\n",
"The DataPipe in Graphbolt is an extension of the PyTorch DataPipe, but it is specifically designed to address the challenges of training graph neural networks (GNNs). Each stage of the data pipeline loads data from different sources and can be combined with other stages to create more complex data pipelines. The intermediate data will be stored in **MiniBatch** data packs.\n",
"\n",
"* **ItemSampler** iterates over input **Itemset** and create subsets."
],
"metadata": {
"id": "Oh2ockWWoXQ0"
}
]
},
{
"cell_type": "code",
"source": [
"datapipe = gb.ItemSampler(item_set, batch_size=3, shuffle=False)\n",
"print(next(iter(datapipe)))"
],
"execution_count": null,
"metadata": {
"id": "XtqPDprrogR7"
},
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"datapipe = gb.ItemSampler(item_set, batch_size=3, shuffle=False)\n",
"print(next(iter(datapipe)))"
]
},
{
"cell_type": "markdown",
"source": [
"* **NegativeSampler** generate negative samples and return a mix of positive and negative samples."
],
"metadata": {
"id": "BjkAK37xopp1"
}
},
"source": [
"* **NegativeSampler** generate negative samples and return a mix of positive and negative samples."
]
},
{
"cell_type": "code",
"source": [
"datapipe = datapipe.sample_uniform_negative(graph, 1)\n",
"print(next(iter(datapipe)))"
],
"execution_count": null,
"metadata": {
"id": "PrFpGoOGopJy"
},
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"datapipe = datapipe.sample_uniform_negative(graph, 1)\n",
"print(next(iter(datapipe)))"
]
},
{
"cell_type": "markdown",
"source": [
"* **SubgraphSampler** samples a subgraph from a given set of nodes from a larger graph."
],
"metadata": {
"id": "fYO_oIwkpmb3"
}
},
"source": [
"* **SubgraphSampler** samples a subgraph from a given set of nodes from a larger graph."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4UsY3PL3ppYV"
},
"outputs": [],
"source": [
"fanouts = torch.tensor([1])\n",
"datapipe = datapipe.sample_neighbor(graph, [fanouts])\n",
"print(next(iter(datapipe)))"
],
"metadata": {
"id": "4UsY3PL3ppYV"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"source": [
"* **FeatureFetcher** fetchs features for node/edge in graphbolt."
],
"metadata": {
"id": "0uIydsjUqMA0"
}
},
"source": [
"* **FeatureFetcher** fetchs features for node/edge in graphbolt."
]
},
{
"cell_type": "code",
"source": [
"datapipe = datapipe.fetch_feature(feature_store, node_feature_keys=[\"feat\"], edge_feature_keys=[\"feat\"])\n",
"print(next(iter(datapipe)))"
],
"execution_count": null,
"metadata": {
"id": "YAj8G7YBqO6G"
},
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"datapipe = datapipe.fetch_feature(feature_store, node_feature_keys=[\"feat\"], edge_feature_keys=[\"feat\"])\n",
"print(next(iter(datapipe)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gt059n1xrmj-"
},
"source": [
"After retrieving the required data, Graphbolt provides helper methods to convert it to the output format needed for subsequent GNN training.\n",
"\n",
"* Convert to **DGLMiniBatch** format for training with DGL."
],
"metadata": {
"id": "Gt059n1xrmj-"
}
]
},
{
"cell_type": "code",
"source": [
"datapipe = datapipe.to_dgl()\n",
"print(next(iter(datapipe)))"
],
"execution_count": null,
"metadata": {
"id": "o8Yoi8BeqSdu"
},
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"datapipe = datapipe.to_dgl()\n",
"print(next(iter(datapipe)))"
]
},
{
"cell_type": "markdown",
"source": [
"* Copy the data to the GPU for training on the GPU."
],
"metadata": {
"id": "hjBSLPRPrsD2"
}
},
"source": [
"* Copy the data to the GPU for training on the GPU."
]
},
{
"cell_type": "code",
"source": [
"datapipe = datapipe.copy_to(device=\"cuda\")\n",
"print(next(iter(datapipe)))"
],
"execution_count": null,
"metadata": {
"id": "RofiZOUMqt_u"
},
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"datapipe = datapipe.copy_to(device=\"cuda\")\n",
"print(next(iter(datapipe)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xm9HnyHRvxXj"
},
"source": [
"## Exercise: Node classification\n",
"\n",
"Similarly, the following Dataset is created for node classification, can you implement the data pipeline for the dataset?"
],
"metadata": {
"id": "xm9HnyHRvxXj"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YV-mk-xAv78v"
},
"outputs": [],
"source": [
"# Dataset for node classification.\n",
"num_nodes = 10\n",
......@@ -311,7 +295,7 @@
"eid = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n",
" 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])\n",
"edge_attributes = {gb.ORIGINAL_EDGE_ID: eid}\n",
"graph = gb.from_csc(indptr, indices, None, None, edge_attributes, None)\n",
"graph = gb.from_fused_csc(indptr, indices, None, None, edge_attributes, None)\n",
"\n",
"num_nodes = 10\n",
"num_edges = 25\n",
......@@ -328,12 +312,28 @@
"# Datapipe.\n",
"...\n",
"print(next(iter(datapipe)))"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyOCdFtYQweXnIR1/5oWDSGq",
"collapsed_sections": [
"BjkAK37xopp1"
],
"metadata": {
"id": "YV-mk-xAv78v"
},
"execution_count": null,
"outputs": []
"gpuType": "T4",
"private_outputs": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
]
}
\ No newline at end of file
},
"nbformat": 4,
"nbformat_minor": 0
}
......@@ -1222,14 +1222,14 @@ def partition_graph(
def convert_dgl_partition_to_csc_sampling_graph(part_config):
"""Convert partitions of dgl to CSCSamplingGraph of GraphBolt.
"""Convert partitions of dgl to FusedCSCSamplingGraph of GraphBolt.
This API converts `DGLGraph` partitions to `CSCSamplingGraph` which is
This API converts `DGLGraph` partitions to `FusedCSCSamplingGraph` which is
dedicated for sampling in `GraphBolt`. New graphs will be stored alongside
original graph as `csc_sampling_graph.tar`.
original graph as `fused_csc_sampling_graph.tar`.
In the near future, partitions are supposed to be saved as
`CSCSamplingGraph` directly. At that time, this API should be deprecated.
`FusedCSCSamplingGraph` directly. At that time, this API should be deprecated.
Parameters
----------
......@@ -1262,7 +1262,7 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
type_per_edge = type_per_edge.to(RESERVED_FIELD_DTYPE[ETYPE])
# Sanity check.
assert len(type_per_edge) == graph.num_edges()
csc_graph = graphbolt.from_csc(
csc_graph = graphbolt.from_fused_csc(
indptr, indices, None, type_per_edge, metadata=metadata
)
orig_graph_path = os.path.join(
......@@ -1270,6 +1270,6 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
part_meta[f"part-{part_id}"]["part_graph"],
)
csc_graph_path = os.path.join(
os.path.dirname(orig_graph_path), "csc_sampling_graph.tar"
os.path.dirname(orig_graph_path), "fused_csc_sampling_graph.tar"
)
graphbolt.save_csc_sampling_graph(csc_graph, csc_graph_path)
graphbolt.save_fused_csc_sampling_graph(csc_graph, csc_graph_path)
"""Implementation of GraphBolt."""
from .basic_feature_store import *
from .csc_sampling_graph import *
from .fused_csc_sampling_graph import *
from .neighbor_sampler import *
from .ondisk_dataset import *
from .ondisk_metadata import *
......
......@@ -20,11 +20,11 @@ from .sampled_subgraph_impl import SampledSubgraphImpl
__all__ = [
"GraphMetadata",
"CSCSamplingGraph",
"from_csc",
"FusedCSCSamplingGraph",
"from_fused_csc",
"load_from_shared_memory",
"load_csc_sampling_graph",
"save_csc_sampling_graph",
"load_fused_csc_sampling_graph",
"save_fused_csc_sampling_graph",
"from_dglgraph",
]
......@@ -88,7 +88,7 @@ class GraphMetadata:
self.edge_type_to_id = edge_type_to_id
class CSCSamplingGraph(SamplingGraph):
class FusedCSCSamplingGraph(SamplingGraph):
r"""A sampling graph in CSC format."""
def __repr__(self):
......@@ -150,7 +150,7 @@ class CSCSamplingGraph(SamplingGraph):
>>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> graph = gb.from_csc(indptr, indices, node_type_offset,
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata)
>>> print(graph.num_nodes)
{'N0': 2, 'N1': 3}
......@@ -425,7 +425,7 @@ class CSCSamplingGraph(SamplingGraph):
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_csc(indptr, indices, type_per_edge=type_per_edge,
>>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge,
... node_type_offset=node_type_offset, metadata=metadata)
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1])
......@@ -605,7 +605,7 @@ class CSCSamplingGraph(SamplingGraph):
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_csc(indptr, indices, type_per_edge=type_per_edge,
>>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge,
... node_type_offset=node_type_offset, metadata=metadata)
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1])
......@@ -697,16 +697,16 @@ class CSCSamplingGraph(SamplingGraph):
Returns
-------
CSCSamplingGraph
The copied CSCSamplingGraph object on shared memory.
FusedCSCSamplingGraph
The copied FusedCSCSamplingGraph object on shared memory.
"""
return CSCSamplingGraph(
return FusedCSCSamplingGraph(
self._c_csc_graph.copy_to_shared_memory(shared_memory_name),
self._metadata,
)
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `CSCSamplingGraph` to the specified device."""
"""Copy `FusedCSCSamplingGraph` to the specified device."""
def _to(x, device):
return x.to(device) if hasattr(x, "to") else x
......@@ -728,15 +728,15 @@ class CSCSamplingGraph(SamplingGraph):
return self
def from_csc(
def from_fused_csc(
csc_indptr: torch.Tensor,
indices: torch.Tensor,
node_type_offset: Optional[torch.tensor] = None,
type_per_edge: Optional[torch.tensor] = None,
edge_attributes: Optional[Dict[str, torch.tensor]] = None,
metadata: Optional[GraphMetadata] = None,
) -> CSCSamplingGraph:
"""Create a CSCSamplingGraph object from a CSC representation.
) -> FusedCSCSamplingGraph:
"""Create a FusedCSCSamplingGraph object from a CSC representation.
Parameters
----------
......@@ -756,8 +756,8 @@ def from_csc(
Metadata of the graph, by default None.
Returns
-------
CSCSamplingGraph
The created CSCSamplingGraph object.
FusedCSCSamplingGraph
The created FusedCSCSamplingGraph object.
Examples
--------
......@@ -768,13 +768,13 @@ def from_csc(
>>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
>>> node_type_offset = torch.tensor([0, 1, 2, 3])
>>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0])
>>> graph = graphbolt.from_csc(csc_indptr, indices,
>>> graph = graphbolt.from_fused_csc(csc_indptr, indices,
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... edge_attributes=None, metadata=metadata)
None, metadata)
>>> print(graph)
CSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]),
total_num_nodes=3, total_num_edges=7)
"""
......@@ -782,8 +782,8 @@ def from_csc(
assert len(metadata.node_type_to_id) + 1 == node_type_offset.size(
0
), "node_type_offset length should be |ntypes| + 1."
return CSCSamplingGraph(
torch.ops.graphbolt.from_csc(
return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc(
csc_indptr,
indices,
node_type_offset,
......@@ -797,8 +797,8 @@ def from_csc(
def load_from_shared_memory(
shared_memory_name: str,
metadata: Optional[GraphMetadata] = None,
) -> CSCSamplingGraph:
"""Load a CSCSamplingGraph object from shared memory.
) -> FusedCSCSamplingGraph:
"""Load a FusedCSCSamplingGraph object from shared memory.
Parameters
----------
......@@ -807,16 +807,16 @@ def load_from_shared_memory(
Returns
-------
CSCSamplingGraph
The loaded CSCSamplingGraph object on shared memory.
FusedCSCSamplingGraph
The loaded FusedCSCSamplingGraph object on shared memory.
"""
return CSCSamplingGraph(
return FusedCSCSamplingGraph(
torch.ops.graphbolt.load_from_shared_memory(shared_memory_name),
metadata,
)
def _csc_sampling_graph_str(graph: CSCSamplingGraph) -> str:
def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
"""Internal function for converting a csc sampling graph to string
representation.
"""
......@@ -848,24 +848,24 @@ def _csc_sampling_graph_str(graph: CSCSamplingGraph) -> str:
return final_str
def load_csc_sampling_graph(filename):
"""Load CSCSamplingGraph from tar file."""
def load_fused_csc_sampling_graph(filename):
"""Load FusedCSCSamplingGraph from tar file."""
with tempfile.TemporaryDirectory() as temp_dir:
with tarfile.open(filename, "r") as archive:
archive.extractall(temp_dir)
graph_filename = os.path.join(temp_dir, "csc_sampling_graph.pt")
graph_filename = os.path.join(temp_dir, "fused_csc_sampling_graph.pt")
metadata_filename = os.path.join(temp_dir, "metadata.pt")
return CSCSamplingGraph(
torch.ops.graphbolt.load_csc_sampling_graph(graph_filename),
return FusedCSCSamplingGraph(
torch.ops.graphbolt.load_fused_csc_sampling_graph(graph_filename),
torch.load(metadata_filename),
)
def save_csc_sampling_graph(graph, filename):
"""Save CSCSamplingGraph to tar file."""
def save_fused_csc_sampling_graph(graph, filename):
"""Save FusedCSCSamplingGraph to tar file."""
with tempfile.TemporaryDirectory() as temp_dir:
graph_filename = os.path.join(temp_dir, "csc_sampling_graph.pt")
torch.ops.graphbolt.save_csc_sampling_graph(
graph_filename = os.path.join(temp_dir, "fused_csc_sampling_graph.pt")
torch.ops.graphbolt.save_fused_csc_sampling_graph(
graph._c_csc_graph, graph_filename
)
metadata_filename = os.path.join(temp_dir, "metadata.pt")
......@@ -877,15 +877,15 @@ def save_csc_sampling_graph(graph, filename):
archive.add(
metadata_filename, arcname=os.path.basename(metadata_filename)
)
print(f"CSCSamplingGraph has been saved to {filename}.")
print(f"FusedCSCSamplingGraph has been saved to {filename}.")
def from_dglgraph(
g: DGLGraph,
is_homogeneous: bool = False,
include_original_edge_id: bool = False,
) -> CSCSamplingGraph:
"""Convert a DGLGraph to CSCSamplingGraph."""
) -> FusedCSCSamplingGraph:
"""Convert a DGLGraph to FusedCSCSamplingGraph."""
homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
......@@ -913,8 +913,8 @@ def from_dglgraph(
# Assign edge attributes according to the original eids mapping.
edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
return CSCSamplingGraph(
torch.ops.graphbolt.from_csc(
return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc(
indptr,
indices,
node_type_offset,
......
......@@ -28,7 +28,7 @@ class NeighborSampler(SubgraphSampler):
----------
datapipe : DataPipe
The datapipe.
graph : CSCSamplingGraph
graph : FusedCSCSamplingGraph
The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor] or list[int]
The number of edges to be sampled for each node with or without
......@@ -59,7 +59,7 @@ class NeighborSampler(SubgraphSampler):
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices)
>>> graph = gb.from_fused_csc(indptr, indices)
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler(
......@@ -165,7 +165,7 @@ class LayerNeighborSampler(NeighborSampler):
----------
datapipe : DataPipe
The datapipe.
graph : CSCSamplingGraph
graph : FusedCSCSamplingGraph
The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor]
The number of edges to be sampled for each node with or without
......@@ -192,7 +192,7 @@ class LayerNeighborSampler(NeighborSampler):
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.from_csc(indptr, indices)
>>> graph = gb.from_fused_csc(indptr, indices)
>>> data_format = gb.LinkPredictionEdgeFormat.INDEPENDENT
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
......
......@@ -17,11 +17,11 @@ from ..dataset import Dataset, Task
from ..itemset import ItemSet, ItemSetDict
from ..sampling_graph import SamplingGraph
from ..utils import copy_or_convert_data, read_data
from .csc_sampling_graph import (
CSCSamplingGraph,
from .fused_csc_sampling_graph import (
from_dglgraph,
load_csc_sampling_graph,
save_csc_sampling_graph,
FusedCSCSamplingGraph,
load_fused_csc_sampling_graph,
save_fused_csc_sampling_graph,
)
from .ondisk_metadata import (
OnDiskGraphTopology,
......@@ -45,7 +45,7 @@ def preprocess_ondisk_dataset(
dataset_dir : str
The path to the dataset directory.
include_original_edge_id : bool, optional
Whether to include the original edge id in the CSCSamplingGraph.
Whether to include the original edge id in the FusedCSCSamplingGraph.
Returns
-------
......@@ -138,20 +138,20 @@ def preprocess_ondisk_dataset(
)
g.edata[graph_feature["name"]] = edge_data
# 4. Convert the DGLGraph to a CSCSamplingGraph.
csc_sampling_graph = from_dglgraph(
# 4. Convert the DGLGraph to a FusedCSCSamplingGraph.
fused_csc_sampling_graph = from_dglgraph(
g, is_homogeneous, include_original_edge_id
)
# 5. Save the CSCSamplingGraph and modify the output_config.
# 5. Save the FusedCSCSamplingGraph and modify the output_config.
output_config["graph_topology"] = {}
output_config["graph_topology"]["type"] = "CSCSamplingGraph"
output_config["graph_topology"]["type"] = "FusedCSCSamplingGraph"
output_config["graph_topology"]["path"] = os.path.join(
processed_dir_prefix, "csc_sampling_graph.tar"
processed_dir_prefix, "fused_csc_sampling_graph.tar"
)
save_csc_sampling_graph(
csc_sampling_graph,
save_fused_csc_sampling_graph(
fused_csc_sampling_graph,
os.path.join(
dataset_dir,
output_config["graph_topology"]["path"],
......@@ -283,8 +283,8 @@ class OnDiskDataset(Dataset):
dataset_name: graphbolt_test
graph_topology:
type: CSCSamplingGraph
path: graph_topology/csc_sampling_graph.tar
type: FusedCSCSamplingGraph
path: graph_topology/fused_csc_sampling_graph.tar
feature_data:
- domain: node
type: paper
......@@ -340,7 +340,7 @@ class OnDiskDataset(Dataset):
path: str
The YAML file path.
include_original_edge_id: bool, optional
Whether to include the original edge id in the CSCSamplingGraph.
Whether to include the original edge id in the FusedCSCSamplingGraph.
"""
def __init__(
......@@ -434,12 +434,12 @@ class OnDiskDataset(Dataset):
def _load_graph(
self, graph_topology: OnDiskGraphTopology
) -> CSCSamplingGraph:
) -> FusedCSCSamplingGraph:
"""Load the graph topology."""
if graph_topology is None:
return None
if graph_topology.type == "CSCSamplingGraph":
return load_csc_sampling_graph(graph_topology.path)
if graph_topology.type == "FusedCSCSamplingGraph":
return load_fused_csc_sampling_graph(graph_topology.path)
raise NotImplementedError(
f"Graph topology type {graph_topology.type} is not supported."
)
......
......@@ -64,7 +64,7 @@ class OnDiskFeatureData(pydantic.BaseModel):
class OnDiskGraphTopologyType(str, Enum):
"""Enum of graph topology type."""
CSC_SAMPLING = "CSCSamplingGraph"
FUSED_CSC_SAMPLING = "FusedCSCSamplingGraph"
class OnDiskGraphTopology(pydantic.BaseModel):
......
"""Sampled subgraph for CSCSamplingGraph."""
"""Sampled subgraph for FusedCSCSamplingGraph."""
# pylint: disable= invalid-name
from dataclasses import dataclass
from typing import Dict, Tuple, Union
......@@ -13,7 +13,7 @@ __all__ = ["SampledSubgraphImpl"]
@dataclass
class SampledSubgraphImpl(SampledSubgraph):
r"""Sampled subgraph of CSCSamplingGraph.
r"""Sampled subgraph of FusedCSCSamplingGraph.
Examples
--------
......
......@@ -22,7 +22,7 @@ class UniformNegativeSampler(NegativeSampler):
----------
datapipe : DataPipe
The datapipe.
graph : CSCSamplingGraph
graph : FusedCSCSamplingGraph
The graph on which to perform negative sampling.
negative_ratio : int
The proportion of negative samples to positive samples.
......@@ -32,7 +32,7 @@ class UniformNegativeSampler(NegativeSampler):
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> graph = gb.from_csc(indptr, indices)
>>> graph = gb.from_fused_csc(indptr, indices)
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler(
......
......@@ -694,8 +694,10 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
new_g = dgl.graphbolt.load_csc_sampling_graph(
os.path.join(test_dir, f"part{part_id}/csc_sampling_graph.tar")
new_g = dgl.graphbolt.load_fused_csc_sampling_graph(
os.path.join(
test_dir, f"part{part_id}/fused_csc_sampling_graph.tar"
)
)
orig_indptr, orig_indices, _ = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr)
......@@ -725,8 +727,10 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
new_g = dgl.graphbolt.load_csc_sampling_graph(
os.path.join(test_dir, f"part{part_id}/csc_sampling_graph.tar")
new_g = dgl.graphbolt.load_fused_csc_sampling_graph(
os.path.join(
test_dir, f"part{part_id}/fused_csc_sampling_graph.tar"
)
)
orig_indptr, orig_indices, _ = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr)
......
......@@ -17,7 +17,7 @@ def rand_csc_graph(N, density):
indptr = torch.LongTensor(adj.indptr)
indices = torch.LongTensor(adj.indices)
graph = gb.from_csc(indptr, indices)
graph = gb.from_fused_csc(indptr, indices)
return graph
......
......@@ -26,7 +26,7 @@ mp.set_sharing_strategy("file_system")
def test_empty_graph(total_num_nodes):
csc_indptr = torch.zeros((total_num_nodes + 1,), dtype=int)
indices = torch.tensor([])
graph = gb.from_csc(csc_indptr, indices)
graph = gb.from_fused_csc(csc_indptr, indices)
assert graph.total_num_edges == 0
assert graph.total_num_nodes == total_num_nodes
assert torch.equal(graph.csc_indptr, csc_indptr)
......@@ -52,7 +52,7 @@ def test_hetero_empty_graph(total_num_nodes):
node_type_offset[0] = 0
node_type_offset[-1] = total_num_nodes
type_per_edge = torch.tensor([])
graph = gb.from_csc(
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset,
......@@ -119,7 +119,9 @@ def test_homo_graph(total_num_nodes, total_num_edges):
"A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges),
}
graph = gb.from_csc(csc_indptr, indices, edge_attributes=edge_attributes)
graph = gb.from_fused_csc(
csc_indptr, indices, edge_attributes=edge_attributes
)
assert graph.total_num_nodes == total_num_nodes
assert graph.total_num_edges == total_num_edges
......@@ -156,7 +158,7 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
"A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges),
}
graph = gb.from_csc(
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset,
......@@ -193,7 +195,9 @@ def test_num_nodes_homo(total_num_nodes, total_num_edges):
"A1": torch.randn(total_num_edges),
"A2": torch.randn(total_num_edges),
}
graph = gb.from_csc(csc_indptr, indices, edge_attributes=edge_attributes)
graph = gb.from_fused_csc(
csc_indptr, indices, edge_attributes=edge_attributes
)
assert graph.num_nodes == total_num_nodes
......@@ -239,9 +243,9 @@ def test_num_nodes_hetero():
assert node_type_offset[-1] == total_num_nodes
assert all(type_per_edge < len(etypes))
# Construct CSCSamplingGraph.
# Construct FusedCSCSamplingGraph.
metadata = gb.GraphMetadata(ntypes, etypes)
graph = gb.from_csc(
graph = gb.from_fused_csc(
indptr, indices, node_type_offset, type_per_edge, None, metadata
)
......@@ -273,7 +277,7 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
10, 50, num_ntypes, 5
)
with pytest.raises(Exception):
gb.from_csc(
gb.from_fused_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
)
......@@ -290,12 +294,12 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges
)
graph = gb.from_csc(csc_indptr, indices)
graph = gb.from_fused_csc(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "csc_sampling_graph.tar")
gb.save_csc_sampling_graph(graph, filename)
graph2 = gb.load_csc_sampling_graph(filename)
filename = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, filename)
graph2 = gb.load_fused_csc_sampling_graph(filename)
assert graph.total_num_nodes == graph2.total_num_nodes
assert graph.total_num_edges == graph2.total_num_edges
......@@ -329,14 +333,14 @@ def test_load_save_hetero_graph(
) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes
)
graph = gb.from_csc(
graph = gb.from_fused_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
)
with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "csc_sampling_graph.tar")
gb.save_csc_sampling_graph(graph, filename)
graph2 = gb.load_csc_sampling_graph(filename)
filename = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, filename)
graph2 = gb.load_fused_csc_sampling_graph(filename)
assert graph.total_num_nodes == graph2.total_num_nodes
assert graph.total_num_edges == graph2.total_num_edges
......@@ -361,7 +365,7 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges):
csc_indptr, indices = gbt.random_homo_graph(
total_num_nodes, total_num_edges
)
graph = gb.from_csc(csc_indptr, indices)
graph = gb.from_fused_csc(csc_indptr, indices)
serialized = pickle.dumps(graph)
graph2 = pickle.loads(serialized)
......@@ -402,7 +406,7 @@ def test_pickle_hetero_graph(
"a": torch.randn((total_num_edges,)),
"b": torch.randint(1, 10, (total_num_edges,)),
}
graph = gb.from_csc(
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset,
......@@ -453,7 +457,7 @@ def test_multiprocessing():
edge_attributes = {
"a": torch.randn((total_num_edges,)),
}
graph = gb.from_csc(
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset,
......@@ -489,8 +493,8 @@ def test_in_subgraph_homogeneous():
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices)
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices)
# Extract in subgraph.
nodes = torch.LongTensor([1, 3, 4])
......@@ -552,9 +556,9 @@ def test_in_subgraph_heterogeneous():
assert node_type_offset[-1] == total_num_nodes
assert all(type_per_edge < len(etypes))
# Construct CSCSamplingGraph.
# Construct FusedCSCSamplingGraph.
metadata = gb.GraphMetadata(ntypes, etypes)
graph = gb.from_csc(
graph = gb.from_fused_csc(
indptr, indices, node_type_offset, type_per_edge, None, metadata
)
......@@ -599,8 +603,8 @@ def test_sample_neighbors_homo():
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices)
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
......@@ -642,8 +646,8 @@ def test_sample_neighbors_hetero(labor):
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(
indptr,
indices,
node_type_offset=node_type_offset,
......@@ -748,8 +752,8 @@ def test_sample_neighbors_fanouts(
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(
indptr,
indices,
node_type_offset=node_type_offset,
......@@ -806,8 +810,8 @@ def test_sample_neighbors_replace(
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(
indptr,
indices,
node_type_offset=node_type_offset,
......@@ -849,8 +853,8 @@ def test_sample_neighbors_return_eids_homo(labor):
# Add edge id mapping from CSC graph -> original graph.
edge_attributes = {gb.ORIGINAL_EDGE_ID: torch.randperm(total_num_edges)}
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes)
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
......@@ -897,8 +901,8 @@ def test_sample_neighbors_return_eids_hetero(labor):
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(
indptr,
indices,
node_type_offset=node_type_offset,
......@@ -956,8 +960,8 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
"mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1]),
}
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes)
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
......@@ -1002,8 +1006,8 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
edge_attributes = {"probs_or_mask": probs_or_mask}
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes)
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
......@@ -1038,7 +1042,7 @@ def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
@unittest.skipIf(
F._default_context_str == "gpu",
reason="CSCSamplingGraph is only supported on CPU.",
reason="FusedCSCSamplingGraph is only supported on CPU.",
)
@pytest.mark.parametrize(
"total_num_nodes, total_num_edges",
......@@ -1058,7 +1062,9 @@ def test_homo_graph_on_shared_memory(
}
else:
edge_attributes = None
graph = gb.from_csc(csc_indptr, indices, edge_attributes=edge_attributes)
graph = gb.from_fused_csc(
csc_indptr, indices, edge_attributes=edge_attributes
)
shm_name = "test_homo_g"
graph1 = graph.copy_to_shared_memory(shm_name)
......@@ -1099,7 +1105,7 @@ def test_homo_graph_on_shared_memory(
@unittest.skipIf(
F._default_context_str == "gpu",
reason="CSCSamplingGraph is only supported on CPU.",
reason="FusedCSCSamplingGraph is only supported on CPU.",
)
@pytest.mark.parametrize(
"total_num_nodes, total_num_edges",
......@@ -1127,7 +1133,7 @@ def test_hetero_graph_on_shared_memory(
}
else:
edge_attributes = None
graph = gb.from_csc(
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset=node_type_offset,
......@@ -1250,7 +1256,7 @@ def test_multiprocessing_with_shared_memory():
node_type_offset.share_memory_()
type_per_edge.share_memory_()
graph = gb.from_csc(
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset=node_type_offset,
......@@ -1308,7 +1314,7 @@ def test_from_dglgraph_homogeneous():
gb_g = gb.from_dglgraph(
dgl_g, is_homogeneous=True, include_original_edge_id=True
)
# Get the COO representation of the CSCSamplingGraph.
# Get the COO representation of the FusedCSCSamplingGraph.
num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1]
rows = gb_g.indices
columns = torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns)
......@@ -1360,21 +1366,21 @@ def test_from_dglgraph_heterogeneous():
dgl_g, is_homogeneous=False, include_original_edge_id=True
)
# `reverse_node_id` is used to map the node id in CSCSamplingGraph to the
# `reverse_node_id` is used to map the node id in FusedCSCSamplingGraph to the
# node id in Hetero-DGLGraph.
num_ntypes = gb_g.node_type_offset[1:] - gb_g.node_type_offset[:-1]
reverse_node_id = torch.cat([torch.arange(num) for num in num_ntypes])
# Get the COO representation of the CSCSamplingGraph.
# Get the COO representation of the FusedCSCSamplingGraph.
num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1]
rows = reverse_node_id[gb_g.indices]
columns = reverse_node_id[
torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns)
]
# Check the order of etypes in DGLGraph is the same as CSCSamplingGraph.
# Check the order of etypes in DGLGraph is the same as FusedCSCSamplingGraph.
assert (
# Since the etypes in CSCSamplingGraph is "srctype:etype:dsttype",
# Since the etypes in FusedCSCSamplingGraph is "srctype:etype:dsttype",
# we need to split the string and get the middle part.
list(
map(
......@@ -1463,8 +1469,8 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0]),
}
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes)
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1])
......@@ -1547,8 +1553,8 @@ def test_sample_neighbors_hetero_pick_number(
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0, 0, 0, 0]),
}
# Construct CSCSamplingGraph.
graph = gb.from_csc(
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(
indptr,
indices,
edge_attributes=edge_attributes,
......@@ -1636,8 +1642,8 @@ def test_csc_sampling_graph_to_device():
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0, 0, 0, 0]),
}
# Construct CSCSamplingGraph.
graph = gb.from_csc(
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(
indptr,
indices,
edge_attributes=edge_attributes,
......
......@@ -68,7 +68,7 @@ def test_UniformNegativeSampler_invoke():
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_Uniform_NegativeSampler(negative_ratio):
# Construct CSCSamplingGraph.
# Construct FusedCSCSamplingGraph.
graph = gb_test_utils.rand_csc_graph(100, 0.05)
num_seeds = 30
item_set = gb.ItemSet(
......@@ -110,7 +110,7 @@ def get_hetero_graph():
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
return gb.from_csc(
return gb.from_fused_csc(
indptr,
indices,
node_type_offset=node_type_offset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment