"vscode:/vscode.git/clone" did not exist on "c8f41cfc4df4cf89198b09787c87261473e55f99"
Unverified Commit 44b68641 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Enable UVA for Weighted Samplers (#4314)

* enable use for weighted neighbor sampler and biased random walk

* add unit tests

* fix for mxnet/tf

* fix typo
parent 9a16a5e0
......@@ -30,7 +30,7 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
If a random walk stops in advance, DGL pads the trace with -1 to have the same
length.
This function supports the graph on GPU.
This function supports the graph on GPU and UVA sampling.
Parameters
----------
......@@ -39,8 +39,9 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
nodes : Tensor
Node ID tensor from which the random walk traces starts.
The tensor must be on the same device as the graph and have the same dtype as the ID type
of the graph.
The tensor must have the same dtype as the ID type of the graph.
The tensor must be on the same device as the graph or
on the GPU when the graph is pinned (UVA sampling).
metapath : list[str or tuple of str], optional
Metapath, specified as a list of edge types.
......@@ -69,6 +70,7 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
Probability to terminate the current trace before each transition.
If a tensor is given, :attr:`restart_prob` should be on the same device as the graph
or on the GPU when the graph is pinned (UVA sampling),
and have the same length as :attr:`metapath` or :attr:`length`.
return_eids : bool, optional
If True, additionally return the edge IDs traversed.
......@@ -180,19 +182,16 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
metapath = F.to_dgl_nd(F.astype(F.tensor(metapath), g.idtype))
# Load the probability tensor from the edge frames
ctx = utils.to_dgl_context(g.device)
if prob is None:
p_nd = [nd.array([], ctx=nodes.ctx) for _ in g.canonical_etypes]
p_nd = [nd.array([], ctx=ctx) for _ in g.canonical_etypes]
else:
p_nd = []
for etype in g.canonical_etypes:
if prob in g.edges[etype].data:
prob_nd = F.to_dgl_nd(g.edges[etype].data[prob])
if prob_nd.ctx != nodes.ctx:
raise ValueError(
'context of seed node array and edges[%s].data[%s] are different' %
(etype, prob))
else:
prob_nd = nd.array([], ctx=nodes.ctx)
prob_nd = nd.array([], ctx=ctx)
p_nd.append(prob_nd)
# Actual random walk
......@@ -202,9 +201,11 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
restart_prob = F.to_dgl_nd(restart_prob)
traces, eids, types = _CAPI_DGLSamplingRandomWalkWithStepwiseRestart(
gidx, nodes, metapath, p_nd, restart_prob)
else:
elif isinstance(restart_prob, float):
traces, eids, types = _CAPI_DGLSamplingRandomWalkWithRestart(
gidx, nodes, metapath, p_nd, restart_prob)
else:
raise TypeError("restart_prob should be float or Tensor.")
traces = F.from_dgl_nd(traces)
types = F.from_dgl_nd(types)
......
......@@ -553,7 +553,8 @@ COOMatrix CSRRowWiseSampling(
ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace);
});
} else {
CHECK_SAME_CONTEXT(rows, prob);
// prob is pinned and rows on GPU is valid
CHECK_VALID_CONTEXT(prob, rows);
ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSampling", {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(
......
......@@ -400,7 +400,6 @@ std::pair<IdArray, IdArray> RandomWalk(
if (!isUniform) {
std::pair<IdArray, IdArray> ret;
ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
CHECK(prob[0]->ctx.device_type == kDLGPU) << "prob should be in GPU.";
ret = RandomWalkBiased<XPU, FloatType, IdType>(hg, seeds, metapath, prob, restart_prob);
});
return ret;
......@@ -442,7 +441,6 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart(
if (!isUniform) {
std::pair<IdArray, IdArray> ret;
ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
CHECK(prob[0]->ctx.device_type == kDLGPU) << "prob should be in GPU.";
ret = RandomWalkBiased<XPU, FloatType, IdType>(
hg, seeds, metapath, prob, restart_prob_array);
});
......@@ -471,7 +469,6 @@ std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
if (!isUniform) {
std::pair<IdArray, IdArray> ret;
ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
CHECK(prob[0]->ctx.device_type == kDLGPU) << "prob should be in GPU.";
ret = RandomWalkBiased<XPU, FloatType, IdType>(hg, seeds, metapath, prob, restart_prob);
});
return ret;
......
......@@ -44,10 +44,15 @@ void CheckRandomWalkInputs(
}
for (uint64_t i = 0; i < prob.size(); ++i) {
FloatArray p = prob[i];
CHECK_EQ(hg->Context(), p->ctx) << "Expected prob (" << p->ctx << ")" << " to have the same " \
<< "context as graph (" << hg->Context() << ").";
CHECK_FLOAT(p, "probability");
if (p.GetSize() != 0)
if (p.GetSize() != 0) {
CHECK_EQ(hg->IsPinned(), p.IsPinned())
<< "The prob array should have the same pinning status as the graph";
CHECK_NDIM(p, 1, "probability");
}
}
}
}; // namespace
......
......@@ -24,58 +24,74 @@ def check_random_walk(g, metapath, traces, ntypes, prob=None, trace_eids=None):
u, v = g.find_edges(trace_eids[i, j], etype=metapath[j])
assert (u == traces[i, j]) and (v == traces[i, j + 1])
def test_non_uniform_random_walk():
@pytest.mark.parametrize('use_uva', [True, False])
def test_non_uniform_random_walk(use_uva):
if use_uva:
if F.ctx() == F.cpu():
pytest.skip('UVA biased random walk requires a GPU.')
if dgl.backend.backend_name != 'pytorch':
pytest.skip('UVA biased random walk is only supported with PyTorch.')
g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0])
}).to(F.ctx())
})
g4 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]),
('user', 'view', 'item'): ([0, 0, 1, 2, 3, 3], [0, 1, 1, 2, 2, 1]),
('item', 'viewed-by', 'user'): ([0, 1, 1, 2, 2, 1], [0, 0, 1, 2, 3, 3])
}).to(F.ctx())
})
g2.edata['p'] = F.tensor([3, 0, 3, 3, 3], dtype=F.float32)
g2.edata['p2'] = F.tensor([[3], [0], [3], [3], [3]], dtype=F.float32)
g4.edges['follow'].data['p'] = F.tensor([3, 0, 3, 3, 3], dtype=F.float32)
g4.edges['viewed-by'].data['p'] = F.tensor([1, 1, 1, 1, 1, 1], dtype=F.float32)
g2.edata['p'] = F.copy_to(F.tensor([3, 0, 3, 3, 3], dtype=F.float32), F.cpu())
g2.edata['p2'] = F.copy_to(F.tensor([[3], [0], [3], [3], [3]], dtype=F.float32), F.cpu())
g4.edges['follow'].data['p'] = F.copy_to(F.tensor([3, 0, 3, 3, 3], dtype=F.float32), F.cpu())
g4.edges['viewed-by'].data['p'] = F.copy_to(F.tensor([1, 1, 1, 1, 1, 1], dtype=F.float32), F.cpu())
if use_uva:
for g in (g2, g4):
g.create_formats_()
g.pin_memory_()
elif F._default_context_str == 'gpu':
g2 = g2.to(F.ctx())
g4 = g4.to(F.ctx())
try:
traces, eids, ntypes = dgl.sampling.random_walk(
g2, [0, 1, 2, 3, 0, 1, 2, 3], length=4, prob='p', return_eids=True)
g2, F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g2.idtype),
length=4, prob='p', return_eids=True)
check_random_walk(g2, ['follow'] * 4, traces, ntypes, 'p', trace_eids=eids)
try:
with pytest.raises(dgl.DGLError):
traces, ntypes = dgl.sampling.random_walk(
g2, [0, 1, 2, 3, 0, 1, 2, 3], length=4, prob='p2')
fail = False
except dgl.DGLError:
fail = True
assert fail
g2, F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g2.idtype),
length=4, prob='p2')
metapath = ['follow', 'view', 'viewed-by'] * 2
traces, eids, ntypes = dgl.sampling.random_walk(
g4, [0, 1, 2, 3, 0, 1, 2, 3], metapath=metapath, prob='p', return_eids=True)
g4, F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),
metapath=metapath, prob='p', return_eids=True)
check_random_walk(g4, metapath, traces, ntypes, 'p', trace_eids=eids)
traces, eids, ntypes = dgl.sampling.random_walk(
g4, [0, 1, 2, 3, 0, 1, 2, 3], metapath=metapath, prob='p', restart_prob=0., return_eids=True)
g4, F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),
metapath=metapath, prob='p', restart_prob=0., return_eids=True)
check_random_walk(g4, metapath, traces, ntypes, 'p', trace_eids=eids)
traces, eids, ntypes = dgl.sampling.random_walk(
g4, [0, 1, 2, 3, 0, 1, 2, 3], metapath=metapath, prob='p',
g4, F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),
metapath=metapath, prob='p',
restart_prob=F.zeros((6,), F.float32, F.ctx()), return_eids=True)
check_random_walk(g4, metapath, traces, ntypes, 'p', trace_eids=eids)
traces, eids, ntypes = dgl.sampling.random_walk(
g4, [0, 1, 2, 3, 0, 1, 2, 3], metapath=metapath + ['follow'], prob='p',
g4, F.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=g4.idtype),
metapath=metapath + ['follow'], prob='p',
restart_prob=F.tensor([0, 0, 0, 0, 0, 0, 1], F.float32), return_eids=True)
check_random_walk(g4, metapath, traces[:, :7], ntypes[:7], 'p', trace_eids=eids)
assert (F.asnumpy(traces[:, 7]) == -1).all()
finally:
for g in (g2, g4):
g.unpin_memory_()
def _use_uva():
if F._default_context_str == 'cpu':
return [False]
else:
return [True, False]
@pytest.mark.parametrize('use_uva', _use_uva())
@pytest.mark.parametrize('use_uva', [True, False])
def test_uniform_random_walk(use_uva):
if use_uva and F.ctx() == F.cpu():
pytest.skip('UVA random walk requires a GPU.')
g1 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 1, 2], [1, 2, 0])
})
......@@ -178,8 +194,10 @@ def test_pack_traces():
assert F.array_equal(result[2], F.tensor([2, 7], dtype=F.int64))
assert F.array_equal(result[3], F.tensor([0, 2], dtype=F.int64))
@pytest.mark.parametrize('use_uva', _use_uva())
@pytest.mark.parametrize('use_uva', [True, False])
def test_pinsage_sampling(use_uva):
if use_uva and F.ctx() == F.cpu():
pytest.skip('UVA sampling requires a GPU.')
def _test_sampler(g, sampler, ntype):
seeds = F.copy_to(F.tensor([0, 2], dtype=g.idtype), F.ctx())
neighbor_g = sampler(seeds)
......
......@@ -5,6 +5,7 @@ import dgl.ops as OPS
import backend as F
import unittest
import torch
import torch.distributed as dist
from functools import partial
from torch.utils.data import DataLoader
from collections import defaultdict
......@@ -70,12 +71,33 @@ def test_saint(num_workers, mode):
for sg in dataloader:
pass
@pytest.mark.parametrize('num_workers', [0, 4])
def test_neighbor_nonuniform(num_workers):
g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]))
@parametrize_idtype
@pytest.mark.parametrize('mode', ['cpu', 'uva_cuda_indices', 'uva_cpu_indices', 'pure_gpu'])
@pytest.mark.parametrize('use_ddp', [False, True])
def test_neighbor_nonuniform(idtype, mode, use_ddp):
if mode != 'cpu' and F.ctx() == F.cpu():
pytest.skip('UVA and GPU sampling require a GPU.')
if use_ddp:
dist.init_process_group('gloo' if F.ctx() == F.cpu() else 'nccl',
'tcp://127.0.0.1:12347', world_size=1, rank=0)
g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1])).astype(idtype)
g.edata['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
if mode in ('cpu', 'uva_cpu_indices'):
indices = F.copy_to(F.tensor([0, 1], idtype), F.cpu())
else:
indices = F.copy_to(F.tensor([0, 1], idtype), F.cuda())
if mode == 'pure_gpu':
g = g.to(F.cuda())
use_uva = mode.startswith('uva')
sampler = dgl.dataloading.MultiLayerNeighborSampler([2], prob='p')
dataloader = dgl.dataloading.NodeDataLoader(g, [0, 1], sampler, batch_size=1, device=F.ctx())
for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
dataloader = dgl.dataloading.NodeDataLoader(
g, indices, sampler,
batch_size=1, device=F.ctx(),
num_workers=num_workers,
use_uva=use_uva,
use_ddp=use_ddp)
for input_nodes, output_nodes, blocks in dataloader:
seed = output_nodes.item()
neighbors = set(input_nodes[1:].cpu().numpy())
......@@ -87,12 +109,18 @@ def test_neighbor_nonuniform(num_workers):
g = dgl.heterograph({
('B', 'BA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]),
('C', 'CA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]),
})
}).astype(idtype)
g.edges['BA'].data['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
g.edges['CA'].data['p'] = torch.FloatTensor([0, 0, 1, 1, 0, 0, 1, 1])
sampler = dgl.dataloading.MultiLayerNeighborSampler([2], prob='p')
if mode == 'pure_gpu':
g = g.to(F.cuda())
for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
dataloader = dgl.dataloading.NodeDataLoader(
g, {'A': [0, 1]}, sampler, batch_size=1, device=F.ctx())
g, {'A': indices}, sampler,
batch_size=1, device=F.ctx(),
num_workers=num_workers,
use_uva=use_uva,
use_ddp=use_ddp)
for input_nodes, output_nodes, blocks in dataloader:
seed = output_nodes['A'].item()
# Seed and neighbors are of different node types so slicing is not necessary here.
......@@ -108,6 +136,9 @@ def test_neighbor_nonuniform(num_workers):
elif seed == 0:
assert neighbors == {3, 4}
if use_ddp:
dist.destroy_process_group()
def _check_dtype(data, dtype, attr_name):
if isinstance(data, dict):
for k, v in data.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment