SequenceManager.cc 14.7 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/BlockManager.h"
#include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/debug_utils.h"
#include "src/turbomind/utils/logger.h"
#include <cstddef>
#include <cstdlib>
#include <ctime>
#include <numeric>
#include <stdexcept>

namespace turbomind {

SequenceManager::SequenceManager(size_t      layer_num,
                                 size_t      head_num,
                                 size_t      head_dim,
                                 size_t      block_seq_len,
                                 double      block_count,
                                 int         chunk_size,
                                 size_t      elem_bits,
                                 int         rank,
                                 IAllocator* allocator):
    block_seq_len_(block_seq_len)
{
    constexpr int kBitsPerByte = 8;

    // [2, L, H, block_seq_len, D]
    size_t block_size = 2UL * layer_num * head_num * block_seq_len * head_dim * elem_bits / kBitsPerByte;

    block_manager_ = std::make_unique<BlockManager>(block_size, block_count, chunk_size, allocator);

    val_offset_ = block_size / 2;
}

const Sequence* SequenceManager::Create(uint64_t id)
{
39
    Sequence sequence{id};
Li Zhang's avatar
Li Zhang committed
40
    auto     it = sequences_.find(id);
Li Zhang's avatar
Li Zhang committed
41
42
43
44
    if (it != sequences_.end()) {
        if (rank_ == 0) {
            TM_LOG_WARNING("[SequenceManager][Create] Removing conflicting ID %ld", (long)id);
        }
Li Zhang's avatar
Li Zhang committed
45
        Erase(it);
Li Zhang's avatar
Li Zhang committed
46
    }
Li Zhang's avatar
Li Zhang committed
47
    it = sequences_.emplace_hint(it, id, std::move(sequence));
Li Zhang's avatar
Li Zhang committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    return &it->second;
}

const Sequence* SequenceManager::Get(uint64_t id)
{
    if (auto it = sequences_.find(id); it != sequences_.end()) {
        return &it->second;
    }
    return nullptr;
}

bool SequenceManager::Contains(uint64_t id)
{
    return sequences_.find(id) != sequences_.end();
}

Li Zhang's avatar
Li Zhang committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
void SequenceManager::Erase(std::map<uint64_t, Sequence>::iterator it)
{
    auto& seq = it->second;
    if (seq.status == Sequence::kCached) {
        const int count = block_manager_->Verify(seq.blocks, seq.block_unique_ids);
        seq.blocks.resize(count);
    }
    else {
        UpdateAndSetUnlock(seq);
    }
    freed_.insert(freed_.end(), seq.blocks.begin(), seq.blocks.end());
    sequences_.erase(it);
}

Li Zhang's avatar
Li Zhang committed
78
79
80
bool SequenceManager::Erase(uint64_t id)
{
    if (auto it = sequences_.find(id); it != sequences_.end()) {
Li Zhang's avatar
Li Zhang committed
81
        Erase(it);
Li Zhang's avatar
Li Zhang committed
82
        return true;
Li Zhang's avatar
Li Zhang committed
83
84
85
86
87
88
    }
    return false;
}

void SequenceManager::VerifyAndLockCached(const Sequences& sequences)
{
Li Zhang's avatar
Li Zhang committed
89
    BlockIds blocks;
Li Zhang's avatar
Li Zhang committed
90
91
92
93
94
95
    for (const auto& p : sequences) {
        auto& seq = const_cast<Sequence&>(*p);
        if (seq.status != Sequence::kCached) {
            continue;
        }
        FT_CHECK(seq.blocks.size() == seq.block_unique_ids.size());
Li Zhang's avatar
Li Zhang committed
96
97
98
99
100
        // Verify cache blocks that may be invalidated
        const int count = block_manager_->Verify(seq.blocks, seq.block_unique_ids);
        seq.blocks.resize(count);
        seq.block_unique_ids.resize(count);

Li Zhang's avatar
Li Zhang committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        blocks.insert(blocks.end(), seq.blocks.begin(), seq.blocks.end());
        seq.cache_len = std::min<int>(seq.cache_len, seq.blocks.size() * block_seq_len_);
        seq.status    = Sequence::kLocked;
    }
    block_manager_->Lock(blocks);
}

void SequenceManager::CommitUnlockAndFree()
{
    if (!unlocked_.empty()) {
        block_manager_->Unlock(unlocked_);
        unlocked_.clear();
    }

    if (!freed_.empty()) {
        block_manager_->Free(freed_);
        freed_.clear();
    }
}

void SequenceManager::UpdateAndSetUnlock(const Sequence& sequence)
{
    FT_CHECK(sequence.status != Sequence::kCached);
    auto& seq = const_cast<Sequence&>(sequence);
    block_manager_->Touch(seq.blocks);
    unlocked_.insert(unlocked_.end(), seq.blocks.begin(), seq.blocks.end());
    seq.status = Sequence::kCached;
}

namespace {

struct Schedule {
    int free;
    int cached;

    int allocate{};
    int evict{};
    int preempt{};

    int last;

142
143
144
    int input_count1;
    int input_count2;

Li Zhang's avatar
Li Zhang committed
145
146
147
148
149
    Sequences        active;
    std::vector<int> block_counts;
    Sequences        inactive;
    Sequences        victims;

150
    Schedule(Snapshot snapshot, int size, int _input_count1, int _input_count2):
Li Zhang's avatar
Li Zhang committed
151
152
153
154
155
        free(snapshot.free),
        cached(snapshot.cached),
        last(size),
        use_count_(std::move(snapshot.use_count)),
        unlocked_(size),
156
157
158
        it_(size),
        input_count1(_input_count1),
        input_count2(_input_count2)
Li Zhang's avatar
Li Zhang committed
159
160
161
162
163
164
165
166
    {
    }

    int Unlock(const Sequences& seqs, int vidx)
    {
        while (vidx < it_) {
            const auto& blocks = seqs[--it_]->blocks;
            int         count  = 0;
Li Zhang's avatar
Li Zhang committed
167
168
            for (const auto& bid : blocks) {
                count += static_cast<int>(--use_count_[bid] == 0);
Li Zhang's avatar
Li Zhang committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
            }
            unlocked_[it_] = count;
        }
        return unlocked_[vidx];
    }

private:
    std::vector<int> use_count_;
    std::vector<int> unlocked_;
    int              it_;
};

template<typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
    os << "[";
    for (int i = 0; i < v.size(); ++i) {
        os << (i ? "," : "") << v[i];
    }
    os << "]";
    return os;
}

std::ostream& operator<<(std::ostream& os, const Schedule& s)
{
    os << "free=" << s.free << ", cached=" << s.cached << ", allocate=" << s.allocate << ", evict=" << s.evict
       << ", preempt=" << s.preempt << ", active=" << s.active << ", victims=" << s.victims
       << ", block_counts=" << s.block_counts << ", inactive=" << s.inactive;
    return os;
}

struct Transaction {
    int index_;
    int block_count_;
203
    int input_count_;
Li Zhang's avatar
Li Zhang committed
204
205
206
207
208
209
210
211
212
213

    int allocate_{};
    int evict_{};
    int preempt_{};

    Sequences victims_;

    const Sequences& sequences_;
    Schedule&        schedule_;

214
215
    explicit Transaction(const Sequences& sequences, int index, int block_count, int input_count, Schedule& sched):
        sequences_(sequences), schedule_(sched), index_(index), block_count_(block_count), input_count_(input_count)
Li Zhang's avatar
Li Zhang committed
216
217
218
219
220
    {
    }

    void Process()
    {
221
222
223
224
225
226
        if (schedule_.input_count1 > 0) {
            int count = block_count_;

            int tmp = std::min(schedule_.free, count);
            count -= tmp;
            allocate_ += tmp;
Li Zhang's avatar
Li Zhang committed
227

228
229
230
            tmp = std::min(schedule_.cached, count);
            count -= tmp;
            evict_ += tmp;
Li Zhang's avatar
Li Zhang committed
231

232
233
234
235
236
237
            for (int vidx = schedule_.last - 1; count && vidx > index_; --vidx) {
                if (sequences_[vidx]->status == Sequence::kCached) {
                    continue;
                }
                victims_.push_back(sequences_[vidx]);
                preempt_ += schedule_.Unlock(sequences_, vidx);
Li Zhang's avatar
Li Zhang committed
238

239
240
241
242
243
244
                if (count <= preempt_) {
                    evict_ += count;
                    count -= count;
                    schedule_.last = vidx;  // ! modifiying `sched_.last` is part of commit
                    break;
                }
Li Zhang's avatar
Li Zhang committed
245
            }
246
247
            if (count == 0) {
                return Commit();
Li Zhang's avatar
Li Zhang committed
248
249
250
            }
        }

251
252
        const_cast<Sequence*>(sequences_[index_])->input_length = 0;
        schedule_.inactive.push_back(sequences_[index_]);
Li Zhang's avatar
Li Zhang committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    }

    void Commit()
    {
        // update available resources
        schedule_.free -= allocate_;
        FT_CHECK(schedule_.free >= 0);
        schedule_.cached += preempt_;
        schedule_.cached -= evict_;
        FT_CHECK(schedule_.cached >= 0);

        // update scheduled operations
        schedule_.allocate += allocate_;
        schedule_.evict += evict_;
        schedule_.preempt += preempt_;
        schedule_.victims.insert(schedule_.victims.end(), victims_.begin(), victims_.end());

        // update active sequences
        schedule_.active.push_back(sequences_[index_]);
        schedule_.block_counts.push_back(block_count_);
273
274
275
276
277
278
279

        if (input_count_ > schedule_.input_count2) {
            input_count_ = schedule_.input_count1;
        }
        schedule_.input_count1 -= input_count_;
        schedule_.input_count2 -= input_count_;
        const_cast<Sequence*>(sequences_[index_])->input_length = input_count_;
Li Zhang's avatar
Li Zhang committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    }
};

std::ostream& operator<<(std::ostream& os, const Transaction& trans)
{
    os << "index=" << trans.index_ << ", block_count=" << trans.block_count_ << ", allocate=" << trans.allocate_
       << ", evict=" << trans.evict_ << ", preempt=" << trans.preempt_ << ", victims=" << trans.victims_;
    return os;
}

}  // namespace

void SequenceManager::SortByPriority(Sequences&                   sequences,
                                     std::vector<int>&            context_lengths,
                                     const std::vector<uint64_t>& priorities)
{
    // sort according to priority
    std::vector<int> idxs(sequences.size());
    std::iota(idxs.begin(), idxs.end(), 0);
    std::sort(idxs.begin(), idxs.end(), [&](int i, int j) {
        return priorities[i] < priorities[j];  //
    });
    Sequences        tmp_sequences(sequences.size());
    std::vector<int> tmp_lengths(context_lengths.size());
    for (int i = 0; i < sequences.size(); ++i) {
        tmp_sequences[i] = sequences[idxs[i]];
        tmp_lengths[i]   = context_lengths[idxs[i]];
    }
    sequences.swap(tmp_sequences);
    context_lengths.swap(tmp_lengths);
}

312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
// template<class P, class... Ts>
// void SortByPriority(const std::vector<P>& priorities, Ts&... ranges)
// {
//     // sort according to priority
//     std::vector<int> idxs(priorities.size());
//     std::iota(idxs.begin(), idxs.end(), 0);
//     std::sort(idxs.begin(), idxs.end(), [&](int i, int j) {
//         return priorities[i] < priorities[j];  //
//     });
//     auto reorder = [&](auto& src) {
//         auto dst = src;
//         for (size_t i = 0; i < idxs.size(); ++i) {
//             dst[i] = src[idxs[i]];
//         }
//         src.swap(dst);
//     };
//     (reorder(ranges), ...);
// }

Li Zhang's avatar
Li Zhang committed
331
332
333
334
335
336
337
338
339
340
341
342
343
std::vector<int> SequenceManager::CountRequiredBlocks(const Sequences&        sequences,
                                                      const std::vector<int>& context_lengths,
                                                      int                     step_length)
{
    std::vector<int> required(sequences.size());
    for (int i = 0; i < sequences.size(); ++i) {
        int seq_len = context_lengths[i] + step_length;
        int count   = (seq_len + block_seq_len_ - 1) / block_seq_len_ - static_cast<int>(sequences[i]->blocks.size());
        required[i] = std::max(0, count);
    }
    return required;
}

Li Zhang's avatar
Li Zhang committed
344
345
346
347
void SequenceManager::AssignAndActivate(const Sequences&        sequences,  //
                                        const std::vector<int>& counts,
                                        const BlockIds&         blocks,
                                        const UniqueIds&        unique_ids)
Li Zhang's avatar
Li Zhang committed
348
349
{
    FT_CHECK(sequences.size() == counts.size());
Li Zhang's avatar
Li Zhang committed
350
    int first = 0;
Li Zhang's avatar
Li Zhang committed
351
352
353
    for (int i = 0; i < sequences.size(); ++i) {
        auto& s     = const_cast<Sequence&>(*sequences[i]);
        auto  count = counts[i];
Li Zhang's avatar
Li Zhang committed
354
355
356
357
        int   last  = first + count;
        FT_CHECK(last <= blocks.size());
        s.blocks.insert(s.blocks.end(), blocks.begin() + first, blocks.begin() + last);
        s.block_unique_ids.insert(s.block_unique_ids.end(), unique_ids.begin() + first, unique_ids.begin() + last);
Li Zhang's avatar
Li Zhang committed
358
359
360
361
362
363
364
365
        s.status = Sequence::kActive;
        first    = last;
    }
}

auto SequenceManager::Materialize(Sequences                    sequences,
                                  std::vector<int>             context_lengths,
                                  const std::vector<uint64_t>& priorities,
366
367
                                  int                          step_length,
                                  AdjustInputCount             adjust) -> Outcome
Li Zhang's avatar
Li Zhang committed
368
369
370
371
372
373
374
375
376
{
    ////////////////////////////////////////////////////////////////////////////////
    /// Schedule the assignment of blocks to sequences

    // process deferred unlock and free operations
    CommitUnlockAndFree();

    SortByPriority(sequences, context_lengths, priorities);

377
378
    // SortByPriority(priorities, sequences, context_lengths);

Li Zhang's avatar
Li Zhang committed
379
380
381
382
    // Verify and lock cache sequences to avoid their blocks being evicted unnoticed
    // the blocks can still be preempted later
    VerifyAndLockCached(sequences);

383
384
    auto [input_count1, input_count2] = adjust(sequences, context_lengths);

Li Zhang's avatar
Li Zhang committed
385
386
387
    std::vector<int> required = CountRequiredBlocks(sequences, context_lengths, step_length);
    // dbg(required);

388
    Schedule schedule(block_manager_->TakeSnapshot(), sequences.size(), input_count1, input_count2);
Li Zhang's avatar
Li Zhang committed
389
390
391

    // `schedule.last` is decreasing in the loop
    for (int i = 0; i < schedule.last; ++i) {
392
393
        const int input_length = context_lengths[i] - sequences[i]->cache_len;
        Transaction{sequences, i, required[i], input_length, schedule}.Process();
Li Zhang's avatar
Li Zhang committed
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    }

    // mark remaining sequences invalid
    for (int i = schedule.last; i < sequences.size(); ++i) {
        schedule.inactive.push_back(sequences[i]);
    }

    ////////////////////////////////////////////////////////////////////////////////
    /// Schedule is ready, time to execute it. (locked -> cached -> free -> locked)

    // combine allocate and evict since evicted blocks are reused by allocation
    schedule.allocate += schedule.evict;

    if (schedule.allocate) {
        dbg(*block_manager_);
    }

    Outcome outcome{};
    outcome.allocation = schedule.allocate;
    outcome.swap_in    = std::count_if(schedule.active.begin(), schedule.active.end(), [](auto p) {
        if (p->status != Sequence::kActive) {
            dbg(*p);
        }
        return p->status != Sequence::kActive;  //
    });
    outcome.swap_out   = std::count_if(schedule.inactive.begin(), schedule.inactive.end(), [](auto p) {
        if (p->status == Sequence::kActive) {
            dbg(*p);
        }
        return p->status == Sequence::kActive;  //
    });

    // release preempted blocks -> cached
    if (!schedule.victims.empty()) {
        for (const auto& p : schedule.victims) {
            UpdateAndSetUnlock(*p);
        }
        CommitUnlockAndFree();
    }

    // evict cached blocks -> free
    if (schedule.evict) {
        block_manager_->Evict(schedule.evict);
    }

    // allocate & assign blocks
    {
Li Zhang's avatar
Li Zhang committed
441
442
        BlockIds  block_ids;
        UniqueIds unique_ids;
Li Zhang's avatar
Li Zhang committed
443
        if (schedule.allocate) {
Li Zhang's avatar
Li Zhang committed
444
            std::tie(block_ids, unique_ids) = block_manager_->Allocate(schedule.allocate);
Li Zhang's avatar
Li Zhang committed
445
        }
Li Zhang's avatar
Li Zhang committed
446
        AssignAndActivate(schedule.active, schedule.block_counts, block_ids, unique_ids);
Li Zhang's avatar
Li Zhang committed
447
448
449
450
451
452
453
454
455
    }

    // active -> locked
    for (const auto& p : schedule.inactive) {
        if (p->status == Sequence::kActive) {
            const_cast<Sequence*>(p)->status = Sequence::kLocked;
        }
    }

Li Zhang's avatar
Li Zhang committed
456
457
458
459
460
    // TM_LOG_ERROR("active: %4d, cached: %4d, free: %4d",
    //              block_manager_->active_count(),
    //              block_manager_->cached_count(),
    //              block_manager_->free_count());

Li Zhang's avatar
Li Zhang committed
461
462
463
464
    return outcome;
}

}  // namespace turbomind