Unverified Commit b3d3a2c4 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Feature] Gather mm (#3641)



* init

* init

* working cublasGemm

* benchmark high-mem/low-mem, err gather_mm output

* cuda kernel for bmm like kernel

* removed cpu copy for E_per_Rel

* benchmark code from Minjie

* fixed cublas results in gathermm sorted

* use GPU shared mem in unsorted gather mm

* minor

* Added an optimal version of gather_mm_unsorted

* lint

* init gather_mm_scatter

* cublas transpose added

* fixed h_offset for multiple rel

* backward unittest

* cublas support to transpose W

* adding missed file

* forgot to add header file

* lint

* lint

* cleanup

* lint

* docstring

* lint

* added unittest

* lint

* lint

* unittest

* changed err type

* skip cpu test

* skip CPU code

* move in-len loop inside

* lint

* added check different dim length for B

* w_per_len is optional now

* moved gather_mm to pytorch/backend with backward support

* removed a_/b_trans support

* transpose op inside GEMM call

* removed out alloc from API, changed W 2D to 3D

* Added se_gather_mm, Separate API for sortedE

* Fixed gather_mm (unsorted) user interface

* unsorted gmm backward + separate CAPI for un/sorted A

* typecast to float to support atomicAdd

* lint typecast

* lint

* added gather_mm_scatter

* minor

* const

* design changes

* Added idx_a, idx_b support gmm_scatter

* dgl doc

* lint

* adding gather_mm in ops

* lint

* lint

* minor

* removed benchmark files

* minor

* empty commit
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
parent ab50eb9b
...@@ -246,6 +246,19 @@ DGL provide operators to reduce value tensor along the first dimension by segmen ...@@ -246,6 +246,19 @@ DGL provide operators to reduce value tensor along the first dimension by segmen
segment_reduce segment_reduce
GatherMM and SegmentMM Module
-----------------------------
SegmentMM: DGL provide operators to perform matrix multiplication according to segments.
GatherMM: DGL provide operators to gather data according to the given indices and perform matrix multiplication.
.. autosummary::
:toctree: ../../generated/
gather_mm
segment_mm
Supported Data types Supported Data types
-------------------- --------------------
Operators defined in ``dgl.ops`` support floating point data types, i.e. the operands Operators defined in ``dgl.ops`` support floating point data types, i.e. the operands
......
...@@ -1827,6 +1827,51 @@ def csrmask(A, A_weights, B): ...@@ -1827,6 +1827,51 @@ def csrmask(A, A_weights, B):
""" """
pass pass
def gather_mm(A, B, idx_a, idx_b):
r""" Dense Matrix Multiplication interface. It multiplies 2D dense tensor A
and 3D dense tensor B according to their relation types. A is unsorted and
the relation type is fetched from idx_b.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
3-D tensor of shape (R, D1, D2)
idx_a : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
pass
def segment_mm(A, B, seglen_A):
r""" Dense Matrix Multiplication interface. It multiplies dense tensor A
and dense tensor B according to relation types. A is sorted and concatenated
according to relation types.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
3-D tensor of shape (R, D1, D2)
seglen_A : Tensor
An integer tensor of shape (R,). Each element is the length of segments
of input ``A``. The summation of all elements must be equal to N.
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
pass
############################################################################### ###############################################################################
# Other interfaces # Other interfaces
......
import torch as th import torch as th
from distutils.version import LooseVersion from distutils.version import LooseVersion
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero, _gather_mm, _gather_mm_scatter, _segment_mm
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _edge_softmax_forward, _edge_softmax_backward from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _edge_softmax_forward, _edge_softmax_backward
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...heterograph_index import create_unitgraph_from_csr from ...heterograph_index import create_unitgraph_from_csr
...@@ -27,7 +29,7 @@ else: ...@@ -27,7 +29,7 @@ else:
return decorate_bwd return decorate_bwd
__all__ = ['gspmm', 'gsddmm', 'gspmm_hetero', 'gsddmm_hetero', 'edge_softmax', 'edge_softmax_hetero', __all__ = ['gspmm', 'gsddmm', 'gspmm_hetero', 'gsddmm_hetero', 'edge_softmax', 'edge_softmax_hetero',
'segment_reduce', 'scatter_add', 'csrmm', 'csrsum', 'csrmask'] 'segment_reduce', 'scatter_add', 'csrmm', 'csrsum', 'csrmask', 'gather_mm', 'segment_mm']
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
...@@ -691,6 +693,70 @@ class CSRMask(th.autograd.Function): ...@@ -691,6 +693,70 @@ class CSRMask(th.autograd.Function):
return None, csrmask(gidxB, dB_weights, gidxA), None return None, csrmask(gidxB, dB_weights, gidxA), None
class SEGMENTMM(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, seglen_A):
if A.shape[0] != th.sum(seglen_A):
raise Exception("The summation of the elements of seglen_A must be equal to " +
"dimension 0 of A. Expected "+ str(A.shape[0]) + "got" + str(th.sum(seglen_A)))
if B.dim() != 3:
raise Exception("Expected dimension of B is 3. Got " + str(B.dim()))
# Reshaping B form 3D to 2D
B_3D_shape = B.shape
B = B.reshape(B.shape[0] * B.shape[1], B.shape[2])
C = th.zeros((A.shape[0], B.shape[1]), device=A.device, dtype=A.dtype)
C = _segment_mm(A, B, C, seglen_A)
ctx.backward_cache = A, B, seglen_A, B_3D_shape
return C
@staticmethod
def backward(ctx, dZ):
A, B, seglen_A, B_3D_shape = ctx.backward_cache
A_grad = B_grad = None
if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T
A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype)
A_grad = _segment_mm(dZ, B, A_grad, seglen_A, b_trans=True)
if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = _segment_mm(A, dZ, B_grad, seglen_A, a_trans=True)
B_grad = B_grad.reshape(B_3D_shape[0], B_3D_shape[1], B_3D_shape[2])
return A_grad, B_grad, None, None, None, None, None, None
class GATHERMM(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, idx_a, idx_b):
if B.dim() != 3:
raise Exception("Expected dimension of B is 3. Got " + str(B.dim()))
# Reshaping B form 3D to 2D
B_3D_shape = B.shape
B = B.reshape(B.shape[0] * B.shape[1], B.shape[2])
C = th.zeros((A.shape[0], B.shape[1]), device=A.device, dtype=A.dtype)
C = _gather_mm(A, B, C, B_3D_shape[0], idx_a, idx_b)
ctx.backward_cache = A, B, idx_a, idx_b, B_3D_shape
return C
@staticmethod
def backward(ctx, dZ):
A, B, idx_a, idx_b, B_3D_shape = ctx.backward_cache
A_grad = B_grad = None
if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T
A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype)
A_grad = _gather_mm_scatter(dZ, B, A_grad, B_3D_shape[0],
idx_b=idx_b, idx_c=idx_a, b_trans=True)
if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = _gather_mm_scatter(A, dZ, B_grad, B_3D_shape[0],
idx_a=idx_a, idx_c=idx_b)
B_grad = B_grad.reshape(B_3D_shape[0], B_3D_shape[1], B_3D_shape[2])
return A_grad, B_grad, None, None, None, None, None, None
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
if op == 'sub': if op == 'sub':
op = 'add' op = 'add'
...@@ -766,3 +832,9 @@ def csrsum(gidxs, weights): ...@@ -766,3 +832,9 @@ def csrsum(gidxs, weights):
def csrmask(gidxA, A_weights, gidxB): def csrmask(gidxA, A_weights, gidxB):
return CSRMask.apply(gidxA, A_weights, gidxB) return CSRMask.apply(gidxA, A_weights, gidxB)
def segment_mm(A, B, seglen_A):
return SEGMENTMM.apply(A, B, seglen_A)
def gather_mm(A, B, idx_a = None, idx_b = None):
return GATHERMM.apply(A, B, idx_a, idx_b)
...@@ -3,3 +3,4 @@ from .spmm import * ...@@ -3,3 +3,4 @@ from .spmm import *
from .sddmm import * from .sddmm import *
from .edge_softmax import * from .edge_softmax import *
from .segment import * from .segment import *
from .gather_mm import *
"""dgl gather_mm operator module."""
from ..backend import gather_mm as gather_mm_internal
from ..backend import segment_mm as segment_mm_internal
__all__ = ['gather_mm', 'segment_mm']
def segment_mm(lhs_data, rhs_data, seglen_lhs):
r""" Performs matrix multiplication according to segments.
Suppose ``seglen_lhs == [10, 5, 0, 3]``, the operator will perform
four matrix multiplications:
lhs_data[0:10] @ rhs_data[0], lhs_data[10:15] @ rhs_data[1],
lhs_data[15:15] @ rhs_data[2], lhs_data[15:18] @ rhs_data[3]
Parameters
----------
lhs_data : tensor
The left operand, 2-D tensor of shape (N, D1)
rhs_data : tensor
The right operand, 2-D tensor of shape (R * D1, D2)
seglen_lhs : tensor
An integer tensor of shape (R,). Each element is the length of segments
of input ``lhs_data``. The summation of all elements must be equal to N.
Returns
-------
tensor
The output dense matrix of shape (N, D2)
"""
return segment_mm_internal(lhs_data, rhs_data, seglen_lhs)
def gather_mm(lhs_data, rhs_data, idx_lhs = None, idx_rhs = None):
r"""Gather data according to the given indices and perform matrix multiplication.
Let the result tensor be C, the operator conducts the following computation:
If both idx_lhs and idx_rhs are not none:
c[i] = lhs_data[idx_lhs[i]] @ rhs_data[idx_rhs[i]]
, where len(C) == len(idx_lhs) == len(idx_rhs)
If idx_lhs is given but not idx_rhs:
c[i] = rhs_data[idx_lhs[i]] @ rhs_data[i]
, where len(C) == len(idx_lhs)
If idx_rhs is given but not idx_lhs:
c[i] = lhs_data[i] @ rhs_data[idx_rhs[i]]
, where len(C) == len(idx_rhs)
Parameters
----------
lhs_data : tensor
2-D tensor of shape (N, D1)
rhs_data : tensor
3-D tensor of shape (R, D1, D2)
idx_lhs : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
idx_rhs : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
return gather_mm_internal(lhs_data, rhs_data, idx_lhs, idx_rhs)
...@@ -389,6 +389,111 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): ...@@ -389,6 +389,111 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
return out, (list_arg_u, list_arg_e, list_arg_u_ntype, list_arg_e_etype) return out, (list_arg_u, list_arg_e, list_arg_u_ntype, list_arg_e_etype)
def _segment_mm(A, B, out, seglen_A, a_trans=False, b_trans=False):
r""" Dense Matrix Multiplication interface. It multiplies dense tensor A
and dense tensor B according to relation types. A is sorted and concatenated
according to relation types.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
seglen_A : Tensor
An integer tensor of shape (R,). Each element is the length of segments
of input ``A``. The summation of all elements must be equal to N.
a_trans : bool
Indicates whether matrix A needs to be tranposed
b_trans : bool
Indicates whether matrix B needs to be tranposed
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
_CAPI_DGLKernelSEGMENTMM(to_dgl_nd(A),
to_dgl_nd(B),
to_dgl_nd_for_write(out),
to_dgl_nd(seglen_A),
a_trans, b_trans)
return out
def _gather_mm(A, B, out, num_rel, idx_a=None, idx_b=None):
r""" Generalized Dense Matrix Multiplication interface. It multiplies
tensor A and B according to relation types and outputs in out. B is a
concatenated tensor across relation types. A is unsorted and the
relation type is fetched from param etypes.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
idx_a : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,)
idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
_CAPI_DGLKernelGATHERMM(to_dgl_nd(A),
to_dgl_nd(B),
to_dgl_nd_for_write(out),
to_dgl_nd(idx_a),
to_dgl_nd(idx_b),
num_rel)
return out
def _gather_mm_scatter(A, B, out, num_rel, idx_a=None, idx_b=None, idx_c=None,
a_trans=False, b_trans=False):
r""" Generalized Dense Matrix Multiplication interface. It multiplies
tensor A and B according to relation types and outputs in out. B is a
concatenated tensor across relation types. A is unsorted and the
relation type is fetched from param etypes.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
idx_a : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,)
idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
idx_c : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
A_trans : bool
Indicates whether matrix A needs to be tranposed
B_trans : bool
Indicates whether matrix B needs to be tranposed
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
_CAPI_DGLKernelGATHERMMSCATTER(to_dgl_nd(A),
to_dgl_nd(B),
to_dgl_nd_for_write(out),
to_dgl_nd(idx_a),
to_dgl_nd(idx_b),
to_dgl_nd(idx_c),
num_rel, a_trans, b_trans)
return out
def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. It r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. It
takes the result of :attr:`op` on source node feature and destination node takes the result of :attr:`op` on source node feature and destination node
......
/*!
* Copyright (c) 2020 by Contributors
* \file kernel/cpu/gaher_mm.cc
* \brief GatherMM C APIs and definitions.
*/
#include "./gather_mm.h"
#include <dgl/array.h>
namespace dgl {
namespace aten {
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16 || (bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
/*! \brief Generalized segmentMM. */
template <int XPU, typename IdType, int bits>
void segmentMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray seglen_A,
bool a_trans, bool b_trans) {
SWITCH_BITS(bits, DType, {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
});
}
/*! \brief Generalized GatherMM. */
template <int XPU, typename IdType, int bits>
void gatherMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const int num_rel) {
SWITCH_BITS(bits, DType, {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
});
}
/*! \brief Generalized GatherMM_scatter. */
template <int XPU, typename IdType, int bits>
void gatherMM_scatter(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c,
const int num_rel,
bool a_trans, bool b_trans) {
SWITCH_BITS(bits, DType, {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
});
}
template void gatherMM<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM_scatter<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2022 by Contributors
* \file array/cpu/gather_mm.h
* \brief GATHER_MM CPU kernel function header.
*/
#ifndef DGL_ARRAY_CPU_GATHER_MM_H_
#define DGL_ARRAY_CPU_GATHER_MM_H_
#include <dgl/array.h>
#include <dgl/bcast.h>
#include <utility>
namespace dgl {
namespace aten {
namespace cpu {
template <typename DType>
void transpose(const DType *in, DType *out, const int N, const int M) {
#pragma omp parallel for
for (int n = 0; n < N * M; n++) {
int i = n / N;
int j = n % N;
out[n] = in[M * j + i];
}
}
template <typename DType>
void matmul(const DType *A, const DType *B,
DType *C, const int M, const int N, const int K) {
#pragma omp parallel
{
int i, j, k;
#pragma omp for
for (i = 0; i < M; i++) {
for (j = 0; j < N; j++) {
DType local_accum = 0;
for (k = 0; k < K; k++) {
local_accum += A[i * K + k] * B[k * N + j];
}
C[i * N + j] = local_accum;
}
}
}
}
/*!
* \brief CPU kernel of Gather_mm. The input matrix A is expected to be
* sorted according to relation type.
* \param A The input dense matrix of dimension m x k
* \param B The input dense matrix of dimension k x n
* \param C The output dense matrix od dimension m x n
* \param A_dim1_per_rel The number of rows in each relation in A
* \param B_dim1_per_rel The number of rows in each relation in B
* \param a_trans Matrix A to be transposed
* \param b_trans Matrix B to be transposed
*/
template <int XPU, typename IdType, typename DType>
void gatherMM_SortedEtype(const NDArray A,
const NDArray B,
NDArray C,
const NDArray A_dim1_per_rel,
const NDArray B_dim1_per_rel,
bool a_trans, bool b_trans) {
assert(A_dim1_per_rel.NumElements() == B_dim1_per_rel.NumElements());
int64_t num_rel = A_dim1_per_rel.NumElements();
const DType *A_data = A.Ptr<DType>();
const DType *B_data = B.Ptr<DType>();
const IdType* A_rel_data = A_dim1_per_rel.Ptr<IdType>();
const IdType* B_rel_data = B_dim1_per_rel.Ptr<IdType>();
DType *C_data = C.Ptr<DType>();
int64_t A_offset = 0, B_offset = 0, C_offset = 0;
int64_t m, n, k, h_col, w_row;
for (int etype = 0; etype < num_rel; ++etype) {
assert((a_trans) ? A_rel_data[etype] : A->shape[1] == \
(b_trans) ? B->shape[1] : B_rel_data[etype]);
m = A_rel_data[etype]; // rows of A
n = B->shape[1]; // cols of B
k = B_rel_data[etype]; // rows of B == cols of A
NDArray A_trans, B_trans;
if (a_trans) {
A_trans = NDArray::Empty({m * k}, A->dtype, A->ctx);
transpose<DType>(A_data + A_offset, static_cast<DType *>(A_trans->data), m, k);
}
if (b_trans) {
B_trans = NDArray::Empty({k * n}, B->dtype, B->ctx);
transpose<DType>(B_data + B_offset, static_cast<DType *>(B_trans->data), k, n);
}
if (a_trans || b_trans) {
int64_t tmp = k;
if (a_trans)
std::swap(m, k);
if (b_trans) {
k = tmp;
std::swap(n, k);
}
}
matmul<DType>(
(a_trans) ? static_cast<DType *>(A_trans->data) : A_data + A_offset,
(b_trans) ? static_cast<DType *>(B_trans->data) : B_data + B_offset,
C_data + C_offset, m, n, k);
A_offset += m * k;
B_offset += k * n;
C_offset += m * n;
}
}
} // namespace cpu
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_CPU_GATHER_MM_H_
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/gather_mm.cu
* \brief GatherMM C APIs and definitions.
*/
#include <dgl/array.h>
#include <algorithm> // std::swap
#include "./utils.h"
#include "./functor.cuh"
#include "./atomic.cuh"
namespace dgl {
using namespace cuda;
namespace aten {
namespace {
/*! \brief Call cuBLAS geam API for transpose operation for float and double. */
template <typename DType>
cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const DType* alpha, const DType* A, int lda,
const DType* beta, const DType* B, int ldb,
DType* C, int ldc) {
LOG(INFO) << "Not supported dtype";
return CUBLAS_STATUS_EXECUTION_FAILED;
}
template <>
cublasStatus_t Xgeam<float>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const float* alpha, const float* A, int lda,
const float* beta, const float* B, int ldb,
float* C, int ldc) {
return cublasSgeam(handle, transa, transb, m, n, alpha, A, lda,
beta, B, ldb, C, ldc);
}
template <>
cublasStatus_t Xgeam<double>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const double* alpha, const double* A, int lda,
const double* beta, const double* B, int ldb,
double* C, int ldc) {
return cublasDgeam(handle, transa, transb, m, n, alpha, A, lda,
beta, B, ldb, C, ldc);
}
/*! \brief Call cuBLAS GEMM API for dense matmul operation for float and double. */
template <typename DType>
cublasStatus_t cublasGemm(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const DType* alpha, const DType* A, int lda,
const DType* B, int ldb, const DType* beta,
DType* C, int ldc) {
LOG(INFO) << "Not supported dtype";
return CUBLAS_STATUS_EXECUTION_FAILED;
}
template <>
cublasStatus_t cublasGemm<float>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float* alpha, const float* A, int lda,
const float* B, int ldb, const float* beta,
float* C, int ldc) {
return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda,
B, ldb, beta, C, ldc);
}
template <>
cublasStatus_t cublasGemm<double>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const double* alpha, const double* A, int lda,
const double* B, int ldb, const double* beta,
double* C, int ldc) {
return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda,
B, ldb, beta, C, ldc);
}
/*
* \brief Tranpose the input matrix.
* \param row number of rows of input matrix.
* \param col number of columns of input matrix.
*/
template <typename DType>
void _Transpose(cublasHandle_t handle,
const DType* in, DType* out,
int row, int col) {
DType alpha = 1., beta = 0.;
CUBLAS_CALL(Xgeam<DType>(
handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
row, col,
&alpha, in, col,
&beta, nullptr, row,
out, row));
}
} // namespace
namespace cuda {
/* \Note Each row of A multiplies a segment of matrix of B of dimension in_len * outlen.
One warp is assigned to process one row of A. Each WARP sequentially multiplies
one element of A and a row of B to compute partial result of the output. A
is loaded in shared memory in a coalesced way. Output matrix is loaded in
registers. B should get benefit from L2 cache.
*/
template <typename Idx, typename DType>
__global__ void gatherMMKernel(
const DType* __restrict__ A,
const DType* __restrict__ B,
DType* __restrict__ C,
const Idx* __restrict__ idx_a,
const Idx* __restrict__ idx_b,
int64_t num_rows,
int64_t in_len, int64_t out_len) {
unsigned int tId = threadIdx.x;
unsigned int laneId = tId & 31;
unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
unsigned int warpId = gId >> 5;
unsigned int row = warpId;
if (row < num_rows) {
unsigned int local_row = row & 3; // hardcoded for TB size 128 (4 warps)
Idx cur_rowA = (idx_a) ? idx_a[row] : row;
Idx cur_rowB = (idx_b) ? idx_b[row] : row / in_len;
Idx B_offset = cur_rowB * in_len * out_len;
const int sh_a_tile = 64;
__shared__ DType sh_A[4 * sh_a_tile];
int a_tile = sh_a_tile;
for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
/* Load A in shared mem in a coalesced way */
for (unsigned int l = laneId; l < a_tile; l += 32)
sh_A[local_row * sh_a_tile + l] = A[cur_rowA * in_len + (k_start + l)];
__syncwarp();
for (unsigned int outloop = 0; outloop < out_len; outloop +=32) {
DType out_reg = 0; // thread private
const unsigned int l = laneId;
if (l < out_len) {
/* iterate over elements of a row of A */
for (unsigned int i = 0; i < a_tile; i++) {
const DType a_val = sh_A[local_row * sh_a_tile + i];
/* iterate over elements of a row of B in parallel */
out_reg += a_val * B[B_offset + ((i + k_start) * out_len + (outloop + l))];
}
C[row * out_len + (outloop + l)] += out_reg;
}
}
}
}
}
/* \Note Output matrix is accumulated via atomic operations. Rest of the strategies
are similar to gatherMMKernel. One warp is assigned to process one row of A. Each
WARP sequentially multiplies one element of A and a row of B to compute partial
result of the output. A is loaded in shared memory in a coalesced way. B should
get benefit from L2 cache.
*/
template <typename Idx, typename DType>
__global__ void gatherMMScatterKernel(
const DType* __restrict__ A,
const DType* __restrict__ B,
DType* __restrict__ C,
const Idx* __restrict__ idx_a,
const Idx* __restrict__ idx_b,
const Idx* __restrict__ idx_c,
int64_t num_rows,
int64_t in_len, int64_t out_len) {
unsigned int tId = threadIdx.x;
unsigned int laneId = tId & 31;
unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
unsigned int warpId = gId >> 5;
unsigned int row = warpId;
if (row < num_rows) {
unsigned int local_row = row & 3; // hardcoded for TB size 128 (4 warps)
unsigned int row_a = (idx_a) ? idx_a[row] : row;
unsigned int row_b = (idx_b) ? idx_b[row] : row;
Idx C_offset = (idx_c) ? idx_c[row] * in_len * out_len : 0;
const int sh_a_tile = 64;
__shared__ DType sh_A[4 * sh_a_tile];
int a_tile = sh_a_tile;
for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
/* Load A in shared mem in a coalesced way */
for (unsigned int l = laneId; l < a_tile; l += 32)
sh_A[local_row * sh_a_tile + l] = A[row_a * in_len + (k_start + l)];
__syncwarp();
for (unsigned int outloop = 0; outloop < out_len; outloop +=32) {
DType out_reg = 0; // thread private
const unsigned int l = laneId;
if (l < out_len) {
const DType b_val = B[row_b * out_len + (outloop + l)];
/* iterate over elements of a row of A */
for (unsigned int i = 0; i < a_tile; i++) {
const DType a_val = sh_A[local_row * sh_a_tile + i];
const Idx C_idx = C_offset + ((i + k_start) * out_len + (outloop + l));
atomicAdd(reinterpret_cast<float*>(&C[C_idx]),
static_cast<float>(a_val * b_val));
}
}
}
}
}
}
/* \brief Implementation of GatherMM operator. The indices of A (or B)
* are looked up from idx_a (or idx_b) when defined.
*/
template <int XPU, typename IdType, int bits>
void gatherMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
int64_t num_rel) {
SWITCH_BITS(bits, DType, {
auto device = runtime::DeviceAPI::Get(A->ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const DType *A_data = A.Ptr<DType>();
const DType *B_data = B.Ptr<DType>();
int64_t out_len = B->shape[1]; // cols of B
int64_t in_len = A->shape[1]; // cols of A
if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle,
thr_entry->stream));
int64_t tot_num_rows = A->shape[0];
const int ntx = 128;
const int warp_size = 32;
const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
const dim3 nblks(nbx);
const dim3 nthrs(ntx);
CUDA_KERNEL_CALL((gatherMMKernel<IdType, DType>),
nblks, nthrs, 0, thr_entry->stream,
static_cast<DType*>(A->data),
static_cast<DType*>(B->data),
static_cast<DType*>(C->data),
static_cast<IdType*>(idx_a->data),
static_cast<IdType*>(idx_b->data),
tot_num_rows,
in_len, out_len);
});
}
/* \brief Implementation of GatherMM operator. The indices of A (or B or C)
* are looked up from idx_a (or idx_b or idx_c) when defined.
*/
template <int XPU, typename IdType, int bits>
void gatherMM_scatter(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c,
int num_rel, bool a_trans, bool b_trans) {
SWITCH_BITS(bits, DType, {
auto device = runtime::DeviceAPI::Get(A->ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const IdType *idx_c_data = idx_c.Ptr<IdType>();
int64_t out_len = B->shape[1]; // cols of B
int64_t in_len = A->shape[1]; // cols of A
if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle,
thr_entry->stream));
DType* B_trans_data = nullptr;
if (b_trans) {
int64_t B_offset = 0;
const DType *B_data = B.Ptr<DType>();
in_len = B->shape[0]/num_rel;
B_trans_data = static_cast<DType*>(device->AllocWorkspace \
(B->ctx, B->shape[0] * B->shape[1] * sizeof(DType)));
// tranpose B per relation
for (int rel = 0; rel < num_rel; ++rel) {
_Transpose(thr_entry->cublas_handle, B_data + B_offset,
B_trans_data + B_offset, in_len, out_len);
B_offset += in_len * out_len;
}
std::swap(in_len, out_len);
}
int64_t tot_num_rows = A->shape[0];
const int ntx = 128;
const int warp_size = 32;
const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
const dim3 nblks(nbx);
const dim3 nthrs(ntx);
if (idx_c_data) {
// Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i]
// This kernel accesses rows of A in a transposed way w/o explicitly converting A
CUDA_KERNEL_CALL((gatherMMScatterKernel<IdType, DType>),
nblks, nthrs, 0, thr_entry->stream,
static_cast<DType*>(A->data),
static_cast<DType*>(B->data),
static_cast<DType*>(C->data),
static_cast<IdType*>(idx_a->data),
static_cast<IdType*>(idx_b->data),
static_cast<IdType*>(idx_c->data),
tot_num_rows,
in_len, out_len);
} else { // use generic gather_mm
CUDA_KERNEL_CALL((gatherMMKernel<IdType, DType>),
nblks, nthrs, 0, thr_entry->stream,
static_cast<DType*>(A->data),
(b_trans) ? B_trans_data : static_cast<DType*>(B->data),
static_cast<DType*>(C->data),
static_cast<IdType*>(idx_a->data),
static_cast<IdType*>(idx_b->data),
tot_num_rows,
in_len, out_len);
}
if (b_trans)
device->FreeWorkspace(B->ctx, B_trans_data);
});
}
} // namespace cuda
/* \brief Implementation of SegmentMM operator. Each segment calls cuBLAS
* GEMM operator to multiply segment of A and B. When A or B needs to be
* tranposed, cuBLAS GEMM switches it's transpose parameter (CUBLAS_OP_T).
*/
template <int XPU, typename IdType, int bits>
void segment_mm(const NDArray A,
const NDArray B,
NDArray C,
const NDArray seglen_A,
bool a_trans, bool b_trans) {
SWITCH_BITS(bits, DType, {
auto device = runtime::DeviceAPI::Get(A->ctx);
const DType *A_data = A.Ptr<DType>();
const DType *B_data = B.Ptr<DType>();
const IdType* seglen_A_data = seglen_A.Ptr<IdType>();
DType *C_data = C.Ptr<DType>();
int64_t A_offset = 0, B_offset = 0, C_offset = 0;
int64_t m, n, k;
int64_t num_rel = seglen_A.NumElements();
DType alpha = 1., beta = 0.;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle,
thr_entry->stream));
for (int etype = 0; etype < num_rel; ++etype) {
IdType B_dim1 = B->shape[0] / num_rel;
assert((a_trans) ? seglen_A_data[etype] : A->shape[1] == \
(b_trans) ? B->shape[1] : B_dim1);
m = seglen_A_data[etype]; // rows of A
n = B->shape[1]; // cols of B
k = A->shape[1]; // cols of A == rows of B
int ldb = n, lda = k, ldc = n;
cublasOperation_t transB = CUBLAS_OP_N;
cublasOperation_t transA = CUBLAS_OP_N;
if (a_trans) {
transA = CUBLAS_OP_T;
ldb = n, lda = k, ldc = n;
std::swap(m, k);
}
if (b_trans) {
transB = CUBLAS_OP_T;
k = B_dim1;
ldb = n, lda = n, ldc = k;
std::swap(n, k);
}
CUBLAS_CALL(cublasGemm<DType>(
thr_entry->cublas_handle,
transB,
transA,
n, m, k,
&alpha,
B_data + B_offset, ldb,
A_data + A_offset, lda,
&beta,
C_data + C_offset, ldc));
A_offset += m * k;
B_offset += k * n;
C_offset += m * n;
}
});
}
/*!
* \brief Implementation of Gather_mm operator. The input matrix A is
* expected to be sorted according to relation type.
* \param A The input dense matrix of dimension m x k
* \param B The input dense matrix of dimension k x n
* \param C The output dense matrix of dimension m x n
* \param seglen_A The input vector of size R. Each element
* is the length of segments of input ``A``
* \param a_trans Matrix A to be transposed
* \param b_trans Matrix B to be transposed
*/
template <int XPU, typename IdType, int bits>
void segmentMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray seglen_A,
bool a_trans, bool b_trans) {
segment_mm<XPU, IdType, bits>(A, B, C, seglen_A, a_trans, b_trans);
}
/*!
* \brief Implementation of Gather_mm operator. The input matrix A is
* expected to be sorted according to relation type.
* \param A The input dense matrix of dimension m x k
* \param B The input dense matrix of dimension k x n
* \param C The output dense matrix of dimension m x n
* \param idx_a The input vector to gather left hand operand on
* \param idx_b The input vector to gather right hand operand on
* \param num_rel The number of idx types in idx_b
*/
template <int XPU, typename IdType, int bits>
void gatherMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const int num_rel) {
cuda::gatherMM<XPU, IdType, bits>(A, B, C, idx_a, idx_b, num_rel);
}
/*!
* \brief Implementation of Gather_mm operator. The input matrix A is
* expected to be sorted according to relation type.
* \param A The input dense matrix of dimension m x k
* \param B The input dense matrix of dimension k x n
* \param C The output dense matrix of dimension m x n
* \param idx_a The input vector to gather left hand operand on
* \param idx_b The input vector to gather right hand operand on
* \param idx_c The input vector to gather output operand on
* \param num_rel The number of idx types in idx_b
* \param a_trans Matrix A to be transposed
* \param b_trans Matrix B to be transposed
*/
template <int XPU, typename IdType, int bits>
void gatherMM_scatter(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c,
const int num_rel,
bool a_trans, bool b_trans) {
cuda::gatherMM_scatter<XPU, IdType, bits>(A, B, C, idx_a, idx_b, idx_c,
num_rel, a_trans, b_trans);
}
template void gatherMM<kDLGPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLGPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLGPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLGPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLGPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLGPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM_scatter<kDLGPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLGPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLGPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLGPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLGPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLGPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void segmentMM<kDLGPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLGPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLGPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLGPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLGPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLGPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
} // namespace aten
} // namespace dgl
...@@ -52,6 +52,60 @@ void SpMM(const std::string& op, const std::string& reduce, ...@@ -52,6 +52,60 @@ void SpMM(const std::string& op, const std::string& reduce,
}); });
} }
/*! \brief Generalized segmented dense Matrix-Matrix Multiplication. */
void SegmentMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray seglen_A,
bool A_trans, bool B_trans) {
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(seglen_A->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", {
segmentMM<XPU, IdType, bits>(A, B, C, seglen_A, A_trans, B_trans);
});
});
});
}
/*! \brief Generalized Dense Matrix-Matrix Multiplication according to relation types. */
void GatherMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const int num_rel) {
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(idx_b->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", {
gatherMM<XPU, IdType, bits>(A, B, C, idx_a, idx_b, num_rel);
});
});
});
}
/*! \brief Generalized Dense Matrix-Matrix Multiplication according to relation types. */
void GatherMM_scatter(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c,
const int num_rel,
bool A_trans, bool B_trans) {
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(idx_b->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", {
gatherMM_scatter<XPU, IdType, bits>(A, B, C, idx_a, idx_b, idx_c,
num_rel, A_trans, B_trans);
});
});
});
}
/*! \brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph support. */ /*! \brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph support. */
void SpMMHetero(const std::string& op, const std::string& reduce, void SpMMHetero(const std::string& op, const std::string& reduce,
HeteroGraphPtr graph, HeteroGraphPtr graph,
...@@ -390,6 +444,42 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM") ...@@ -390,6 +444,42 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM")
SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE}); SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE});
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGATHERMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray B = args[1];
NDArray C = args[2];
NDArray idx_a = args[3];
NDArray idx_b = args[4];
int num_rel = args[5];
GatherMM(A, B, C, idx_a, idx_b, num_rel);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGATHERMMSCATTER")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray B = args[1];
NDArray C = args[2];
NDArray idx_a = args[3];
NDArray idx_b = args[4];
NDArray idx_c = args[5];
int num_rel = args[6];
bool A_trans = args[7];
bool B_trans = args[8];
GatherMM_scatter(A, B, C, idx_a, idx_b, idx_c, num_rel, A_trans, B_trans);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSEGMENTMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray B = args[1];
NDArray C = args[2];
NDArray seglen_A = args[3];
bool A_trans = args[4];
bool B_trans = args[5];
SegmentMM(A, B, C, seglen_A, A_trans, B_trans);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_forward") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_forward")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
......
...@@ -112,6 +112,39 @@ void SDDMMCooHetero(const std::string& op, ...@@ -112,6 +112,39 @@ void SDDMMCooHetero(const std::string& op,
const std::vector<dgl_type_t>& lhs_eid, const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid); const std::vector<dgl_type_t>& rhs_eid);
/*!
* \brief Generalized Dense Matrix-Matrix Multiplication according to relation types.
*/
template <int XPU, typename IdType, int bits>
void gatherMM(const NDArray A,
const NDArray B,
NDArray out,
const NDArray idx_a,
const NDArray idx_b,
const int num_rel);
/*!
* \brief Generalized Dense Matrix-Matrix Multiplication according to relation types.
*/
template <int XPU, typename IdType, int bits>
void gatherMM_scatter(const NDArray A,
const NDArray B,
NDArray out,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
/*!
* \brief Generalized segmented dense Matrix-Matrix Multiplication.
*/
template <int XPU, typename IdType, int bits>
void segmentMM(const NDArray A,
const NDArray B,
NDArray out,
const NDArray seglen_A,
bool a_trans, bool b_trans);
/*! /*!
* \brief Segment reduce. * \brief Segment reduce.
*/ */
......
from timeit import default_timer
import dgl
import backend as F
import dgl.function as fn
import time
import numpy as np
import unittest, pytest
from test_utils import parametrize_dtype, get_cases
iters = 5
n_edge_scale = 1
num_rel_scale = 1
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@unittest.skipIf(F._default_context_str == 'cpu', reason="Not implemented.")
@parametrize_dtype
def test_gathermm(idtype):
def _test(feat_scale):
in_feat = 16 * feat_scale
out_feat = 8 * feat_scale
print("in/out feat", in_feat, out_feat)
E_per_rel = F.copy_to(F.tensor([50, 100, 20, 284, 89, 10, 82, 9200, 10, 20, 30, 100,
128, 20, 284, 89, 10, 82, 92, 10, 20, 30, 100, 1280, 20, 284, 89, 1000, 82,
92, 10, 2000, 30, 100, 128, 20, 284, 89, 10, 82, 92, 10, 20, 30]), F.cpu())
E_per_rel *= n_edge_scale
num_rel = len(E_per_rel)
print('num_rel', num_rel)
W_per_len = F.copy_to(F.full((num_rel,) ,in_feat, dtype=F.dtype(E_per_rel)), F.cpu())
H_arr = []
W_arr = []
Out_arr = []
Out_grad_arr = []
for eid in range(num_rel):
H_arr.append(F.randn((E_per_rel[eid], in_feat)))
W_arr.append(F.randn((in_feat, out_feat)))
Out_arr.append(F.zeros((E_per_rel[eid], out_feat)))
Out_grad_arr.append(F.ones((E_per_rel[eid], out_feat)))
H = F.cat([h for h in H_arr], 0)
W = F.cat([w for w in W_arr], 0)
W_3D = W.reshape(num_rel, in_feat, out_feat)
Out = F.cat([out for out in Out_arr], 0)
Out_grad = F.cat([o for o in Out_grad_arr], 0)
print('H.shape', H.shape)
print('W.shape', W.shape)
print('W_3D.shape', W_3D.shape)
print('Out.shape', Out.shape)
etype_arr = []
for eid in range(num_rel):
etype_arr.append(F.full((E_per_rel[eid],), eid, dtype=F.dtype(E_per_rel)))
etypes = F.cat([etype for etype in etype_arr], 0)
#################################################################
# low-mem version using PyTorch operator
#################################################################
# forward pass
out = []
for i in range(len(E_per_rel)):
Hi = H_arr[i]
Wi = W_arr[i]
out.append(F.matmul(Hi, Wi))
out_low_mem = F.cat(out, 0)
# backward pass
H_grad = []
W_grad = []
for i in range(len(E_per_rel)):
Hi = H_arr[i]
Wi = W_arr[i]
Out_gradi = Out_grad_arr[i]
H_grad.append(F.matmul(Out_gradi, Wi.transpose(0,1)))
W_grad.append(F.matmul(Hi.transpose(0,1), Out_gradi))
Hgrad_low_mem = F.cat(H_grad, 0)
Wgrad_low_mem = F.cat(W_grad, 0)
Wgrad_low_mem = Wgrad_low_mem.reshape(num_rel, in_feat, out_feat)
#################################################################
# gather_mm where H sorted according to etype
#################################################################
seglen_A = E_per_rel
F.attach_grad(H)
F.attach_grad(W_3D)
with F.record_grad():
out_gmm_sorted = dgl.ops.segment_mm(H, W_3D, seglen_A)
F.backward(F.reduce_sum(out_gmm_sorted))
Hgrad_gmm_sorted = H.grad
Wgrad_gmm_sorted = W_3D.grad
#################################################################
# gather_mm where H is not sorted (backward not supported yet)
#################################################################
F.attach_grad(H)
F.attach_grad(W_3D)
with F.record_grad():
out_gmm_unsorted = dgl.ops.gather_mm(H, W_3D, idx_rhs=etypes)
F.backward(F.reduce_sum(out_gmm_unsorted))
Hgrad_gmm_unsorted = H.grad
Wgrad_gmm_unsorted = W_3D.grad
# correctness check
assert F.allclose(out_low_mem, out_gmm_sorted, atol=1e-3, rtol=1e-3)
assert F.allclose(Hgrad_low_mem, Hgrad_gmm_sorted, atol=1e-3, rtol=1e-3)
assert F.allclose(Wgrad_low_mem, Wgrad_gmm_sorted, atol=1e-3, rtol=1e-3)
assert F.allclose(out_low_mem, out_gmm_unsorted, atol=1e-3, rtol=1e-3)
assert F.allclose(Hgrad_low_mem, Hgrad_gmm_unsorted, atol=1e-3, rtol=1e-3)
assert F.allclose(Wgrad_low_mem, Wgrad_gmm_unsorted, atol=1e-3, rtol=1e-3)
_test(1)
_test(4)
_test(16)
_test(32)
if __name__ == '__main__':
test_gathermm()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment