ft_attention.cpp 7.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"

#include "decoder_masked_multihead_attention.h"

#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...)                  \
  if (TYPE == at::ScalarType::Half) {                                      \
    using scalar_t = at::Half;                                             \
    __VA_ARGS__();                                                         \
  } else if (TYPE == at::ScalarType::BFloat16) {                           \
    using scalar_t = at::BFloat16;                                         \
    __VA_ARGS__();                                                         \
  } else if (TYPE == at::ScalarType::Float)  {                             \
    using scalar_t = float;                                                \
    __VA_ARGS__();                                                         \
  } else {                                                                 \
    AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
  }

// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...)                  \
//   if (TYPE == at::ScalarType::Half) {                                      \
//     using scalar_t = at::Half;                                             \
//     __VA_ARGS__();                                                         \
//   } else if (TYPE == at::ScalarType::Float)  {                             \
//     using scalar_t = float;                                                \
//     __VA_ARGS__();                                                         \
//   } else {                                                                 \
//     AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
//   }

template<typename T>
void masked_multihead_attention(const Masked_multihead_attention_params<T>& params,
                                const cudaStream_t& stream);

template<typename T>
void cross_multihead_attention(const Masked_multihead_attention_params<T>& params,
                               const cudaStream_t& stream);

template<typename T>
struct SATypeConverter {
    using Type = T;
};

template<>
struct SATypeConverter<at::Half> {
    using Type = uint16_t;
};

template<>
struct SATypeConverter<at::BFloat16> {
    using Type = __nv_bfloat16;
};

template <typename T>
void set_params(Masked_multihead_attention_params<T> &params,
                const size_t batch_size,
                const size_t nheads,
                const size_t memory_max_seqlen,
                const size_t headdim,
                const int timestep,
                const int rotary_embedding_dim,
                const bool neox_rotary_style,
                T *q_ptr,
                T *k_ptr,
                T *v_ptr,
                T *k_cache_ptr,
                T *v_cache_ptr,
                int *length_per_sample,
                T *out_ptr) {
    // Reset the parameters
    memset(&params, 0, sizeof(params));
    params.q = q_ptr;
    params.k = k_ptr;
    params.v = v_ptr;
    params.q_bias = nullptr;
    params.k_bias = nullptr;
    params.v_bias = nullptr;
    params.k_cache = k_cache_ptr;
    params.v_cache = v_cache_ptr;
    params.out = out_ptr;
    params.cache_indir = nullptr;
    params.stride = 0;
    params.batch_size = batch_size;
    params.beam_width = 1;
    params.memory_max_len = memory_max_seqlen;
    params.num_heads = nheads;
    params.hidden_size_per_head = headdim;
    params.rotary_embedding_dim = rotary_embedding_dim;
    params.neox_rotary_style = neox_rotary_style;
    params.timestep = timestep;
    params.inv_sqrt_dh = 1.f / sqrt(float(headdim));
    params.total_padding_tokens = nullptr;
    params.masked_tokens = nullptr;
    params.prefix_prompt_lengths = nullptr;
    // params.max_prefix_prompt_length = memory_max_seqlen;  // TODO: waht should this be?
    params.max_prefix_prompt_length = 0;  // TODO: waht should this be?
    params.relative_attention_bias = nullptr;
    params.relative_attention_bias_stride = 0;
    params.cross_attention_out = nullptr;
    params.max_decoder_seq_len = 0;
    params.is_return_cross_attentions = false;
    params.finished = nullptr;
    params.memory_length_per_sample = nullptr;
    params.length_per_sample = length_per_sample;
}

torch::Tensor single_query_attention(const torch::Tensor q,
                                     const torch::Tensor k,
                                     const torch::Tensor v,
                                     torch::Tensor k_cache,
                                     torch::Tensor v_cache,
                                     c10::optional<const torch::Tensor> length_per_sample_,
                                     const int timestep,
                                     const int rotary_embedding_dim = 0,
                                     const bool neox_rotary_style=true) {
    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
    int batch_size = v_cache.size(0);
    int nheads = v_cache.size(1);
    int memory_max_seqlen = v_cache.size(2);
    int headdim = v_cache.size(3);
    CHECK_SHAPE(q, batch_size, nheads, headdim);
    CHECK_SHAPE(k, batch_size, nheads, headdim);
    CHECK_SHAPE(v, batch_size, nheads, headdim);
    // TODO: Check shape of k_cache: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
    // TODO: avoid contiguous requirment by storing the stride
    CHECK_CONTIGUOUS(q); CHECK_CONTIGUOUS(k); CHECK_CONTIGUOUS(v);
    CHECK_CONTIGUOUS(v_cache);

    if (length_per_sample_.has_value()) {
        auto length_per_sample = length_per_sample_.value();
        CHECK_DEVICE(length_per_sample);
        CHECK_SHAPE(length_per_sample, batch_size);
        CHECK_CONTIGUOUS(length_per_sample);
        TORCH_CHECK(length_per_sample.dtype() == torch::kInt32);
    }

    torch::Tensor out = torch::empty_like(q);

    DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), out.scalar_type(), "single_query_attention", [&] {
        using DataType = typename SATypeConverter<scalar_t>::Type;
        Masked_multihead_attention_params<DataType> params;
        set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep,
                   rotary_embedding_dim, neox_rotary_style,
                   reinterpret_cast<DataType*>(q.data_ptr()),
                   reinterpret_cast<DataType*>(k.data_ptr()),
                   reinterpret_cast<DataType*>(v.data_ptr()),
                   reinterpret_cast<DataType*>(k_cache.data_ptr()),
                   reinterpret_cast<DataType*>(v_cache.data_ptr()),
                   length_per_sample_.has_value()
                       ? length_per_sample_.value().data_ptr<int>() : nullptr,
                   reinterpret_cast<DataType*>(out.data_ptr()));
        auto stream = at::cuda::getCurrentCUDAStream();
        masked_multihead_attention(params, stream);
    });
    return out;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("single_query_attention", &single_query_attention, "Attention with a single query",
          py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
          py::arg("length_per_sample_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
          py::arg("neox_rotary_style")=true);
}