Unverified Commit bc8f8b0b authored by ranzhejiang's avatar ranzhejiang Committed by GitHub
Browse files

New fused edge_softmax op (#3650)



* [feature] edge softmax refact.

* delete file

* fix backward and cmake version

* fix backward

* format function

* fix setting

* refix

* refix

* refix

* refix

* refix

* refix

* refix

* refix

* refix

* refix

* refix

* refix

* add cuda kernel for backward and rename some function

* add benchmark for edge_softmax

* fix format

* remove cuda_backwrd

* fix code format and add comment for op on CPU

* fix lint
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 45ac5726
import time
import dgl
import torch
from .. import utils
# The benchmarks for ops edge_softmax
@utils.benchmark('time', timeout=600)
@utils.parametrize('graph', ['ogbn-arxiv', 'reddit', 'cora', 'pubmed'])
@utils.parametrize('num_heads', [1, 4, 8])
def track_time(graph, num_heads):
device = utils.get_bench_device()
graph = utils.get_graph(graph).to(device)
score = torch.randn((graph.num_edges(),num_heads)).requires_grad_(True).float().to(device)
# dry run
for i in range(3):
y = dgl.ops.edge_softmax(graph, score)
# timing
with utils.Timer(device) as t:
for i in range(100):
y = dgl.ops.edge_softmax(graph, score)
return t.elapsed_secs / 100
import torch as th import torch as th
from distutils.version import LooseVersion from distutils.version import LooseVersion
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _edge_softmax_forward, _edge_softmax_backward
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...heterograph_index import create_unitgraph_from_csr from ...heterograph_index import create_unitgraph_from_csr
...@@ -470,10 +470,15 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -470,10 +470,15 @@ class EdgeSoftmax(th.autograd.Function):
gidx = gidx.edge_subgraph([eids], True).graph gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src': if norm_by == 'src':
gidx = gidx.reverse() gidx = gidx.reverse()
#Note: Now _edge_softmax_forward op only supports CPU
#TODO(Zhejiang): We will support GPU in the future
if(score.is_cuda):
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0] score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = th.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v')) score = th.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0] score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v') out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v')
else:
out = _edge_softmax_forward(gidx, score, 'copy_rhs')
ctx.backward_cache = gidx ctx.backward_cache = gidx
ctx.save_for_backward(out) ctx.save_for_backward(out)
return out return out
...@@ -500,9 +505,14 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -500,9 +505,14 @@ class EdgeSoftmax(th.autograd.Function):
ctx.backward_cache = None ctx.backward_cache = None
out, = ctx.saved_tensors out, = ctx.saved_tensors
sds = out * grad_out sds = out * grad_out
#Note: Now _edge_softmax_backward op only supports CPU
#TODO(Zhejiang): We will support GPU in the future
if(out.is_cuda):
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds) accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds)
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v') grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v')
else:
grad_score = _edge_softmax_backward(gidx, out, sds)
return None, grad_score, None, None return None, grad_score, None, None
......
...@@ -83,6 +83,67 @@ target_mapping = { ...@@ -83,6 +83,67 @@ target_mapping = {
'dst': 2 'dst': 2
} }
def _edge_softmax_backward(gidx, out, sds):
r""" Edge_softmax backward interface.
Parameters
----------
gidx : HeteroGraphIndex
The input graph index.
out : tensor
The result of Edge_softmax during forward.
sds : tensor
The result of out * gradient.
Returns
-------
The result of Edge_softmax during backward
Notes
-----
This function does not support gpu op.
"""
op = 'copy_rhs'
back_out = F.zeros_like(out)
_CAPI_DGLKernelEdge_softmax_backward(gidx, op,
to_dgl_nd(out),
to_dgl_nd(sds),
to_dgl_nd_for_write(back_out),
to_dgl_nd(None))
return back_out
def _edge_softmax_forward(gidx, e, op):
r""" Edge_softmax forward interface.
Parameters
----------
gidx : HeteroGraphIndex
The input graph index.
op : str
The binary op's name, default as ``copy_rhs``.
e : tensor or None
The feature on edges.
Returns
-------
The result of Edge_softmax during forward
Notes
-----
This function does not support gpu op.
"""
if F.ndim(e) == 1:
e = F.unsqueeze(e, -1)
expand = True
else:
expand = False
myout = F.zeros_like(e)
_CAPI_DGLKernelEdge_softmax_forward(gidx, op,
to_dgl_nd(None),
to_dgl_nd(e),
to_dgl_nd_for_write(myout))
myout = F.squeeze(myout, -1) if expand else myout
return myout
def _gspmm(gidx, op, reduce_op, u, e): def _gspmm(gidx, op, reduce_op, u, e):
r""" Generalized Sparse Matrix Multiplication interface. It takes the result of r""" Generalized Sparse Matrix Multiplication interface. It takes the result of
......
...@@ -192,6 +192,86 @@ template void SpMMCsrHetero<kDLCPU, int64_t, 64>( ...@@ -192,6 +192,86 @@ template void SpMMCsrHetero<kDLCPU, int64_t, 64>(
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
/*! \brief Edge_softmax_csr forward op on Csr format. */
template <int XPU, typename IdType, int bits>
void Edge_softmax_csr_forward(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_forward<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
});
});
}
/*! \brief Edge_softmax_csr backward op on Csr format. */
template <int XPU, typename IdType, int bits>
void Edge_softmax_csr_backward(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray out,
NDArray sds,
NDArray back_out) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_backward<IdType, DType, Op>(bcast, csr, out, sds, back_out);
});
});
}
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
/*! \brief Generalized SpMM on Coo format. */ /*! \brief Generalized SpMM on Coo format. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, int bits>
void SpMMCoo(const std::string& op, const std::string& reduce, void SpMMCoo(const std::string& op, const std::string& reduce,
......
...@@ -9,9 +9,12 @@ ...@@ -9,9 +9,12 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/bcast.h> #include <dgl/bcast.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <math.h>
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <algorithm>
#include <vector>
#include "spmm_binary_ops.h" #include "spmm_binary_ops.h"
#if !defined(_WIN32) #if !defined(_WIN32)
#ifdef USE_AVX #ifdef USE_AVX
...@@ -466,6 +469,99 @@ void SpMMCmpCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, ...@@ -466,6 +469,99 @@ void SpMMCmpCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat,
} }
} }
/*!
* \brief CPU kernel of Edge_softmax_csr_forward on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result of edge_softmax_forward.
*/
template <typename IdType, typename DType, typename Op>
void Edge_softmax_csr_forward(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
NDArray efeat, NDArray out) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const IdType* edges =
has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
const int64_t dim = bcast.out_len, rhs_dim = bcast.rhs_len;
runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
for (auto rid = b; rid < e; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
std::vector<DType> data_e(row_end-row_start, 0);
std::vector<IdType> num(row_end-row_start, 0);
for (int64_t k = 0; k < dim; ++k) {
DType max_v = -std::numeric_limits<DType>::infinity();
for (IdType j = row_start; j < row_end; ++j) {
const IdType eid = has_idx ? edges[j] : j;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* rhs_off =
Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
data_e[j-row_start] = *rhs_off;
num[j-row_start] = eid*rhs_dim+rhs_add;
max_v = std::max<DType>(max_v, (*rhs_off));
}
DType exp_sum = 0;
for (auto& element : data_e) {
element -= max_v;
element = std::exp(element);
exp_sum += element;
}
for (int i=0; i < row_end-row_start; i++) {
out.Ptr<DType>()[num[i]] = data_e[i]/exp_sum;
}
}
}
});
}
/*!
* \brief CPU kernel of Edge_softmax_csr_backward on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param out The result of forward.
* \param sds The result of gradiet * out.
* \param back_out The result of edge_softmax_backward.
*/
template <typename IdType, typename DType, typename Op>
void Edge_softmax_csr_backward(const BcastOff& bcast, const CSRMatrix& csr, NDArray out,
NDArray sds, NDArray back_out) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const IdType* edges =
has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
const DType* W_out = Op::use_rhs ? static_cast<DType*>(out->data) : nullptr;
const DType* W_sds = Op::use_rhs ? static_cast<DType*>(sds->data) : nullptr;
const int64_t dim = bcast.out_len, rhs_dim = bcast.rhs_len;
runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
for (auto rid = b; rid < e; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
for (int64_t k = 0; k < dim; ++k) {
DType sum_sds = 0;
for (IdType j = row_start; j < row_end; ++j) {
const IdType eid = has_idx ? edges[j] : j;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* rhs_off_sds =
Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;
sum_sds += (*rhs_off_sds);
}
for (IdType j = row_start; j< row_end; ++j) {
const IdType eid = has_idx ? edges[j] : j;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* rhs_off_out =
Op::use_rhs ? W_out + eid * rhs_dim + rhs_add : nullptr;
const DType* rhs_off_sds =
Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;
back_out.Ptr<DType>()[eid*rhs_dim+rhs_add] = (*rhs_off_sds) - sum_sds*(*rhs_off_out);
}
}
}
});
}
} // namespace cpu } // namespace cpu
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -131,7 +131,6 @@ void SDDMM(const std::string& op, ...@@ -131,7 +131,6 @@ void SDDMM(const std::string& op,
}); });
} }
/*! /*!
* \brief Find the src/dst/etype id based on the target 'u', 'v' or 'e'. * \brief Find the src/dst/etype id based on the target 'u', 'v' or 'e'.
* *
...@@ -148,7 +147,6 @@ int get_typeid_by_target(HeteroGraphPtr graph, int target, dgl_type_t etype) { ...@@ -148,7 +147,6 @@ int get_typeid_by_target(HeteroGraphPtr graph, int target, dgl_type_t etype) {
return etype; return etype;
} }
/*! \brief Generalized Sampled Dense-Dense Matrix Multiplication. */ /*! \brief Generalized Sampled Dense-Dense Matrix Multiplication. */
void SDDMMHetero(const std::string& op, void SDDMMHetero(const std::string& op,
HeteroGraphPtr graph, HeteroGraphPtr graph,
...@@ -196,6 +194,50 @@ void SDDMMHetero(const std::string& op, ...@@ -196,6 +194,50 @@ void SDDMMHetero(const std::string& op,
}); });
} }
/*! \brief Generalized Edge_softmax op for forward */
void Edge_softmax_forward(const std::string& op,
HeteroGraphPtr graph,
NDArray ufeat,
NDArray efeat,
NDArray out) {
// TODO(zhejiang): add gpu op for edge_softmax
SparseFormat format = graph->SelectFormat(0, CSC_CODE);
const auto& bcast = CalcBcastOff(op, ufeat, efeat);
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out->dtype, bits, "edge_softmax out data", {
Edge_softmax_csr_forward<XPU, IdType, bits>(
op, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out);
});
});
});
}
/*! \brief Generalized Edge_softmax op for backward */
void Edge_softmax_backward(const std::string& op,
HeteroGraphPtr graph,
NDArray out,
NDArray sds,
NDArray back_out,
NDArray ufeat) {
// TODO(zhejiang): add gpu op for edge_softmax
SparseFormat format = graph->SelectFormat(0, CSC_CODE);
const auto& bcast = CalcBcastOff(op, ufeat, sds);
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax_back", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out->dtype, bits, "edge_softmax out data_back", {
Edge_softmax_csr_backward<XPU, IdType, bits>(
op, bcast, graph->GetCSCMatrix(0), out, sds, back_out);
});
});
});
}
NDArray GetEdgeMapping(HeteroGraphRef graph) { NDArray GetEdgeMapping(HeteroGraphRef graph) {
SparseFormat format = graph->SelectFormat(0, CSC_CODE); SparseFormat format = graph->SelectFormat(0, CSC_CODE);
if (format == SparseFormat::kCSC) { if (format == SparseFormat::kCSC) {
...@@ -348,6 +390,27 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM") ...@@ -348,6 +390,27 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM")
SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE}); SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE});
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_forward")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
NDArray U = args[2];
NDArray E = args[3];
NDArray V = args[4];
Edge_softmax_forward(op, graph.sptr(), U, E, V);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_backward")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
NDArray out = args[2];
NDArray sds = args[3];
NDArray back_out = args[4];
NDArray ufeat = args[5];
Edge_softmax_backward(op, graph.sptr(), out, sds, back_out, ufeat);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMMHetero") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMMHetero")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
......
...@@ -183,6 +183,26 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -183,6 +183,26 @@ std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A, const std::vector<CSRMatrix>& A,
const std::vector<NDArray>& A_weights); const std::vector<NDArray>& A_weights);
/*!
* \brief Edge_softmax_csr forward function on Csr format.
*/
template <int XPU, typename IdType, int bits>
void Edge_softmax_csr_forward(const std::string& op,
const BcastOff& bcast,
const aten::CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out);
/*!
* \brief Edge_softmax_csr backward function on Csr format.
*/
template <int XPU, typename IdType, int bits>
void Edge_softmax_csr_backward(const std::string& op,
const BcastOff& bcast,
const aten::CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment