Unverified Commit a54b16a2 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

Simplify block manager (#812)

* simplify block manager

* fix lint
parent 2d5f5b30
......@@ -4,6 +4,7 @@
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/debug_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/string_utils.h"
#include <algorithm>
#include <iterator>
#include <stdexcept>
......@@ -70,7 +71,6 @@ bool BlockManager::Malloc()
for (int i = 0; i < chunk_size; ++i, ptr += block_size_) {
auto& block = blocks_.emplace_back();
block.use_count = 0;
block.ref_count = 0;
block.id = (int)blocks_.size() - 1;
block.timestamp = 0;
block.data = ptr;
......@@ -91,16 +91,23 @@ size_t BlockManager::GetBlockCount(size_t block_size, double ratio)
void BlockManager::Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst)
{
FT_CHECK(src.size() >= delta.size());
std::vector<int> src1(src.size() - delta.size());
std::set_difference(src.begin(), src.end(), delta.begin(), delta.end(), src1.begin());
{
auto end = std::set_difference(src.begin(), src.end(), delta.begin(), delta.end(), src1.begin());
FT_CHECK(end == src1.end());
}
src.swap(src1);
std::vector<int> dst1(dst.size() + delta.size());
std::set_union(dst.begin(), dst.end(), delta.begin(), delta.end(), dst1.begin());
{
auto end = std::set_union(dst.begin(), dst.end(), delta.begin(), delta.end(), dst1.begin());
FT_CHECK(end == dst1.end());
}
dst.swap(dst1);
}
std::vector<const Block*> BlockManager::Allocate(int count)
auto BlockManager::Allocate(int count) -> std::pair<BlockIds, UniqueIds>
{
while (free_ids_.size() < count) {
if (!Malloc()) {
......@@ -108,30 +115,30 @@ std::vector<const Block*> BlockManager::Allocate(int count)
}
}
std::vector<const Block*> ret;
std::vector<int> idxs(count);
BlockIds block_ids(count);
UniqueIds unique_ids(count);
for (int i = 0; i < count; ++i) {
int idx = free_ids_[i];
idxs[i] = idx;
auto& block = blocks_[idx];
FT_CHECK(is_free(block));
block.ref_count = 1;
block.use_count = 1;
block.unique_id = unique_id_++;
ret.push_back(&block);
int idx = free_ids_[i];
auto& b = blocks_[idx];
FT_CHECK(is_free(b)); // pre-condition: uc == 0 && ts == 0
b.use_count = 1;
b.unique_id = unique_id_++;
FT_CHECK(is_active(b)); // post-condition
block_ids[i] = idx;
unique_ids[i] = b.unique_id;
}
Move(free_ids_, idxs, active_ids_);
Move(free_ids_, block_ids, active_ids_);
dbg(free_ids_, active_ids_);
return ret;
return {block_ids, unique_ids};
}
void BlockManager::Evict(int count)
{
FT_CHECK(count <= cached_ids_.size());
std::vector<int> idxs(cached_ids_);
// get first `count` cached ids according to timestamp
std::nth_element(idxs.begin(), idxs.begin() + count, idxs.end(), [&](int i, int j) {
......@@ -146,9 +153,9 @@ void BlockManager::Evict(int count)
for (const auto& idx : idxs) {
auto& b = blocks_[idx];
FT_CHECK(is_cached(b));
b.ref_count = 0;
b.unique_id = 0;
b.timestamp = 0;
FT_CHECK(is_free(b));
}
Move(cached_ids_, idxs, free_ids_);
......@@ -156,79 +163,94 @@ void BlockManager::Evict(int count)
dbg(cached_ids_, free_ids_);
}
int BlockManager::Free(const std::vector<const Block*>& bs)
void BlockManager::Free(BlockIds ids)
{
std::vector<int> idxs;
std::sort(ids.begin(), ids.end());
for (const auto& p : bs) {
auto& b = blocks_[p->id];
FT_CHECK(is_cached(b));
if (--b.ref_count == 0) {
b.unique_id = 0;
b.timestamp = 0;
idxs.push_back(b.id);
}
for (const auto& i : ids) {
auto& b = blocks_[i];
FT_CHECK(is_cached(b)); // uc == 0 && ts != 0
b.unique_id = 0;
b.timestamp = 0;
FT_CHECK(is_free(b));
}
std::sort(idxs.begin(), idxs.end());
Move(cached_ids_, idxs, free_ids_);
dbg(cached_ids_, free_ids_);
return idxs.size();
Move(cached_ids_, ids, free_ids_);
}
int BlockManager::Unlock(const std::vector<const Block*>& bs)
int BlockManager::Unlock(const BlockIds& ids)
{
std::vector<int> idxs;
for (const auto& p : bs) {
auto& block = blocks_[p->id];
FT_CHECK(is_active(block));
if (--block.use_count == 0) {
idxs.push_back(block.id);
BlockIds unlock;
unlock.reserve(ids.size());
for (const auto& i : ids) {
auto& b = blocks_[i];
FT_CHECK(is_active(b)); // pre-condition: uc > 0
if (--b.use_count == 0) {
unlock.push_back(b.id);
FT_CHECK(is_cached(b)); // post-condition
}
}
std::sort(idxs.begin(), idxs.end());
std::sort(unlock.begin(), unlock.end());
Move(active_ids_, idxs, cached_ids_);
Move(active_ids_, unlock, cached_ids_);
dbg(active_ids_, cached_ids_);
return idxs.size();
return unlock.size();
}
int BlockManager::Lock(const std::vector<const Block*>& bs)
int BlockManager::Lock(const BlockIds& ids)
{
std::vector<int> idxs;
BlockIds lock;
lock.reserve(ids.size());
for (const auto& p : bs) {
auto& block = blocks_[p->id];
FT_CHECK(is_cached(block));
if (++block.use_count == 1) {
idxs.push_back(p->id);
for (const auto& i : ids) {
auto& b = blocks_[i];
FT_CHECK(is_cached(b));
if (++b.use_count == 1) {
lock.push_back(i);
FT_CHECK(is_active(b));
}
}
std::sort(idxs.begin(), idxs.end());
std::sort(lock.begin(), lock.end());
Move(cached_ids_, idxs, active_ids_);
Move(cached_ids_, lock, active_ids_);
// dbg(cached_ids_, active_ids_);
return idxs.size();
return lock.size();
}
void BlockManager::Touch(const std::vector<const Block*>& bs)
void BlockManager::Touch(const BlockIds& ids)
{
std::for_each(bs.crbegin(), bs.crend(), [this](const Block* p) {
FT_CHECK(is_active(*p));
const_cast<Block*>(p)->timestamp = timestamp_++;
std::for_each(ids.crbegin(), ids.crend(), [this](int i) {
FT_CHECK(is_active(blocks_[i]));
blocks_[i].timestamp = timestamp_++;
});
}
int BlockManager::Verify(const std::vector<int>& block_ids, const std::vector<uint64_t>& unique_ids)
{
FT_CHECK(block_ids.size() == unique_ids.size());
int valid = block_ids.size();
for (int i = 0; i < block_ids.size(); ++i) {
if (unique_id(block_ids[i]) != unique_ids[i]) {
valid = i;
break;
}
}
int miss = 0;
for (int i = valid; i < block_ids.size(); ++i) {
miss += (unique_id(block_ids[i]) != unique_ids[i]);
}
// All later blocks should have been invalidated
FT_CHECK_WITH_INFO(miss == (int)block_ids.size() - valid,
fmtstr("count = %d, valid = %d, miss = %d", (int)block_ids.size(), valid, miss));
return valid;
}
Snapshot BlockManager::TakeSnapshot()
{
std::vector<int> use_count(blocks_.size());
......
......@@ -11,6 +11,7 @@
#include <iterator>
#include <numeric>
#include <queue>
#include <sstream>
#include <unordered_map>
#include <vector>
......@@ -22,28 +23,37 @@ namespace turbomind {
struct Block {
int id; // fixed linear id in the pool
int ref_count; // all sequences referencing the block
int use_count; // active sequences using the block
uint64_t unique_id; // unique for every block allocation
uint64_t timestamp;
void* data;
friend std::ostream& operator<<(std::ostream& os, const Block& block);
friend std::string to_string(const Block& b)
{
std::stringstream ss;
ss << b;
return ss.str();
}
};
using BlockIds = std::vector<int>;
using UniqueIds = std::vector<uint64_t>;
inline bool is_active(const Block& block)
{
return block.ref_count > 0 && block.use_count > 0;
// timestamp may be 0 for newly allocated block that has not been written
return block.use_count > 0;
}
inline bool is_cached(const Block& block)
{
return block.ref_count > 0 && block.use_count == 0;
return block.use_count == 0 && block.timestamp != 0;
}
inline bool is_free(const Block& block)
{
return block.ref_count == 0 && block.use_count == 0 && block.timestamp == 0;
return block.use_count == 0 && block.timestamp == 0;
}
struct Snapshot {
......@@ -60,22 +70,24 @@ public:
~BlockManager();
// free -> active (use_count = 1, ref_count = 1)
[[nodiscard]] std::vector<const Block*> Allocate(int count);
[[nodiscard]] std::pair<BlockIds, UniqueIds> Allocate(int count);
// cached -> active (use_count += 1)
[[maybe_unused]] int Lock(const std::vector<const Block*>& bs);
[[maybe_unused]] int Lock(const BlockIds& ids);
// active -> cached (use_count -= 1)
[[maybe_unused]] int Unlock(const std::vector<const Block*>& bs);
[[maybe_unused]] int Unlock(const BlockIds& ids);
// cached -> free (ref_count = 0)
void Evict(int count);
// cached -> free (ref_count -= 1)
[[maybe_unused]] int Free(const std::vector<const Block*>& bs);
void Free(BlockIds bs);
// increase timestamp in reversed order
void Touch(const std::vector<const Block*>& bs);
void Touch(const BlockIds& bs);
[[nodiscard]] int Verify(const BlockIds& block_ids, const UniqueIds& unique_ids);
Snapshot TakeSnapshot();
......@@ -99,13 +111,23 @@ public:
return (max_block_count_ - blocks_.size()) + free_ids_.size();
}
Block& block(int idx)
{
return blocks_[idx];
}
int unique_id(int idx)
{
return blocks_[idx].unique_id;
}
friend std::ostream& operator<<(std::ostream& os, const BlockManager&);
private:
static size_t GetBlockCount(size_t block_size, double ratio);
// move indices between sets
static void Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst);
static void Move(BlockIds& src, const BlockIds& delta, BlockIds& dst);
// allocate a chunk of blocks
bool Malloc();
......@@ -118,13 +140,12 @@ private:
std::vector<void*> chunks_;
std::vector<int> active_ids_;
std::vector<int> cached_ids_;
std::vector<int> free_ids_;
BlockIds active_ids_;
BlockIds cached_ids_;
BlockIds free_ids_;
std::vector<Block> blocks_; // < 100k
// uint64_t unique_id_{1UL << 63};
uint64_t unique_id_{1};
uint64_t timestamp_{1};
};
......
......@@ -505,11 +505,11 @@ void LlamaBatch<T>::Initialize(GenerationState& g)
FT_CHECK_WITH_INFO(h_cu_block_counts_[i + 1] <= sequence_manager_->max_block_count(),
std::to_string(h_cu_block_counts_[i + 1]));
k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](auto p) {
return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetKey(p->data));
k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](int block_id) {
return reinterpret_cast<uintptr_t>(sequence_manager_->GetKeyPtr(block_id));
});
v_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), v_ptrs, [&](auto p) {
return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetVal(p->data));
v_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), v_ptrs, [&](int block_id) {
return reinterpret_cast<uintptr_t>(sequence_manager_->GetValPtr(block_id));
});
}
......
......@@ -37,29 +37,20 @@ SequenceManager::SequenceManager(size_t layer_num,
const Sequence* SequenceManager::Create(uint64_t id)
{
Sequence sequence{id};
auto it = sequences_.find(id);
auto it = sequences_.find(id);
if (it != sequences_.end()) {
if (rank_ == 0) {
TM_LOG_WARNING("[SequenceManager][Create] Removing conflicting ID %ld", (long)id);
}
auto& seq = it->second;
if (seq.status != Sequence::kCached) {
unlocked_.insert(unlocked_.end(), seq.blocks.begin(), seq.blocks.end());
}
seq = std::move(sequence);
Erase(it);
}
else {
it = sequences_.emplace_hint(it, id, std::move(sequence));
}
it = sequences_.emplace_hint(it, id, std::move(sequence));
return &it->second;
}
const Sequence* SequenceManager::Get(uint64_t id)
{
if (auto it = sequences_.find(id); it != sequences_.end()) {
auto& sequence = it->second;
return &it->second;
}
return nullptr;
......@@ -70,23 +61,24 @@ bool SequenceManager::Contains(uint64_t id)
return sequences_.find(id) != sequences_.end();
}
void SequenceManager::Erase(std::map<uint64_t, Sequence>::iterator it)
{
auto& seq = it->second;
if (seq.status == Sequence::kCached) {
const int count = block_manager_->Verify(seq.blocks, seq.block_unique_ids);
seq.blocks.resize(count);
}
else {
UpdateAndSetUnlock(seq);
}
freed_.insert(freed_.end(), seq.blocks.begin(), seq.blocks.end());
sequences_.erase(it);
}
bool SequenceManager::Erase(uint64_t id)
{
if (auto it = sequences_.find(id); it != sequences_.end()) {
auto& seq = it->second;
if (seq.status != Sequence::kCached) {
unlocked_.insert(unlocked_.end(), seq.blocks.begin(), seq.blocks.end());
freed_.insert(freed_.end(), seq.blocks.begin(), seq.blocks.end());
}
else {
for (int i = 0; i < seq.blocks.size(); ++i) {
// filter invalidated blocks
if (seq.blocks[i]->unique_id == seq.block_unique_ids[i]) {
freed_.push_back(seq.blocks[i]);
}
}
}
sequences_.erase(it);
Erase(it);
return true;
}
return false;
......@@ -94,7 +86,7 @@ bool SequenceManager::Erase(uint64_t id)
void SequenceManager::VerifyAndLockCached(const Sequences& sequences)
{
std::vector<const Block*> blocks;
BlockIds blocks;
for (const auto& p : sequences) {
auto& seq = const_cast<Sequence&>(*p);
if (seq.status != Sequence::kCached) {
......@@ -102,13 +94,9 @@ void SequenceManager::VerifyAndLockCached(const Sequences& sequences)
}
FT_CHECK(seq.blocks.size() == seq.block_unique_ids.size());
if (need_verify_) {
for (int i = 0; i < seq.blocks.size(); ++i) {
if (seq.blocks[i]->unique_id != seq.block_unique_ids[i]) {
seq.blocks.resize(i);
seq.block_unique_ids.resize(i);
break;
}
}
const int count = block_manager_->Verify(seq.blocks, seq.block_unique_ids);
seq.blocks.resize(count);
seq.block_unique_ids.resize(count);
}
blocks.insert(blocks.end(), seq.blocks.begin(), seq.blocks.end());
seq.cache_len = std::min<int>(seq.cache_len, seq.blocks.size() * block_seq_len_);
......@@ -177,8 +165,8 @@ struct Schedule {
while (vidx < it_) {
const auto& blocks = seqs[--it_]->blocks;
int count = 0;
for (const auto& p : blocks) {
count += static_cast<int>(--use_count_[p->id] == 0);
for (const auto& bid : blocks) {
count += static_cast<int>(--use_count_[bid] == 0);
}
unlocked_[it_] = count;
}
......@@ -354,21 +342,20 @@ std::vector<int> SequenceManager::CountRequiredBlocks(const Sequences& se
return required;
}
void SequenceManager::AssignAndActivate(const Sequences& sequences, //
const std::vector<int>& counts,
const std::vector<const Block*>& blocks)
void SequenceManager::AssignAndActivate(const Sequences& sequences, //
const std::vector<int>& counts,
const BlockIds& blocks,
const UniqueIds& unique_ids)
{
FT_CHECK(sequences.size() == counts.size());
auto first = blocks.begin();
int first = 0;
for (int i = 0; i < sequences.size(); ++i) {
auto& s = const_cast<Sequence&>(*sequences[i]);
auto count = counts[i];
// dbg(count);
auto last = first + count;
std::for_each(first, last, [&](const Block* b) {
s.blocks.push_back(b);
s.block_unique_ids.push_back(b->unique_id);
});
int last = first + count;
FT_CHECK(last <= blocks.size());
s.blocks.insert(s.blocks.end(), blocks.begin() + first, blocks.begin() + last);
s.block_unique_ids.insert(s.block_unique_ids.end(), unique_ids.begin() + first, unique_ids.begin() + last);
s.status = Sequence::kActive;
first = last;
}
......@@ -453,11 +440,12 @@ auto SequenceManager::Materialize(Sequences sequences,
// allocate & assign blocks
{
std::vector<const Block*> blocks;
BlockIds block_ids;
UniqueIds unique_ids;
if (schedule.allocate) {
blocks = block_manager_->Allocate(schedule.allocate);
std::tie(block_ids, unique_ids) = block_manager_->Allocate(schedule.allocate);
}
AssignAndActivate(schedule.active, schedule.block_counts, blocks);
AssignAndActivate(schedule.active, schedule.block_counts, block_ids, unique_ids);
}
// active -> locked
......@@ -467,6 +455,11 @@ auto SequenceManager::Materialize(Sequences sequences,
}
}
// TM_LOG_ERROR("active: %4d, cached: %4d, free: %4d",
// block_manager_->active_count(),
// block_manager_->cached_count(),
// block_manager_->free_count());
return outcome;
}
......
......@@ -19,8 +19,8 @@ struct Sequence {
uint64_t id;
Status status = kCached;
std::vector<const Block*> blocks;
std::vector<uint64_t> block_unique_ids;
BlockIds blocks;
UniqueIds block_unique_ids;
int input_length = 0;
......@@ -33,7 +33,7 @@ struct Sequence {
mutable float rope_theta = 0.f;
Sequence(uint64_t _id): id(_id) {}
explicit Sequence(uint64_t _id): id(_id) {}
friend std::ostream& operator<<(std::ostream& os, const Sequence& seq);
};
......@@ -87,14 +87,14 @@ public:
int step_length,
AdjustInputCount adjust);
void* OffsetKey(void* block_ptr)
[[nodiscard]] void* GetKeyPtr(int block_id)
{
return block_ptr;
return block_manager_->block(block_id).data;
}
void* OffsetVal(void* block_ptr)
[[nodiscard]] void* GetValPtr(int block_id)
{
return (std::byte*)block_ptr + val_offset_;
return (std::byte*)GetKeyPtr(block_id) + val_offset_;
}
int max_block_count() const noexcept
......@@ -103,6 +103,8 @@ public:
}
private:
void Erase(std::map<uint64_t, Sequence>::iterator it);
void CommitUnlockAndFree();
void VerifyAndLockCached(const Sequences& sequences);
......@@ -115,9 +117,10 @@ private:
std::vector<int>& context_lengths,
const std::vector<uint64_t>& priorities);
static void AssignAndActivate(const Sequences& sequences, //
const std::vector<int>& block_counts,
const std::vector<const Block*>& blocks);
static void AssignAndActivate(const Sequences& sequences, //
const std::vector<int>& counts,
const BlockIds& blocks,
const UniqueIds& unique_ids);
private:
int block_seq_len_;
......@@ -131,8 +134,8 @@ private:
std::unique_ptr<BlockManager> block_manager_;
std::vector<const Block*> unlocked_;
std::vector<const Block*> freed_;
BlockIds unlocked_;
BlockIds freed_;
};
inline std::ostream& operator<<(std::ostream& os, const SequenceManager::Outcome& oc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment