Unverified Commit 90f10b31 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Negative sampling (#3599)

* first commit

* a bunch of fixes

* add unique

* lint

* lint

* lint

* address comments

* Update negative_sampler.py

* fix

* description

* address comments and fix

* fix

* replace unique with replace

* test pylint

* Update negative_sampler.py
parent 01bec4a3
......@@ -7,6 +7,8 @@
#define DGL_ARRAY_CUDA_UTILS_H_
#include <dmlc/logging.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/ndarray.h>
#include <dlpack/dlpack.h>
#include "../../runtime/cuda/cuda_common.h"
......@@ -172,6 +174,24 @@ __global__ void _LinearSearchKernel(
}
}
template <typename DType>
inline DType GetCUDAScalar(
runtime::DeviceAPI* device_api,
DLContext ctx,
const DType* cuda_ptr,
cudaStream_t stream) {
DType result;
device_api->CopyDataFromTo(
cuda_ptr, 0,
&result, 0,
sizeof(result),
ctx,
DLContext{kDLCPU, 0},
DLDataTypeTraits<DType>::dtype,
stream);
return result;
}
} // namespace cuda
} // namespace dgl
......
/*!
* Copyright (c) 2021 by Contributors
* \file graph/sampling/negative/global_uniform.cc
* \brief Global uniform negative sampling.
*/
#include <dgl/array.h>
#include <dgl/sampling/negative.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <utility>
#include "../../../c_api_common.h"
using namespace dgl::runtime;
using namespace dgl::aten;
namespace dgl {
namespace sampling {
std::pair<IdArray, IdArray> GlobalUniformNegativeSampling(
HeteroGraphPtr hg,
dgl_type_t etype,
int64_t num_samples,
int num_trials,
bool exclude_self_loops,
bool replace,
double redundancy) {
dgl_format_code_t allowed = hg->GetAllowedFormats();
auto format = hg->SelectFormat(etype, CSC_CODE | CSR_CODE);
if (format == SparseFormat::kCSC) {
CSRMatrix csc = hg->GetCSCMatrix(etype);
CSRSort_(&csc);
std::pair<IdArray, IdArray> result = CSRGlobalUniformNegativeSampling(
csc, num_samples, num_trials, exclude_self_loops, replace, redundancy);
// reverse the pair since it is CSC
return {result.second, result.first};
} else if (format == SparseFormat::kCSR) {
CSRMatrix csr = hg->GetCSRMatrix(etype);
CSRSort_(&csr);
return CSRGlobalUniformNegativeSampling(
csr, num_samples, num_trials, exclude_self_loops, replace, redundancy);
} else {
LOG(FATAL) << "COO format is not supported in global uniform negative sampling";
return {IdArray(), IdArray()};
}
}
DGL_REGISTER_GLOBAL("sampling.negative._CAPI_DGLGlobalUniformNegativeSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
int64_t num_samples = args[2];
int num_trials = args[3];
bool exclude_self_loops = args[4];
bool replace = args[5];
double redundancy = args[6];
List<Value> result;
std::pair<IdArray, IdArray> ret = GlobalUniformNegativeSampling(
hg.sptr(), etype, num_samples, num_trials, exclude_self_loops, replace, redundancy);
result.push_back(Value(MakeValue(ret.first)));
result.push_back(Value(MakeValue(ret.second)));
*rv = result;
});
}; // namespace sampling
}; // namespace dgl
......@@ -18,6 +18,8 @@ extern "C" void NDArrayDLPackDeleter(DLManagedTensor* tensor);
namespace dgl {
constexpr DLDataType DLDataTypeTraits<int8_t>::dtype;
constexpr DLDataType DLDataTypeTraits<int16_t>::dtype;
constexpr DLDataType DLDataTypeTraits<int32_t>::dtype;
constexpr DLDataType DLDataTypeTraits<int64_t>::dtype;
constexpr DLDataType DLDataTypeTraits<uint32_t>::dtype;
......
......@@ -890,6 +890,36 @@ def test_sample_neighbors_exclude_edges_homoG(dtype):
assert not np.any(F.asnumpy(sg.has_edges_between(excluded_nodes_U,excluded_nodes_V)))
@pytest.mark.parametrize('dtype', ['int32', 'int64'])
def test_global_uniform_negative_sampling(dtype):
g = dgl.graph((np.random.randint(0, 20, (300,)), np.random.randint(0, 20, (300,)))).to(F.ctx())
src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, False, True)
assert not F.asnumpy(g.has_edges_between(src, dst)).any()
src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, False, False)
assert not F.asnumpy(g.has_edges_between(src, dst)).any()
src = F.asnumpy(src)
dst = F.asnumpy(dst)
s = set(zip(src.tolist(), dst.tolist()))
assert len(s) == len(src)
g = dgl.graph(([0], [1])).to(F.ctx())
src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, True, False, redundancy=10)
src = F.asnumpy(src)
dst = F.asnumpy(dst)
# should have either no element or (1, 0)
assert len(src) < 2
assert len(dst) < 2
if len(src) == 1:
assert src[0] == 1
assert dst[0] == 0
g = dgl.heterograph({
('A', 'AB', 'B'): (np.random.randint(0, 20, (300,)), np.random.randint(0, 40, (300,))),
('B', 'BA', 'A'): (np.random.randint(0, 40, (200,)), np.random.randint(0, 20, (200,)))}).to(F.ctx())
src, dst = dgl.sampling.global_uniform_negative_sampling(g, 20, False, etype='AB')
assert not F.asnumpy(g.has_edges_between(src, dst, etype='AB')).any()
if __name__ == '__main__':
from itertools import product
......@@ -906,3 +936,5 @@ if __name__ == '__main__':
test_sample_neighbors_biased_bipartite()
test_sample_neighbors_exclude_edges_heteroG('int32')
test_sample_neighbors_exclude_edges_homoG('int32')
test_global_uniform_negative_sampling('int32')
test_global_uniform_negative_sampling('int64')
......@@ -356,9 +356,11 @@ def test_node_dataloader(sampler_name):
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'shadow'])
def test_edge_dataloader(sampler_name):
neg_sampler = dgl.dataloading.negative_sampler.Uniform(2)
@pytest.mark.parametrize('neg_sampler', [
dgl.dataloading.negative_sampler.Uniform(2),
dgl.dataloading.negative_sampler.GlobalUniform(15, False, 3),
dgl.dataloading.negative_sampler.GlobalUniform(15, True, 3)])
def test_edge_dataloader(sampler_name, neg_sampler):
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
......@@ -428,4 +430,8 @@ if __name__ == '__main__':
test_neighbor_nonuniform(0)
for sampler in ['full', 'neighbor', 'shadow']:
test_node_dataloader(sampler)
test_edge_dataloader(sampler)
for neg_sampler in [
dgl.dataloading.negative_sampler.Uniform(2),
dgl.dataloading.negative_sampler.GlobalUniform(2, False),
dgl.dataloading.negative_sampler.GlobalUniform(2, True)]:
test_edge_dataloader(sampler, neg_sampler)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment