Commit f54d6b04 authored by rusty1s's avatar rusty1s
Browse files

choice

parent db918cc2
...@@ -15,3 +15,50 @@ ...@@ -15,3 +15,50 @@
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
}() }()
template <typename scalar_t>
torch::Tensor from_vector(const std::vector<scalar_t> &vec,
bool inplace = false) {
const auto size = (int64_t)vec.size();
const auto out = torch::from_blob((scalar_t *)vec.data(), {size},
c10::CppTypeToScalarType<scalar_t>::value);
return inplace ? out : out.clone();
}
torch::Tensor choice(int64_t population, int64_t num_samples,
bool replace = false,
torch::optional<torch::Tensor> weight = torch::nullopt) {
if (!replace && num_samples >= population)
return torch::arange(population, at::kLong);
if (weight.has_value())
return torch::multinomial(weight.value(), num_samples, replace);
if (replace) {
const auto out = torch::empty(num_samples, at::kLong);
auto *out_data = out.data_ptr<int64_t>();
for (int64_t i = 0; i < num_samples; i++) {
out_data[i] = rand() % population;
}
return out;
} else {
// Sample without replacement via Robert Floyd algorithm:
// https://www.nowherenearithaca.com/2013/05/
// robert-floyds-tiny-and-beautiful.html
std::unordered_set<int64_t> values;
for (int64_t i = population - num_samples; i < population; i++) {
if (!values.insert(rand() % i).second)
values.insert(i);
}
const auto out = torch::empty(num_samples, at::kLong);
auto *out_data = out.data_ptr<int64_t>();
int64_t i = 0;
for (const auto &value : values) {
out2_data[i] = value;
i++;
}
return out;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment