Unverified Commit e70138bb authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Performance][GPU] Enable GPU uniform edge sampling (#2716)



* Start on uniform GPU sampling

* Save more work

* Get cu file compiling

* Update sampling

* More changes

* Get GPU sampling for uniform probabilities solved

* Fix batch tensor migration

* Fix

* update kernels

* expand blocking

* Undo testing change

* Cut down on sampling overhead

* Fix replacement

* Update unit tests

* Add option to gpu sample in graphsage

* Copy only csc to gpu

* Add ogbn support

* Fix linting

* Remove nvtx from sample

* Improve documentation and error checking

* Expand documentation

* Update assert checking

* delete extra space

* Use standard dataloader when dataset is a dictionary

* ogb -> ogbn

* Fix edge selection determinism

* Fix typos

* Remove nvtx

* Add comment for self.fanout_arrays and assert

* Fix linting

* Migrate to scalarbatcher

* Fix indentation

* Fix batcher

* Fix indexing

* Only use databatcher for GPU

* Convert to DGL NDArray to PyTorch Tensor

* Add optimization for PyTorch's F.tensor() for list of GPU tensors
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent 195f9936
...@@ -9,7 +9,7 @@ import time ...@@ -9,7 +9,7 @@ import time
import argparse import argparse
import tqdm import tqdm
from load_graph import load_reddit, inductive_split from load_graph import load_reddit, inductive_split, load_ogb
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, def __init__(self,
...@@ -122,6 +122,14 @@ def run(args, device, data): ...@@ -122,6 +122,14 @@ def run(args, device, data):
val_nid = th.nonzero(val_g.ndata['val_mask'], as_tuple=True)[0] val_nid = th.nonzero(val_g.ndata['val_mask'], as_tuple=True)[0]
test_nid = th.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0] test_nid = th.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0]
dataloader_device = th.device('cpu')
if args.sample_gpu:
train_nid = train_nid.to(device)
# copy only the csc to the GPU
train_g = train_g.formats(['csc'])
train_g = train_g.to(device)
dataloader_device = device
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler( sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')]) [int(fanout) for fanout in args.fan_out.split(',')])
...@@ -129,6 +137,7 @@ def run(args, device, data): ...@@ -129,6 +137,7 @@ def run(args, device, data):
train_g, train_g,
train_nid, train_nid,
sampler, sampler,
device=dataloader_device,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
...@@ -198,6 +207,8 @@ if __name__ == '__main__': ...@@ -198,6 +207,8 @@ if __name__ == '__main__':
argparser.add_argument('--dropout', type=float, default=0.5) argparser.add_argument('--dropout', type=float, default=0.5)
argparser.add_argument('--num-workers', type=int, default=4, argparser.add_argument('--num-workers', type=int, default=4,
help="Number of sampling processes. Use 0 for no extra process.") help="Number of sampling processes. Use 0 for no extra process.")
argparser.add_argument('--sample-gpu', action='store_true',
help="Perform the sampling process on the GPU. Must have 0 workers.")
argparser.add_argument('--inductive', action='store_true', argparser.add_argument('--inductive', action='store_true',
help="Inductive learning setting") help="Inductive learning setting")
argparser.add_argument('--data-cpu', action='store_true', argparser.add_argument('--data-cpu', action='store_true',
...@@ -214,6 +225,8 @@ if __name__ == '__main__': ...@@ -214,6 +225,8 @@ if __name__ == '__main__':
if args.dataset == 'reddit': if args.dataset == 'reddit':
g, n_classes = load_reddit() g, n_classes = load_reddit()
elif args.dataset == 'ogbn-products':
g, n_classes = load_ogb('ogbn-products')
else: else:
raise Exception('unknown dataset') raise Exception('unknown dataset')
...@@ -234,11 +247,6 @@ if __name__ == '__main__': ...@@ -234,11 +247,6 @@ if __name__ == '__main__':
train_nfeat = train_nfeat.to(device) train_nfeat = train_nfeat.to(device)
train_labels = train_labels.to(device) train_labels = train_labels.to(device)
# Create csr/coo/csc formats before launching training processes with multi-gpu.
# This avoids creating certain formats in each sub-process, which saves momory and CPU.
train_g.create_formats_()
val_g.create_formats_()
test_g.create_formats_()
# Pack data # Pack data
data = n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \ data = n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \
val_nfeat, val_labels, test_nfeat, test_labels val_nfeat, val_labels, test_nfeat, test_labels
......
...@@ -34,6 +34,11 @@ def cpu(): ...@@ -34,6 +34,11 @@ def cpu():
def tensor(data, dtype=None): def tensor(data, dtype=None):
if isinstance(data, numbers.Number): if isinstance(data, numbers.Number):
data = [data] data = [data]
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], th.Tensor):
# prevent GPU->CPU->GPU copies
if data[0].ndim == 0:
# zero dimenion scalar tensors
return th.stack(data)
if isinstance(data, th.Tensor): if isinstance(data, th.Tensor):
return th.as_tensor(data, dtype=dtype, device=data.device) return th.as_tensor(data, dtype=dtype, device=data.device)
else: else:
......
"""Data loading components for neighbor sampling""" """Data loading components for neighbor sampling"""
from .dataloader import BlockSampler from .dataloader import BlockSampler
from .. import sampling, subgraph, distributed from .. import sampling, subgraph, distributed
from .. import ndarray as nd
from .. import backend as F
class MultiLayerNeighborSampler(BlockSampler): class MultiLayerNeighborSampler(BlockSampler):
"""Sampler that builds computational dependency of node representations via """Sampler that builds computational dependency of node representations via
...@@ -63,6 +65,11 @@ class MultiLayerNeighborSampler(BlockSampler): ...@@ -63,6 +65,11 @@ class MultiLayerNeighborSampler(BlockSampler):
self.fanouts = fanouts self.fanouts = fanouts
self.replace = replace self.replace = replace
# used to cache computations and memory allocations
# list[dgl.nd.NDArray]; each array stores the fan-outs of all edge types
self.fanout_arrays = []
self.prob_arrays = None
def sample_frontier(self, block_id, g, seed_nodes): def sample_frontier(self, block_id, g, seed_nodes):
fanout = self.fanouts[block_id] fanout = self.fanouts[block_id]
if isinstance(g, distributed.DistGraph): if isinstance(g, distributed.DistGraph):
...@@ -76,9 +83,39 @@ class MultiLayerNeighborSampler(BlockSampler): ...@@ -76,9 +83,39 @@ class MultiLayerNeighborSampler(BlockSampler):
if fanout is None: if fanout is None:
frontier = subgraph.in_subgraph(g, seed_nodes) frontier = subgraph.in_subgraph(g, seed_nodes)
else: else:
frontier = sampling.sample_neighbors(g, seed_nodes, fanout, replace=self.replace) self._build_fanout(block_id, g)
self._build_prob_arrays(g)
frontier = sampling.sample_neighbors(
g, seed_nodes, self.fanout_arrays[block_id],
replace=self.replace, prob=self.prob_arrays)
return frontier return frontier
def _build_prob_arrays(self, g):
# build prob_arrays only once
if self.prob_arrays is None:
self.prob_arrays = [nd.array([], ctx=nd.cpu())] * len(g.etypes)
def _build_fanout(self, block_id, g):
assert not self.fanouts is None, \
"_build_fanout() should only be called when fanouts is not None"
# build fanout_arrays only once for each layer
while block_id >= len(self.fanout_arrays):
for i in range(len(self.fanouts)):
fanout = self.fanouts[i]
if not isinstance(fanout, dict):
fanout_array = [int(fanout)] * len(g.etypes)
else:
if len(fanout) != len(g.etypes):
raise DGLError('Fan-out must be specified for each edge type '
'if a dict is provided.')
fanout_array = [None] * len(g.etypes)
for etype, value in fanout.items():
fanout_array[g.get_etype_id(etype)] = value
self.fanout_arrays.append(
F.to_dgl_nd(F.tensor(fanout_array, dtype=F.int64)))
class MultiLayerFullNeighborSampler(MultiLayerNeighborSampler): class MultiLayerFullNeighborSampler(MultiLayerNeighborSampler):
"""Sampler that builds computational dependency of node representations by taking messages """Sampler that builds computational dependency of node representations by taking messages
from all neighbors for multilayer GNN. from all neighbors for multilayer GNN.
......
"""DGL PyTorch DataLoaders""" """DGL PyTorch DataLoaders"""
import inspect import inspect
import torch as th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from ..dataloader import NodeCollator, EdgeCollator, GraphCollator from ..dataloader import NodeCollator, EdgeCollator, GraphCollator
from ...distributed import DistGraph from ...distributed import DistGraph
from ...distributed import DistDataLoader from ...distributed import DistDataLoader
from ...ndarray import NDArray as DGLNDArray
from ... import backend as F
class _ScalarDataBatcherIter:
def __init__(self, dataset, batch_size, drop_last):
self.dataset = dataset
self.batch_size = batch_size
self.index = 0
self.drop_last = drop_last
def __next__(self):
num_items = self.dataset.shape[0]
if self.index >= num_items:
raise StopIteration
end_idx = self.index + self.batch_size
if end_idx > num_items:
if self.drop_last:
raise StopIteration
end_idx = num_items
batch = self.dataset[self.index:end_idx]
self.index += self.batch_size
return batch
class _ScalarDataBatcher(th.utils.data.IterableDataset):
"""Custom Dataset wrapper to return mini-batches as tensors, rather than as
lists. When the dataset is on the GPU, this significantly reduces
the overhead. For the case of a batch size of 1024, instead of giving a
list of 1024 tensors to the collator, a single tensor of 1024 dimensions
is passed in.
"""
def __init__(self, dataset, shuffle=False, batch_size=1,
drop_last=False):
super(_ScalarDataBatcher).__init__()
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
def __iter__(self):
worker_info = th.utils.data.get_worker_info()
dataset = self.dataset
if worker_info:
# worker gets only a fraction of the dataset
chunk_size = dataset.shape[0] // worker_info.num_workers
left_over = dataset.shape[0] % worker_info.num_workers
start = (chunk_size*worker_info.id) + min(left_over, worker_info.id)
end = start + chunk_size + (worker_info.id < left_over)
assert worker_info.id < worker_info.num_workers-1 or \
end == dataset.shape[0]
dataset = dataset[start:end]
if self.shuffle:
# permute the dataset
perm = th.randperm(dataset.shape[0], device=dataset.device)
dataset = dataset[perm]
return _ScalarDataBatcherIter(dataset, self.batch_size, self.drop_last)
def _remove_kwargs_dist(kwargs): def _remove_kwargs_dist(kwargs):
if 'num_workers' in kwargs: if 'num_workers' in kwargs:
...@@ -242,7 +301,38 @@ class NodeDataLoader: ...@@ -242,7 +301,38 @@ class NodeDataLoader:
self.is_distributed = True self.is_distributed = True
else: else:
self.collator = _NodeCollator(g, nids, block_sampler, **collator_kwargs) self.collator = _NodeCollator(g, nids, block_sampler, **collator_kwargs)
self.dataloader = DataLoader(self.collator.dataset, dataset = self.collator.dataset
if th.device(device) != th.device('cpu'):
# Only use the '_ScalarDataBatcher' when for the GPU, as it
# doens't seem to have a performance benefit on the CPU.
assert 'num_workers' not in dataloader_kwargs or \
dataloader_kwargs['num_workers'] == 0, \
'When performing dataloading from the GPU, num_workers ' \
'must be zero.'
batch_size = dataloader_kwargs.get('batch_size', 0)
if batch_size > 1:
if isinstance(dataset, DGLNDArray):
# the dataset needs to be a torch tensor for the
# _ScalarDataBatcher
dataset = F.zerocopy_from_dgl_ndarray(dataset)
if isinstance(dataset, th.Tensor):
shuffle = dataloader_kwargs.get('shuffle', False)
drop_last = dataloader_kwargs.get('drop_last', False)
# manually batch into tensors
dataset = _ScalarDataBatcher(dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
# need to overwrite things that will be handled by the batcher
dataloader_kwargs['batch_size'] = None
dataloader_kwargs['shuffle'] = False
dataloader_kwargs['drop_last'] = False
self.dataloader = DataLoader(
dataset,
collate_fn=self.collator.collate, collate_fn=self.collator.collate,
**dataloader_kwargs) **dataloader_kwargs)
self.is_distributed = False self.is_distributed = False
......
...@@ -119,7 +119,6 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, ...@@ -119,7 +119,6 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
if len(g.ntypes) > 1: if len(g.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.") raise DGLError("Must specify node type when the graph is not homogeneous.")
nodes = {g.ntypes[0] : nodes} nodes = {g.ntypes[0] : nodes}
assert g.device == F.cpu(), "Graph must be on CPU."
nodes = utils.prepare_tensor_dict(g, nodes, 'nodes') nodes = utils.prepare_tensor_dict(g, nodes, 'nodes')
nodes_all_types = [] nodes_all_types = []
...@@ -129,6 +128,9 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, ...@@ -129,6 +128,9 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
else: else:
nodes_all_types.append(nd.array([], ctx=nd.cpu())) nodes_all_types.append(nd.array([], ctx=nd.cpu()))
if isinstance(fanout, nd.NDArray):
fanout_array = fanout
else:
if not isinstance(fanout, dict): if not isinstance(fanout, dict):
fanout_array = [int(fanout)] * len(g.etypes) fanout_array = [int(fanout)] * len(g.etypes)
else: else:
...@@ -140,7 +142,10 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, ...@@ -140,7 +142,10 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
fanout_array[g.get_etype_id(etype)] = value fanout_array[g.get_etype_id(etype)] = value
fanout_array = F.to_dgl_nd(F.tensor(fanout_array, dtype=F.int64)) fanout_array = F.to_dgl_nd(F.tensor(fanout_array, dtype=F.int64))
if prob is None: if isinstance(prob, list) and len(prob) > 0 and \
isinstance(prob[0], nd.NDArray):
prob_arrays = prob
elif prob is None:
prob_arrays = [nd.array([], ctx=nd.cpu())] * len(g.etypes) prob_arrays = [nd.array([], ctx=nd.cpu())] * len(g.etypes)
else: else:
prob_arrays = [] prob_arrays = []
......
...@@ -507,16 +507,18 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) { ...@@ -507,16 +507,18 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
COOMatrix CSRRowWiseSampling( COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) { CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) {
COOMatrix ret; COOMatrix ret;
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseSampling", {
if (IsNullArray(prob)) { if (IsNullArray(prob)) {
ATEN_CSR_SWITCH_CUDA(mat, XPU, IdType, "CSRRowWiseSampling", {
ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace); ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace);
});
} else { } else {
ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseSampling", {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", { ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>( ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(
mat, rows, num_samples, prob, replace); mat, rows, num_samples, prob, replace);
}); });
}
}); });
}
return ret; return ret;
} }
......
/*!
* Copyright (c) 2021 by Contributors
* \file array/cuda/rowwise_sampling.cu
* \brief rowwise sampling
*/
#include <dgl/random.h>
#include <dgl/runtime/device_api.h>
#include <curand_kernel.h>
#include <cub/cub.cuh>
#include <numeric>
#include "../../kernel/cuda/atomic.cuh"
#include "../../runtime/cuda/cuda_common.h"
using namespace dgl::kernel::cuda;
namespace dgl {
namespace aten {
namespace impl {
namespace {
constexpr int WARP_SIZE = 32;
/**
* @brief Compute the size of each row in the sampled CSR, without replacement.
*
* @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by
* `in_rows` (output).
*/
template<typename IdType>
__global__ void _CSRRowWiseSampleDegreeKernel(
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
IdType * const out_deg) {
const int tIdx = threadIdx.x + blockIdx.x*blockDim.x;
if (tIdx < num_rows) {
const int in_row = in_rows[tIdx];
const int out_row = tIdx;
out_deg[out_row] = min(static_cast<IdType>(num_picks), in_ptr[in_row+1]-in_ptr[in_row]);
if (out_row == num_rows-1) {
// make the prefixsum work
out_deg[num_rows] = 0;
}
}
}
/**
* @brief Compute the size of each row in the sampled CSR, with replacement.
*
* @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by
* `in_rows` (output).
*/
template<typename IdType>
__global__ void _CSRRowWiseSampleDegreeReplaceKernel(
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
IdType * const out_deg) {
const int tIdx = threadIdx.x + blockIdx.x*blockDim.x;
if (tIdx < num_rows) {
const int64_t in_row = in_rows[tIdx];
const int64_t out_row = tIdx;
if (in_ptr[in_row+1]-in_ptr[in_row] == 0) {
out_deg[out_row] = 0;
} else {
out_deg[out_row] = static_cast<IdType>(num_picks);
}
if (out_row == num_rows-1) {
// make the prefixsum work
out_deg[num_rows] = 0;
}
}
}
/**
* @brief Perform row-wise sampling on a CSR matrix, and generate a COO matrix,
* without replacement.
*
* @tparam IdType The ID type used for matrices.
* @tparam BLOCK_ROWS The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param in_index The indices array of the input CSR.
* @param data The data array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO.
* @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
*/
template<typename IdType, int BLOCK_ROWS>
__global__ void _CSRRowWiseSampleKernel(
const uint64_t rand_seed,
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
const IdType * const in_index,
const IdType * const data,
const IdType * const out_ptr,
IdType * const out_rows,
IdType * const out_cols,
IdType * const out_idxs) {
// we assign one warp per row
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_ROWS);
// we need one state per 256 threads
constexpr int NUM_RNG = ((WARP_SIZE*BLOCK_ROWS)+255)/256;
__shared__ curandState rng_array[NUM_RNG];
assert(blockDim.x >= NUM_RNG);
if (threadIdx.y == 0 && threadIdx.x < NUM_RNG) {
curand_init(rand_seed, 0, threadIdx.x, rng_array+threadIdx.x);
}
__syncthreads();
curandState * const rng = rng_array+((threadIdx.x+WARP_SIZE*threadIdx.y)/256);
int64_t out_row = blockIdx.x*BLOCK_ROWS+threadIdx.y;
while (out_row < num_rows) {
const int64_t row = in_rows[out_row];
const int64_t in_row_start = in_ptr[row];
const int64_t deg = in_ptr[row+1] - in_row_start;
const int64_t out_row_start = out_ptr[out_row];
if (deg <= num_picks) {
// just copy row
for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) {
const IdType in_idx = in_row_start+idx;
out_rows[out_row_start+idx] = row;
out_cols[out_row_start+idx] = in_index[in_idx];
out_idxs[out_row_start+idx] = data ? data[in_idx] : in_idx;
}
} else {
// generate permutation list via reservoir algorithm
for (int idx = threadIdx.x; idx < num_picks; idx+=WARP_SIZE) {
out_idxs[out_row_start+idx] = idx;
}
__syncwarp();
for (int idx = num_picks+threadIdx.x; idx < deg; idx+=WARP_SIZE) {
const int num = curand(rng)%(idx+1);
if (num < num_picks) {
// use max so as to achieve the replacement order the serial
// algorithm would have
AtomicMax(out_idxs+out_row_start+num, idx);
}
}
__syncwarp();
// copy permutation over
for (int idx = threadIdx.x; idx < num_picks; idx += WARP_SIZE) {
const IdType perm_idx = out_idxs[out_row_start+idx]+in_row_start;
out_rows[out_row_start+idx] = row;
out_cols[out_row_start+idx] = in_index[perm_idx];
if (data) {
out_idxs[out_row_start+idx] = data[perm_idx];
}
}
}
out_row += gridDim.x*BLOCK_ROWS;
}
}
/**
* @brief Perform row-wise sampling on a CSR matrix, and generate a COO matrix,
* with replacement.
*
* @tparam IdType The ID type used for matrices.
* @tparam BLOCK_ROWS The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param in_index The indices array of the input CSR.
* @param data The data array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO.
* @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
*/
template<typename IdType, int BLOCK_ROWS>
__global__ void _CSRRowWiseSampleReplaceKernel(
const uint64_t rand_seed,
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
const IdType * const in_index,
const IdType * const data,
const IdType * const out_ptr,
IdType * const out_rows,
IdType * const out_cols,
IdType * const out_idxs) {
// we assign one warp per row
assert(blockDim.x == WARP_SIZE);
// we need one state per 256 threads
constexpr int NUM_RNG = ((WARP_SIZE*BLOCK_ROWS)+255)/256;
__shared__ curandState rng_array[NUM_RNG];
assert(blockDim.x >= NUM_RNG);
if (threadIdx.y == 0 && threadIdx.x < NUM_RNG) {
curand_init(rand_seed, 0, threadIdx.x, rng_array+threadIdx.x);
}
__syncthreads();
curandState * const rng = rng_array+((threadIdx.x+WARP_SIZE*threadIdx.y)/256);
int64_t out_row = blockIdx.x*blockDim.y+threadIdx.y;
while (out_row < num_rows) {
const int64_t row = in_rows[out_row];
const int64_t in_row_start = in_ptr[row];
const int64_t out_row_start = out_ptr[out_row];
const int64_t deg = in_ptr[row+1] - in_row_start;
// each thread then blindly copies in rows
for (int idx = threadIdx.x; idx < num_picks; idx += blockDim.x) {
const int64_t edge = curand(rng) % deg;
const int64_t out_idx = out_row_start+idx;
out_rows[out_idx] = row;
out_cols[out_idx] = in_index[in_row_start+edge];
out_idxs[out_idx] = data ? data[in_row_start+edge] : in_row_start+edge;
}
out_row += gridDim.x*blockDim.y;
}
}
} // namespace
/////////////////////////////// CSR ///////////////////////////////
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
IdArray rows,
const int64_t num_picks,
const bool replace) {
const auto& ctx = mat.indptr->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
// TODO(dlasalle): Once the device api supports getting the stream from the
// context, that should be used instead of the default stream here.
cudaStream_t stream = 0;
const int64_t num_rows = rows->shape[0];
const IdType * const slice_rows = static_cast<const IdType*>(rows->data);
IdArray picked_row = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_col = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_idx = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
const IdType * const in_ptr = static_cast<const IdType*>(mat.indptr->data);
const IdType * const in_cols = static_cast<const IdType*>(mat.indices->data);
IdType* const out_rows = static_cast<IdType*>(picked_row->data);
IdType* const out_cols = static_cast<IdType*>(picked_col->data);
IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);
const IdType* const data = CSRHasData(mat) ?
static_cast<IdType*>(mat.data->data) : nullptr;
// compute degree
IdType * out_deg = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows+1)*sizeof(IdType)));
if (replace) {
const dim3 block(512);
const dim3 grid((num_rows+block.x-1)/block.x);
_CSRRowWiseSampleDegreeReplaceKernel<<<grid, block, 0, stream>>>(
num_picks, num_rows, slice_rows, in_ptr, out_deg);
} else {
const dim3 block(512);
const dim3 grid((num_rows+block.x-1)/block.x);
_CSRRowWiseSampleDegreeKernel<<<grid, block, 0, stream>>>(
num_picks, num_rows, slice_rows, in_ptr, out_deg);
}
// fill out_ptr
IdType * out_ptr = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows+1)*sizeof(IdType)));
size_t prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_temp_size,
out_deg,
out_ptr,
num_rows+1,
stream));
void * prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_temp, prefix_temp_size,
out_deg,
out_ptr,
num_rows+1,
stream));
device->FreeWorkspace(ctx, prefix_temp);
device->FreeWorkspace(ctx, out_deg);
cudaEvent_t copyEvent;
CUDA_CALL(cudaEventCreate(&copyEvent));
// TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on
// a cudaevent
IdType new_len;
device->CopyDataFromTo(out_ptr, num_rows*sizeof(new_len), &new_len, 0,
sizeof(new_len),
ctx,
DGLContext{kDLCPU, 0},
mat.indptr->dtype,
stream);
CUDA_CALL(cudaEventRecord(copyEvent, stream));
const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
// select edges
if (replace) {
constexpr int BLOCK_ROWS = 128/WARP_SIZE;
const dim3 block(WARP_SIZE, BLOCK_ROWS);
const dim3 grid((num_rows+block.y-1)/block.y);
_CSRRowWiseSampleReplaceKernel<IdType, BLOCK_ROWS><<<grid, block, 0, stream>>>(
random_seed,
num_picks,
num_rows,
slice_rows,
in_ptr,
in_cols,
data,
out_ptr,
out_rows,
out_cols,
out_idxs);
} else {
constexpr int BLOCK_ROWS = 128/WARP_SIZE;
const dim3 block(WARP_SIZE, BLOCK_ROWS);
const dim3 grid((num_rows+block.y-1)/block.y);
_CSRRowWiseSampleKernel<IdType, BLOCK_ROWS><<<grid, block, 0, stream>>>(
random_seed,
num_picks,
num_rows,
slice_rows,
in_ptr,
in_cols,
data,
out_ptr,
out_rows,
out_cols,
out_idxs);
}
device->FreeWorkspace(ctx, out_ptr);
// wait for copying `new_len` to finish
CUDA_CALL(cudaEventSynchronize(copyEvent));
CUDA_CALL(cudaEventDestroy(copyEvent));
picked_row = picked_row.CreateView({new_len}, picked_row->dtype);
picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);
return COOMatrix(mat.num_rows, mat.num_cols, picked_row,
picked_col, picked_idx);
}
template COOMatrix CSRRowWiseSamplingUniform<kDLGPU, int32_t>(
CSRMatrix, IdArray, int64_t, bool);
template COOMatrix CSRRowWiseSamplingUniform<kDLGPU, int64_t>(
CSRMatrix, IdArray, int64_t, bool);
} // namespace impl
} // namespace aten
} // namespace dgl
...@@ -170,6 +170,30 @@ inline __device__ int32_t AtomicCAS( ...@@ -170,6 +170,30 @@ inline __device__ int32_t AtomicCAS(
static_cast<Type>(val)); static_cast<Type>(val));
} }
inline __device__ int64_t AtomicMax(
int64_t * const address,
const int64_t val) {
// match the type of "::atomicCAS", so ignore lint warning
using Type = unsigned long long int; // NOLINT
static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
return atomicMax(reinterpret_cast<Type*>(address),
static_cast<Type>(val));
}
inline __device__ int32_t AtomicMax(
int32_t * const address,
const int32_t val) {
// match the type of "::atomicCAS", so ignore lint warning
using Type = int; // NOLINT
static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
return atomicMax(reinterpret_cast<Type*>(address),
static_cast<Type>(val));
}
} // namespace cuda } // namespace cuda
} // namespace kernel } // namespace kernel
} // namespace dgl } // namespace dgl
......
...@@ -157,6 +157,7 @@ def _gen_neighbor_sampling_test_graph(hypersparse, reverse): ...@@ -157,6 +157,7 @@ def _gen_neighbor_sampling_test_graph(hypersparse, reverse):
g = dgl.heterograph({ g = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]) ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0])
}, {'user': card if card is not None else 4}) }, {'user': card if card is not None else 4})
g = g.to(F.ctx())
g.edata['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32) g.edata['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32)
hg = dgl.heterograph({ hg = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2],
...@@ -165,10 +166,12 @@ def _gen_neighbor_sampling_test_graph(hypersparse, reverse): ...@@ -165,10 +166,12 @@ def _gen_neighbor_sampling_test_graph(hypersparse, reverse):
('user', 'liked-by', 'game'): ([0, 1, 2, 0, 3, 0], [2, 2, 2, 1, 1, 0]), ('user', 'liked-by', 'game'): ([0, 1, 2, 0, 3, 0], [2, 2, 2, 1, 1, 0]),
('coin', 'flips', 'user'): ([0, 0, 0, 0], [0, 1, 2, 3]) ('coin', 'flips', 'user'): ([0, 0, 0, 0], [0, 1, 2, 3])
}, num_nodes_dict) }, num_nodes_dict)
hg = hg.to(F.ctx())
else: else:
g = dgl.heterograph({ g = dgl.heterograph({
('user', 'follow', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]) ('user', 'follow', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2])
}, {'user': card if card is not None else 4}) }, {'user': card if card is not None else 4})
g = g.to(F.ctx())
g.edata['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32) g.edata['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32)
hg = dgl.heterograph({ hg = dgl.heterograph({
('user', 'follow', 'user'): ([1, 2, 3, 0, 2, 3, 0], ('user', 'follow', 'user'): ([1, 2, 3, 0, 2, 3, 0],
...@@ -177,6 +180,7 @@ def _gen_neighbor_sampling_test_graph(hypersparse, reverse): ...@@ -177,6 +180,7 @@ def _gen_neighbor_sampling_test_graph(hypersparse, reverse):
('game', 'liked-by', 'user'): ([2, 2, 2, 1, 1, 0], [0, 1, 2, 0, 3, 0]), ('game', 'liked-by', 'user'): ([2, 2, 2, 1, 1, 0], [0, 1, 2, 0, 3, 0]),
('user', 'flips', 'coin'): ([0, 1, 2, 3], [0, 0, 0, 0]) ('user', 'flips', 'coin'): ([0, 1, 2, 3], [0, 0, 0, 0])
}, num_nodes_dict) }, num_nodes_dict)
hg = hg.to(F.ctx())
hg.edges['follow'].data['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32) hg.edges['follow'].data['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32)
hg.edges['play'].data['prob'] = F.tensor([.8, .5, .5, .5], dtype=F.float32) hg.edges['play'].data['prob'] = F.tensor([.8, .5, .5, .5], dtype=F.float32)
hg.edges['liked-by'].data['prob'] = F.tensor([.3, .5, .2, .5, .1, .1], dtype=F.float32) hg.edges['liked-by'].data['prob'] = F.tensor([.3, .5, .2, .5, .1, .1], dtype=F.float32)
...@@ -220,7 +224,7 @@ def _gen_neighbor_topk_test_graph(hypersparse, reverse): ...@@ -220,7 +224,7 @@ def _gen_neighbor_topk_test_graph(hypersparse, reverse):
hg.edges['flips'].data['weight'] = F.tensor([10, 2, 13, -1], dtype=F.float32) hg.edges['flips'].data['weight'] = F.tensor([10, 2, 13, -1], dtype=F.float32)
return g, hg return g, hg
def _test_sample_neighbors(hypersparse): def _test_sample_neighbors(hypersparse, prob):
g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False) g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False)
def _test1(p, replace): def _test1(p, replace):
...@@ -247,10 +251,8 @@ def _test_sample_neighbors(hypersparse): ...@@ -247,10 +251,8 @@ def _test_sample_neighbors(hypersparse):
if p is not None: if p is not None:
assert not (3, 0) in edge_set assert not (3, 0) in edge_set
assert not (3, 1) in edge_set assert not (3, 1) in edge_set
_test1(None, True) # w/ replacement, uniform _test1(prob, True) # w/ replacement, uniform
_test1(None, False) # w/o replacement, uniform _test1(prob, False) # w/o replacement, uniform
_test1('prob', True) # w/ replacement
_test1('prob', False) # w/o replacement
def _test2(p, replace): # fanout > #neighbors def _test2(p, replace): # fanout > #neighbors
subg = dgl.sampling.sample_neighbors(g, [0, 2], -1, prob=p, replace=replace) subg = dgl.sampling.sample_neighbors(g, [0, 2], -1, prob=p, replace=replace)
...@@ -276,10 +278,8 @@ def _test_sample_neighbors(hypersparse): ...@@ -276,10 +278,8 @@ def _test_sample_neighbors(hypersparse):
assert len(edge_set) == num_edges assert len(edge_set) == num_edges
if p is not None: if p is not None:
assert not (3, 0) in edge_set assert not (3, 0) in edge_set
_test2(None, True) # w/ replacement, uniform _test2(prob, True) # w/ replacement, uniform
_test2(None, False) # w/o replacement, uniform _test2(prob, False) # w/o replacement, uniform
_test2('prob', True) # w/ replacement
_test2('prob', False) # w/o replacement
def _test3(p, replace): def _test3(p, replace):
subg = dgl.sampling.sample_neighbors(hg, {'user': [0, 1], 'game': 0}, -1, prob=p, replace=replace) subg = dgl.sampling.sample_neighbors(hg, {'user': [0, 1], 'game': 0}, -1, prob=p, replace=replace)
...@@ -299,10 +299,8 @@ def _test_sample_neighbors(hypersparse): ...@@ -299,10 +299,8 @@ def _test_sample_neighbors(hypersparse):
assert subg['liked-by'].number_of_edges() == 4 if replace else 3 assert subg['liked-by'].number_of_edges() == 4 if replace else 3
assert subg['flips'].number_of_edges() == 0 assert subg['flips'].number_of_edges() == 0
_test3(None, True) # w/ replacement, uniform _test3(prob, True) # w/ replacement, uniform
_test3(None, False) # w/o replacement, uniform _test3(prob, False) # w/o replacement, uniform
_test3('prob', True) # w/ replacement
_test3('prob', False) # w/o replacement
# test different fanouts for different relations # test different fanouts for different relations
for i in range(10): for i in range(10):
...@@ -528,9 +526,13 @@ def _test_sample_neighbors_topk_outedge(hypersparse): ...@@ -528,9 +526,13 @@ def _test_sample_neighbors_topk_outedge(hypersparse):
assert subg['flips'].number_of_edges() == 0 assert subg['flips'].number_of_edges() == 0
_test3() _test3()
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented") def test_sample_neighbors_noprob():
def test_sample_neighbors(): _test_sample_neighbors(False, None)
_test_sample_neighbors(False) #_test_sample_neighbors(True)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors with probability is not implemented")
def test_sample_neighbors_prob():
_test_sample_neighbors(False, 'prob')
#_test_sample_neighbors(True) #_test_sample_neighbors(True)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment