Unverified Commit 223a3da5 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Fix multiple bugs and code refactor (#3841)



* fix

* remove setcxx methods

* move pin flag to CSR and COO matrix
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 12e97c54
...@@ -47,6 +47,8 @@ struct COOMatrix { ...@@ -47,6 +47,8 @@ struct COOMatrix {
bool row_sorted = false; bool row_sorted = false;
/*! \brief whether the column indices per row are sorted */ /*! \brief whether the column indices per row are sorted */
bool col_sorted = false; bool col_sorted = false;
/*! \brief whether the matrix is in pinned memory */
bool is_pinned = false;
/*! \brief default constructor */ /*! \brief default constructor */
COOMatrix() = default; COOMatrix() = default;
/*! \brief constructor */ /*! \brief constructor */
...@@ -139,11 +141,14 @@ struct COOMatrix { ...@@ -139,11 +141,14 @@ struct COOMatrix {
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
inline void PinMemory_() { inline void PinMemory_() {
if (is_pinned)
return;
row.PinMemory_(); row.PinMemory_();
col.PinMemory_(); col.PinMemory_();
if (!aten::IsNullArray(data)) { if (!aten::IsNullArray(data)) {
data.PinMemory_(); data.PinMemory_();
} }
is_pinned = true;
} }
/*! /*!
...@@ -154,11 +159,14 @@ struct COOMatrix { ...@@ -154,11 +159,14 @@ struct COOMatrix {
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
*/ */
inline void UnpinMemory_() { inline void UnpinMemory_() {
if (!is_pinned)
return;
row.UnpinMemory_(); row.UnpinMemory_();
col.UnpinMemory_(); col.UnpinMemory_();
if (!aten::IsNullArray(data)) { if (!aten::IsNullArray(data)) {
data.UnpinMemory_(); data.UnpinMemory_();
} }
is_pinned = false;
} }
}; };
......
...@@ -44,6 +44,8 @@ struct CSRMatrix { ...@@ -44,6 +44,8 @@ struct CSRMatrix {
IdArray data; IdArray data;
/*! \brief whether the column indices per row are sorted */ /*! \brief whether the column indices per row are sorted */
bool sorted = false; bool sorted = false;
/*! \brief whether the matrix is in pinned memory */
bool is_pinned = false;
/*! \brief default constructor */ /*! \brief default constructor */
CSRMatrix() = default; CSRMatrix() = default;
/*! \brief constructor */ /*! \brief constructor */
...@@ -132,11 +134,14 @@ struct CSRMatrix { ...@@ -132,11 +134,14 @@ struct CSRMatrix {
* The context check is deferred to pinning the NDArray. * The context check is deferred to pinning the NDArray.
*/ */
inline void PinMemory_() { inline void PinMemory_() {
if (is_pinned)
return;
indptr.PinMemory_(); indptr.PinMemory_();
indices.PinMemory_(); indices.PinMemory_();
if (!aten::IsNullArray(data)) { if (!aten::IsNullArray(data)) {
data.PinMemory_(); data.PinMemory_();
} }
is_pinned = true;
} }
/*! /*!
...@@ -147,11 +152,14 @@ struct CSRMatrix { ...@@ -147,11 +152,14 @@ struct CSRMatrix {
* The context check is deferred to unpinning the NDArray. * The context check is deferred to unpinning the NDArray.
*/ */
inline void UnpinMemory_() { inline void UnpinMemory_() {
if (!is_pinned)
return;
indptr.UnpinMemory_(); indptr.UnpinMemory_();
indices.UnpinMemory_(); indices.UnpinMemory_();
if (!aten::IsNullArray(data)) { if (!aten::IsNullArray(data)) {
data.UnpinMemory_(); data.UnpinMemory_();
} }
is_pinned = false;
} }
}; };
......
...@@ -437,21 +437,6 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -437,21 +437,6 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const = 0; virtual aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const = 0;
/*!
* \brief Set the COO matrix representation for a given edge type.
*/
virtual void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) = 0;
/*!
* \brief Set the CSR matrix representation for a given edge type.
*/
virtual void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) = 0;
/*!
* \brief Set the CSC matrix representation for a given edge type.
*/
virtual void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) = 0;
/*! /*!
* \brief Extract the induced subgraph by the given vertices. * \brief Extract the induced subgraph by the given vertices.
* *
......
...@@ -30,7 +30,7 @@ def prepare_tensor(g, data, name): ...@@ -30,7 +30,7 @@ def prepare_tensor(g, data, name):
Data in tensor object. Data in tensor object.
""" """
if F.is_tensor(data): if F.is_tensor(data):
if not g.is_pinned() and (F.dtype(data) != g.idtype or F.context(data) != g.device): if (F.dtype(data) != g.idtype or F.context(data) != g.device) and not g.is_pinned():
raise DGLError('Expect argument "{}" to have data type {} and device ' raise DGLError('Expect argument "{}" to have data type {} and device '
'context {}. But got {} and {}.'.format( 'context {}. But got {} and {}.'.format(
name, g.idtype, g.device, F.dtype(data), F.context(data))) name, g.idtype, g.device, F.dtype(data), F.context(data)))
......
...@@ -309,6 +309,8 @@ HeteroGraphPtr HeteroGraph::CopyToSharedMem( ...@@ -309,6 +309,8 @@ HeteroGraphPtr HeteroGraph::CopyToSharedMem(
std::vector<HeteroGraphPtr> relgraphs(g->NumEdgeTypes()); std::vector<HeteroGraphPtr> relgraphs(g->NumEdgeTypes());
for (dgl_type_t etype = 0 ; etype < g->NumEdgeTypes() ; ++etype) { for (dgl_type_t etype = 0 ; etype < g->NumEdgeTypes() ; ++etype) {
auto src_dst_type = g->GetEndpointTypes(etype);
int num_vtypes = (src_dst_type.first == src_dst_type.second ? 1 : 2);
aten::COOMatrix coo; aten::COOMatrix coo;
aten::CSRMatrix csr, csc; aten::CSRMatrix csr, csc;
std::string prefix = name + "_" + std::to_string(etype); std::string prefix = name + "_" + std::to_string(etype);
...@@ -321,7 +323,8 @@ HeteroGraphPtr HeteroGraph::CopyToSharedMem( ...@@ -321,7 +323,8 @@ HeteroGraphPtr HeteroGraph::CopyToSharedMem(
if (has_csc) { if (has_csc) {
csc = shm.CopyToSharedMem(hg->GetCSCMatrix(etype), prefix + "_csc"); csc = shm.CopyToSharedMem(hg->GetCSCMatrix(etype), prefix + "_csc");
} }
relgraphs[etype] = UnitGraph::CreateHomographFrom(csc, csr, coo, has_csc, has_csr, has_coo); relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
} }
auto ret = std::shared_ptr<HeteroGraph>( auto ret = std::shared_ptr<HeteroGraph>(
...@@ -361,6 +364,8 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>> ...@@ -361,6 +364,8 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges()); std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
for (dgl_type_t etype = 0 ; etype < metagraph->NumEdges() ; ++etype) { for (dgl_type_t etype = 0 ; etype < metagraph->NumEdges() ; ++etype) {
auto src_dst = metagraph->FindEdge(etype);
int num_vtypes = (src_dst.first == src_dst.second) ? 1 : 2;
aten::COOMatrix coo; aten::COOMatrix coo;
aten::CSRMatrix csr, csc; aten::CSRMatrix csr, csc;
std::string prefix = name + "_" + std::to_string(etype); std::string prefix = name + "_" + std::to_string(etype);
...@@ -374,7 +379,8 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>> ...@@ -374,7 +379,8 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
shm.CreateFromSharedMem(&csc, prefix + "_csc"); shm.CreateFromSharedMem(&csc, prefix + "_csc");
} }
relgraphs[etype] = UnitGraph::CreateHomographFrom(csc, csr, coo, has_csc, has_csr, has_coo); relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
} }
auto ret = std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type); auto ret = std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type);
......
...@@ -272,18 +272,6 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -272,18 +272,6 @@ class HeteroGraph : public BaseHeteroGraph {
return relation_graphs_; return relation_graphs_;
} }
void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override {
GetRelationGraph(etype)->SetCOOMatrix(0, coo);
}
void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override {
GetRelationGraph(etype)->SetCSRMatrix(0, csr);
}
void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override {
GetRelationGraph(etype)->SetCSCMatrix(0, csc);
}
private: private:
// To create empty class // To create empty class
friend class Serializer; friend class Serializer;
......
...@@ -635,8 +635,8 @@ HeteroGraphPtr ImmutableGraph::AsHeteroGraph() const { ...@@ -635,8 +635,8 @@ HeteroGraphPtr ImmutableGraph::AsHeteroGraph() const {
if (coo_) if (coo_)
coo = GetCOO()->ToCOOMatrix(); coo = GetCOO()->ToCOOMatrix();
auto g = UnitGraph::CreateHomographFrom( auto g = UnitGraph::CreateUnitGraphFrom(
in_csr, out_csr, coo, 1, in_csr, out_csr, coo,
in_csr_ != nullptr, in_csr_ != nullptr,
out_csr_ != nullptr, out_csr_ != nullptr,
coo_ != nullptr); coo_ != nullptr);
......
...@@ -196,7 +196,12 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) { ...@@ -196,7 +196,12 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
dgl_format_code_t created_formats, allowed_formats; dgl_format_code_t created_formats, allowed_formats;
CHECK(strm->Read(&created_formats)) << "Invalid code for created formats"; CHECK(strm->Read(&created_formats)) << "Invalid code for created formats";
CHECK(strm->Read(&allowed_formats)) << "Invalid code for allowed formats"; CHECK(strm->Read(&allowed_formats)) << "Invalid code for allowed formats";
HeteroGraphPtr relgraph = nullptr; aten::COOMatrix coo;
aten::CSRMatrix csr;
aten::CSRMatrix csc;
bool has_coo = (created_formats & COO_CODE);
bool has_csr = (created_formats & CSR_CODE);
bool has_csc = (created_formats & CSC_CODE);
if (created_formats & COO_CODE) { if (created_formats & COO_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 2); CHECK_GE(states.arrays.end() - array_itr, 2);
...@@ -206,11 +211,7 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) { ...@@ -206,11 +211,7 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
bool csorted; bool csorted;
CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'"; CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'"; CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
auto coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted); coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
if (!relgraph)
relgraph = CreateFromCOO(num_vtypes, coo, allowed_formats);
else
relgraph->SetCOOMatrix(0, coo);
} }
if (created_formats & CSR_CODE) { if (created_formats & CSR_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 3); CHECK_GE(states.arrays.end() - array_itr, 3);
...@@ -219,11 +220,7 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) { ...@@ -219,11 +220,7 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
const auto &edge_id = *(array_itr++); const auto &edge_id = *(array_itr++);
bool sorted; bool sorted;
CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'"; CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
auto csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted); csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
if (!relgraph)
relgraph = CreateFromCSR(num_vtypes, csr, allowed_formats);
else
relgraph->SetCSRMatrix(0, csr);
} }
if (created_formats & CSC_CODE) { if (created_formats & CSC_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 3); CHECK_GE(states.arrays.end() - array_itr, 3);
...@@ -232,13 +229,10 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) { ...@@ -232,13 +229,10 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
const auto &edge_id = *(array_itr++); const auto &edge_id = *(array_itr++);
bool sorted; bool sorted;
CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'"; CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
auto csc = aten::CSRMatrix(num_dst, num_src, indptr, indices, edge_id, sorted); csc = aten::CSRMatrix(num_dst, num_src, indptr, indices, edge_id, sorted);
if (!relgraph)
relgraph = CreateFromCSC(num_vtypes, csc, allowed_formats);
else
relgraph->SetCSCMatrix(0, csc);
} }
relgraphs[etype] = relgraph; relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo, allowed_formats);
} }
return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type); return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
} }
......
...@@ -134,7 +134,7 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -134,7 +134,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
} }
bool IsPinned() const override { bool IsPinned() const override {
return adj_.row.IsPinned(); return adj_.is_pinned;
} }
uint8_t NumBits() const override { uint8_t NumBits() const override {
...@@ -359,18 +359,6 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -359,18 +359,6 @@ class UnitGraph::COO : public BaseHeteroGraph {
return aten::CSRMatrix(); return aten::CSRMatrix();
} }
void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override {
adj_ = coo;
}
void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override {
LOG(FATAL) << "Not enabled for COO graph";
}
void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override {
LOG(FATAL) << "Not enabled for COO graph";
}
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override { SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
LOG(FATAL) << "Not enabled for COO graph"; LOG(FATAL) << "Not enabled for COO graph";
return SparseFormat::kCOO; return SparseFormat::kCOO;
...@@ -548,7 +536,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -548,7 +536,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
} }
bool IsPinned() const override { bool IsPinned() const override {
return adj_.indices.IsPinned(); return adj_.is_pinned;
} }
uint8_t NumBits() const override { uint8_t NumBits() const override {
...@@ -791,18 +779,6 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -791,18 +779,6 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return adj_; return adj_;
} }
void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override {
LOG(FATAL) << "Not enabled for CSR graph";
}
void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override {
adj_ = csr;
}
void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override {
LOG(FATAL) << "Please use in_csr_->SetCSRMatrix(etype, csc) instead.";
}
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override { SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
LOG(FATAL) << "Not enabled for CSR graph"; LOG(FATAL) << "Not enabled for CSR graph";
return SparseFormat::kCSR; return SparseFormat::kCSR;
...@@ -1370,7 +1346,8 @@ UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr c ...@@ -1370,7 +1346,8 @@ UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr c
CHECK(GetAny()) << "At least one graph structure should exist."; CHECK(GetAny()) << "At least one graph structure should exist.";
} }
HeteroGraphPtr UnitGraph::CreateHomographFrom( HeteroGraphPtr UnitGraph::CreateUnitGraphFrom(
int num_vtypes,
const aten::CSRMatrix &in_csr, const aten::CSRMatrix &in_csr,
const aten::CSRMatrix &out_csr, const aten::CSRMatrix &out_csr,
const aten::COOMatrix &coo, const aten::COOMatrix &coo,
...@@ -1378,7 +1355,7 @@ HeteroGraphPtr UnitGraph::CreateHomographFrom( ...@@ -1378,7 +1355,7 @@ HeteroGraphPtr UnitGraph::CreateHomographFrom(
bool has_out_csr, bool has_out_csr,
bool has_coo, bool has_coo,
dgl_format_code_t formats) { dgl_format_code_t formats) {
auto mg = CreateUnitGraphMetaGraph1(); auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr in_csr_ptr = nullptr; CSRPtr in_csr_ptr = nullptr;
CSRPtr out_csr_ptr = nullptr; CSRPtr out_csr_ptr = nullptr;
...@@ -1512,54 +1489,6 @@ aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const { ...@@ -1512,54 +1489,6 @@ aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
return GetCOO()->adj(); return GetCOO()->adj();
} }
void UnitGraph::SetCOOMatrix(dgl_type_t etype, COOMatrix coo) {
if (!(formats_ & COO_CODE)) {
LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot set COO matrix.";
return;
}
if (IsPinned()) {
LOG(FATAL) << "Cannot set COOMatrix if the graph is pinned, please unpin the graph.";
return;
}
if (!coo_->defined())
*(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), coo);
else
coo_->SetCOOMatrix(0, coo);
}
void UnitGraph::SetCSRMatrix(dgl_type_t etype, CSRMatrix csr) {
if (!(formats_ & CSR_CODE)) {
LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot set CSR matrix.";
return;
}
if (IsPinned()) {
LOG(FATAL) << "Cannot set CSRMatrix if the graph is pinned, please unpin the graph.";
return;
}
if (!out_csr_->defined())
*(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), csr);
else
out_csr_->SetCSRMatrix(0, csr);
}
void UnitGraph::SetCSCMatrix(dgl_type_t etype, CSRMatrix csc) {
if (!(formats_ & CSC_CODE)) {
LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot set CSC matrix.";
return;
}
if (IsPinned()) {
LOG(FATAL) << "Cannot set CSCMatrix if the graph is pinned, please unpin the graph.";
return;
}
if (!in_csr_->defined())
*(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), csc);
else
in_csr_->SetCSRMatrix(0, csc);
}
HeteroGraphPtr UnitGraph::GetAny() const { HeteroGraphPtr UnitGraph::GetAny() const {
if (in_csr_->defined()) { if (in_csr_->defined()) {
return in_csr_; return in_csr_;
......
...@@ -305,14 +305,11 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -305,14 +305,11 @@ class UnitGraph : public BaseHeteroGraph {
void InvalidateCOO(); void InvalidateCOO();
void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override;
void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override;
void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override;
private: private:
friend class Serializer; friend class Serializer;
friend class HeteroGraph; friend class HeteroGraph;
friend class ImmutableGraph; friend class ImmutableGraph;
friend HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states);
// private empty constructor // private empty constructor
UnitGraph() {} UnitGraph() {}
...@@ -329,6 +326,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -329,6 +326,7 @@ class UnitGraph : public BaseHeteroGraph {
/*! /*!
* \brief constructor * \brief constructor
* \param num_vtypes number of vertex types (1 or 2)
* \param metagraph metagraph * \param metagraph metagraph
* \param in_csr in edge csr * \param in_csr in edge csr
* \param out_csr out edge csr * \param out_csr out edge csr
...@@ -337,7 +335,8 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -337,7 +335,8 @@ class UnitGraph : public BaseHeteroGraph {
* \param has_out_csr whether out_csr is valid * \param has_out_csr whether out_csr is valid
* \param has_coo whether coo is valid * \param has_coo whether coo is valid
*/ */
static HeteroGraphPtr CreateHomographFrom( static HeteroGraphPtr CreateUnitGraphFrom(
int num_vtypes,
const aten::CSRMatrix &in_csr, const aten::CSRMatrix &in_csr,
const aten::CSRMatrix &out_csr, const aten::CSRMatrix &out_csr,
const aten::COOMatrix &coo, const aten::COOMatrix &coo,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment