Commit a2adde40 authored by Vadim Markovtsev's avatar Vadim Markovtsev
Browse files

Swivel: fastprep: replace pthread with std::thread

parent 89bccc63
......@@ -25,7 +25,6 @@
#include <assert.h>
#include <fcntl.h>
#include <pthread.h>
#include <stdio.h>
#include <sys/mman.h>
#include <sys/stat.h>
......@@ -36,7 +35,9 @@
#include <iomanip>
#include <iostream>
#include <map>
#include <mutex>
#include <string>
#include <thread>
#include <tuple>
#include <unordered_map>
#include <vector>
......@@ -250,15 +251,14 @@ class CoocBuffer {
std::vector<int> fds_;
// Ensures that only one buffer file is getting written at a time.
pthread_mutex_t writer_mutex_;
std::mutex writer_mutex_;
};
CoocBuffer::CoocBuffer(const std::string &output_dirname, const int num_shards,
const int shard_size)
: output_dirname_(output_dirname),
num_shards_(num_shards),
shard_size_(shard_size),
writer_mutex_(PTHREAD_MUTEX_INITIALIZER) {
shard_size_(shard_size) {
for (int row = 0; row < num_shards_; ++row) {
for (int col = 0; col < num_shards_; ++col) {
char filename[256];
......@@ -294,14 +294,11 @@ void CoocBuffer::AccumulateCoocs(const cooc_counts_t &coocs) {
bufs[bot_shard_idx].push_back(cooc_t{col_off, row_off, cnt});
}
// XXX TODO: lock
for (int i = 0; i < static_cast<int>(fds_.size()); ++i) {
int rv = pthread_mutex_lock(&writer_mutex_);
assert(rv == 0);
std::lock_guard<std::mutex> rv(writer_mutex_);
const int nbytes = bufs[i].size() * sizeof(cooc_t);
int nwritten = write(fds_[i], bufs[i].data(), nbytes);
assert(nwritten == nbytes);
pthread_mutex_unlock(&writer_mutex_);
}
}
......@@ -634,18 +631,13 @@ int main(int argc, char *argv[]) {
token_to_id_map[vocab[i]] = i;
// Compute the co-occurrences
std::vector<pthread_t> threads;
std::vector<std::thread> threads;
threads.reserve(num_threads);
std::vector<CoocCounter*> counters;
const off_t nbytes_per_thread = input_size / num_threads;
std::cout << "Running " << num_threads << " threads, each on "
<< nbytes_per_thread << " bytes" << std::endl;
pthread_attr_t attr;
if (pthread_attr_init(&attr) != 0) {
std::cerr << "unable to initalize pthreads" << std::endl;
return 1;
}
for (int i = 0; i < num_threads; ++i) {
// We could make this smarter and look around for newlines. But
// realistically that's not going to change things much.
......@@ -658,16 +650,16 @@ int main(int argc, char *argv[]) {
counters.push_back(counter);
pthread_t thread;
pthread_create(&thread, &attr, CoocCounter::Run, counter);
threads.push_back(thread);
threads.emplace_back(CoocCounter::Run, counter);
}
// Wait for threads to finish and collect marginals.
std::vector<double> marginals(vocab.size());
for (int i = 0; i < num_threads; ++i) {
pthread_join(threads[i], 0);
if (i > 0) {
std::cout << "joining thread #" << (i + 1) << std::endl;
}
threads[i].join();
const std::vector<double>& counter_marginals = counters[i]->Marginals();
for (int j = 0; j < static_cast<int>(vocab.size()); ++j)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment