"docs/source/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "8005978e1eebd1b0dbbafb4dada66f5dd504d7ec"
Unverified Commit 7bab1365 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[feature] Supporting half precision floating data type (fp16). (#2552)



* add tvm as submodule

* compilation is ok but calling fails

* can call now

* pack multiple modules, change names

* upd

* upd

* upd

* fix cmake

* upd

* upd

* upd

* upd

* fix

* relative path

* upd

* upd

* upd

* singleton

* upd

* trigger

* fix

* upd

* count reducible

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* only keep related files

* upd

* upd

* upd

* upd

* lint

* lint

* lint

* lint

* pylint

* upd

* upd

* compilation

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd doc

* refactor

* fix

* upd number
Co-authored-by: default avatarZhi Lin <linzhilynn@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-42-78.us-east-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-21-156.us-east-2.compute.internal>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent a7e941c3
...@@ -25,6 +25,7 @@ endif() ...@@ -25,6 +25,7 @@ endif()
dgl_option(USE_CUDA "Build with CUDA" OFF) dgl_option(USE_CUDA "Build with CUDA" OFF)
dgl_option(USE_OPENMP "Build with OpenMP" ON) dgl_option(USE_OPENMP "Build with OpenMP" ON)
dgl_option(USE_AVX "Build with AVX optimization" ON) dgl_option(USE_AVX "Build with AVX optimization" ON)
dgl_option(USE_FP16 "Build with fp16 support to enable mixed precision training" OFF)
dgl_option(USE_TVM "Build with TVM kernels" OFF) dgl_option(USE_TVM "Build with TVM kernels" OFF)
dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF) dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
dgl_option(LIBCXX_ENABLE_PARALLEL_ALGORITHMS "Enable the parallel algorithms library. This requires the PSTL to be available." OFF) dgl_option(LIBCXX_ENABLE_PARALLEL_ALGORITHMS "Enable the parallel algorithms library. This requires the PSTL to be available." OFF)
...@@ -101,13 +102,22 @@ if(USE_OPENMP) ...@@ -101,13 +102,22 @@ if(USE_OPENMP)
set(CMAKE_C_FLAGS "${OpenMP_C_FLAGS} ${CMAKE_C_FLAGS}") set(CMAKE_C_FLAGS "${OpenMP_C_FLAGS} ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${OpenMP_CXX_FLAGS} ${CMAKE_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "${OpenMP_CXX_FLAGS} ${CMAKE_CXX_FLAGS}")
endif(OPENMP_FOUND) endif(OPENMP_FOUND)
message(STATUS "Build with OpenMP.")
endif(USE_OPENMP) endif(USE_OPENMP)
if(USE_AVX) if(USE_AVX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_AVX") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_AVX")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_AVX") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_AVX")
message(STATUS "Build with AVX optimization.")
endif(USE_AVX) endif(USE_AVX)
# Build with fp16 to support mixed precision training.
if(USE_FP16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_FP16")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_FP16")
message(STATUS "Build with fp16 to support mixed precision training")
endif(USE_FP16)
# To compile METIS correct for DGL. # To compile METIS correct for DGL.
if(MSVC) if(MSVC)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /DIDXTYPEWIDTH=64 /DREALTYPEWIDTH=32") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /DIDXTYPEWIDTH=64 /DREALTYPEWIDTH=32")
...@@ -194,6 +204,7 @@ if(USE_TVM) ...@@ -194,6 +204,7 @@ if(USE_TVM)
target_include_directories(dgl PRIVATE "featgraph/include") target_include_directories(dgl PRIVATE "featgraph/include")
add_subdirectory("featgraph/") add_subdirectory("featgraph/")
list(APPEND DGL_LINKER_LIBS featgraph_runtime) list(APPEND DGL_LINKER_LIBS featgraph_runtime)
message(STATUS "Build with TVM runtime and featgraph kernels.")
endif(USE_TVM) endif(USE_TVM)
# support PARALLEL_ALGORITHMS # support PARALLEL_ALGORITHMS
......
...@@ -49,3 +49,5 @@ set(BUILD_TORCH ON) ...@@ -49,3 +49,5 @@ set(BUILD_TORCH ON)
# Whether to enable CUDA kernels compiled with TVM. # Whether to enable CUDA kernels compiled with TVM.
set(USE_TVM OFF) set(USE_TVM OFF)
# Whether to enable fp16 to support mixed precision training.
set(USE_FP16 OFF)
...@@ -136,7 +136,7 @@ function(dgl_select_nvcc_arch_flags out_variable) ...@@ -136,7 +136,7 @@ function(dgl_select_nvcc_arch_flags out_variable)
string(REGEX MATCHALL "[0-9]+" __cuda_arch_ptx "${__cuda_arch_ptx}") string(REGEX MATCHALL "[0-9]+" __cuda_arch_ptx "${__cuda_arch_ptx}")
mshadow_list_unique(__cuda_arch_bin __cuda_arch_ptx) mshadow_list_unique(__cuda_arch_bin __cuda_arch_ptx)
set(__nvcc_flags "") set(__nvcc_flags "--expt-relaxed-constexpr")
set(__nvcc_archs_readable "") set(__nvcc_archs_readable "")
# Tell NVCC to add binaries for the specified GPUs # Tell NVCC to add binaries for the specified GPUs
......
...@@ -2,4 +2,5 @@ build ...@@ -2,4 +2,5 @@ build
# tutorials are auto-generated # tutorials are auto-generated
source/tutorials source/tutorials
source/new-tutorial
source/generated source/generated
...@@ -246,6 +246,21 @@ DGL provide operators to reduce value tensor along the first dimension by segmen ...@@ -246,6 +246,21 @@ DGL provide operators to reduce value tensor along the first dimension by segmen
segment_reduce segment_reduce
Supported Data types
--------------------
Operators defined in ``dgl.ops`` support floating point data types, i.e. the operands
must be ``half`` (``float16``) /``float``/``double`` tensors.
The input tensors must have the same data type (if one input tensor has type float16
and the other input tensor has data type float32, user must convert one of them to
align with the other one).
``float16`` data type support is disabled by default as it has a minimum GPU
compute capacity requirement of ``sm_53`` (Pascal, Volta, Turing and Ampere
architectures).
User can enable float16 for mixed precision training by compiling DGL from source
(see :doc:`Mixed Precision Training </guide/mixed_precision>` tutorial for details).
Relation with Message Passing APIs Relation with Message Passing APIs
---------------------------------- ----------------------------------
......
...@@ -97,3 +97,12 @@ features ``ft``, and finally multiply ``ft`` by 2 to get the result ...@@ -97,3 +97,12 @@ features ``ft``, and finally multiply ``ft`` by 2 to get the result
The math formula for the above function is: The math formula for the above function is:
.. math:: {final\_ft}_i = 2 * \sum_{j\in\mathcal{N}(i)} ({ft}_j * a_{ij}) .. math:: {final\_ft}_i = 2 * \sum_{j\in\mathcal{N}(i)} ({ft}_j * a_{ij})
DGL's built-in functions support floating point data types, i.e. the feature must
be ``half`` (``float16``) /``float``/``double`` tensors.
``float16`` data type support is disabled by default as it has a minimum GPU
compute capacity requirement of ``sm_53`` (Pascal, Volta, Turing and Ampere
architectures).
User can enable float16 for mixed precision training by compiling DGL from source
(see :doc:`Mixed Precision Training <mixed_precision>` tutorial for details).
\ No newline at end of file
.. _guide-mixed_precision:
Chapter 8: Mixed Precision Training
===================================
DGL is compatible with `PyTorch's automatic mixed precision package
<https://pytorch.org/docs/stable/amp.html>`_
for mixed precision training, thus saving both training time and GPU memory
consumption. To enable this feature, users need to install PyTorch 1.6+ and
build DGL from source file to support ``float16`` data type (this feature is
still in its beta stage and we do not provide official pre-built pip wheels).
Installation
------------
First download DGL's source code from GitHub and build the shared library
with flag ``USE_FP16=ON``.
.. code:: bash
git clone --recurse-submodules https://github.com/dmlc/dgl.git
cd dgl
mkdir build
cd build
cmake -DUSE_CUDA=ON -DUSE_FP16=ON ..
make -j
Then install the Python binding.
.. code:: bash
cd ../python
python setup.py install
Message-Passing with Half Precision
-----------------------------------
DGL with fp16 support allows message-passing on ``float16`` features for both
UDF(User Defined Function)s and built-in functions (e.g. ``dgl.function.sum``,
``dgl.function.copy_u``).
The following examples shows how to use DGL's message-passing API on half-precision
features:
>>> import torch
>>> import dgl
>>> import dgl.function as fn
>>> g = dgl.rand_graph(30, 100).to(0) # Create a graph on GPU w/ 30 nodes and 100 edges.
>>> g.ndata['h'] = torch.rand(30, 16).to(0).half() # Create fp16 node features.
>>> g.edata['w'] = torch.rand(100, 1).to(0).half() # Create fp16 edge features.
>>> # Use DGL's built-in functions for message passing on fp16 features.
>>> g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'x'))
>>> g.ndata['x'][0]
tensor([0.3391, 0.2208, 0.7163, 0.6655, 0.7031, 0.5854, 0.9404, 0.7720, 0.6562,
0.4028, 0.6943, 0.5908, 0.9307, 0.5962, 0.7827, 0.5034],
device='cuda:0', dtype=torch.float16)
>>> g.apply_edges(fn.u_dot_v('h', 'x', 'hx'))
>>> g.edata['hx'][0]
tensor([5.4570], device='cuda:0', dtype=torch.float16)
>>> # Use UDF(User Defined Functions) for message passing on fp16 features.
>>> def message(edges):
... return {'m': edges.src['h'] * edges.data['w']}
...
>>> def reduce(nodes):
... return {'y': torch.sum(nodes.mailbox['m'], 1)}
...
>>> def dot(edges):
... return {'hy': (edges.src['h'] * edges.dst['y']).sum(-1, keepdims=True)}
...
>>> g.update_all(message, reduce)
>>> g.ndata['y'][0]
tensor([0.3394, 0.2209, 0.7168, 0.6655, 0.7026, 0.5854, 0.9404, 0.7720, 0.6562,
0.4028, 0.6943, 0.5908, 0.9307, 0.5967, 0.7827, 0.5039],
device='cuda:0', dtype=torch.float16)
>>> g.apply_edges(dot)
>>> g.edata['hy'][0]
tensor([5.4609], device='cuda:0', dtype=torch.float16)
End-to-End Mixed Precision Training
-----------------------------------
DGL relies on PyTorch's AMP package for mixed precision training,
and the user experience is exactly
the same as `PyTorch's <https://pytorch.org/docs/stable/notes/amp_examples.html>`_.
By wrapping the forward pass (including loss computation) of your GNN model with
``torch.cuda.amp.autocast()``, PyTorch automatically selects the appropriate datatype
for each op and tensor. Half precision tensors are memory efficient, most operators
on half precision tensors are faster as they leverage GPU's tensorcores.
Small Gradients in ``float16`` format have underflow problems (flush to zero), and
PyTorch provides a ``GradScaler`` module to address this issue. ``GradScaler`` multiplies
loss by a factor and invokes backward pass on scaled loss, and unscales graidents before
optimizers update the parameters, thus preventing the underflow problem.
The scale factor is determined automatically.
Following is the training script of 3-layer GAT on Reddit dataset (w/ 114 million edges),
note the difference in codes when ``use_fp16`` is activated/not activated:
.. code::
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import dgl
from dgl.data import RedditDataset
from dgl.nn import GATConv
use_fp16 = True
class GAT(nn.Module):
def __init__(self,
in_feats,
n_hidden,
n_classes,
heads):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(GATConv(in_feats, n_hidden, heads[0], activation=F.elu))
self.layers.append(GATConv(n_hidden * heads[0], n_hidden, heads[1], activation=F.elu))
self.layers.append(GATConv(n_hidden * heads[1], n_classes, heads[2], activation=F.elu))
def forward(self, g, h):
for l, layer in enumerate(self.layers):
h = layer(g, h)
if l != len(self.layers) - 1:
h = h.flatten(1)
else:
h = h.mean(1)
return h
# Data loading
data = RedditDataset()
device = torch.device(0)
g = data[0]
g = dgl.add_self_loop(g)
g = g.int().to(device)
train_mask = g.ndata['train_mask']
features = g.ndata['feat']
labels = g.ndata['label']
in_feats = features.shape[1]
n_hidden = 256
n_classes = data.num_classes
n_edges = g.number_of_edges()
heads = [1, 1, 1]
model = GAT(in_feats, n_hidden, n_classes, heads)
model = model.to(device)
# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
# Create gradient scaler
scaler = GradScaler()
for epoch in range(100):
model.train()
optimizer.zero_grad()
# Wrap forward pass with autocast
with autocast(enabled=use_fp16):
logits = model(g, features)
loss = F.cross_entropy(logits[train_mask], labels[train_mask])
if use_fp16:
# Backprop w/ gradient scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
print('Epoch {} | Loss {}'.format(epoch, loss.item()))
On a NVIDIA V100 (16GB) machine, training this model without fp16 consumes
15.2GB GPU memory; with fp16 turned on, the training consumes 12.8G
GPU memory, the loss converges to similar values in both settings.
If we change the number of heads to ``[2, 2, 2]``, training without fp16
triggers GPU OOM(out-of-memory) issue while training with fp16 consumes
15.7G GPU memory.
DGL is still improving its half-precision support and the compute kernel's
performance is far from optimal, please stay tuned to our future updates.
...@@ -97,6 +97,7 @@ Getting Started ...@@ -97,6 +97,7 @@ Getting Started
guide/training guide/training
guide/minibatch guide/minibatch
guide/distributed guide/distributed
guide/mixed_precision
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
......
...@@ -124,6 +124,23 @@ ...@@ -124,6 +124,23 @@
} \ } \
} while (0) } while (0)
#define ATEN_FLOAT_BITS_SWITCH(val, bits, val_name, ...) do { \
CHECK_EQ((val).code, kDLFloat) \
<< (val_name) << " must be float type"; \
if ((val).bits == 16) { \
constexpr int bits = 16; \
{__VA_ARGS__} \
} else if ((val).bits == 32) { \
constexpr int bits = 32; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
constexpr int bits = 64; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be float32 or float64"; \
} \
} while (0)
/* /*
* Dispatch according to data type (int32, int64, float32 or float64): * Dispatch according to data type (int32, int64, float32 or float64):
* *
......
import torch as th import torch as th
from distutils.version import LooseVersion
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _reverse from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _reverse
if LooseVersion(th.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import custom_fwd, custom_bwd
else:
import functools
"""PyTorch natively supports automatic mixed precision in DGL 1.6, we redefine
the custom_fwd and custom_bwd function to be compatible with DGL 1.5.
"""
def custom_fwd(**kwargs):
def custom_fwd_inner(fwd):
@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
return fwd(*args, **kwargs)
return decorate_fwd
return custom_fwd_inner
def custom_bwd(bwd):
@functools.wraps(bwd)
def decorate_bwd(*args, **kwargs):
return bwd(*args, **kwargs)
return decorate_bwd
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce'] __all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
...@@ -60,6 +82,7 @@ def _expand(x, shape): ...@@ -60,6 +82,7 @@ def _expand(x, shape):
class GSpMM(th.autograd.Function): class GSpMM(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, reduce_op, X, Y): def forward(ctx, gidx, op, reduce_op, X, Y):
out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y) out, (argX, argY) = _gspmm(gidx, op, reduce_op, X, Y)
ctx.backward_cache = gidx, op, reduce_op ctx.backward_cache = gidx, op, reduce_op
...@@ -67,6 +90,7 @@ class GSpMM(th.autograd.Function): ...@@ -67,6 +90,7 @@ class GSpMM(th.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, dZ): def backward(ctx, dZ):
gidx, op, reduce_op = ctx.backward_cache gidx, op, reduce_op = ctx.backward_cache
X, Y, argX, argY = ctx.saved_tensors X, Y, argX, argY = ctx.saved_tensors
...@@ -120,6 +144,7 @@ class GSpMM(th.autograd.Function): ...@@ -120,6 +144,7 @@ class GSpMM(th.autograd.Function):
class GSDDMM(th.autograd.Function): class GSDDMM(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, op, X, Y, lhs_target, rhs_target): def forward(ctx, gidx, op, X, Y, lhs_target, rhs_target):
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target) out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
ctx.backward_cache = gidx, op, lhs_target, rhs_target ctx.backward_cache = gidx, op, lhs_target, rhs_target
...@@ -127,6 +152,7 @@ class GSDDMM(th.autograd.Function): ...@@ -127,6 +152,7 @@ class GSDDMM(th.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, dZ): def backward(ctx, dZ):
gidx, op, lhs_target, rhs_target = ctx.backward_cache gidx, op, lhs_target, rhs_target = ctx.backward_cache
X, Y = ctx.saved_tensors X, Y = ctx.saved_tensors
...@@ -179,6 +205,7 @@ class GSDDMM(th.autograd.Function): ...@@ -179,6 +205,7 @@ class GSDDMM(th.autograd.Function):
class EdgeSoftmax(th.autograd.Function): class EdgeSoftmax(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, score, eids, norm_by): def forward(ctx, gidx, score, eids, norm_by):
"""Forward function. """Forward function.
...@@ -208,6 +235,7 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -208,6 +235,7 @@ class EdgeSoftmax(th.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, grad_out): def backward(ctx, grad_out):
"""Backward function. """Backward function.
...@@ -233,6 +261,7 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -233,6 +261,7 @@ class EdgeSoftmax(th.autograd.Function):
class SegmentReduce(th.autograd.Function): class SegmentReduce(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, op, x, offsets): def forward(ctx, op, x, offsets):
y, arg = _segment_reduce(op, x, offsets) y, arg = _segment_reduce(op, x, offsets)
ctx.save_for_backward(arg, offsets) ctx.save_for_backward(arg, offsets)
...@@ -240,6 +269,7 @@ class SegmentReduce(th.autograd.Function): ...@@ -240,6 +269,7 @@ class SegmentReduce(th.autograd.Function):
return y return y
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, dy): def backward(ctx, dy):
op = ctx.backward_cache op = ctx.backward_cache
arg, offsets = ctx.saved_tensors arg, offsets = ctx.saved_tensors
......
...@@ -159,6 +159,7 @@ def _gspmm(gidx, op, reduce_op, u, e): ...@@ -159,6 +159,7 @@ def _gspmm(gidx, op, reduce_op, u, e):
if F.ndim(e) == 1: if F.ndim(e) == 1:
e = F.unsqueeze(e, -1) e = F.unsqueeze(e, -1)
expand_e = True expand_e = True
ctx = F.context(u) if use_u else F.context(e) ctx = F.context(u) if use_u else F.context(e)
dtype = F.dtype(u) if use_u else F.dtype(e) dtype = F.dtype(u) if use_u else F.dtype(e)
u_shp = F.shape(u) if use_u else (0,) u_shp = F.shape(u) if use_u else (0,)
......
...@@ -41,9 +41,22 @@ namespace aten { ...@@ -41,9 +41,22 @@ namespace aten {
} \ } \
} while (0) } while (0)
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16 || (bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not renogized with bits " << bits; \
} \
} while (0)
/*! \brief Generalized SDDMM on Csr format. */ /*! \brief Generalized SDDMM on Csr format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, int bits>
void SDDMMCsr(const std::string& op, void SDDMMCsr(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr,
...@@ -52,32 +65,43 @@ void SDDMMCsr(const std::string& op, ...@@ -52,32 +65,43 @@ void SDDMMCsr(const std::string& op,
NDArray out, NDArray out,
int lhs_target, int lhs_target,
int rhs_target) { int rhs_target) {
SWITCH_OP(op, Op, { SWITCH_BITS(bits, DType, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_OP(op, Op, {
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out); SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
});
}); });
}); });
} }
template void SDDMMCsr<kDLCPU, int32_t, float>( template void SDDMMCsr<kDLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, float>( template void SDDMMCsr<kDLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int32_t, double>( template void SDDMMCsr<kDLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, double>( template void SDDMMCsr<kDLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
/*! \brief Generalized SDDMM on Coo format. */ /*! \brief Generalized SDDMM on Coo format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, int bits>
void SDDMMCoo(const std::string& op, void SDDMMCoo(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const COOMatrix& coo, const COOMatrix& coo,
...@@ -86,29 +110,40 @@ void SDDMMCoo(const std::string& op, ...@@ -86,29 +110,40 @@ void SDDMMCoo(const std::string& op,
NDArray out, NDArray out,
int lhs_target, int lhs_target,
int rhs_target) { int rhs_target) {
SWITCH_OP(op, Op, { SWITCH_BITS(bits, DType, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_OP(op, Op, {
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out); SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
});
}); });
}); });
} }
template void SDDMMCoo<kDLCPU, int32_t, float>( template void SDDMMCoo<kDLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, float>( template void SDDMMCoo<kDLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int32_t, double>( template void SDDMMCoo<kDLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, double>( template void SDDMMCoo<kDLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -12,7 +12,7 @@ namespace dgl { ...@@ -12,7 +12,7 @@ namespace dgl {
namespace aten { namespace aten {
/*! \brief Segment Reduce operator. */ /*! \brief Segment Reduce operator. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, int bits>
void SegmentReduce( void SegmentReduce(
const std::string& op, const std::string& op,
NDArray feat, NDArray feat,
...@@ -20,65 +20,94 @@ void SegmentReduce( ...@@ -20,65 +20,94 @@ void SegmentReduce(
NDArray out, NDArray out,
NDArray arg) { NDArray arg) {
if (op == "sum") { if (op == "sum") {
cpu::SegmentSum<IdType, DType>(feat, offsets, out); SWITCH_BITS(bits, DType, {
cpu::SegmentSum<IdType, DType>(feat, offsets, out);
});
} else if (op == "max" || op == "min") { } else if (op == "max" || op == "min") {
if (op == "max") if (op == "max") {
cpu::SegmentCmp<IdType, DType, cpu::op::Max<DType>>( SWITCH_BITS(bits, DType, {
feat, offsets, out, arg); cpu::SegmentCmp<IdType, DType, cpu::op::Max<DType>>(
else feat, offsets, out, arg);
cpu::SegmentCmp<IdType, DType, cpu::op::Min<DType>>( });
feat, offsets, out, arg); } else {
SWITCH_BITS(bits, DType, {
cpu::SegmentCmp<IdType, DType, cpu::op::Min<DType>>(
feat, offsets, out, arg);
});
}
} else { } else {
LOG(FATAL) << "Unsupported reduce function " << op; LOG(FATAL) << "Unsupported reduce function " << op;
} }
} }
/*! \brief Backward function of segment cmp.*/ /*! \brief Backward function of segment cmp.*/
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, int bits>
void BackwardSegmentCmp( void BackwardSegmentCmp(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out) { NDArray out) {
cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out); SWITCH_BITS(bits, DType, {
cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
});
} }
template void SegmentReduce<kDLCPU, int32_t, float>( template void SegmentReduce<kDLCPU, int32_t, 16>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, float>( template void SegmentReduce<kDLCPU, int64_t, 16>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDLCPU, int32_t, double>( template void SegmentReduce<kDLCPU, int32_t, 32>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, double>( template void SegmentReduce<kDLCPU, int64_t, 32>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void BackwardSegmentCmp<kDLCPU, int32_t, float>( template void SegmentReduce<kDLCPU, int32_t, 64>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, 64>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void BackwardSegmentCmp<kDLCPU, int32_t, 16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, 16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int32_t, 32>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, float>( template void BackwardSegmentCmp<kDLCPU, int64_t, 32>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDLCPU, int32_t, double>( template void BackwardSegmentCmp<kDLCPU, int32_t, 64>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, double>( template void BackwardSegmentCmp<kDLCPU, int64_t, 64>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
......
...@@ -10,7 +10,7 @@ namespace dgl { ...@@ -10,7 +10,7 @@ namespace dgl {
namespace aten { namespace aten {
/*! \brief Generalized SpMM on Csr format. */ /*! \brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, int bits>
void SpMMCsr(const std::string& op, const std::string& reduce, void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr,
...@@ -19,42 +19,55 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -19,42 +19,55 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
NDArray out, NDArray out,
std::vector<NDArray> out_aux) { std::vector<NDArray> out_aux) {
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_OP(op, Op, { SWITCH_BITS(bits, DType, {
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out); SWITCH_OP(op, Op, {
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
});
}); });
} else if (reduce == "max" || reduce == "min") { } else if (reduce == "max" || reduce == "min") {
SWITCH_OP(op, Op, { SWITCH_BITS(bits, DType, {
if (reduce == "max") SWITCH_OP(op, Op, {
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>( if (reduce == "max")
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
else bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>( else
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
}); });
} else { } else {
LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
} }
} }
template void SpMMCsr<kDLCPU, int32_t, float>( template void SpMMCsr<kDLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, float>( template void SpMMCsr<kDLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int32_t, double>( template void SpMMCsr<kDLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, double>( template void SpMMCsr<kDLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
/*! \brief Generalized SpMM on Coo format. */ /*! \brief Generalized SpMM on Coo format. */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, int bits>
void SpMMCoo(const std::string& op, const std::string& reduce, void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const BcastOff& bcast,
const COOMatrix& coo, const COOMatrix& coo,
...@@ -63,39 +76,52 @@ void SpMMCoo(const std::string& op, const std::string& reduce, ...@@ -63,39 +76,52 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
NDArray out, NDArray out,
std::vector<NDArray> out_aux) { std::vector<NDArray> out_aux) {
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_OP(op, Op, { SWITCH_BITS(bits, DType, {
cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out); SWITCH_OP(op, Op, {
cpu::SpMMSumCoo<IdType, DType, Op>(bcast, coo, ufeat, efeat, out);
});
}); });
} else if (reduce == "max" || reduce == "min") { } else if (reduce == "max" || reduce == "min") {
SWITCH_OP(op, Op, { SWITCH_BITS(bits, DType, {
if (reduce == "max") SWITCH_OP(op, Op, {
cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>( if (reduce == "max")
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Max<DType>>(
else bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>( else
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); cpu::SpMMCmpCoo<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
}); });
} else { } else {
LOG(FATAL) << "Unsupported SpMM reducer: " << reduce; LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
} }
} }
template void SpMMCoo<kDLCPU, int32_t, float>( template void SpMMCoo<kDLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, float>( template void SpMMCoo<kDLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int32_t, double>( template void SpMMCoo<kDLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, double>( template void SpMMCoo<kDLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -146,6 +146,19 @@ constexpr DType Min<DType>::zero; ...@@ -146,6 +146,19 @@ constexpr DType Min<DType>::zero;
} \ } \
} while (0) } while (0)
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16 || (bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not renogized with bits " << bits; \
} \
} while (0)
} // namespace op } // namespace op
} // namespace cpu } // namespace cpu
......
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
#define DGL_ARRAY_CUDA_ATOMIC_H_ #define DGL_ARRAY_CUDA_ATOMIC_H_
#include <cuda_runtime.h> #include <cuda_runtime.h>
#if __CUDA_ARCH__ >= 600 #include <cassert>
#include <cuda_fp16.h> #include "fp16.cuh"
#endif
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -18,6 +18,10 @@ namespace cuda { ...@@ -18,6 +18,10 @@ namespace cuda {
// Type trait for selecting code type // Type trait for selecting code type
template <int Bytes> struct Code { }; template <int Bytes> struct Code { };
template <> struct Code<2> {
typedef unsigned short int Type;
};
template <> struct Code<4> { template <> struct Code<4> {
typedef unsigned int Type; typedef unsigned int Type;
}; };
...@@ -37,6 +41,18 @@ template <typename T> struct Cast { ...@@ -37,6 +41,18 @@ template <typename T> struct Cast {
} }
}; };
#ifdef USE_FP16
template <> struct Cast<half> {
typedef Code<sizeof(half)>::Type Type;
static __device__ __forceinline__ Type Encode(half val) {
return __half_as_ushort(val);
}
static __device__ __forceinline__ half Decode(Type code) {
return __ushort_as_half(code);
}
};
#endif
template <> struct Cast<float> { template <> struct Cast<float> {
typedef Code<sizeof(float)>::Type Type; typedef Code<sizeof(float)>::Type Type;
static __device__ __forceinline__ Type Encode(float val) { static __device__ __forceinline__ Type Encode(float val) {
...@@ -57,6 +73,18 @@ template <> struct Cast<double> { ...@@ -57,6 +73,18 @@ template <> struct Cast<double> {
} }
}; };
static __device__ __forceinline__ unsigned short int atomicCASshort(
unsigned short int *address,
unsigned short int compare,
unsigned short int val) {
#if (defined(CUDART_VERSION) && (CUDART_VERSION > 10000))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
return atomicCAS(address, compare, val);
#endif // (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
#endif // (defined(CUDART_VERSION) && (CUDART_VERSION > 10000))
return val;
}
#define DEFINE_ATOMIC(NAME) \ #define DEFINE_ATOMIC(NAME) \
template <typename T> \ template <typename T> \
__device__ __forceinline__ T Atomic##NAME(T* addr, T val) { \ __device__ __forceinline__ T Atomic##NAME(T* addr, T val) { \
...@@ -72,51 +100,70 @@ template <> struct Cast<double> { ...@@ -72,51 +100,70 @@ template <> struct Cast<double> {
return Cast<T>::Decode(old); \ return Cast<T>::Decode(old); \
} }
#define DEFINE_ATOMIC_HALF(NAME) \
template <> \
__device__ __forceinline__ half Atomic##NAME<half>(half* addr, half val) { \
typedef unsigned short int CT; \
CT* addr_as_ui = reinterpret_cast<CT*>(addr); \
CT old = *addr_as_ui; \
CT assumed = old; \
do { \
assumed = old; \
old = atomicCASshort(addr_as_ui, assumed, \
Cast<half>::Encode(OP(val, Cast<half>::Decode(old)))); \
} while (assumed != old); \
return Cast<half>::Decode(old); \
}
#define OP(a, b) max(a, b) #define OP(a, b) max(a, b)
DEFINE_ATOMIC(Max) DEFINE_ATOMIC(Max)
#ifdef USE_FP16
DEFINE_ATOMIC_HALF(Max)
#endif // USE_FP16
#undef OP #undef OP
#define OP(a, b) min(a, b) #define OP(a, b) min(a, b)
DEFINE_ATOMIC(Min) DEFINE_ATOMIC(Min)
#ifdef USE_FP16
DEFINE_ATOMIC_HALF(Min)
#endif // USE_FP16
#undef OP #undef OP
#define OP(a, b) a + b #define OP(a, b) a + b
DEFINE_ATOMIC(Add) DEFINE_ATOMIC(Add)
#undef OP #undef OP
#if __CUDA_ARCH__ >= 200
template <> template <>
__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) { __device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
#if __CUDA_ARCH__ >= 200
return atomicAdd(addr, val); return atomicAdd(addr, val);
} #else
return *addr + val;
#endif // __CUDA_ARCH__ #endif // __CUDA_ARCH__
}
#if __CUDA_ARCH__ >= 600
template <> template <>
__device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) { __device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
#if __CUDA_ARCH__ >= 600
return atomicAdd(addr, val); return atomicAdd(addr, val);
} #else
return *addr + val;
#endif #endif
}
#ifdef USE_FP16
#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000 #if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
#if __CUDA_ARCH__ >= 600
template <> template <>
__device__ __forceinline__ __half2 AtomicAdd<__half2>(__half2* addr, __half2 val) { __device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
return atomicAdd(addr, val);
}
#endif // __CUDA_ARCH__
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
template <>
__device__ __forceinline__ __half AtomicAdd<__half>(__half* addr, __half val) {
return atomicAdd(addr, val); return atomicAdd(addr, val);
} #else
return *addr + val;
#endif // __CUDA_ARCH__ #endif // __CUDA_ARCH__
#endif }
#endif // defined(CUDART_VERSION) && CUDART_VERSION >= 10000
#endif // USE_FP16
#define OP(a, b) a * b
DEFINE_ATOMIC(Mul)
#undef OP
} // namespace cuda } // namespace cuda
} // namespace aten } // namespace aten
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/fp16.cuh
* \brief float16 related functions.
* \note this file is modified from TVM project:
* https://github.com/apache/tvm/blob/e561007f0c330e3d14c2bc8a3ef40fb741db9004/src/target/source/literal/cuda_half_t.h.
*/
#ifndef DGL_ARRAY_FP16_CUH_
#define DGL_ARRAY_FP16_CUH_
#ifdef USE_FP16
#include <cuda_fp16.h>
static __device__ __forceinline__ half max(half a, half b)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hgt(__half(a), __half(b)) ? a : b;
#else
return a;
#endif
}
static __device__ __forceinline__ half min(half a, half b)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hlt(__half(a), __half(b)) ? a : b;
#else
return a;
#endif
}
#endif // USE_FP16
#endif // DGL_ARRAY_FP16_CUH_
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#define DGL_ARRAY_CUDA_FUNCTOR_CUH_ #define DGL_ARRAY_CUDA_FUNCTOR_CUH_
#include "./atomic.cuh" #include "./atomic.cuh"
#include "./fp16.cuh"
#include <cmath>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -122,9 +124,11 @@ template <typename DType> constexpr bool Dot<DType>::reduce_last_dim; ...@@ -122,9 +124,11 @@ template <typename DType> constexpr bool Dot<DType>::reduce_last_dim;
namespace reduce { namespace reduce {
template <typename Idx, template <typename Idx,
typename DType, typename DType,
bool atomic=false> bool atomic>
struct Sum { struct _Sum {
static constexpr DType zero = 0; static constexpr __host__ __device__ __forceinline__ DType zero() {
return 0.;
};
static constexpr bool require_arg = false; static constexpr bool require_arg = false;
static __device__ __forceinline__ void Call( static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
...@@ -148,16 +152,28 @@ struct Sum { ...@@ -148,16 +152,28 @@ struct Sum {
Idx *arg_u_buf, Idx *arg_e_buf, Idx *arg_u_buf, Idx *arg_e_buf,
DType val, DType val_ref, Idx uid, Idx eid) {} DType val, DType val_ref, Idx uid, Idx eid) {}
}; };
template <typename Idx, typename DType, bool atomic>
constexpr DType Sum<Idx, DType, atomic>::zero;
template <typename Idx, typename DType, bool atomic>
constexpr bool Sum<Idx, DType, atomic>::require_arg;
template <typename Idx, template <typename Idx,
typename DType, typename DType,
bool atomic=false> bool atomic=false>
struct Max { struct Sum: _Sum<Idx, DType, atomic> { };
static constexpr DType zero = -std::numeric_limits<DType>::infinity();
#ifdef USE_FP16
template <typename Idx, bool atomic>
struct Sum<Idx, half, atomic>: _Sum<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() {
return __float2half_rn(0.);
};
};
#endif // USE_FP16
template <typename Idx,
typename DType,
bool atomic>
struct _Max {
static constexpr __host__ __device__ __forceinline__ DType zero() {
return -std::numeric_limits<DType>::infinity();
};
static constexpr bool require_arg = true; static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call( static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
...@@ -197,16 +213,29 @@ struct Max { ...@@ -197,16 +213,29 @@ struct Max {
} }
} }
}; };
template <typename Idx, typename DType, bool atomic>
constexpr DType Max<Idx, DType, atomic>::zero;
template <typename Idx, typename DType, bool atomic>
constexpr bool Max<Idx, DType, atomic>::require_arg;
template <typename Idx, template <typename Idx,
typename DType, typename DType,
bool atomic=false> bool atomic=false>
struct Min { struct Max : _Max<Idx, DType, atomic> { };
static constexpr DType zero = std::numeric_limits<DType>::infinity();
#ifdef USE_FP16
template <typename Idx,
bool atomic>
struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() {
return __float2half_rn(-6.550400e+04f);
};
};
#endif
template <typename Idx,
typename DType,
bool atomic>
struct _Min {
static constexpr __host__ __device__ __forceinline__ DType zero() {
return std::numeric_limits<DType>::infinity();
};
static constexpr bool require_arg = true; static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call( static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
...@@ -246,10 +275,21 @@ struct Min { ...@@ -246,10 +275,21 @@ struct Min {
} }
} }
}; };
template <typename Idx, typename DType, bool atomic>
constexpr DType Min<Idx, DType, atomic>::zero; template <typename Idx,
template <typename Idx, typename DType, bool atomic> typename DType,
constexpr bool Min<Idx, DType, atomic>::require_arg; bool atomic=false>
struct Min : _Min<Idx, DType, atomic> { };
#ifdef USE_FP16
template <typename Idx,
bool atomic>
struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() {
return __float2half_rn(6.550400e+04f);
};
};
#endif // USE_FP16
} // namespace reduce } // namespace reduce
......
...@@ -73,7 +73,7 @@ namespace aten { ...@@ -73,7 +73,7 @@ namespace aten {
/*! /*!
* \brief CUDA implementation of g-SDDMM on Csr format. * \brief CUDA implementation of g-SDDMM on Csr format.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, int bits>
void SDDMMCsr(const std::string& op, void SDDMMCsr(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr,
...@@ -82,9 +82,11 @@ void SDDMMCsr(const std::string& op, ...@@ -82,9 +82,11 @@ void SDDMMCsr(const std::string& op,
NDArray out, NDArray out,
int lhs_target, int lhs_target,
int rhs_target) { int rhs_target) {
SWITCH_OP(op, Op, { SWITCH_BITS(bits, DType, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_OP(op, Op, {
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out); SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
});
}); });
}); });
} }
...@@ -92,7 +94,7 @@ void SDDMMCsr(const std::string& op, ...@@ -92,7 +94,7 @@ void SDDMMCsr(const std::string& op,
/*! /*!
* \brief CUDA implementation of g-SDDMM on Coo format. * \brief CUDA implementation of g-SDDMM on Coo format.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, int bits>
void SDDMMCoo(const std::string& op, void SDDMMCoo(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const COOMatrix& coo, const COOMatrix& coo,
...@@ -101,43 +103,61 @@ void SDDMMCoo(const std::string& op, ...@@ -101,43 +103,61 @@ void SDDMMCoo(const std::string& op,
NDArray out, NDArray out,
int lhs_target, int lhs_target,
int rhs_target) { int rhs_target) {
SWITCH_OP(op, Op, { SWITCH_BITS(bits, DType, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_OP(op, Op, {
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out); SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
});
}); });
}); });
} }
template void SDDMMCsr<kDLGPU, int32_t, float>( template void SDDMMCsr<kDLGPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int64_t, float>( template void SDDMMCsr<kDLGPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int32_t, double>( template void SDDMMCsr<kDLGPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int64_t, double>( template void SDDMMCsr<kDLGPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int32_t, float>( template void SDDMMCoo<kDLGPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int64_t, float>( template void SDDMMCoo<kDLGPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int32_t, double>( template void SDDMMCoo<kDLGPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int64_t, double>( template void SDDMMCoo<kDLGPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "macro.cuh" #include "macro.cuh"
#include "atomic.cuh" #include "atomic.cuh"
#include "functor.cuh" #include "functor.cuh"
#include "fp16.cuh"
#include "./utils.h" #include "./utils.h"
#include "../selector.h" #include "../selector.h"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
...@@ -105,7 +106,7 @@ __global__ void SDDMMCooTreeReduceKernel( ...@@ -105,7 +106,7 @@ __global__ void SDDMMCooTreeReduceKernel(
for (int i = blockIdx.y; i < out_len; i += gridDim.y) { // over output feature dimension for (int i = blockIdx.y; i < out_len; i += gridDim.y) { // over output feature dimension
const Idx lhs_add = UseBcast ? __ldg(lhs_off + i) : i; const Idx lhs_add = UseBcast ? __ldg(lhs_off + i) : i;
const Idx rhs_add = UseBcast ? __ldg(rhs_off + i) : i; const Idx rhs_add = UseBcast ? __ldg(rhs_off + i) : i;
DType val = 0.; DType val = reduce::Sum<Idx, DType>::zero();;
for (int j = tx; j < reduce_size; j += 64) { for (int j = tx; j < reduce_size; j += 64) {
val += lhsoff[lhs_add * reduce_size + j] * rhsoff[rhs_add * reduce_size + j]; val += lhsoff[lhs_add * reduce_size + j] * rhsoff[rhs_add * reduce_size + j];
if (j + 32 < reduce_size) if (j + 32 < reduce_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment