"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c44fba889965638f447d20f5730745c7963494d7"
Commit e4cac317 authored by rusty1s's avatar rusty1s
Browse files

fix

parent e2dca775
...@@ -119,19 +119,26 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict, ...@@ -119,19 +119,26 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
rel_to_edge_type[rel_type] = split(rel_type); rel_to_edge_type[rel_type] = split(rel_type);
} }
// Initialize various data structures for the sampling process, and add the // Initialize various data structures for the sampling process:
// input nodes to the final sampled output set (line 1):
std::unordered_map<node_t, std::vector<int64_t>> sampled_nodes_dict; std::unordered_map<node_t, std::vector<int64_t>> sampled_nodes_dict;
std::unordered_map<node_t, std::unordered_map<int64_t, int64_t>> std::unordered_map<node_t, std::unordered_map<int64_t, int64_t>>
global_to_local_node_dict; global_to_local_node_dict;
std::unordered_map<node_t, std::unordered_map<int64_t, float>> budget_dict;
for (const auto &kv : num_samples_dict) {
const auto &node_type = kv.key();
sampled_nodes_dict[node_type];
global_to_local_node_dict[node_type];
budget_dict[node_type];
}
// Add all input nodes of every node type to the sampled output set (line 1):
for (const auto &kv : input_node_dict) { for (const auto &kv : input_node_dict) {
const auto &node_type = kv.key(); const auto &node_type = kv.key();
const auto &input_node = kv.value(); const auto &input_node = kv.value();
const auto *input_node_data = input_node.data_ptr<int64_t>(); const auto *input_node_data = input_node.data_ptr<int64_t>();
auto &sampled_nodes = sampled_nodes_dict[node_type]; auto &sampled_nodes = sampled_nodes_dict.at(node_type);
auto &global_to_local_node = global_to_local_node_dict[node_type]; auto &global_to_local_node = global_to_local_node_dict.at(node_type);
// Add each origin node to the sampled output nodes: // Add each origin node to the sampled output nodes:
for (int64_t i = 0; i < input_node.numel(); i++) { for (int64_t i = 0; i < input_node.numel(); i++) {
...@@ -143,7 +150,6 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict, ...@@ -143,7 +150,6 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
// Update budget after all input nodes have been added to the sampled output // Update budget after all input nodes have been added to the sampled output
// set (line 2-5): // set (line 2-5):
std::unordered_map<node_t, std::unordered_map<int64_t, float>> budget_dict;
for (const auto &kv : sampled_nodes_dict) { for (const auto &kv : sampled_nodes_dict) {
update_budget(&budget_dict, kv.first, kv.second, global_to_local_node_dict, update_budget(&budget_dict, kv.first, kv.second, global_to_local_node_dict,
rel_to_edge_type, rowptr_dict, col_dict, false); rel_to_edge_type, rowptr_dict, col_dict, false);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment