Unverified Commit b3d3a2c4 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Feature] Gather mm (#3641)



* init

* init

* working cublasGemm

* benchmark high-mem/low-mem, err gather_mm output

* cuda kernel for bmm like kernel

* removed cpu copy for E_per_Rel

* benchmark code from Minjie

* fixed cublas results in gathermm sorted

* use GPU shared mem in unsorted gather mm

* minor

* Added an optimal version of gather_mm_unsorted

* lint

* init gather_mm_scatter

* cublas transpose added

* fixed h_offset for multiple rel

* backward unittest

* cublas support to transpose W

* adding missed file

* forgot to add header file

* lint

* lint

* cleanup

* lint

* docstring

* lint

* added unittest

* lint

* lint

* unittest

* changed err type

* skip cpu test

* skip CPU code

* move in-len loop inside

* lint

* added check different dim length for B

* w_per_len is optional now

* moved gather_mm to pytorch/backend with backward support

* removed a_/b_trans support

* transpose op inside GEMM call

* removed out alloc from API, changed W 2D to 3D

* Added se_gather_mm, Separate API for sortedE

* Fixed gather_mm (unsorted) user interface

* unsorted gmm backward + separate CAPI for un/sorted A

* typecast to float to support atomicAdd

* lint typecast

* lint

* added gather_mm_scatter

* minor

* const

* design changes

* Added idx_a, idx_b support gmm_scatter

* dgl doc

* lint

* adding gather_mm in ops

* lint

* lint

* minor

* removed benchmark files

* minor

* empty commit
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
parent ab50eb9b
...@@ -246,6 +246,19 @@ DGL provide operators to reduce value tensor along the first dimension by segmen ...@@ -246,6 +246,19 @@ DGL provide operators to reduce value tensor along the first dimension by segmen
segment_reduce segment_reduce
GatherMM and SegmentMM Module
-----------------------------
SegmentMM: DGL provide operators to perform matrix multiplication according to segments.
GatherMM: DGL provide operators to gather data according to the given indices and perform matrix multiplication.
.. autosummary::
:toctree: ../../generated/
gather_mm
segment_mm
Supported Data types Supported Data types
-------------------- --------------------
Operators defined in ``dgl.ops`` support floating point data types, i.e. the operands Operators defined in ``dgl.ops`` support floating point data types, i.e. the operands
......
...@@ -1827,6 +1827,51 @@ def csrmask(A, A_weights, B): ...@@ -1827,6 +1827,51 @@ def csrmask(A, A_weights, B):
""" """
pass pass
def gather_mm(A, B, idx_a, idx_b):
r""" Dense Matrix Multiplication interface. It multiplies 2D dense tensor A
and 3D dense tensor B according to their relation types. A is unsorted and
the relation type is fetched from idx_b.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
3-D tensor of shape (R, D1, D2)
idx_a : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
pass
def segment_mm(A, B, seglen_A):
r""" Dense Matrix Multiplication interface. It multiplies dense tensor A
and dense tensor B according to relation types. A is sorted and concatenated
according to relation types.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
3-D tensor of shape (R, D1, D2)
seglen_A : Tensor
An integer tensor of shape (R,). Each element is the length of segments
of input ``A``. The summation of all elements must be equal to N.
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
pass
############################################################################### ###############################################################################
# Other interfaces # Other interfaces
......
import torch as th import torch as th
from distutils.version import LooseVersion from distutils.version import LooseVersion
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero, _gather_mm, _gather_mm_scatter, _segment_mm
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _edge_softmax_forward, _edge_softmax_backward from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _edge_softmax_forward, _edge_softmax_backward
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...heterograph_index import create_unitgraph_from_csr from ...heterograph_index import create_unitgraph_from_csr
...@@ -27,7 +29,7 @@ else: ...@@ -27,7 +29,7 @@ else:
return decorate_bwd return decorate_bwd
__all__ = ['gspmm', 'gsddmm', 'gspmm_hetero', 'gsddmm_hetero', 'edge_softmax', 'edge_softmax_hetero', __all__ = ['gspmm', 'gsddmm', 'gspmm_hetero', 'gsddmm_hetero', 'edge_softmax', 'edge_softmax_hetero',
'segment_reduce', 'scatter_add', 'csrmm', 'csrsum', 'csrmask'] 'segment_reduce', 'scatter_add', 'csrmm', 'csrsum', 'csrmask', 'gather_mm', 'segment_mm']
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
...@@ -691,6 +693,70 @@ class CSRMask(th.autograd.Function): ...@@ -691,6 +693,70 @@ class CSRMask(th.autograd.Function):
return None, csrmask(gidxB, dB_weights, gidxA), None return None, csrmask(gidxB, dB_weights, gidxA), None
class SEGMENTMM(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, seglen_A):
if A.shape[0] != th.sum(seglen_A):
raise Exception("The summation of the elements of seglen_A must be equal to " +
"dimension 0 of A. Expected "+ str(A.shape[0]) + "got" + str(th.sum(seglen_A)))
if B.dim() != 3:
raise Exception("Expected dimension of B is 3. Got " + str(B.dim()))
# Reshaping B form 3D to 2D
B_3D_shape = B.shape
B = B.reshape(B.shape[0] * B.shape[1], B.shape[2])
C = th.zeros((A.shape[0], B.shape[1]), device=A.device, dtype=A.dtype)
C = _segment_mm(A, B, C, seglen_A)
ctx.backward_cache = A, B, seglen_A, B_3D_shape
return C
@staticmethod
def backward(ctx, dZ):
A, B, seglen_A, B_3D_shape = ctx.backward_cache
A_grad = B_grad = None
if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T
A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype)
A_grad = _segment_mm(dZ, B, A_grad, seglen_A, b_trans=True)
if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = _segment_mm(A, dZ, B_grad, seglen_A, a_trans=True)
B_grad = B_grad.reshape(B_3D_shape[0], B_3D_shape[1], B_3D_shape[2])
return A_grad, B_grad, None, None, None, None, None, None
class GATHERMM(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, idx_a, idx_b):
if B.dim() != 3:
raise Exception("Expected dimension of B is 3. Got " + str(B.dim()))
# Reshaping B form 3D to 2D
B_3D_shape = B.shape
B = B.reshape(B.shape[0] * B.shape[1], B.shape[2])
C = th.zeros((A.shape[0], B.shape[1]), device=A.device, dtype=A.dtype)
C = _gather_mm(A, B, C, B_3D_shape[0], idx_a, idx_b)
ctx.backward_cache = A, B, idx_a, idx_b, B_3D_shape
return C
@staticmethod
def backward(ctx, dZ):
A, B, idx_a, idx_b, B_3D_shape = ctx.backward_cache
A_grad = B_grad = None
if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T
A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype)
A_grad = _gather_mm_scatter(dZ, B, A_grad, B_3D_shape[0],
idx_b=idx_b, idx_c=idx_a, b_trans=True)
if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = _gather_mm_scatter(A, dZ, B_grad, B_3D_shape[0],
idx_a=idx_a, idx_c=idx_b)
B_grad = B_grad.reshape(B_3D_shape[0], B_3D_shape[1], B_3D_shape[2])
return A_grad, B_grad, None, None, None, None, None, None
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
if op == 'sub': if op == 'sub':
op = 'add' op = 'add'
...@@ -766,3 +832,9 @@ def csrsum(gidxs, weights): ...@@ -766,3 +832,9 @@ def csrsum(gidxs, weights):
def csrmask(gidxA, A_weights, gidxB): def csrmask(gidxA, A_weights, gidxB):
return CSRMask.apply(gidxA, A_weights, gidxB) return CSRMask.apply(gidxA, A_weights, gidxB)
def segment_mm(A, B, seglen_A):
return SEGMENTMM.apply(A, B, seglen_A)
def gather_mm(A, B, idx_a = None, idx_b = None):
return GATHERMM.apply(A, B, idx_a, idx_b)
...@@ -3,3 +3,4 @@ from .spmm import * ...@@ -3,3 +3,4 @@ from .spmm import *
from .sddmm import * from .sddmm import *
from .edge_softmax import * from .edge_softmax import *
from .segment import * from .segment import *
from .gather_mm import *
"""dgl gather_mm operator module."""
from ..backend import gather_mm as gather_mm_internal
from ..backend import segment_mm as segment_mm_internal
__all__ = ['gather_mm', 'segment_mm']
def segment_mm(lhs_data, rhs_data, seglen_lhs):
r""" Performs matrix multiplication according to segments.
Suppose ``seglen_lhs == [10, 5, 0, 3]``, the operator will perform
four matrix multiplications:
lhs_data[0:10] @ rhs_data[0], lhs_data[10:15] @ rhs_data[1],
lhs_data[15:15] @ rhs_data[2], lhs_data[15:18] @ rhs_data[3]
Parameters
----------
lhs_data : tensor
The left operand, 2-D tensor of shape (N, D1)
rhs_data : tensor
The right operand, 2-D tensor of shape (R * D1, D2)
seglen_lhs : tensor
An integer tensor of shape (R,). Each element is the length of segments
of input ``lhs_data``. The summation of all elements must be equal to N.
Returns
-------
tensor
The output dense matrix of shape (N, D2)
"""
return segment_mm_internal(lhs_data, rhs_data, seglen_lhs)
def gather_mm(lhs_data, rhs_data, idx_lhs = None, idx_rhs = None):
r"""Gather data according to the given indices and perform matrix multiplication.
Let the result tensor be C, the operator conducts the following computation:
If both idx_lhs and idx_rhs are not none:
c[i] = lhs_data[idx_lhs[i]] @ rhs_data[idx_rhs[i]]
, where len(C) == len(idx_lhs) == len(idx_rhs)
If idx_lhs is given but not idx_rhs:
c[i] = rhs_data[idx_lhs[i]] @ rhs_data[i]
, where len(C) == len(idx_lhs)
If idx_rhs is given but not idx_lhs:
c[i] = lhs_data[i] @ rhs_data[idx_rhs[i]]
, where len(C) == len(idx_rhs)
Parameters
----------
lhs_data : tensor
2-D tensor of shape (N, D1)
rhs_data : tensor
3-D tensor of shape (R, D1, D2)
idx_lhs : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
idx_rhs : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
return gather_mm_internal(lhs_data, rhs_data, idx_lhs, idx_rhs)
...@@ -389,6 +389,111 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): ...@@ -389,6 +389,111 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
return out, (list_arg_u, list_arg_e, list_arg_u_ntype, list_arg_e_etype) return out, (list_arg_u, list_arg_e, list_arg_u_ntype, list_arg_e_etype)
def _segment_mm(A, B, out, seglen_A, a_trans=False, b_trans=False):
r""" Dense Matrix Multiplication interface. It multiplies dense tensor A
and dense tensor B according to relation types. A is sorted and concatenated
according to relation types.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
seglen_A : Tensor
An integer tensor of shape (R,). Each element is the length of segments
of input ``A``. The summation of all elements must be equal to N.
a_trans : bool
Indicates whether matrix A needs to be tranposed
b_trans : bool
Indicates whether matrix B needs to be tranposed
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
_CAPI_DGLKernelSEGMENTMM(to_dgl_nd(A),
to_dgl_nd(B),
to_dgl_nd_for_write(out),
to_dgl_nd(seglen_A),
a_trans, b_trans)
return out
def _gather_mm(A, B, out, num_rel, idx_a=None, idx_b=None):
r""" Generalized Dense Matrix Multiplication interface. It multiplies
tensor A and B according to relation types and outputs in out. B is a
concatenated tensor across relation types. A is unsorted and the
relation type is fetched from param etypes.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
idx_a : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,)
idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
_CAPI_DGLKernelGATHERMM(to_dgl_nd(A),
to_dgl_nd(B),
to_dgl_nd_for_write(out),
to_dgl_nd(idx_a),
to_dgl_nd(idx_b),
num_rel)
return out
def _gather_mm_scatter(A, B, out, num_rel, idx_a=None, idx_b=None, idx_c=None,
a_trans=False, b_trans=False):
r""" Generalized Dense Matrix Multiplication interface. It multiplies
tensor A and B according to relation types and outputs in out. B is a
concatenated tensor across relation types. A is unsorted and the
relation type is fetched from param etypes.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
idx_a : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,)
idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
idx_c : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
A_trans : bool
Indicates whether matrix A needs to be tranposed
B_trans : bool
Indicates whether matrix B needs to be tranposed
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
_CAPI_DGLKernelGATHERMMSCATTER(to_dgl_nd(A),
to_dgl_nd(B),
to_dgl_nd_for_write(out),
to_dgl_nd(idx_a),
to_dgl_nd(idx_b),
to_dgl_nd(idx_c),
num_rel, a_trans, b_trans)
return out
def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. It r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. It
takes the result of :attr:`op` on source node feature and destination node takes the result of :attr:`op` on source node feature and destination node
......
/*!
* Copyright (c) 2020 by Contributors
* \file kernel/cpu/gaher_mm.cc
* \brief GatherMM C APIs and definitions.
*/
#include "./gather_mm.h"
#include <dgl/array.h>
namespace dgl {
namespace aten {
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16 || (bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
/*! \brief Generalized segmentMM. */
template <int XPU, typename IdType, int bits>
void segmentMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray seglen_A,
bool a_trans, bool b_trans) {
SWITCH_BITS(bits, DType, {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
});
}
/*! \brief Generalized GatherMM. */
template <int XPU, typename IdType, int bits>
void gatherMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const int num_rel) {
SWITCH_BITS(bits, DType, {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
});
}
/*! \brief Generalized GatherMM_scatter. */
template <int XPU, typename IdType, int bits>
void gatherMM_scatter(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c,
const int num_rel,
bool a_trans, bool b_trans) {
SWITCH_BITS(bits, DType, {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
});
}
template void gatherMM<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM_scatter<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2022 by Contributors
* \file array/cpu/gather_mm.h
* \brief GATHER_MM CPU kernel function header.
*/
#ifndef DGL_ARRAY_CPU_GATHER_MM_H_
#define DGL_ARRAY_CPU_GATHER_MM_H_
#include <dgl/array.h>
#include <dgl/bcast.h>
#include <utility>
namespace dgl {
namespace aten {
namespace cpu {
template <typename DType>
void transpose(const DType *in, DType *out, const int N, const int M) {
#pragma omp parallel for
for (int n = 0; n < N * M; n++) {
int i = n / N;
int j = n % N;
out[n] = in[M * j + i];
}
}
template <typename DType>
void matmul(const DType *A, const DType *B,
DType *C, const int M, const int N, const int K) {
#pragma omp parallel
{
int i, j, k;
#pragma omp for
for (i = 0; i < M; i++) {
for (j = 0; j < N; j++) {
DType local_accum = 0;
for (k = 0; k < K; k++) {
local_accum += A[i * K + k] * B[k * N + j];
}
C[i * N + j] = local_accum;
}
}
}
}
/*!
* \brief CPU kernel of Gather_mm. The input matrix A is expected to be
* sorted according to relation type.
* \param A The input dense matrix of dimension m x k
* \param B The input dense matrix of dimension k x n
* \param C The output dense matrix od dimension m x n
* \param A_dim1_per_rel The number of rows in each relation in A
* \param B_dim1_per_rel The number of rows in each relation in B
* \param a_trans Matrix A to be transposed
* \param b_trans Matrix B to be transposed
*/
template <int XPU, typename IdType, typename DType>
void gatherMM_SortedEtype(const NDArray A,
const NDArray B,
NDArray C,
const NDArray A_dim1_per_rel,
const NDArray B_dim1_per_rel,
bool a_trans, bool b_trans) {
assert(A_dim1_per_rel.NumElements() == B_dim1_per_rel.NumElements());
int64_t num_rel = A_dim1_per_rel.NumElements();
const DType *A_data = A.Ptr<DType>();
const DType *B_data = B.Ptr<DType>();
const IdType* A_rel_data = A_dim1_per_rel.Ptr<IdType>();
const IdType* B_rel_data = B_dim1_per_rel.Ptr<IdType>();
DType *C_data = C.Ptr<DType>();
int64_t A_offset = 0, B_offset = 0, C_offset = 0;
int64_t m, n, k, h_col, w_row;
for (int etype = 0; etype < num_rel; ++etype) {
assert((a_trans) ? A_rel_data[etype] : A->shape[1] == \
(b_trans) ? B->shape[1] : B_rel_data[etype]);
m = A_rel_data[etype]; // rows of A
n = B->shape[1]; // cols of B
k = B_rel_data[etype]; // rows of B == cols of A
NDArray A_trans, B_trans;
if (a_trans) {
A_trans = NDArray::Empty({m * k}, A->dtype, A->ctx);
transpose<DType>(A_data + A_offset, static_cast<DType *>(A_trans->data), m, k);
}
if (b_trans) {
B_trans = NDArray::Empty({k * n}, B->dtype, B->ctx);
transpose<DType>(B_data + B_offset, static_cast<DType *>(B_trans->data), k, n);
}
if (a_trans || b_trans) {
int64_t tmp = k;
if (a_trans)
std::swap(m, k);
if (b_trans) {
k = tmp;
std::swap(n, k);
}
}
matmul<DType>(
(a_trans) ? static_cast<DType *>(A_trans->data) : A_data + A_offset,
(b_trans) ? static_cast<DType *>(B_trans->data) : B_data + B_offset,
C_data + C_offset, m, n, k);
A_offset += m * k;
B_offset += k * n;
C_offset += m * n;
}
}
} // namespace cpu
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_CPU_GATHER_MM_H_
This diff is collapsed.
...@@ -52,6 +52,60 @@ void SpMM(const std::string& op, const std::string& reduce, ...@@ -52,6 +52,60 @@ void SpMM(const std::string& op, const std::string& reduce,
}); });
} }
/*! \brief Generalized segmented dense Matrix-Matrix Multiplication. */
void SegmentMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray seglen_A,
bool A_trans, bool B_trans) {
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(seglen_A->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", {
segmentMM<XPU, IdType, bits>(A, B, C, seglen_A, A_trans, B_trans);
});
});
});
}
/*! \brief Generalized Dense Matrix-Matrix Multiplication according to relation types. */
void GatherMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const int num_rel) {
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(idx_b->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", {
gatherMM<XPU, IdType, bits>(A, B, C, idx_a, idx_b, num_rel);
});
});
});
}
/*! \brief Generalized Dense Matrix-Matrix Multiplication according to relation types. */
void GatherMM_scatter(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c,
const int num_rel,
bool A_trans, bool B_trans) {
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(idx_b->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", {
gatherMM_scatter<XPU, IdType, bits>(A, B, C, idx_a, idx_b, idx_c,
num_rel, A_trans, B_trans);
});
});
});
}
/*! \brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph support. */ /*! \brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph support. */
void SpMMHetero(const std::string& op, const std::string& reduce, void SpMMHetero(const std::string& op, const std::string& reduce,
HeteroGraphPtr graph, HeteroGraphPtr graph,
...@@ -390,6 +444,42 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM") ...@@ -390,6 +444,42 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM")
SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE}); SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE});
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGATHERMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray B = args[1];
NDArray C = args[2];
NDArray idx_a = args[3];
NDArray idx_b = args[4];
int num_rel = args[5];
GatherMM(A, B, C, idx_a, idx_b, num_rel);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGATHERMMSCATTER")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray B = args[1];
NDArray C = args[2];
NDArray idx_a = args[3];
NDArray idx_b = args[4];
NDArray idx_c = args[5];
int num_rel = args[6];
bool A_trans = args[7];
bool B_trans = args[8];
GatherMM_scatter(A, B, C, idx_a, idx_b, idx_c, num_rel, A_trans, B_trans);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSEGMENTMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray B = args[1];
NDArray C = args[2];
NDArray seglen_A = args[3];
bool A_trans = args[4];
bool B_trans = args[5];
SegmentMM(A, B, C, seglen_A, A_trans, B_trans);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_forward") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_forward")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
......
...@@ -112,6 +112,39 @@ void SDDMMCooHetero(const std::string& op, ...@@ -112,6 +112,39 @@ void SDDMMCooHetero(const std::string& op,
const std::vector<dgl_type_t>& lhs_eid, const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid); const std::vector<dgl_type_t>& rhs_eid);
/*!
* \brief Generalized Dense Matrix-Matrix Multiplication according to relation types.
*/
template <int XPU, typename IdType, int bits>
void gatherMM(const NDArray A,
const NDArray B,
NDArray out,
const NDArray idx_a,
const NDArray idx_b,
const int num_rel);
/*!
* \brief Generalized Dense Matrix-Matrix Multiplication according to relation types.
*/
template <int XPU, typename IdType, int bits>
void gatherMM_scatter(const NDArray A,
const NDArray B,
NDArray out,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
/*!
* \brief Generalized segmented dense Matrix-Matrix Multiplication.
*/
template <int XPU, typename IdType, int bits>
void segmentMM(const NDArray A,
const NDArray B,
NDArray out,
const NDArray seglen_A,
bool a_trans, bool b_trans);
/*! /*!
* \brief Segment reduce. * \brief Segment reduce.
*/ */
......
from timeit import default_timer
import dgl
import backend as F
import dgl.function as fn
import time
import numpy as np
import unittest, pytest
from test_utils import parametrize_dtype, get_cases
iters = 5
n_edge_scale = 1
num_rel_scale = 1
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@unittest.skipIf(F._default_context_str == 'cpu', reason="Not implemented.")
@parametrize_dtype
def test_gathermm(idtype):
def _test(feat_scale):
in_feat = 16 * feat_scale
out_feat = 8 * feat_scale
print("in/out feat", in_feat, out_feat)
E_per_rel = F.copy_to(F.tensor([50, 100, 20, 284, 89, 10, 82, 9200, 10, 20, 30, 100,
128, 20, 284, 89, 10, 82, 92, 10, 20, 30, 100, 1280, 20, 284, 89, 1000, 82,
92, 10, 2000, 30, 100, 128, 20, 284, 89, 10, 82, 92, 10, 20, 30]), F.cpu())
E_per_rel *= n_edge_scale
num_rel = len(E_per_rel)
print('num_rel', num_rel)
W_per_len = F.copy_to(F.full((num_rel,) ,in_feat, dtype=F.dtype(E_per_rel)), F.cpu())
H_arr = []
W_arr = []
Out_arr = []
Out_grad_arr = []
for eid in range(num_rel):
H_arr.append(F.randn((E_per_rel[eid], in_feat)))
W_arr.append(F.randn((in_feat, out_feat)))
Out_arr.append(F.zeros((E_per_rel[eid], out_feat)))
Out_grad_arr.append(F.ones((E_per_rel[eid], out_feat)))
H = F.cat([h for h in H_arr], 0)
W = F.cat([w for w in W_arr], 0)
W_3D = W.reshape(num_rel, in_feat, out_feat)
Out = F.cat([out for out in Out_arr], 0)
Out_grad = F.cat([o for o in Out_grad_arr], 0)
print('H.shape', H.shape)
print('W.shape', W.shape)
print('W_3D.shape', W_3D.shape)
print('Out.shape', Out.shape)
etype_arr = []
for eid in range(num_rel):
etype_arr.append(F.full((E_per_rel[eid],), eid, dtype=F.dtype(E_per_rel)))
etypes = F.cat([etype for etype in etype_arr], 0)
#################################################################
# low-mem version using PyTorch operator
#################################################################
# forward pass
out = []
for i in range(len(E_per_rel)):
Hi = H_arr[i]
Wi = W_arr[i]
out.append(F.matmul(Hi, Wi))
out_low_mem = F.cat(out, 0)
# backward pass
H_grad = []
W_grad = []
for i in range(len(E_per_rel)):
Hi = H_arr[i]
Wi = W_arr[i]
Out_gradi = Out_grad_arr[i]
H_grad.append(F.matmul(Out_gradi, Wi.transpose(0,1)))
W_grad.append(F.matmul(Hi.transpose(0,1), Out_gradi))
Hgrad_low_mem = F.cat(H_grad, 0)
Wgrad_low_mem = F.cat(W_grad, 0)
Wgrad_low_mem = Wgrad_low_mem.reshape(num_rel, in_feat, out_feat)
#################################################################
# gather_mm where H sorted according to etype
#################################################################
seglen_A = E_per_rel
F.attach_grad(H)
F.attach_grad(W_3D)
with F.record_grad():
out_gmm_sorted = dgl.ops.segment_mm(H, W_3D, seglen_A)
F.backward(F.reduce_sum(out_gmm_sorted))
Hgrad_gmm_sorted = H.grad
Wgrad_gmm_sorted = W_3D.grad
#################################################################
# gather_mm where H is not sorted (backward not supported yet)
#################################################################
F.attach_grad(H)
F.attach_grad(W_3D)
with F.record_grad():
out_gmm_unsorted = dgl.ops.gather_mm(H, W_3D, idx_rhs=etypes)
F.backward(F.reduce_sum(out_gmm_unsorted))
Hgrad_gmm_unsorted = H.grad
Wgrad_gmm_unsorted = W_3D.grad
# correctness check
assert F.allclose(out_low_mem, out_gmm_sorted, atol=1e-3, rtol=1e-3)
assert F.allclose(Hgrad_low_mem, Hgrad_gmm_sorted, atol=1e-3, rtol=1e-3)
assert F.allclose(Wgrad_low_mem, Wgrad_gmm_sorted, atol=1e-3, rtol=1e-3)
assert F.allclose(out_low_mem, out_gmm_unsorted, atol=1e-3, rtol=1e-3)
assert F.allclose(Hgrad_low_mem, Hgrad_gmm_unsorted, atol=1e-3, rtol=1e-3)
assert F.allclose(Wgrad_low_mem, Wgrad_gmm_unsorted, atol=1e-3, rtol=1e-3)
_test(1)
_test(4)
_test(16)
_test(32)
if __name__ == '__main__':
test_gathermm()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment