llama_kernels.h 5.93 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/fastertransformer/kernels/gpt_kernels.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_utils.h"
#include <assert.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <numeric>

namespace fastertransformer {

template<typename T>
void invokeRootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, int m, int n, cudaStream_t stream);

template<typename T>
void invokeAddResidual(T* out, const T* in, int m, int n, cudaStream_t stream);

void invokeFixInputIds(int*         ids,
                       const int*   input_ids,
                       const int*   input_lengths,
                       int          batch_size,
                       int          seq_len,
                       int          max_input_len,
                       cudaStream_t st);

template<typename T>
void invokeSliceCausalMask(T* mask, int seq_len, int key_len, int step, int batch_size, cudaStream_t stream);

template<typename T>
void invokeCreateCausalMasks(
    T* mask, const int* q_lens, const int* k_lens, int max_q_len, int max_k_len, int batch_size, cudaStream_t stream);

template<typename T>
void invokeExtendKVCache(T**          k_dst,
                         T**          v_dst,
                         size_t       layer_offset,
                         const T*     k_src,
                         const T*     v_src,
                         int          batch_size,
                         const int*   query_length,
                         int          max_q_len,
                         const int*   history_length,
                         int          max_seq_len,
                         int          size_per_head,
                         int          local_head_num,
                         cudaStream_t stream);

template<typename T>
void invokeTransposeKVCache(T*           key_cache_trans,
                            T*           val_cache_trans,
                            const T**    key_cache,
                            const T**    val_cache,
                            size_t       layer_offset,
                            int          batch_size,
                            const int*   key_length,
                            int          max_kv_len,
                            int          max_seq_len,
                            int          size_per_head,
                            int          head_num,
                            cudaStream_t stream);

void invokeGatherOutput(int*         output_ids,
                        const int*   ids,
                        const int*   context_length,
                        int          max_context_len,
                        int          max_gen_step,
                        int          max_output_len,
                        int          batch_size,
                        cudaStream_t stream);

void invokeMyCopyInt(int* dst, const int* src, size_t count, cudaStream_t st);

template<typename T>
class FlashAttentionOp {
public:
    struct AttentionLayout {
        int  stride_batch;
        int  stride_seq;
        int  stride_head;
        bool use_seqlens       = false;
        int  batch_seqs_offset = 0;
        T**  batch_seqs        = nullptr;
    };

    struct Params {
        T*              attn_out;
        T*              query;
        T*              key;
        T*              val;
        T*              mask;
        float*          out_accum    = nullptr;
        int*            cu_seqlens_q = nullptr;
        int*            cu_seqlens_k = nullptr;
        AttentionLayout layout_q;
        AttentionLayout layout_k;
        AttentionLayout layout_v;
        AttentionLayout layout_o;
    };

public:
    FlashAttentionOp(int batch_size, int head_num, int key_len, int seq_len, int size_per_head);
    ~FlashAttentionOp();

    int get_workspace_size() const;

    void operator()(Params& params, cudaStream_t st) const;

private:
    class impl;
    std::unique_ptr<impl> pimpl;
};

template<typename T>
inline void dump(const T* x, int size, cudaStream_t st, const char* msg, bool full = false)
{
    std::vector<T> h_x(size);
    cudaMemcpyAsync(h_x.data(), x, sizeof(T) * size, cudaMemcpyDefault, st);
    cudaStreamSynchronize(st);
    fprintf(stderr, "\n%s:\n", msg);
    std::vector<float> h_y(h_x.begin(), h_x.end());
    float              asum = 0.f;
    for (const auto& x : h_y) {
        asum += std::fabs(x);
    }
    if (full) {
        for (int i = 0; i < size; ++i) {
            printf("%d %.8f\n", i, h_y[i]);
        }
    }
    else {
        for (int i = 0; i < 8; ++i) {
            fprintf(stderr, "%.8f\n", h_y[i]);
        }
        for (int i = size - 8; i < size; ++i) {
            fprintf(stderr, "%.8f\n", h_y[i]);
        }
    }
    fprintf(stderr, "\nasum = %f\n", asum);
    // getchar();
}

template<typename T>
struct TempBuffer {
    TempBuffer(size_t size)
    {
        deviceMalloc(&data, size, false);
    }
    T* data;
};

template<typename T>
inline T*
transpose_key_cache(T* key_cache, size_t head_num, size_t size_per_head_by_x, size_t mem_len, size_t x, cudaStream_t st)
{
    static TempBuffer<T> buf(8192 * 8192);
    // from: H Dx, S, x
    // to  : S, H Dx, x
    invokeTransposeAxis01(buf.data, key_cache, head_num * size_per_head_by_x, mem_len, x, st);
    return buf.data;
}

template<typename T>
inline T* transpose_value_cache(T* value_cache, size_t head_num, size_t mem_len, size_t size_per_head, cudaStream_t st)
{
    static TempBuffer<T> buf(8192 * 8192);
    invokeTransposeAxis01(buf.data, value_cache, head_num, mem_len, size_per_head, st);
    return buf.data;
}

inline void dump_sequence_len(int* d_seq_len, int step, int tp_rank, cudaStream_t st)
{
    int h_seq_len = -1;
    cudaMemcpyAsync(&h_seq_len, d_seq_len, sizeof(int), cudaMemcpyDefault, st);
    cudaStreamSynchronize(st);
    FT_LOG_ERROR("--------> rank = %d, step = %d, seq_len = %d <--------", tp_rank, step, h_seq_len);
}

}  // namespace fastertransformer