#pragma once #include "../extensions.h" #include "parallel_hashmap/phmap.h" #define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") #define AT_DISPATCH_HAS_VALUE(optional_value, ...) \ [&] { \ if (optional_value.has_value()) { \ const bool HAS_VALUE = true; \ return __VA_ARGS__(); \ } else { \ const bool HAS_VALUE = false; \ return __VA_ARGS__(); \ } \ }() template inline torch::Tensor from_vector(const std::vector &vec, bool inplace = false) { const auto size = (int64_t)vec.size(); const auto out = torch::from_blob((scalar_t *)vec.data(), {size}, c10::CppTypeToScalarType::value); return inplace ? out : out.clone(); } template inline c10::Dict from_vector(const phmap::flat_hash_map> &vec_dict, bool inplace = false) { c10::Dict out_dict; for (const auto &kv : vec_dict) out_dict.insert(kv.first, from_vector(kv.second, inplace)); return out_dict; } inline int64_t uniform_randint(int64_t low, int64_t high) { CHECK_LT(low, high); auto options = torch::TensorOptions().dtype(torch::kInt64); auto ret = torch::randint(low, high, {1}, options); auto ptr = ret.data_ptr(); return *ptr; } inline int64_t uniform_randint(int64_t high) { return uniform_randint(0, high); } inline torch::Tensor choice(int64_t population, int64_t num_samples, bool replace = false, torch::optional weight = torch::nullopt) { if (population == 0 || num_samples == 0) return torch::empty({0}, at::kLong); if (!replace && num_samples >= population) return torch::arange(population, at::kLong); if (weight.has_value()) return torch::multinomial(weight.value(), num_samples, replace); if (replace) { const auto out = torch::empty({num_samples}, at::kLong); auto *out_data = out.data_ptr(); for (int64_t i = 0; i < num_samples; i++) { out_data[i] = uniform_randint(population); } return out; } else { // Sample without replacement via Robert Floyd algorithm: // https://www.nowherenearithaca.com/2013/05/ // robert-floyds-tiny-and-beautiful.html const auto out = torch::empty({num_samples}, at::kLong); auto *out_data = out.data_ptr(); std::unordered_set samples; for (int64_t i = population - num_samples; i < population; i++) { int64_t sample = uniform_randint(i); if (!samples.insert(sample).second) { sample = i; samples.insert(sample); } out_data[i - population + num_samples] = sample; } return out; } } template inline void uniform_choice(const int64_t population, const int64_t num_samples, const int64_t *idx_data, std::vector *samples, phmap::flat_hash_map *to_local_node) { if (population == 0 || num_samples == 0) return; if (replace) { for (int64_t i = 0; i < num_samples; i++) { const int64_t &v = idx_data[uniform_randint(population)]; if (to_local_node->insert({v, samples->size()}).second) samples->push_back(v); } } else if (num_samples >= population) { for (int64_t i = 0; i < population; i++) { const int64_t &v = idx_data[i]; if (to_local_node->insert({v, samples->size()}).second) samples->push_back(v); } } else { std::unordered_set indices; for (int64_t i = population - num_samples; i < population; i++) { int64_t j = uniform_randint(i); if (!indices.insert(j).second) { j = i; indices.insert(j); } const int64_t &v = idx_data[j]; if (to_local_node->insert({v, samples->size()}).second) samples->push_back(v); } } }