Commit 3369b5f0 authored by rusty1s's avatar rusty1s
Browse files

typos

parent 0580c3f8
......@@ -7,14 +7,11 @@ std::tuple<at::Tensor, at::Tensor> neighbor_sampler(at::Tensor start,
at::Tensor cumdeg,
at::Tensor col, size_t size,
float factor) {
THGenerator *generator = THGenerator_new();
auto start_ptr = start.data<int64_t>();
auto cumdeg_ptr = cumdeg.data<int64_t>();
// TODO: size float/int, sampling
std::vector<int64_t> e_ids;
for (ptrdiff_t i = 0; i < start.size(0); i++) {
int64_t low = cumdeg_ptr[start_ptr[i]];
......@@ -47,11 +44,11 @@ std::tuple<at::Tensor, at::Tensor> neighbor_sampler(at::Tensor start,
THGenerator_free(generator);
auto e_id =
torch::from_blob(e_ids.data(), {(signed)e_ids.size()}, start.options());
int64_t len = e_ids.size();
auto e_id = torch::from_blob(e_ids.data(), {len}, start.options()).clone();
auto n_id = std::get<0>(at::_unique(col.index_select(0, e_id)));
return std::make_tuple(n_id, e_id.clone());
return std::make_tuple(n_id, e_id);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment