ft_attention.cpp 9.86 KB
Newer Older
1
2
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
3
4
#include <c10/cuda/CUDAGuard.h>

5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

#include "decoder_masked_multihead_attention.h"

#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...)                  \
  if (TYPE == at::ScalarType::Half) {                                      \
    using scalar_t = at::Half;                                             \
    __VA_ARGS__();                                                         \
  } else if (TYPE == at::ScalarType::BFloat16) {                           \
    using scalar_t = at::BFloat16;                                         \
    __VA_ARGS__();                                                         \
  } else if (TYPE == at::ScalarType::Float)  {                             \
    using scalar_t = float;                                                \
    __VA_ARGS__();                                                         \
  } else {                                                                 \
    AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
  }

template<typename T>
void masked_multihead_attention(const Masked_multihead_attention_params<T>& params,
                                const cudaStream_t& stream);

template<typename T>
void cross_multihead_attention(const Masked_multihead_attention_params<T>& params,
                               const cudaStream_t& stream);

template<typename T>
struct SATypeConverter {
    using Type = T;
};

template<>
struct SATypeConverter<at::Half> {
    using Type = uint16_t;
};

template<>
struct SATypeConverter<at::BFloat16> {
    using Type = __nv_bfloat16;
};

template <typename T>
void set_params(Masked_multihead_attention_params<T> &params,
                const size_t batch_size,
                const size_t nheads,
                const size_t memory_max_seqlen,
                const size_t headdim,
                const int timestep,
                const int rotary_embedding_dim,
57
                const float rotary_base,
58
                const bool neox_rotary_style,
59
                const int qkv_batch_stride,
60
                const int nnz_heads,
61
62
63
64
65
66
                T *q_ptr,
                T *k_ptr,
                T *v_ptr,
                T *k_cache_ptr,
                T *v_cache_ptr,
                int *length_per_sample,
67
68
69
70
                T *rotary_cos,
                T *rotary_sin,
                T *out_ptr,
                int *nnz_head_idx) {
71
72
73
74
75
76
77
78
79
80
81
82
    // Reset the parameters
    memset(&params, 0, sizeof(params));
    params.q = q_ptr;
    params.k = k_ptr;
    params.v = v_ptr;
    params.q_bias = nullptr;
    params.k_bias = nullptr;
    params.v_bias = nullptr;
    params.k_cache = k_cache_ptr;
    params.v_cache = v_cache_ptr;
    params.out = out_ptr;
    params.cache_indir = nullptr;
83
    params.stride = qkv_batch_stride;
84
85
86
87
    params.batch_size = batch_size;
    params.beam_width = 1;
    params.memory_max_len = memory_max_seqlen;
    params.num_heads = nheads;
88
    params.nnz_heads = nnz_heads;
89
90
    params.hidden_size_per_head = headdim;
    params.rotary_embedding_dim = rotary_embedding_dim;
91
    params.rotary_base = rotary_base;
92
93
94
95
96
97
    params.neox_rotary_style = neox_rotary_style;
    params.timestep = timestep;
    params.inv_sqrt_dh = 1.f / sqrt(float(headdim));
    params.total_padding_tokens = nullptr;
    params.masked_tokens = nullptr;
    params.prefix_prompt_lengths = nullptr;
98
    params.max_prefix_prompt_length = 0;
99
100
101
102
103
104
105
106
    params.relative_attention_bias = nullptr;
    params.relative_attention_bias_stride = 0;
    params.cross_attention_out = nullptr;
    params.max_decoder_seq_len = 0;
    params.is_return_cross_attentions = false;
    params.finished = nullptr;
    params.memory_length_per_sample = nullptr;
    params.length_per_sample = length_per_sample;
107
108
109
    params.rotary_cos = rotary_cos;
    params.rotary_sin = rotary_sin;
    params.nnz_head_idx = nnz_head_idx;
110
111
112
113
114
115
116
117
}

torch::Tensor single_query_attention(const torch::Tensor q,
                                     const torch::Tensor k,
                                     const torch::Tensor v,
                                     torch::Tensor k_cache,
                                     torch::Tensor v_cache,
                                     c10::optional<const torch::Tensor> length_per_sample_,
118
119
120
                                     c10::optional<const torch::Tensor> rotary_cos_,
                                     c10::optional<const torch::Tensor> rotary_sin_,
                                     c10::optional<const torch::Tensor> nnz_head_idx_,
121
                                     const int timestep,
122
                                     int rotary_embedding_dim = 0,
123
                                     const float rotary_base = 10000.0f,
124
125
126
127
128
129
                                     const bool neox_rotary_style=true) {
    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
    int batch_size = v_cache.size(0);
    int nheads = v_cache.size(1);
    int memory_max_seqlen = v_cache.size(2);
    int headdim = v_cache.size(3);
130
131
132
    auto input_type = q.scalar_type();
    TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);

133
134
135
    CHECK_SHAPE(q, batch_size, nheads, headdim);
    CHECK_SHAPE(k, batch_size, nheads, headdim);
    CHECK_SHAPE(v, batch_size, nheads, headdim);
136
137
138
139
140
141
142
143
144
    CHECK_SHAPE(v_cache, batch_size, nheads, memory_max_seqlen, headdim);
    // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
    int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8;
    CHECK_SHAPE(k_cache, batch_size, nheads, headdim / packsize, memory_max_seqlen, packsize);
    TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim);
    TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim);
    TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim);
    TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
    CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache);
145

146
147
148
149
150
151
    TORCH_CHECK(q.scalar_type() == input_type);
    TORCH_CHECK(k.scalar_type() == input_type);
    TORCH_CHECK(v.scalar_type() == input_type);
    TORCH_CHECK(k_cache.scalar_type() == input_type);
    TORCH_CHECK(v_cache.scalar_type() == input_type);

152
153
154
155
156
157
158
159
    if (length_per_sample_.has_value()) {
        auto length_per_sample = length_per_sample_.value();
        CHECK_DEVICE(length_per_sample);
        CHECK_SHAPE(length_per_sample, batch_size);
        CHECK_CONTIGUOUS(length_per_sample);
        TORCH_CHECK(length_per_sample.dtype() == torch::kInt32);
    }

160
161
162
    if (rotary_cos_.has_value()) {
        auto rotary_cos = rotary_cos_.value();
        CHECK_DEVICE(rotary_cos);
163
164
        rotary_embedding_dim = rotary_cos.size(0) * 2;
        CHECK_SHAPE(rotary_cos, rotary_embedding_dim / 2);
165
166
167
168
169
170
        CHECK_CONTIGUOUS(rotary_cos);
        TORCH_CHECK(rotary_cos.scalar_type() == input_type);

        TORCH_CHECK(rotary_sin_.has_value());
        auto rotary_sin = rotary_sin_.value();
        CHECK_DEVICE(rotary_sin);
171
        CHECK_SHAPE(rotary_cos, rotary_embedding_dim / 2);
172
173
174
175
176
177
178
179
180
181
182
183
184
        CHECK_CONTIGUOUS(rotary_sin);
        TORCH_CHECK(rotary_sin.scalar_type() == input_type);
    }

    if (nnz_head_idx_.has_value()) {
        auto nnz_head_idx = nnz_head_idx_.value();
        CHECK_DEVICE(nnz_head_idx);
        int nnz_heads = nnz_head_idx.size(0);
        CHECK_SHAPE(nnz_head_idx, nnz_heads);
        CHECK_CONTIGUOUS(nnz_head_idx);
        TORCH_CHECK(nnz_head_idx.dtype() == torch::kInt32);
    }

185
186
187
188
    // Otherwise the kernel will be launched from cuda:0 device
    // Cast to char to avoid compiler warning about narrowing
    at::cuda::CUDAGuard device_guard{(char)q.get_device()};

189
190
    torch::Tensor out = torch::empty_like(q);

191
    DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] {
192
193
194
        using DataType = typename SATypeConverter<scalar_t>::Type;
        Masked_multihead_attention_params<DataType> params;
        set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep,
195
                   rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0),
196
                   nnz_head_idx_.has_value() ? nnz_head_idx_.value().size(0) : 0,
197
198
199
200
201
202
203
                   reinterpret_cast<DataType*>(q.data_ptr()),
                   reinterpret_cast<DataType*>(k.data_ptr()),
                   reinterpret_cast<DataType*>(v.data_ptr()),
                   reinterpret_cast<DataType*>(k_cache.data_ptr()),
                   reinterpret_cast<DataType*>(v_cache.data_ptr()),
                   length_per_sample_.has_value()
                       ? length_per_sample_.value().data_ptr<int>() : nullptr,
204
205
206
207
208
209
210
                   rotary_cos_.has_value()
                       ? reinterpret_cast<DataType*>(rotary_cos_.value().data_ptr()) : nullptr,
                   rotary_sin_.has_value()
                       ? reinterpret_cast<DataType*>(rotary_sin_.value().data_ptr()) : nullptr,
                   reinterpret_cast<DataType*>(out.data_ptr()),
                   nnz_head_idx_.has_value() ? nnz_head_idx_.value().data_ptr<int>() : nullptr
                   );
211
212
213
214
215
216
217
218
219
        auto stream = at::cuda::getCurrentCUDAStream();
        masked_multihead_attention(params, stream);
    });
    return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("single_query_attention", &single_query_attention, "Attention with a single query",
          py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
220
221
222
          py::arg("length_per_sample_"), py::arg("rotary_cos_"),
          py::arg("rotary_sin_"), py::arg("nnz_head_idx_"),
          py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
223
          py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true);
224
}