sparse_fwd.h 5.76 KB
Newer Older
1
2
#pragma once

shenzhe's avatar
shenzhe committed
3
4
#include <cstdlib>

5
6
7
8
#include "common.h"

#include "params.h"

shenzhe's avatar
shenzhe committed
9
#include "gfx93/prefill/sparse/dsa_mls/fwd.h"
zhanghj2's avatar
zhanghj2 committed
10
#include "gfx93/prefill/sparse/phase1.h"
zhanghj2's avatar
zhanghj2 committed
11

12
13

enum class FwdFeatures : int {
zhanghj2's avatar
zhanghj2 committed
14
    HEAD_16,
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    HEAD_64,
    HEAD_128,

    HEAD_DIM_576,
    HEAD_DIM_512,

    ATTN_SINK,
    SINK_LSE,
    TOPK_LENGTH
};

class FwdImplBase : public ImplBase<
    SparseAttnFwdParams,
    FwdFeatures
> {};

class Fwd_Sm90_Impl : public FwdImplBase {
    DECLARE_SUPPORTED_FEATURES(
zhanghj2's avatar
zhanghj2 committed
33
        FwdFeatures::HEAD_16,
34
35
36
37
38
39
40
41
42
43
44
        FwdFeatures::HEAD_64,
        FwdFeatures::HEAD_128,
        FwdFeatures::HEAD_DIM_512,
        FwdFeatures::HEAD_DIM_576,
        FwdFeatures::ATTN_SINK,
        FwdFeatures::SINK_LSE,
        FwdFeatures::TOPK_LENGTH
    )

protected:
    void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
shenzhe's avatar
shenzhe committed
45
46
47
48
49
50
        if ((std::getenv("FLASH_MLA_FORCE_DSA_MLS_PREFILL") != nullptr && gfx93::fwd::dsa_mls::can_run(params)) ||
            gfx93::fwd::dsa_mls::should_run(params)) {
            gfx93::fwd::dsa_mls::run(params);
            return;
        }

51
52
        DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
            DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() {
zhanghj2's avatar
zhanghj2 committed
53
                gfx93::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            });
        });
    }
};

static std::vector<at::Tensor> sparse_attn_prefill_interface(
    const at::Tensor &q,
    const at::Tensor &kv,
    const at::Tensor &indices,
    float sm_scale,
    int d_v,
    const std::optional<at::Tensor> &attn_sink,
    const std::optional<at::Tensor> &topk_length
) {
    using bf16 = cutlass::bfloat16_t;
    
    Arch arch = Arch();
zhanghj2's avatar
zhanghj2 committed
71
72
73
    if (!arch.is_gfx93x()) {
        TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
    }
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    KU_CHECK_NDIM(q, 3);
    KU_CHECK_NDIM(kv, 3);
    KU_CHECK_NDIM(indices, 3);
    KU_CHECK_NDIM(attn_sink, 1);
    KU_CHECK_NDIM(topk_length, 1);

    int s_q = q.size(0);
    int s_kv = kv.size(0);
    int h_q = q.size(1);
    int h_kv = kv.size(1);
    int d_qk = q.size(2);
    int topk = indices.size(2);
    bool have_topk_length = topk_length.has_value();

    TORCH_CHECK(d_qk == 576 || d_qk == 512, "Invalid d_qk: ", d_qk);
    TORCH_CHECK(d_v == 512, "Invalid d_v", d_v);
    
    KU_CHECK_DEVICE(q);
    KU_CHECK_DEVICE(kv);
    KU_CHECK_DEVICE(indices);
    KU_CHECK_DEVICE(attn_sink);
    KU_CHECK_DEVICE(topk_length);
    
    KU_CHECK_DTYPE(q, torch::kBFloat16);
    KU_CHECK_DTYPE(kv, torch::kBFloat16);
    KU_CHECK_DTYPE(indices, torch::kInt32);
    KU_CHECK_DTYPE(attn_sink, torch::kFloat32);
    KU_CHECK_DTYPE(topk_length, torch::kInt32);
    
    KU_CHECK_SHAPE(q, s_q, h_q, d_qk);
    KU_CHECK_SHAPE(kv, s_kv, h_kv, d_qk);
    KU_CHECK_SHAPE(indices, s_q, h_kv, topk);
    KU_CHECK_SHAPE(attn_sink, h_q);
    KU_CHECK_SHAPE(topk_length, s_q);
    
    KU_CHECK_LAST_DIM_CONTIGUOUS(q);
    KU_CHECK_LAST_DIM_CONTIGUOUS(kv);
    KU_CHECK_LAST_DIM_CONTIGUOUS(indices);
    KU_CHECK_LAST_DIM_CONTIGUOUS(attn_sink);
    KU_CHECK_LAST_DIM_CONTIGUOUS(topk_length);
    
    // Allocate results and buffers
    at::cuda::CUDAGuard device_guard{(char)q.get_device()};
    auto opts = q.options();
    
    at::Tensor out = torch::empty({s_q, h_q, d_v}, opts);
    at::Tensor lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
    at::Tensor max_logits = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
    KU_CHECK_CONTIGUOUS(out);
    KU_CHECK_CONTIGUOUS(lse);
    KU_CHECK_CONTIGUOUS(max_logits);
zhanghj2's avatar
zhanghj2 committed
125
126
127
128
129
130
131
132
133
134
135
136
137
    bool print_param = false;
    if (const char* val = std::getenv("FLASH_MLA_PRINT_PARAM")) {
        print_param = (std::string(val) == "1");
    }
    if (print_param) {
        fprintf(stderr, "[FlashMLA] [sparse_attn_prefill_interface] [%s] "
            "s_q = %d s_kv = %d h_q = %d h_kv = %d d_qk = %d "
            "topk = %d have_topk_length = %d \n",
            arch.archName.c_str(), 
            s_q, s_kv, h_q, h_kv, d_qk,
            topk, have_topk_length
        );
    }
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    SparseAttnFwdParams params = {
        s_q, s_kv, h_q, h_kv, d_qk, d_v, topk,
        sm_scale, sm_scale * LOG_2_E,

        (bf16*)q.data_ptr(),
        (bf16*)kv.data_ptr(),
        (int*)indices.data_ptr(),
        ku::get_optional_tensor_ptr<float>(attn_sink),
        ku::get_optional_tensor_ptr<int>(topk_length),

        int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)),
        int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)),
        int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)),

        (bf16*)out.data_ptr(),
        (float*)max_logits.data_ptr(),
        (float*)lse.data_ptr(),

        arch.num_sms,
        at::cuda::getCurrentCUDAStream().stream()
    };

    std::vector<FwdFeatures> required_features;
zhanghj2's avatar
zhanghj2 committed
161
    if (h_q <= 16) {
zhanghj2's avatar
zhanghj2 committed
162
163
        required_features.push_back(FwdFeatures::HEAD_16);
    } else if (h_q == 64) {
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        required_features.push_back(FwdFeatures::HEAD_64);
    } else if (h_q == 128) {
        required_features.push_back(FwdFeatures::HEAD_128);
    } else {
        TORCH_CHECK(false, "Unsupported h_q: ", h_q);
    }
    if (d_qk == 576) {
        required_features.push_back(FwdFeatures::HEAD_DIM_576);
    } else if (d_qk == 512) {
        required_features.push_back(FwdFeatures::HEAD_DIM_512);
    } else {
        TORCH_CHECK(false, "Unsupported d_qk: ", d_qk);
    }
    if (attn_sink.has_value()) {
        required_features.push_back(FwdFeatures::ATTN_SINK);
    }
    if (have_topk_length) {
        required_features.push_back(FwdFeatures::TOPK_LENGTH);
    }

zhanghj2's avatar
zhanghj2 committed
184
    if (arch.is_gfx93x()) {
zhanghj2's avatar
zhanghj2 committed
185
186
187
188
189
        Fwd_Sm90_Impl fwd_impl;
        fwd_impl.run(params, required_features);
    } else {
        TORCH_CHECK(false, "Unsupported architecture");
    }
190
191
192

    return {out, max_logits, lse};
}