jiuge_kv_cache.cpp 2.3 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include "jiuge_impl.hpp"

__C struct KVCache *createKVCache(const JiugeModel *model) {
    KVCache *cache = new KVCache();
    auto ndev = model->dev_resources.size();
    auto nkvh = model->meta.nkvh / ndev;
    auto max_len = model->meta.dctx;
    auto dh = model->meta.dh;
    auto shape = std::vector<size_t>{nkvh, max_len, dh};
    for (unsigned int idev = 0; idev < ndev; idev++) {
        RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
        auto kcache = std::vector<std::shared_ptr<Tensor>>();
        auto vcache = std::vector<std::shared_ptr<Tensor>>();
        for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) {
            kcache.push_back(std::move(Tensor::buffer(model->meta.dt_mat, shape)));
            vcache.push_back(std::move(Tensor::buffer(model->meta.dt_mat, shape)));
        }
        cache->k.push_back(kcache);
        cache->v.push_back(vcache);
    }

    return cache;
}

__C struct KVCache *duplicateKVCache(const JiugeModel *model,
                                     const KVCache *kv_cache,
                                     unsigned int seq_len) {
    auto new_kv_cache = createKVCache(model);
    auto ndev = model->dev_resources.size();
    for (unsigned int idev = 0; idev < ndev; idev++) {
        RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
        for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) {
            new_kv_cache->k[idev][layer]
                ->slice(1, 0, seq_len)
PanZezhong's avatar
PanZezhong committed
35
36
                ->copyFrom(kv_cache->k[idev][layer]->slice(1, 0, seq_len),
                           model->dev_resources[idev].handle);
PanZezhong's avatar
init  
PanZezhong committed
37
38
39

            new_kv_cache->v[idev][layer]
                ->slice(1, 0, seq_len)
PanZezhong's avatar
PanZezhong committed
40
41
                ->copyFrom(kv_cache->v[idev][layer]->slice(1, 0, seq_len),
                           model->dev_resources[idev].handle);
PanZezhong's avatar
init  
PanZezhong committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        }
    }
    return new_kv_cache;
}

__C void dropKVCache(JiugeModel const *model, KVCache *kv_cache) {
    auto ndev = model->dev_resources.size();
    for (unsigned int idev = 0; idev < ndev; idev++) {
        RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
        for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) {
            kv_cache->k[idev][layer].reset();
            kv_cache->v[idev][layer].reset();
        }
    }
    delete kv_cache;
}