flash_memory.h 2.4 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

#define DEFINE_SPLITKV_MEMORY_MANAGER \
class SplitkvMemoryManager { \
private: \
    int num_splits, batch_size, num_heads, seqlen_q, head_size; \
    bool is_split_empty = true; \
public: \
    at::Tensor scores_sum, scores_max, out_accum; \
public: \
    void set_split_memory( \
        int __num_splits, int __batch_size, int __num_heads, \
        int __seqlen_q, int __head_size, const torch::TensorOptions opts) { \
        if (!this->is_split_empty and this->is_same_split(__num_splits, __batch_size, __num_heads, __seqlen_q, __head_size)) { \
        } \
        else if (!this->is_split_empty and this->is_compatible_split(__num_splits, __batch_size, __num_heads, __seqlen_q, __head_size)) { \
            this->scores_sum = this->scores_sum.view({__num_splits, __batch_size, __num_heads, __seqlen_q}).contiguous(); \
            this->scores_max = this->scores_max.view({__num_splits, __batch_size, __num_heads, __seqlen_q}).contiguous(); \
            this->out_accum  = this->out_accum.view({__num_splits, __batch_size, __num_heads, __seqlen_q, __head_size}).contiguous(); \
        } \
        else { \
            auto raw_memory  = torch::empty({2, __num_splits, __batch_size, __num_heads, __seqlen_q}, opts.dtype(at::kFloat)); \
            this->scores_sum = raw_memory.index({0}); \
            this->scores_max = raw_memory.index({1}); \
            this->out_accum  = torch::empty({__num_splits, __batch_size, __num_heads, __seqlen_q, __head_size}, opts.dtype(at::kHalf)); \
            this->num_splits = __num_splits; \
            this->batch_size = __batch_size; \
            this->num_heads  = __num_heads; \
            this->seqlen_q   = __seqlen_q; \
            this->head_size  = __head_size; \
            this->is_split_empty = false; \
        } \
    } \
    bool is_same_split(int __num_splits, int __batch_size, int __num_heads, int __seqlen_q, int __head_size) { \
        return (this->num_splits == __num_splits) and (this->batch_size == __batch_size) \
            and (this->num_heads == __num_heads) and (this->seqlen_q == __seqlen_q) and (this->head_size == __head_size); \
    } \
    bool is_compatible_split(int __num_splits, int __batch_size, int __num_heads, int __seqlen_q, int __head_size) { \
        return (__num_splits * __batch_size * __num_heads * __seqlen_q * __head_size) \
            == (this->num_splits * this->batch_size * this->num_heads * this->seqlen_q * this->head_size); \
    } \
};