Commit 9e8a8c05 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
def __bootstrap__():
global __bootstrap__, __loader__, __file__
import sys, pkg_resources, importlib.util
__file__ = pkg_resources.resource_filename(__name__, 'batch_C_v0p5.cpython-310-x86_64-linux-gnu.so')
__loader__ = None; del __bootstrap__, __loader__
spec = importlib.util.spec_from_file_location(__name__,__file__)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
__bootstrap__()
def __bootstrap__():
global __bootstrap__, __loader__, __file__
import sys, pkg_resources, importlib.util
__file__ = pkg_resources.resource_filename(__name__, 'batch_C_v0p5_better.cpython-310-x86_64-linux-gnu.so')
__loader__ = None; del __bootstrap__, __loader__
spec = importlib.util.spec_from_file_location(__name__,__file__)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
__bootstrap__()
def __bootstrap__():
global __bootstrap__, __loader__, __file__
import sys, pkg_resources, importlib.util
__file__ = pkg_resources.resource_filename(__name__, 'batch_C_v0p6.cpython-310-x86_64-linux-gnu.so')
__loader__ = None; del __bootstrap__, __loader__
spec = importlib.util.spec_from_file_location(__name__,__file__)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
__bootstrap__()
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
namespace at { namespace native {
namespace {
bool is_batch_full(int64_t num_tokens, int64_t max_tokens, int64_t max_sentences, int64_t batch_length){
if (batch_length == 0){
return false;
} else if (batch_length == max_sentences || num_tokens > max_tokens){
return true;
} else {
return false;
}
}
}
std::vector<std::vector<int64_t> > make_batches_v0p5(py::array_t<int64_t> src_lengths, py::array_t<int64_t> tgt_lengths, py::array_t<int64_t> idx_list, int64_t max_tokens, int64_t max_sentences, uint64_t bsz_mult, int64_t max_len){
std::vector<std::vector<int64_t> > batches;
auto src_l = src_lengths.unchecked<1>();
auto tgt_l = tgt_lengths.unchecked<1>();
auto idx_l = idx_list.unchecked<1>();
AT_ASSERTM(src_l.shape(0) == tgt_l.shape(0), "tgt_list and src_list should have the same shape");
AT_ASSERTM(idx_l.shape(0) == tgt_l.shape(0), "idx_list and tgt_list should have the same shape");
ssize_t nelem = src_l.shape(0);
int64_t sample_len =0;
std::vector<int64_t> sample_lens;
std::vector<int64_t> batch;
for (ssize_t i=0; i < nelem; i++){
int64_t idx = idx_l(i);
int64_t sample_num_tokens = std::max(src_l(idx), tgt_l(idx));
if (sample_num_tokens > max_len) continue;
sample_len = std::max(sample_len, sample_num_tokens);
sample_lens.push_back(sample_num_tokens);
int64_t num_tokens = (batch.size() + 1) * sample_len;
if (is_batch_full(num_tokens, max_tokens, max_sentences, batch.size())){
int64_t mode_len = std::max(batch.size() / bsz_mult * bsz_mult, batch.size() % bsz_mult);
std::vector<int64_t> new_batch;
new_batch.reserve(mode_len);
std::copy(batch.begin()+mode_len, batch.end(), std::back_inserter(new_batch));
batch.erase(batch.begin()+mode_len, batch.end());
sample_lens.erase(sample_lens.begin(), sample_lens.begin()+mode_len);
//sample_len always contains at least one element
sample_len = *std::max_element(sample_lens.begin(), sample_lens.end());
batches.push_back(batch);
batch = new_batch;
}
batch.push_back(idx);
}
if (batch.size() > 0) batches.push_back(batch);
return batches;
}
}}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.def("make_batches_v0p5", &at::native::make_batches_v0p5);
}
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
namespace at {
namespace native {
// In lieu of a header file...
enum BatchingScheme
{
MAKE_BATCHES_V0P5_BETTER = 0,
MAKE_BATCHES_V0P5_EVEN_BETTER = 1
};
int64_t roundup(int64_t x, int64_t multiple);
int64_t rounddown(int64_t x, int64_t multiple);
template<int BatchingScheme> std::vector<std::vector<int64_t> > make_batches(
py::array_t<int64_t> src_lengths,
py::array_t<int64_t> tgt_lengths,
py::array_t<int64_t> idx_list,
int64_t max_tokens,
int64_t max_sentences,
int64_t max_len,
int64_t bsz_mult,
int64_t pad_seq);
// Source starts here
int64_t roundup(int64_t x, int64_t multiple) {
return (x + multiple - 1) / multiple * multiple;
} // roundup
int64_t rounddown(int64_t x, int64_t multiple)
{
return x / multiple * multiple;
} // rounddown
bool is_batch_full(int64_t num_tokens, int64_t max_tokens, int64_t max_sentences, int64_t batch_length){
if (batch_length == 0)
{
return false;
}
else if (batch_length == max_sentences || num_tokens > max_tokens)
{
return true;
}
else
{
return false;
}
} // is_batch_full
template <>
std::vector<std::vector<int64_t>> make_batches<MAKE_BATCHES_V0P5_BETTER>(
py::array_t<int64_t> src_lengths,
py::array_t<int64_t> tgt_lengths,
py::array_t<int64_t> idx_list,
int64_t max_tokens,
int64_t max_sentences,
int64_t max_len,
int64_t bsz_mult,
int64_t pad_seq)
{
std::vector<std::vector<int64_t>> batches;
const auto src_l = src_lengths.unchecked<1>();
const auto tgt_l = tgt_lengths.unchecked<1>();
const auto idx_l = idx_list.unchecked<1>();
AT_ASSERTM(src_l.shape(0) == tgt_l.shape(0), "tgt_list and src_list should have the same shape");
AT_ASSERTM(idx_l.shape(0) == tgt_l.shape(0), "idx_list and tgt_list should have the same shape");
const auto nelem = src_l.shape(0);
int64_t sample_len = 0;
int64_t padded_sample_len = 0;
const auto num_seqs_mult = ((bsz_mult % pad_seq) == 0) ? bsz_mult / pad_seq : bsz_mult;
std::vector<int64_t> sample_lens;
std::vector<int64_t> batch;
for (ssize_t i = 0; i < nelem; ++i){
const auto idx = idx_l(i);
const auto sample_num_tokens = std::max(src_l(idx), tgt_l(idx));
if (sample_num_tokens > max_len) continue;
sample_len = std::max(sample_len, sample_num_tokens);
padded_sample_len = (static_cast<int64_t>(batch.size()) < num_seqs_mult) ? roundup(sample_len, bsz_mult) : roundup(sample_len, pad_seq);
sample_lens.emplace_back(sample_num_tokens);
int64_t num_tokens = (batch.size() + 1) * padded_sample_len;
if (is_batch_full(num_tokens, max_tokens, max_sentences, batch.size()))
{
auto sequences = batch.size();
if ( ((sequences % num_seqs_mult) != 0) && (sequences > num_seqs_mult) ) {
auto pad_sequences_opt_seqs = rounddown(sequences, num_seqs_mult);
auto total_tokens_opt_seqs = padded_sample_len * pad_sequences_opt_seqs;
auto pad_seq_len_opt_seqlen = roundup(padded_sample_len, bsz_mult);
auto pad_sequences_opt_seqlen= max_tokens / pad_seq_len_opt_seqlen;
auto total_tokens_opt_seqlen = padded_sample_len * pad_sequences_opt_seqlen;
if(total_tokens_opt_seqs >= total_tokens_opt_seqlen) {
sequences = pad_sequences_opt_seqs;
} else {
sequences = pad_sequences_opt_seqlen;
}
}
//std::cout << "BATCH: Sentences: " << sequences << " Sent Length: " << sample_len << " Total: " << sample_len*sequences << " " << (static_cast<float>(sample_len * sequences) / static_cast<float>(max_tokens) * 100.0) << std::endl;
std::vector<int64_t> new_batch;
new_batch.reserve(sequences);
std::copy(batch.begin() + sequences, batch.end(), std::back_inserter(new_batch));
batch.erase(batch.begin() + sequences, batch.end());
sample_lens.erase(sample_lens.begin(), sample_lens.begin() + sequences);
sample_len = *std::max_element(sample_lens.begin(), sample_lens.end());
batches.emplace_back(batch);
batch = new_batch;
}
batch.emplace_back(idx);
}
while (batch.size() > 0)
{
const auto sequences = std::max(batch.size() / num_seqs_mult * num_seqs_mult, batch.size() % num_seqs_mult);
std::vector<int64_t> new_batch;
new_batch.reserve(sequences);
std::copy(batch.begin() + sequences, batch.end(), std::back_inserter(new_batch));
batch.erase(batch.begin() + sequences, batch.end());
batches.emplace_back(batch);
batch = new_batch;
}
return batches;
} // make_batches<MAKE_BATCHES_V0P5_BETTER>
template <>
std::vector<std::vector<int64_t>> make_batches<MAKE_BATCHES_V0P5_EVEN_BETTER>(
py::array_t<int64_t> src_lengths,
py::array_t<int64_t> tgt_lengths,
py::array_t<int64_t> idx_list,
int64_t max_tokens,
int64_t max_sentences,
int64_t max_len,
int64_t bsz_mult,
int64_t pad_seq)
{
std::vector<std::vector<int64_t> > batches(1);
const auto src_l = src_lengths.unchecked<1>();
const auto tgt_l = tgt_lengths.unchecked<1>();
const auto idx_l = idx_list.unchecked<1>();
AT_ASSERTM(src_l.shape(0) == tgt_l.shape(0), "tgt_list and src_list should have the same shape");
AT_ASSERTM(idx_l.shape(0) == tgt_l.shape(0), "idx_list and tgt_list should have the same shape");
// argsort
std::vector<int64_t> max_lengths(src_l.size());
for (int64_t i = 0; i < src_l.size(); ++i)
{
max_lengths[i] = std::max(src_l[i], tgt_l[i]);
}
std::vector<int64_t> perm(src_l.size());
iota(perm.begin(), perm.end(), 0);
std::sort(
perm.begin(),
perm.end(),
[&max_lengths](int64_t i1, int64_t i2) { return max_lengths[i1] > max_lengths[i2]; }); // descending order
int64_t offset = 0;
while (max_lengths[perm[offset]] > max_len) // skip all sequences over specified length
{
++offset;
}
int64_t padded_seq_len = roundup(max_lengths[perm[offset]], pad_seq);
int64_t max_seq_in_batch = max_tokens / padded_seq_len;
int64_t n_seq_in_batch = 0;
int64_t n_tok_in_batch = 0;
for (auto && i : perm)
{
if (max_lengths[i] > max_len) continue;
if (n_tok_in_batch + padded_seq_len < max_tokens && n_seq_in_batch < rounddown(max_sentences, bsz_mult))
{
batches.back().emplace_back(i);
++n_seq_in_batch;
n_tok_in_batch += padded_seq_len;
}
else
{
batches.emplace_back(std::vector<int64_t>(1, i));
padded_seq_len = roundup(max_lengths[i], pad_seq);
max_seq_in_batch = max_tokens / padded_seq_len;
n_seq_in_batch = 1;
n_tok_in_batch = padded_seq_len;
}
}
return batches;
} // make_batches<MAKE_BATCHES_V0P5_EVEN_BETTER>
} // namespace native
} // namespace at
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("make_batches_v0p5_better", &at::native::make_batches<at::native::MAKE_BATCHES_V0P5_BETTER>); // Relying on this line for instantiation
} // PYBIND11_MODULE
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
namespace at {
namespace native {
namespace {
int64_t roundup(int64_t x, int64_t multiple)
{
return (x + multiple - 1) / multiple * multiple;
} // roundup
int64_t rounddown(int64_t x, int64_t multiple)
{
return x / multiple * multiple;
} // rounddown
std::pair<std::vector<int64_t>, std::vector<int64_t>> create_bucket_bounds_lists(
int64_t max_allowable_seq_length,
int64_t bucket_specify_min_boundary,
float bucket_specify_growth_scale,
bool use_efficient_last_pack)
{
std::vector<int64_t> bucket_boundaries;
auto x = bucket_specify_min_boundary;
while (x < max_allowable_seq_length)
{
bucket_boundaries.emplace_back(x);
x = std::max(x + 1, static_cast<int64_t>(x * bucket_specify_growth_scale));
}
std::vector<int64_t> buckets_min_list;
buckets_min_list.emplace_back(0);
std::vector<int64_t> buckets_max_list;
if (use_efficient_last_pack)
{
for (auto && bound : bucket_boundaries)
{
buckets_min_list.emplace_back(bound + 1);
buckets_max_list.emplace_back(bound);
}
buckets_max_list.push_back(max_allowable_seq_length);
}
else
{
for (auto && bound : bucket_boundaries)
{
buckets_min_list.emplace_back(bound);
buckets_max_list.emplace_back(bound);
}
buckets_max_list.emplace_back(max_allowable_seq_length + 1);
}
return std::make_pair(buckets_min_list, buckets_max_list);
} // create_bucket_bounds_lists
int64_t seq_len_to_bucket_idx(
int64_t seq_length,
int64_t max_seq_length,
std::vector<int64_t> buckets_min_list,
std::vector<int64_t> buckets_max_list)
{
int64_t idx = 0;
// TODO: Update to bisection if execution time actually matters (avoiding premature optimization)
// TODO: Alternate is to make lookup table keys = [0...256] and values are the buckets (this loop just indexes in to avoid repeated traversals)
if (seq_length <= max_seq_length)
{
while (idx < static_cast<int64_t>(buckets_min_list.size()) && !(buckets_min_list[idx] <= seq_length && seq_length < buckets_max_list[idx]))
{
++idx;
}
}
else
{
idx = -1;
}
return idx;
} // seq_len_to_bucket_idx
int64_t seq_len_to_bucket_idx_improved_pack(
int64_t seq_length,
int64_t max_seq_length,
std::vector<int64_t> buckets_min_list,
std::vector<int64_t> buckets_max_list)
{
int64_t idx = 0;
// TODO: Update to bisection if execution time actually matters (avoiding premature optimization)
// TODO: Alternate is to make lookup table keys = [0...256] and values are the buckets (this loop just indexes in to avoid repeated traversals)
if (seq_length <= max_seq_length)
{
while (idx < static_cast<int64_t>(buckets_min_list.size()) && !(buckets_min_list[idx] <= seq_length && seq_length <= buckets_max_list[idx]))
{
++idx;
}
}
else
{
idx = -1;
}
return idx;
} // seq_len_to_bucket_idx_improved_pack
std::pair<std::vector<int64_t>, std::vector<int64_t>> create_seq_to_bucket_id_list_and_n_seq_per_batch(
std::vector<int64_t> n_tok_per_seq, // Number of tokens per sequence
int64_t max_allowable_seq_length, // Maximum sequence length to be considered (rejected if over)
int64_t max_tokens, // Maximum number of tokens allowed in the batch
int64_t pad_seq_per_batch_to_multiple_of, // Padding multiple required, for number of sequences in batch
int64_t pad_tok_per_seq_to_multiple_of, // Padding multiple required, for number of tokens for sequence
int64_t bucket_specify_min_boundary, // This is the first non-zero beginning of a bucket (zero implicitly added)
float bucket_specify_growth_scale, // The next bucket bound is determined from the previous based on this factor
bool do_seq_len_padding_to_multiple, // Switch, enables padding sequence length to multiple
bool do_batch_size_rounding_down_to_multiple, // Switch, enables making other dimension of batch a multiple, based on number of sequences
bool do_dynamic_batch_size_choice, // Switch, enables choosing between methods on a batch-by-batch basis for efficiency
bool use_efficient_last_pack) // Switch, modifies bucket bounds logic to improve batching
{
const auto min_max_bounds = create_bucket_bounds_lists(
max_allowable_seq_length,
bucket_specify_min_boundary,
bucket_specify_growth_scale,
use_efficient_last_pack);
std::vector<int64_t> n_seq_per_batch;
std::vector<int64_t> bucket_idx_list;
const auto bucket_interval_min = min_max_bounds.first;
const auto bucket_interval_max = min_max_bounds.second;
// Choose method
if (do_seq_len_padding_to_multiple)
{
for (auto && item : bucket_interval_max)
{
n_seq_per_batch.emplace_back(max_tokens / roundup(item, pad_tok_per_seq_to_multiple_of));
}
}
else if (do_batch_size_rounding_down_to_multiple)
{
for (auto && item : bucket_interval_max)
{
n_seq_per_batch.emplace_back(rounddown(max_tokens / item, pad_seq_per_batch_to_multiple_of));
}
}
else if (do_dynamic_batch_size_choice)
{
for (auto && item : bucket_interval_max)
{
auto option1 = max_tokens / roundup(item, pad_tok_per_seq_to_multiple_of);
auto option2 = rounddown(max_tokens / item, pad_seq_per_batch_to_multiple_of);
n_seq_per_batch.emplace_back(std::max(option1, option2));
}
}
else
{
for (auto && item : bucket_interval_max)
{
n_seq_per_batch.emplace_back(max_tokens / item);
}
}
// Choose more efficient bounds
if (use_efficient_last_pack)
{
for (auto && seq_length : n_tok_per_seq)
{
int64_t bucket_idx = seq_len_to_bucket_idx_improved_pack(seq_length, max_allowable_seq_length, bucket_interval_min, bucket_interval_max);
bucket_idx_list.push_back(bucket_idx);
}
}
else
{
for (auto && seq_length : n_tok_per_seq)
{
auto bucket_idx = seq_len_to_bucket_idx(seq_length, max_allowable_seq_length, bucket_interval_min, bucket_interval_max);
bucket_idx_list.emplace_back(bucket_idx);
}
}
return std::make_pair(bucket_idx_list, n_seq_per_batch);
} // create_seq_to_bucket_id_list_and_n_seq_per_batch
std::vector<std::vector<int64_t> > make_batches_v0p6(
py::array_t<int64_t> src_lengths,
py::array_t<int64_t> tgt_lengths,
py::array_t<int64_t> idx_list,
int64_t max_tokens,
int64_t max_sentences,
int64_t bsz_mult,
int64_t max_len,
int64_t bucket_specify_min_boundary,
float bucket_specify_growth_scale,
int64_t batch_strategy,
bool use_efficient_last_pack)
{
auto src_l = src_lengths.unchecked<1>();
auto tgt_l = tgt_lengths.unchecked<1>();
auto idx_l = idx_list.unchecked<1>();
std::vector<std::vector<int64_t> > batches(1);
std::vector<int64_t> n_tok_per_seq;
for (int64_t i = 0; i < src_l.shape(0); ++i)
{
const int64_t src_len = src_l(i);
const int64_t tgt_len = tgt_l(i);
n_tok_per_seq.emplace_back(std::max(src_len, tgt_len));
}
const bool do_seq_len_padding_to_multiple = batch_strategy == 1;
const bool do_batch_size_rounding_down_to_multiple = batch_strategy == 0;
const bool do_dynamic_batch_size_choice = batch_strategy == 2;
// Get vector of bucket ids (one per seq)
const auto bucket_ids_and_n_seq_per_batch = create_seq_to_bucket_id_list_and_n_seq_per_batch(
n_tok_per_seq,
max_len,
max_tokens,
bsz_mult,
bsz_mult, // TODO: Make this independently varied (for now assumed to be 8 for both anyways)
bucket_specify_min_boundary,
bucket_specify_growth_scale,
do_seq_len_padding_to_multiple,
do_batch_size_rounding_down_to_multiple,
do_dynamic_batch_size_choice,
use_efficient_last_pack);
const auto bucket_ids = bucket_ids_and_n_seq_per_batch.first;
const auto n_seq_per_batch = bucket_ids_and_n_seq_per_batch.second;
// Get buckets
const auto min_max_bounds = create_bucket_bounds_lists(
max_len,
bucket_specify_min_boundary,
bucket_specify_growth_scale,
use_efficient_last_pack);
const auto bucket_interval_min = min_max_bounds.first;
const auto bucket_interval_max = min_max_bounds.second;
// Fill buckets
std::vector<std::vector<int64_t> > buckets(bucket_interval_min.size(), std::vector<int64_t>());
int64_t id_cnt = 0;
for (auto && id : bucket_ids)
{
if (id == -1)
{
id_cnt += 1;
}
}
int64_t reject_count = 0;
int64_t dummy = 0;
for (int64_t i = 0; i < static_cast<int64_t>(bucket_ids.size()); ++i)
{
if (bucket_ids[i] >= 0)
{
const auto bidx = bucket_ids[i];
buckets[bidx].emplace_back(i);
}
else
{
++reject_count;
}
}
// Get number sequences rejected due to sequence length
std::cout << reject_count << " sequences were omitted due to containing over " << max_len << " tokens." << std::endl;
int64_t batch_n_seq = 0;
for (int64_t i = 0; i < static_cast<int64_t>(buckets.size()); ++i)
{
const auto bucket = buckets[i];
const auto nspb = n_seq_per_batch[i];
const auto bkt_max_len = bucket_interval_max[i];
for (auto && item : bucket)
{
if (batch_n_seq < nspb)
{
batches.back().emplace_back(item);
++batch_n_seq;
}
else
{
std::vector<int64_t> new_batch;
new_batch.emplace_back(item);
batches.emplace_back(new_batch);
batch_n_seq = 1;
}
}
auto &last_batch = batches.back();
if (last_batch.size() % bsz_mult != 0) {
auto batch_size = last_batch.size();
auto max_len = std::max(last_batch.begin(), last_batch.end());
auto tokens = batch_size * roundup(bkt_max_len, bsz_mult);
if (tokens > max_tokens) {
auto half_batch = batch_size / 2;
std::vector<int64_t> new_batch;
new_batch.reserve(half_batch);
std::copy(last_batch.begin() + half_batch, last_batch.end(), std::back_inserter(new_batch));
last_batch.erase(last_batch.begin() + half_batch, last_batch.end());
batches.emplace_back(new_batch);
}
}
batches.emplace_back(std::vector<int64_t>());
batch_n_seq = 0;
}
if (batches.back().empty())
{
batches.pop_back();
}
auto i = std::begin(batches);
while (i != std::end(batches))
{
if ((*i).empty())
{
i = batches.erase(i);
}
else
{
++i;
}
}
return batches;
}
} // namespace
} // namespace at
} // namespace native
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.def("make_batches_v0p6", &at::native::make_batches_v0p6);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment