Unverified Commit a54b16a2 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

Simplify block manager (#812)

* simplify block manager

* fix lint
parent 2d5f5b30
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/debug_utils.h" #include "src/turbomind/utils/debug_utils.h"
#include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/string_utils.h"
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include <stdexcept> #include <stdexcept>
...@@ -70,7 +71,6 @@ bool BlockManager::Malloc() ...@@ -70,7 +71,6 @@ bool BlockManager::Malloc()
for (int i = 0; i < chunk_size; ++i, ptr += block_size_) { for (int i = 0; i < chunk_size; ++i, ptr += block_size_) {
auto& block = blocks_.emplace_back(); auto& block = blocks_.emplace_back();
block.use_count = 0; block.use_count = 0;
block.ref_count = 0;
block.id = (int)blocks_.size() - 1; block.id = (int)blocks_.size() - 1;
block.timestamp = 0; block.timestamp = 0;
block.data = ptr; block.data = ptr;
...@@ -91,16 +91,23 @@ size_t BlockManager::GetBlockCount(size_t block_size, double ratio) ...@@ -91,16 +91,23 @@ size_t BlockManager::GetBlockCount(size_t block_size, double ratio)
void BlockManager::Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst) void BlockManager::Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst)
{ {
FT_CHECK(src.size() >= delta.size());
std::vector<int> src1(src.size() - delta.size()); std::vector<int> src1(src.size() - delta.size());
std::set_difference(src.begin(), src.end(), delta.begin(), delta.end(), src1.begin()); {
auto end = std::set_difference(src.begin(), src.end(), delta.begin(), delta.end(), src1.begin());
FT_CHECK(end == src1.end());
}
src.swap(src1); src.swap(src1);
std::vector<int> dst1(dst.size() + delta.size()); std::vector<int> dst1(dst.size() + delta.size());
std::set_union(dst.begin(), dst.end(), delta.begin(), delta.end(), dst1.begin()); {
auto end = std::set_union(dst.begin(), dst.end(), delta.begin(), delta.end(), dst1.begin());
FT_CHECK(end == dst1.end());
}
dst.swap(dst1); dst.swap(dst1);
} }
std::vector<const Block*> BlockManager::Allocate(int count) auto BlockManager::Allocate(int count) -> std::pair<BlockIds, UniqueIds>
{ {
while (free_ids_.size() < count) { while (free_ids_.size() < count) {
if (!Malloc()) { if (!Malloc()) {
...@@ -108,30 +115,30 @@ std::vector<const Block*> BlockManager::Allocate(int count) ...@@ -108,30 +115,30 @@ std::vector<const Block*> BlockManager::Allocate(int count)
} }
} }
std::vector<const Block*> ret; BlockIds block_ids(count);
UniqueIds unique_ids(count);
std::vector<int> idxs(count);
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
int idx = free_ids_[i]; int idx = free_ids_[i];
idxs[i] = idx; auto& b = blocks_[idx];
auto& block = blocks_[idx]; FT_CHECK(is_free(b)); // pre-condition: uc == 0 && ts == 0
FT_CHECK(is_free(block)); b.use_count = 1;
block.ref_count = 1; b.unique_id = unique_id_++;
block.use_count = 1; FT_CHECK(is_active(b)); // post-condition
block.unique_id = unique_id_++; block_ids[i] = idx;
ret.push_back(&block); unique_ids[i] = b.unique_id;
} }
Move(free_ids_, idxs, active_ids_); Move(free_ids_, block_ids, active_ids_);
dbg(free_ids_, active_ids_); dbg(free_ids_, active_ids_);
return ret; return {block_ids, unique_ids};
} }
void BlockManager::Evict(int count) void BlockManager::Evict(int count)
{ {
FT_CHECK(count <= cached_ids_.size());
std::vector<int> idxs(cached_ids_); std::vector<int> idxs(cached_ids_);
// get first `count` cached ids according to timestamp // get first `count` cached ids according to timestamp
std::nth_element(idxs.begin(), idxs.begin() + count, idxs.end(), [&](int i, int j) { std::nth_element(idxs.begin(), idxs.begin() + count, idxs.end(), [&](int i, int j) {
...@@ -146,9 +153,9 @@ void BlockManager::Evict(int count) ...@@ -146,9 +153,9 @@ void BlockManager::Evict(int count)
for (const auto& idx : idxs) { for (const auto& idx : idxs) {
auto& b = blocks_[idx]; auto& b = blocks_[idx];
FT_CHECK(is_cached(b)); FT_CHECK(is_cached(b));
b.ref_count = 0;
b.unique_id = 0; b.unique_id = 0;
b.timestamp = 0; b.timestamp = 0;
FT_CHECK(is_free(b));
} }
Move(cached_ids_, idxs, free_ids_); Move(cached_ids_, idxs, free_ids_);
...@@ -156,79 +163,94 @@ void BlockManager::Evict(int count) ...@@ -156,79 +163,94 @@ void BlockManager::Evict(int count)
dbg(cached_ids_, free_ids_); dbg(cached_ids_, free_ids_);
} }
int BlockManager::Free(const std::vector<const Block*>& bs) void BlockManager::Free(BlockIds ids)
{ {
std::vector<int> idxs; std::sort(ids.begin(), ids.end());
for (const auto& p : bs) { for (const auto& i : ids) {
auto& b = blocks_[p->id]; auto& b = blocks_[i];
FT_CHECK(is_cached(b)); FT_CHECK(is_cached(b)); // uc == 0 && ts != 0
if (--b.ref_count == 0) { b.unique_id = 0;
b.unique_id = 0; b.timestamp = 0;
b.timestamp = 0; FT_CHECK(is_free(b));
idxs.push_back(b.id);
}
} }
std::sort(idxs.begin(), idxs.end()); Move(cached_ids_, ids, free_ids_);
Move(cached_ids_, idxs, free_ids_);
dbg(cached_ids_, free_ids_);
return idxs.size();
} }
int BlockManager::Unlock(const std::vector<const Block*>& bs) int BlockManager::Unlock(const BlockIds& ids)
{ {
std::vector<int> idxs; BlockIds unlock;
unlock.reserve(ids.size());
for (const auto& p : bs) {
auto& block = blocks_[p->id]; for (const auto& i : ids) {
FT_CHECK(is_active(block)); auto& b = blocks_[i];
if (--block.use_count == 0) { FT_CHECK(is_active(b)); // pre-condition: uc > 0
idxs.push_back(block.id); if (--b.use_count == 0) {
unlock.push_back(b.id);
FT_CHECK(is_cached(b)); // post-condition
} }
} }
std::sort(idxs.begin(), idxs.end()); std::sort(unlock.begin(), unlock.end());
Move(active_ids_, idxs, cached_ids_); Move(active_ids_, unlock, cached_ids_);
dbg(active_ids_, cached_ids_); dbg(active_ids_, cached_ids_);
return unlock.size();
return idxs.size();
} }
int BlockManager::Lock(const std::vector<const Block*>& bs) int BlockManager::Lock(const BlockIds& ids)
{ {
std::vector<int> idxs; BlockIds lock;
lock.reserve(ids.size());
for (const auto& p : bs) { for (const auto& i : ids) {
auto& block = blocks_[p->id]; auto& b = blocks_[i];
FT_CHECK(is_cached(block)); FT_CHECK(is_cached(b));
if (++block.use_count == 1) { if (++b.use_count == 1) {
idxs.push_back(p->id); lock.push_back(i);
FT_CHECK(is_active(b));
} }
} }
std::sort(idxs.begin(), idxs.end()); std::sort(lock.begin(), lock.end());
Move(cached_ids_, idxs, active_ids_); Move(cached_ids_, lock, active_ids_);
// dbg(cached_ids_, active_ids_); // dbg(cached_ids_, active_ids_);
return idxs.size(); return lock.size();
} }
void BlockManager::Touch(const std::vector<const Block*>& bs) void BlockManager::Touch(const BlockIds& ids)
{ {
std::for_each(bs.crbegin(), bs.crend(), [this](const Block* p) { std::for_each(ids.crbegin(), ids.crend(), [this](int i) {
FT_CHECK(is_active(*p)); FT_CHECK(is_active(blocks_[i]));
const_cast<Block*>(p)->timestamp = timestamp_++; blocks_[i].timestamp = timestamp_++;
}); });
} }
int BlockManager::Verify(const std::vector<int>& block_ids, const std::vector<uint64_t>& unique_ids)
{
FT_CHECK(block_ids.size() == unique_ids.size());
int valid = block_ids.size();
for (int i = 0; i < block_ids.size(); ++i) {
if (unique_id(block_ids[i]) != unique_ids[i]) {
valid = i;
break;
}
}
int miss = 0;
for (int i = valid; i < block_ids.size(); ++i) {
miss += (unique_id(block_ids[i]) != unique_ids[i]);
}
// All later blocks should have been invalidated
FT_CHECK_WITH_INFO(miss == (int)block_ids.size() - valid,
fmtstr("count = %d, valid = %d, miss = %d", (int)block_ids.size(), valid, miss));
return valid;
}
Snapshot BlockManager::TakeSnapshot() Snapshot BlockManager::TakeSnapshot()
{ {
std::vector<int> use_count(blocks_.size()); std::vector<int> use_count(blocks_.size());
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <queue> #include <queue>
#include <sstream>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -22,28 +23,37 @@ namespace turbomind { ...@@ -22,28 +23,37 @@ namespace turbomind {
struct Block { struct Block {
int id; // fixed linear id in the pool int id; // fixed linear id in the pool
int ref_count; // all sequences referencing the block
int use_count; // active sequences using the block int use_count; // active sequences using the block
uint64_t unique_id; // unique for every block allocation uint64_t unique_id; // unique for every block allocation
uint64_t timestamp; uint64_t timestamp;
void* data; void* data;
friend std::ostream& operator<<(std::ostream& os, const Block& block); friend std::ostream& operator<<(std::ostream& os, const Block& block);
friend std::string to_string(const Block& b)
{
std::stringstream ss;
ss << b;
return ss.str();
}
}; };
using BlockIds = std::vector<int>;
using UniqueIds = std::vector<uint64_t>;
inline bool is_active(const Block& block) inline bool is_active(const Block& block)
{ {
return block.ref_count > 0 && block.use_count > 0; // timestamp may be 0 for newly allocated block that has not been written
return block.use_count > 0;
} }
inline bool is_cached(const Block& block) inline bool is_cached(const Block& block)
{ {
return block.ref_count > 0 && block.use_count == 0; return block.use_count == 0 && block.timestamp != 0;
} }
inline bool is_free(const Block& block) inline bool is_free(const Block& block)
{ {
return block.ref_count == 0 && block.use_count == 0 && block.timestamp == 0; return block.use_count == 0 && block.timestamp == 0;
} }
struct Snapshot { struct Snapshot {
...@@ -60,22 +70,24 @@ public: ...@@ -60,22 +70,24 @@ public:
~BlockManager(); ~BlockManager();
// free -> active (use_count = 1, ref_count = 1) // free -> active (use_count = 1, ref_count = 1)
[[nodiscard]] std::vector<const Block*> Allocate(int count); [[nodiscard]] std::pair<BlockIds, UniqueIds> Allocate(int count);
// cached -> active (use_count += 1) // cached -> active (use_count += 1)
[[maybe_unused]] int Lock(const std::vector<const Block*>& bs); [[maybe_unused]] int Lock(const BlockIds& ids);
// active -> cached (use_count -= 1) // active -> cached (use_count -= 1)
[[maybe_unused]] int Unlock(const std::vector<const Block*>& bs); [[maybe_unused]] int Unlock(const BlockIds& ids);
// cached -> free (ref_count = 0) // cached -> free (ref_count = 0)
void Evict(int count); void Evict(int count);
// cached -> free (ref_count -= 1) // cached -> free (ref_count -= 1)
[[maybe_unused]] int Free(const std::vector<const Block*>& bs); void Free(BlockIds bs);
// increase timestamp in reversed order // increase timestamp in reversed order
void Touch(const std::vector<const Block*>& bs); void Touch(const BlockIds& bs);
[[nodiscard]] int Verify(const BlockIds& block_ids, const UniqueIds& unique_ids);
Snapshot TakeSnapshot(); Snapshot TakeSnapshot();
...@@ -99,13 +111,23 @@ public: ...@@ -99,13 +111,23 @@ public:
return (max_block_count_ - blocks_.size()) + free_ids_.size(); return (max_block_count_ - blocks_.size()) + free_ids_.size();
} }
Block& block(int idx)
{
return blocks_[idx];
}
int unique_id(int idx)
{
return blocks_[idx].unique_id;
}
friend std::ostream& operator<<(std::ostream& os, const BlockManager&); friend std::ostream& operator<<(std::ostream& os, const BlockManager&);
private: private:
static size_t GetBlockCount(size_t block_size, double ratio); static size_t GetBlockCount(size_t block_size, double ratio);
// move indices between sets // move indices between sets
static void Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst); static void Move(BlockIds& src, const BlockIds& delta, BlockIds& dst);
// allocate a chunk of blocks // allocate a chunk of blocks
bool Malloc(); bool Malloc();
...@@ -118,13 +140,12 @@ private: ...@@ -118,13 +140,12 @@ private:
std::vector<void*> chunks_; std::vector<void*> chunks_;
std::vector<int> active_ids_; BlockIds active_ids_;
std::vector<int> cached_ids_; BlockIds cached_ids_;
std::vector<int> free_ids_; BlockIds free_ids_;
std::vector<Block> blocks_; // < 100k std::vector<Block> blocks_; // < 100k
// uint64_t unique_id_{1UL << 63};
uint64_t unique_id_{1}; uint64_t unique_id_{1};
uint64_t timestamp_{1}; uint64_t timestamp_{1};
}; };
......
...@@ -505,11 +505,11 @@ void LlamaBatch<T>::Initialize(GenerationState& g) ...@@ -505,11 +505,11 @@ void LlamaBatch<T>::Initialize(GenerationState& g)
FT_CHECK_WITH_INFO(h_cu_block_counts_[i + 1] <= sequence_manager_->max_block_count(), FT_CHECK_WITH_INFO(h_cu_block_counts_[i + 1] <= sequence_manager_->max_block_count(),
std::to_string(h_cu_block_counts_[i + 1])); std::to_string(h_cu_block_counts_[i + 1]));
k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](auto p) { k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](int block_id) {
return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetKey(p->data)); return reinterpret_cast<uintptr_t>(sequence_manager_->GetKeyPtr(block_id));
}); });
v_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), v_ptrs, [&](auto p) { v_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), v_ptrs, [&](int block_id) {
return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetVal(p->data)); return reinterpret_cast<uintptr_t>(sequence_manager_->GetValPtr(block_id));
}); });
} }
......
...@@ -37,29 +37,20 @@ SequenceManager::SequenceManager(size_t layer_num, ...@@ -37,29 +37,20 @@ SequenceManager::SequenceManager(size_t layer_num,
const Sequence* SequenceManager::Create(uint64_t id) const Sequence* SequenceManager::Create(uint64_t id)
{ {
Sequence sequence{id}; Sequence sequence{id};
auto it = sequences_.find(id);
auto it = sequences_.find(id);
if (it != sequences_.end()) { if (it != sequences_.end()) {
if (rank_ == 0) { if (rank_ == 0) {
TM_LOG_WARNING("[SequenceManager][Create] Removing conflicting ID %ld", (long)id); TM_LOG_WARNING("[SequenceManager][Create] Removing conflicting ID %ld", (long)id);
} }
auto& seq = it->second; Erase(it);
if (seq.status != Sequence::kCached) {
unlocked_.insert(unlocked_.end(), seq.blocks.begin(), seq.blocks.end());
}
seq = std::move(sequence);
} }
else { it = sequences_.emplace_hint(it, id, std::move(sequence));
it = sequences_.emplace_hint(it, id, std::move(sequence));
}
return &it->second; return &it->second;
} }
const Sequence* SequenceManager::Get(uint64_t id) const Sequence* SequenceManager::Get(uint64_t id)
{ {
if (auto it = sequences_.find(id); it != sequences_.end()) { if (auto it = sequences_.find(id); it != sequences_.end()) {
auto& sequence = it->second;
return &it->second; return &it->second;
} }
return nullptr; return nullptr;
...@@ -70,23 +61,24 @@ bool SequenceManager::Contains(uint64_t id) ...@@ -70,23 +61,24 @@ bool SequenceManager::Contains(uint64_t id)
return sequences_.find(id) != sequences_.end(); return sequences_.find(id) != sequences_.end();
} }
void SequenceManager::Erase(std::map<uint64_t, Sequence>::iterator it)
{
auto& seq = it->second;
if (seq.status == Sequence::kCached) {
const int count = block_manager_->Verify(seq.blocks, seq.block_unique_ids);
seq.blocks.resize(count);
}
else {
UpdateAndSetUnlock(seq);
}
freed_.insert(freed_.end(), seq.blocks.begin(), seq.blocks.end());
sequences_.erase(it);
}
bool SequenceManager::Erase(uint64_t id) bool SequenceManager::Erase(uint64_t id)
{ {
if (auto it = sequences_.find(id); it != sequences_.end()) { if (auto it = sequences_.find(id); it != sequences_.end()) {
auto& seq = it->second; Erase(it);
if (seq.status != Sequence::kCached) {
unlocked_.insert(unlocked_.end(), seq.blocks.begin(), seq.blocks.end());
freed_.insert(freed_.end(), seq.blocks.begin(), seq.blocks.end());
}
else {
for (int i = 0; i < seq.blocks.size(); ++i) {
// filter invalidated blocks
if (seq.blocks[i]->unique_id == seq.block_unique_ids[i]) {
freed_.push_back(seq.blocks[i]);
}
}
}
sequences_.erase(it);
return true; return true;
} }
return false; return false;
...@@ -94,7 +86,7 @@ bool SequenceManager::Erase(uint64_t id) ...@@ -94,7 +86,7 @@ bool SequenceManager::Erase(uint64_t id)
void SequenceManager::VerifyAndLockCached(const Sequences& sequences) void SequenceManager::VerifyAndLockCached(const Sequences& sequences)
{ {
std::vector<const Block*> blocks; BlockIds blocks;
for (const auto& p : sequences) { for (const auto& p : sequences) {
auto& seq = const_cast<Sequence&>(*p); auto& seq = const_cast<Sequence&>(*p);
if (seq.status != Sequence::kCached) { if (seq.status != Sequence::kCached) {
...@@ -102,13 +94,9 @@ void SequenceManager::VerifyAndLockCached(const Sequences& sequences) ...@@ -102,13 +94,9 @@ void SequenceManager::VerifyAndLockCached(const Sequences& sequences)
} }
FT_CHECK(seq.blocks.size() == seq.block_unique_ids.size()); FT_CHECK(seq.blocks.size() == seq.block_unique_ids.size());
if (need_verify_) { if (need_verify_) {
for (int i = 0; i < seq.blocks.size(); ++i) { const int count = block_manager_->Verify(seq.blocks, seq.block_unique_ids);
if (seq.blocks[i]->unique_id != seq.block_unique_ids[i]) { seq.blocks.resize(count);
seq.blocks.resize(i); seq.block_unique_ids.resize(count);
seq.block_unique_ids.resize(i);
break;
}
}
} }
blocks.insert(blocks.end(), seq.blocks.begin(), seq.blocks.end()); blocks.insert(blocks.end(), seq.blocks.begin(), seq.blocks.end());
seq.cache_len = std::min<int>(seq.cache_len, seq.blocks.size() * block_seq_len_); seq.cache_len = std::min<int>(seq.cache_len, seq.blocks.size() * block_seq_len_);
...@@ -177,8 +165,8 @@ struct Schedule { ...@@ -177,8 +165,8 @@ struct Schedule {
while (vidx < it_) { while (vidx < it_) {
const auto& blocks = seqs[--it_]->blocks; const auto& blocks = seqs[--it_]->blocks;
int count = 0; int count = 0;
for (const auto& p : blocks) { for (const auto& bid : blocks) {
count += static_cast<int>(--use_count_[p->id] == 0); count += static_cast<int>(--use_count_[bid] == 0);
} }
unlocked_[it_] = count; unlocked_[it_] = count;
} }
...@@ -354,21 +342,20 @@ std::vector<int> SequenceManager::CountRequiredBlocks(const Sequences& se ...@@ -354,21 +342,20 @@ std::vector<int> SequenceManager::CountRequiredBlocks(const Sequences& se
return required; return required;
} }
void SequenceManager::AssignAndActivate(const Sequences& sequences, // void SequenceManager::AssignAndActivate(const Sequences& sequences, //
const std::vector<int>& counts, const std::vector<int>& counts,
const std::vector<const Block*>& blocks) const BlockIds& blocks,
const UniqueIds& unique_ids)
{ {
FT_CHECK(sequences.size() == counts.size()); FT_CHECK(sequences.size() == counts.size());
auto first = blocks.begin(); int first = 0;
for (int i = 0; i < sequences.size(); ++i) { for (int i = 0; i < sequences.size(); ++i) {
auto& s = const_cast<Sequence&>(*sequences[i]); auto& s = const_cast<Sequence&>(*sequences[i]);
auto count = counts[i]; auto count = counts[i];
// dbg(count); int last = first + count;
auto last = first + count; FT_CHECK(last <= blocks.size());
std::for_each(first, last, [&](const Block* b) { s.blocks.insert(s.blocks.end(), blocks.begin() + first, blocks.begin() + last);
s.blocks.push_back(b); s.block_unique_ids.insert(s.block_unique_ids.end(), unique_ids.begin() + first, unique_ids.begin() + last);
s.block_unique_ids.push_back(b->unique_id);
});
s.status = Sequence::kActive; s.status = Sequence::kActive;
first = last; first = last;
} }
...@@ -453,11 +440,12 @@ auto SequenceManager::Materialize(Sequences sequences, ...@@ -453,11 +440,12 @@ auto SequenceManager::Materialize(Sequences sequences,
// allocate & assign blocks // allocate & assign blocks
{ {
std::vector<const Block*> blocks; BlockIds block_ids;
UniqueIds unique_ids;
if (schedule.allocate) { if (schedule.allocate) {
blocks = block_manager_->Allocate(schedule.allocate); std::tie(block_ids, unique_ids) = block_manager_->Allocate(schedule.allocate);
} }
AssignAndActivate(schedule.active, schedule.block_counts, blocks); AssignAndActivate(schedule.active, schedule.block_counts, block_ids, unique_ids);
} }
// active -> locked // active -> locked
...@@ -467,6 +455,11 @@ auto SequenceManager::Materialize(Sequences sequences, ...@@ -467,6 +455,11 @@ auto SequenceManager::Materialize(Sequences sequences,
} }
} }
// TM_LOG_ERROR("active: %4d, cached: %4d, free: %4d",
// block_manager_->active_count(),
// block_manager_->cached_count(),
// block_manager_->free_count());
return outcome; return outcome;
} }
......
...@@ -19,8 +19,8 @@ struct Sequence { ...@@ -19,8 +19,8 @@ struct Sequence {
uint64_t id; uint64_t id;
Status status = kCached; Status status = kCached;
std::vector<const Block*> blocks; BlockIds blocks;
std::vector<uint64_t> block_unique_ids; UniqueIds block_unique_ids;
int input_length = 0; int input_length = 0;
...@@ -33,7 +33,7 @@ struct Sequence { ...@@ -33,7 +33,7 @@ struct Sequence {
mutable float rope_theta = 0.f; mutable float rope_theta = 0.f;
Sequence(uint64_t _id): id(_id) {} explicit Sequence(uint64_t _id): id(_id) {}
friend std::ostream& operator<<(std::ostream& os, const Sequence& seq); friend std::ostream& operator<<(std::ostream& os, const Sequence& seq);
}; };
...@@ -87,14 +87,14 @@ public: ...@@ -87,14 +87,14 @@ public:
int step_length, int step_length,
AdjustInputCount adjust); AdjustInputCount adjust);
void* OffsetKey(void* block_ptr) [[nodiscard]] void* GetKeyPtr(int block_id)
{ {
return block_ptr; return block_manager_->block(block_id).data;
} }
void* OffsetVal(void* block_ptr) [[nodiscard]] void* GetValPtr(int block_id)
{ {
return (std::byte*)block_ptr + val_offset_; return (std::byte*)GetKeyPtr(block_id) + val_offset_;
} }
int max_block_count() const noexcept int max_block_count() const noexcept
...@@ -103,6 +103,8 @@ public: ...@@ -103,6 +103,8 @@ public:
} }
private: private:
void Erase(std::map<uint64_t, Sequence>::iterator it);
void CommitUnlockAndFree(); void CommitUnlockAndFree();
void VerifyAndLockCached(const Sequences& sequences); void VerifyAndLockCached(const Sequences& sequences);
...@@ -115,9 +117,10 @@ private: ...@@ -115,9 +117,10 @@ private:
std::vector<int>& context_lengths, std::vector<int>& context_lengths,
const std::vector<uint64_t>& priorities); const std::vector<uint64_t>& priorities);
static void AssignAndActivate(const Sequences& sequences, // static void AssignAndActivate(const Sequences& sequences, //
const std::vector<int>& block_counts, const std::vector<int>& counts,
const std::vector<const Block*>& blocks); const BlockIds& blocks,
const UniqueIds& unique_ids);
private: private:
int block_seq_len_; int block_seq_len_;
...@@ -131,8 +134,8 @@ private: ...@@ -131,8 +134,8 @@ private:
std::unique_ptr<BlockManager> block_manager_; std::unique_ptr<BlockManager> block_manager_;
std::vector<const Block*> unlocked_; BlockIds unlocked_;
std::vector<const Block*> freed_; BlockIds freed_;
}; };
inline std::ostream& operator<<(std::ostream& os, const SequenceManager::Outcome& oc) inline std::ostream& operator<<(std::ostream& os, const SequenceManager::Outcome& oc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment