"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "750bd7920622b3fe538d20035d3f03855c5d6621"
Commit db918cc2 authored by rusty1s's avatar rusty1s
Browse files

update

parent 4bff7c3f
#include "hgt_sample_cpu.h" #include "hgt_sample_cpu.h"
#include <chrono> // TODO
#include <random> #include <random>
#include <ATen/Parallel.h>
edge_t split(const rel_t &rel_type) { edge_t split(const rel_t &rel_type) {
std::vector<std::string> result(3); std::vector<std::string> result(3);
int start = 0, end = 0; int start = 0, end = 0;
...@@ -14,8 +17,9 @@ edge_t split(const rel_t &rel_type) { ...@@ -14,8 +17,9 @@ edge_t split(const rel_t &rel_type) {
} }
torch::Tensor vec_to_tensor(const std::vector<int64_t> &v) { torch::Tensor vec_to_tensor(const std::vector<int64_t> &v) {
return torch::from_blob((int64_t *)v.data(), {(int64_t)v.size()}, at::kLong) auto *data = (int64_t *)v.data();
.clone(); auto size = (int64_t)v.size();
return torch::from_blob(data, {size}, at::kLong).clone();
} }
template <typename Container> template <typename Container>
...@@ -46,13 +50,30 @@ void update_budget( ...@@ -46,13 +50,30 @@ void update_budget(
for (const auto &v : sampled_nodes) { for (const auto &v : sampled_nodes) {
const int64_t col_start = colptr_data[v], col_end = colptr_data[v + 1]; const int64_t col_start = colptr_data[v], col_end = colptr_data[v + 1];
if (col_end != col_start) { const auto col_count = col_end - col_start;
if (col_count > 520) { // TODO
// There might be same neighbors with large neighborhood sizes.
// In order to prevent that we fill our budget stare with many values
// of low probability, we simply sample a subset without replacement.
std::unordered_set<int64_t> perm;
for (int64_t j = col_count - 520; j < col_count; j++) {
if (!perm.insert(rand() % j).second)
perm.insert(j);
}
const auto inv_deg = 1.f / 520.f;
for (const auto &p : perm) {
const auto w = row_data[col_start + p];
// Only add the neighbor in case we have not yet seen it before:
if (global_to_local_node.find(w) == global_to_local_node.end())
budget[w] += inv_deg;
}
} else if (col_count > 0) {
const auto inv_deg = 1.f / float(col_end - col_start); const auto inv_deg = 1.f / float(col_end - col_start);
for (int64_t j = col_start; j < col_end; j++) { for (int64_t j = col_start; j < col_end; j++) {
const auto w = row_data[j]; const auto w = row_data[j];
// Only add the neighbor in case we have not yet seen it before: // Only add the neighbor in case we have not yet seen it before:
if (global_to_local_node.find(w) == global_to_local_node.end()) if (global_to_local_node.find(w) == global_to_local_node.end())
budget[row_data[j]] += inv_deg; budget[w] += inv_deg;
} }
} }
} }
...@@ -91,7 +112,14 @@ sample_from(const std::unordered_map<int64_t, float> &budget, ...@@ -91,7 +112,14 @@ sample_from(const std::unordered_map<int64_t, float> &budget,
// The implementation assigns two iterators on budget and samples, // The implementation assigns two iterators on budget and samples,
// respectively, and then computes the node samples in linear time by // respectively, and then computes the node samples in linear time by
// alternatingly incrementing the two iterators based on their values. // alternatingly incrementing the two iterators based on their values.
// TODO
output.reserve(num_samples); output.reserve(num_samples);
for (const auto &kv : budget) {
output.insert(kv.first);
if (output.size() == num_samples)
break;
}
return output;
auto j = samples.begin(); auto j = samples.begin();
auto cum_prob = 0.f; auto cum_prob = 0.f;
...@@ -120,6 +148,7 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict, ...@@ -120,6 +148,7 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<node_t, std::vector<int64_t>> &num_samples_dict, const c10::Dict<node_t, std::vector<int64_t>> &num_samples_dict,
int64_t num_hops) { int64_t num_hops) {
std::chrono::steady_clock::time_point a = std::chrono::steady_clock::now();
// Create mapping to convert single string relations to edge type triplets: // Create mapping to convert single string relations to edge type triplets:
std::unordered_map<rel_t, edge_t> rel_to_edge_type; std::unordered_map<rel_t, edge_t> rel_to_edge_type;
for (const auto &kv : colptr_dict) { for (const auto &kv : colptr_dict) {
...@@ -155,46 +184,79 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict, ...@@ -155,46 +184,79 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
sampled_nodes.push_back(v); sampled_nodes.push_back(v);
global_to_local_node[v] = i; global_to_local_node[v] = i;
} }
}
// Update budget after input nodes have been added to the sampled output set
// (line 2-5):
for (const auto &kv : input_node_dict) {
const auto &node_type = kv.key();
const auto &sampled_nodes = sampled_nodes_dict.at(node_type);
// Update budget after input nodes have been added to the sampled output set
// (line 2-5):
update_budget<std::vector<int64_t>>( update_budget<std::vector<int64_t>>(
&budget_dict, node_type, sampled_nodes, global_to_local_node_dict, &budget_dict, node_type, sampled_nodes, global_to_local_node_dict,
rel_to_edge_type, colptr_dict, row_dict); rel_to_edge_type, colptr_dict, row_dict);
} }
std::chrono::steady_clock::time_point b = std::chrono::steady_clock::now();
std::cout
<< "[1] = "
<< std::chrono::duration_cast<std::chrono::microseconds>(b - a).count()
<< "[µs]" << std::endl;
a = std::chrono::steady_clock::now();
// Sample nodes for each node type in each layer (line 6 - 18): // Sample nodes for each node type in each layer (line 6 - 18):
for (int64_t ell = 0; ell < num_hops; ell++) { for (int64_t ell = 0; ell < num_hops; ell++) {
for (auto &kv : budget_dict) { std::vector<node_t> node_types; // Only iterate over non-empty budgets.
const auto &node_type = kv.first; for (const auto &kv : budget_dict) {
auto &budget = kv.second; if (kv.second.size() > 0)
const auto num_samples = num_samples_dict.at(node_type)[ell]; node_types.push_back(kv.first);
}
// Sample `num_samples` nodes of `node_type` according to the budget std::unordered_map<node_t, std::unordered_set<int64_t>>
// (line 9-11): tmp_sampled_nodes_dict;
const auto samples = sample_from(budget, num_samples); at::parallel_for(0, node_types.size(), 1, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; i++) {
const auto &node_type = node_types[i];
const auto &budget = budget_dict.at(node_type);
const auto num_samples = num_samples_dict.at(node_type)[ell];
if (samples.size() > 0) { // Sample `num_samples` nodes of `node_type` according to the budget
// Add sampled nodes to the sampled output set (line 13): // (line 9-11):
const auto tmp_sampled_nodes = sample_from(budget, num_samples);
tmp_sampled_nodes_dict[node_type] = tmp_sampled_nodes;
// Add intermediate samples to the sampled output set (line 13):
auto &sampled_nodes = sampled_nodes_dict.at(node_type); auto &sampled_nodes = sampled_nodes_dict.at(node_type);
auto &global_to_local_node = global_to_local_node_dict.at(node_type); auto &global_to_local_node = global_to_local_node_dict.at(node_type);
for (const auto &v : samples) { for (const auto &v : tmp_sampled_nodes) {
sampled_nodes.push_back(v); sampled_nodes.push_back(v);
global_to_local_node[v] = sampled_nodes.size(); global_to_local_node[v] = sampled_nodes.size();
} }
// Add neighbors of newly sampled nodes to the bucket (line 14-15):
update_budget<std::unordered_set<int64_t>>(
&budget_dict, node_type, samples, global_to_local_node_dict,
rel_to_edge_type, colptr_dict, row_dict);
} }
});
for (const auto &kv : tmp_sampled_nodes_dict) {
// Add neighbors of newly sampled nodes to the bucket (line 14-15):
update_budget<std::unordered_set<int64_t>>(
&budget_dict, kv.first, kv.second, global_to_local_node_dict,
rel_to_edge_type, colptr_dict, row_dict);
} }
} }
b = std::chrono::steady_clock::now();
std::cout
<< "[2] = "
<< std::chrono::duration_cast<std::chrono::microseconds>(b - a).count()
<< "[µs]" << std::endl;
a = std::chrono::steady_clock::now();
// Reconstruct the sampled adjacency matrix among the sampled nodes (line 19): // Reconstruct the sampled adjacency matrix among the sampled nodes (line 19):
c10::Dict<rel_t, torch::Tensor> output_row_dict; c10::Dict<rel_t, torch::Tensor> output_row_dict;
c10::Dict<rel_t, torch::Tensor> output_col_dict; c10::Dict<rel_t, torch::Tensor> output_col_dict;
c10::Dict<rel_t, torch::Tensor> output_edge_dict; c10::Dict<rel_t, torch::Tensor> output_edge_dict;
// TODO: Parallelize across edge types?
//
// at::parallel_for(0, edge_types.size(), 1, [&](int64_t begin, int64_t end) {
for (const auto &kv : colptr_dict) { for (const auto &kv : colptr_dict) {
const auto &rel_type = kv.key(); const auto &rel_type = kv.key();
const auto &edge_type = rel_to_edge_type.at(rel_type); const auto &edge_type = rel_to_edge_type.at(rel_type);
...@@ -234,6 +296,11 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict, ...@@ -234,6 +296,11 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
if (kv.second.size() > 0) if (kv.second.size() > 0)
output_node_dict.insert(kv.first, vec_to_tensor(kv.second)); output_node_dict.insert(kv.first, vec_to_tensor(kv.second));
} }
b = std::chrono::steady_clock::now();
std::cout
<< "[3] = "
<< std::chrono::duration_cast<std::chrono::microseconds>(b - a).count()
<< "[µs]" << std::endl;
return std::make_tuple(output_node_dict, output_row_dict, output_col_dict, return std::make_tuple(output_node_dict, output_row_dict, output_col_dict,
output_edge_dict); output_edge_dict);
......
# from typing import Dict, List
import torch
# from torch import Tensor
# from torch_sparse import SparseTensor
def test_hgt_sample():
rowptr = torch.tensor([0, 1, 3, 4])
# row = torch.tensor([0, 1, 1, 2])
col = torch.tensor([1, 0, 2, 1])
# _ = SparseTensor(row=row, col=col, sparse_sizes=(6, 6))
rowptr_dict = {'paper__to__paper': rowptr}
col_dict = {'paper__to__paper': col}
node_idx_dict = {'paper': torch.arange(rowptr.numel() - 1)}
num_neighbors_dict = {'paper__to__paper': [5, 5]}
num_hops = 2
# nid = torch.tensor([0, 1])
fn = torch.ops.torch_sparse.hgt_sample
# print(fn)
# fn(rowptr_dict, col_dict, node_idx_dict, num_neighbors_dict, num_hops)
fn(rowptr_dict, col_dict, node_idx_dict, num_neighbors_dict, num_hops)
# out = fn(rowptr, col, nid, 1, 3, False)
# rowptr, col, nid, eid, ptr, root_n_id = out
# assert nid.tolist() == [0, 1, 2, 3, 0, 1, 2]
# assert rowptr.tolist() == [0, 3, 5, 7, 8, 10, 12, 14]
# # row [0, 0, 0, 1, 1, 2, 2, 3, 4, 4, 5, 5, 6, 6]
# assert col.tolist() == [1, 2, 3, 0, 2, 0, 1, 0, 5, 6, 4, 6, 4, 5]
# assert eid.tolist() == [0, 1, 2, 3, 4, 5, 6, 9, 0, 1, 3, 4, 5, 6]
# assert ptr.tolist() == [0, 4, 7]
# assert root_n_id.tolist() == [0, 5]
# import timeit
# print('Timeit', timeit.timeit(test_hgt_sample))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment