Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
08b60eb1
Unverified
Commit
08b60eb1
authored
Dec 15, 2022
by
czkkkkkk
Committed by
GitHub
Dec 15, 2022
Browse files
[Sparse] Add SpMM and SDDMM on CSR and COO in dgl include headers (#5016)
parent
0038a29b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
187 additions
and
23 deletions
+187
-23
include/dgl/aten/coo.h
include/dgl/aten/coo.h
+46
-8
include/dgl/aten/csr.h
include/dgl/aten/csr.h
+48
-9
src/array/array.cc
src/array/array.cc
+93
-6
No files found.
include/dgl/aten/coo.h
View file @
08b60eb1
...
...
@@ -431,13 +431,9 @@ COOMatrix COOReorder(
* value array.
*/
std
::
pair
<
COOMatrix
,
FloatArray
>
COOLaborSampling
(
COOMatrix
mat
,
IdArray
rows
,
int64_t
num_samples
,
FloatArray
prob
=
NullArray
(),
int
importance_sampling
=
0
,
IdArray
random_seed
=
NullArray
(),
IdArray
NIDs
=
NullArray
());
COOMatrix
mat
,
IdArray
rows
,
int64_t
num_samples
,
FloatArray
prob
=
NullArray
(),
int
importance_sampling
=
0
,
IdArray
random_seed
=
NullArray
(),
IdArray
NIDs
=
NullArray
());
/**
* @brief Randomly select a fixed number of non-zero entries along each given
...
...
@@ -785,6 +781,48 @@ COOMatrix COOSliceContiguousChunk(
*/
COOMatrix
COOLineGraph
(
const
COOMatrix
&
coo
,
bool
backtracking
);
/**
* @brief Generalized Sparse Matrix-Matrix Multiplication on COO.
* @param op The binary operator, could be `add`, `sub', `mul`, 'div',
* `copy_u`, `copy_e'.
* @param op The reduce operator, could be `sum`, `min`, `max'.
* @param coo The COO we apply SpMM on.
* @param ufeat The source node feature.
* @param efeat The edge feature.
* @param out The output feature on destination nodes.
* @param out_aux A list of NDArray's that contains auxiliary information such
* as the argmax on source nodes and edges for reduce operators such as
* `min` and `max`.
*/
void
COOSpMM
(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
/** @brief COOSpMM C interface without std::string. */
void
COOSpMM
(
const
char
*
op
,
const
char
*
reduce
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
/**
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on COO.
* @param op The binary operator, could be `add`, `sub', `mul`, 'div',
* `dot`, `copy_u`, `copy_e'.
* @param coo The COO we apply SpMM on.
* @param ufeat The source node feature.
* @param vfeat The destination node feature.
* @param out The output feature on edge.
* @param lhs_target Type of `ufeat` (0: source, 1: edge, 2: destination).
* @param rhs_target Type of `ufeat` (0: source, 1: edge, 2: destination).
*/
void
COOSDDMM
(
const
std
::
string
&
op
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
/** @brief COOSDDMM C interface without std::string. */
void
COOSDDMM
(
const
char
*
op
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
}
// namespace aten
}
// namespace dgl
...
...
include/dgl/aten/csr.h
View file @
08b60eb1
...
...
@@ -459,16 +459,13 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
* array.
*/
std
::
pair
<
COOMatrix
,
FloatArray
>
CSRLaborSampling
(
CSRMatrix
mat
,
IdArray
rows
,
int64_t
num_samples
,
FloatArray
prob
=
NullArray
(),
int
importance_sampling
=
0
,
IdArray
random_seed
=
NullArray
(),
IdArray
NIDs
=
NullArray
());
CSRMatrix
mat
,
IdArray
rows
,
int64_t
num_samples
,
FloatArray
prob
=
NullArray
(),
int
importance_sampling
=
0
,
IdArray
random_seed
=
NullArray
(),
IdArray
NIDs
=
NullArray
());
/*!
* @brief Randomly select a fixed number of non-zero entries along each given row independently.
* @brief Randomly select a fixed number of non-zero entries along each given
* row independently.
*
* The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix.
...
...
@@ -895,6 +892,48 @@ CSRMatrix CSRSliceContiguousChunk(
const
std
::
vector
<
uint64_t
>&
src_vertex_range
,
const
std
::
vector
<
uint64_t
>&
dst_vertex_range
);
/**
* @brief Generalized Sparse Matrix-Matrix Multiplication on CSR.
* @param op The binary operator, could be `add`, `sub', `mul`, 'div',
* `copy_u`, `copy_e'.
* @param op The reduce operator, could be `sum`, `min`, `max'.
* @param csr The CSR we apply SpMM on.
* @param ufeat The source node feature.
* @param efeat The edge feature.
* @param out The output feature on destination nodes.
* @param out_aux A list of NDArray's that contains auxiliary information such
* as the argmax on source nodes and edges for reduce operators such as
* `min` and `max`.
*/
void
CSRSpMM
(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
/** @brief CSRSpMM C interface without std::string. */
void
CSRSpMM
(
const
char
*
op
,
const
char
*
reduce
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
);
/**
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on CSR.
* @param op The binary operator, could be `add`, `sub', `mul`, 'div',
* `dot`, `copy_u`, `copy_e'.
* @param csr The CSR we apply SpMM on.
* @param ufeat The source node feature.
* @param vfeat The destination node feature.
* @param out The output feature on edge.
* @param lhs_target Type of `ufeat` (0: source, 1: edge, 2: destination).
* @param rhs_target Type of `ufeat` (0: source, 1: edge, 2: destination).
*/
void
CSRSDDMM
(
const
std
::
string
&
op
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
/** @brief CSRSDDMM C interface without std::string. */
void
CSRSDDMM
(
const
char
*
op
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
);
}
// namespace aten
}
// namespace dgl
...
...
src/array/array.cc
View file @
08b60eb1
...
...
@@ -4,6 +4,7 @@
* @brief DGL array utilities implementation
*/
#include <dgl/array.h>
#include <dgl/bcast.h>
#include <dgl/graph_traversal.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
...
...
@@ -15,6 +16,7 @@
#include "../c_api_common.h"
#include "./arith.h"
#include "./array_op.h"
#include "./kernel_decl.h"
using
namespace
dgl
::
runtime
;
...
...
@@ -545,9 +547,8 @@ std::pair<COOMatrix, FloatArray> CSRLaborSampling(
int
importance_sampling
,
IdArray
random_seed
,
IdArray
NIDs
)
{
std
::
pair
<
COOMatrix
,
FloatArray
>
ret
;
ATEN_CSR_SWITCH_CUDA_UVA
(
mat
,
rows
,
XPU
,
IdType
,
"CSRLaborSampling"
,
{
const
auto
dtype
=
IsNullArray
(
prob
)
?
DGLDataTypeTraits
<
float
>::
dtype
:
prob
->
dtype
;
const
auto
dtype
=
IsNullArray
(
prob
)
?
DGLDataTypeTraits
<
float
>::
dtype
:
prob
->
dtype
;
ATEN_FLOAT_TYPE_SWITCH
(
dtype
,
FloatType
,
"probability"
,
{
ret
=
impl
::
CSRLaborSampling
<
XPU
,
IdType
,
FloatType
>
(
mat
,
rows
,
num_samples
,
prob
,
importance_sampling
,
random_seed
,
NIDs
);
...
...
@@ -819,9 +820,8 @@ std::pair<COOMatrix, FloatArray> COOLaborSampling(
int
importance_sampling
,
IdArray
random_seed
,
IdArray
NIDs
)
{
std
::
pair
<
COOMatrix
,
FloatArray
>
ret
;
ATEN_COO_SWITCH
(
mat
,
XPU
,
IdType
,
"COOLaborSampling"
,
{
const
auto
dtype
=
IsNullArray
(
prob
)
?
DGLDataTypeTraits
<
float
>::
dtype
:
prob
->
dtype
;
const
auto
dtype
=
IsNullArray
(
prob
)
?
DGLDataTypeTraits
<
float
>::
dtype
:
prob
->
dtype
;
ATEN_FLOAT_TYPE_SWITCH
(
dtype
,
FloatType
,
"probability"
,
{
ret
=
impl
::
COOLaborSampling
<
XPU
,
IdType
,
FloatType
>
(
mat
,
rows
,
num_samples
,
prob
,
importance_sampling
,
random_seed
,
NIDs
);
...
...
@@ -1088,6 +1088,93 @@ Frontiers DGLDFSLabeledEdges(
return
ret
;
}
void
CSRSpMM
(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
)
{
const
auto
&
bcast
=
CalcBcastOff
(
op
,
ufeat
,
efeat
);
ATEN_XPU_SWITCH_CUDA
(
csr
.
indptr
->
ctx
.
device_type
,
XPU
,
"SpMM"
,
{
ATEN_ID_TYPE_SWITCH
(
csr
.
indptr
->
dtype
,
IdType
,
{
ATEN_FLOAT_TYPE_SWITCH_16BITS
(
out
->
dtype
,
Dtype
,
XPU
,
"Feature data"
,
{
SpMMCsr
<
XPU
,
IdType
,
Dtype
>
(
op
,
reduce
,
bcast
,
csr
,
ufeat
,
efeat
,
out
,
out_aux
);
});
});
});
}
void
CSRSpMM
(
const
char
*
op
,
const
char
*
reduce
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
)
{
CSRSpMM
(
std
::
string
(
op
),
std
::
string
(
reduce
),
csr
,
ufeat
,
efeat
,
out
,
out_aux
);
}
void
CSRSDDMM
(
const
std
::
string
&
op
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
)
{
const
auto
&
bcast
=
CalcBcastOff
(
op
,
ufeat
,
efeat
);
ATEN_XPU_SWITCH_CUDA
(
csr
.
indptr
->
ctx
.
device_type
,
XPU
,
"SDDMM"
,
{
ATEN_ID_TYPE_SWITCH
(
csr
.
indptr
->
dtype
,
IdType
,
{
ATEN_FLOAT_TYPE_SWITCH_16BITS
(
out
->
dtype
,
Dtype
,
XPU
,
"Feature data"
,
{
SDDMMCsr
<
XPU
,
IdType
,
Dtype
>
(
op
,
bcast
,
csr
,
ufeat
,
efeat
,
out
,
lhs_target
,
rhs_target
);
});
});
});
}
void
CSRSDDMM
(
const
char
*
op
,
const
CSRMatrix
&
csr
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
)
{
return
CSRSDDMM
(
std
::
string
(
op
),
csr
,
ufeat
,
efeat
,
out
,
lhs_target
,
rhs_target
);
}
void
COOSpMM
(
const
std
::
string
&
op
,
const
std
::
string
&
reduce
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
)
{
const
auto
&
bcast
=
CalcBcastOff
(
op
,
ufeat
,
efeat
);
ATEN_XPU_SWITCH_CUDA
(
coo
.
row
->
ctx
.
device_type
,
XPU
,
"SpMM"
,
{
ATEN_ID_TYPE_SWITCH
(
coo
.
row
->
dtype
,
IdType
,
{
ATEN_FLOAT_TYPE_SWITCH_16BITS
(
out
->
dtype
,
Dtype
,
XPU
,
"Feature data"
,
{
SpMMCoo
<
XPU
,
IdType
,
Dtype
>
(
op
,
reduce
,
bcast
,
coo
,
ufeat
,
efeat
,
out
,
out_aux
);
});
});
});
}
void
COOSpMM
(
const
char
*
op
,
const
char
*
reduce
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
std
::
vector
<
NDArray
>
out_aux
)
{
COOSpMM
(
std
::
string
(
op
),
std
::
string
(
reduce
),
coo
,
ufeat
,
efeat
,
out
,
out_aux
);
}
void
COOSDDMM
(
const
std
::
string
&
op
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
)
{
const
auto
&
bcast
=
CalcBcastOff
(
op
,
ufeat
,
efeat
);
ATEN_XPU_SWITCH_CUDA
(
coo
.
row
->
ctx
.
device_type
,
XPU
,
"SDDMM"
,
{
ATEN_ID_TYPE_SWITCH
(
coo
.
row
->
dtype
,
IdType
,
{
ATEN_FLOAT_TYPE_SWITCH_16BITS
(
out
->
dtype
,
Dtype
,
XPU
,
"Feature data"
,
{
SDDMMCoo
<
XPU
,
IdType
,
Dtype
>
(
op
,
bcast
,
coo
,
ufeat
,
efeat
,
out
,
lhs_target
,
rhs_target
);
});
});
});
}
void
COOSDDMM
(
const
char
*
op
,
const
COOMatrix
&
coo
,
NDArray
ufeat
,
NDArray
efeat
,
NDArray
out
,
int
lhs_target
,
int
rhs_target
)
{
COOSDDMM
(
std
::
string
(
op
),
coo
,
ufeat
,
efeat
,
out
,
lhs_target
,
rhs_target
);
}
///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL
(
"ndarray._CAPI_DGLSparseMatrixGetFormat"
)
.
set_body
([](
DGLArgs
args
,
DGLRetValue
*
rv
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment