common.h 9.88 KB
Newer Older
1
2
#pragma once

Jiashi Li's avatar
Jiashi Li committed
3
4
#include <span>

5
6
7
8
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <kerutils/supplemental/torch_tensors.h>
zhanghj2's avatar
zhanghj2 committed
9
#include <string>
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <cutlass/bfloat16.h>

static constexpr float LOG_2_E = 1.44269504f;

// Instantiation for tensor.data_ptr<cutlass::bfloat16_t>()
template<>
inline cutlass::bfloat16_t* at::TensorBase::data_ptr<cutlass::bfloat16_t>() const {
    return reinterpret_cast<cutlass::bfloat16_t*>(this->data_ptr());
}

// A struct that holds the architecture information of the current GPU.
struct Arch {
    int major;
    int minor;
    int num_sms;
zhanghj2's avatar
zhanghj2 committed
25
    std::string archName;
26
27
28
29
30
31
32
    cudaDeviceProp* device_prop;

    Arch() {
        device_prop = at::cuda::getCurrentDeviceProperties();
        major = device_prop->major;
        minor = device_prop->minor;
        num_sms = device_prop->multiProcessorCount;
zhanghj2's avatar
zhanghj2 committed
33
        archName = device_prop->gcnArchName;
34
35
36
    }

    bool is_sm90a() const {
zhanghj2's avatar
zhanghj2 committed
37
38
        return major == 9 && minor == 3;
        // return major == 9 && minor == 0;
39
40
41
42
    }

    bool is_sm100f() const {
        return major == 10;
zhanghj2's avatar
zhanghj2 committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    }   

    bool is_gfx938() const {
        return archName.substr(0, archName.find(':')) == "gfx938";
    }

    bool is_gfx936() const {
        return archName.substr(0, archName.find(':')) == "gfx936";
    }

    bool is_gfx928() const {
        return archName.substr(0, archName.find(':')) == "gfx928";
    }

    bool is_gfx93x() const {
        return is_gfx936() || is_gfx938();
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    }
};

// Convert int64_t stride to int32_t, with overflow check.
inline int int64_stride_to_int(int64_t orig_stride) {
    if (orig_stride > std::numeric_limits<int>::max()) {
        TORCH_CHECK(false, "[FlashMLA] Stride exceeds int32 limit: ", orig_stride);
    }
    return static_cast<int>(orig_stride);
}

#define DISPATCH_NUM_HEADS(NUM_HEADS, CONSTEXPR_NAME, ...) \
    [&] () { \
        if (NUM_HEADS == 128) { \
            static constexpr int CONSTEXPR_NAME = 128; \
            return __VA_ARGS__(); \
        } else if (NUM_HEADS == 64) { \
            static constexpr int CONSTEXPR_NAME = 64; \
            return __VA_ARGS__(); \
zhanghj2's avatar
zhanghj2 committed
78
79
80
        } else if (NUM_HEADS <= 16) { \
            static constexpr int CONSTEXPR_NAME = 16; \
            return __VA_ARGS__(); \
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        } else { \
            TORCH_CHECK(false, "Unsupported num_heads_q: ", NUM_HEADS); \
        } \
    } ();

#define DISPATCH_HEAD_DIM(HEAD_DIM, CONSTEXPR_NAME, ...) \
[&] () { \
    if (HEAD_DIM == 576) { \
        static constexpr int CONSTEXPR_NAME = 576; \
        return __VA_ARGS__(); \
    } else if (HEAD_DIM == 512) { \
        static constexpr int CONSTEXPR_NAME = 512; \
        return __VA_ARGS__(); \
    } else { \
        TORCH_CHECK(false, "Unsupported head_dim_qk: ", HEAD_DIM); \
    } \
} ();

#define DISPATCH_BOOLEAN_FLAG(FLAG, CONSTEXPR_NAME, ...) \
    [&] () { \
        if (FLAG) { \
            static constexpr bool CONSTEXPR_NAME = true; \
            return __VA_ARGS__(); \
        } else { \
            static constexpr bool CONSTEXPR_NAME = false; \
            return __VA_ARGS__(); \
        } \
    } ();

#define DISPATCH_MODEL_TYPE(MODEL_TYPE, CONSTEXPR_NAME, ...) \
[&] () { \
    if (MODEL_TYPE == ModelType::V32) { \
        static constexpr ModelType CONSTEXPR_NAME = ModelType::V32; \
        return __VA_ARGS__(); \
    } else if (MODEL_TYPE == ModelType::MODEL1) { \
        static constexpr ModelType CONSTEXPR_NAME = ModelType::MODEL1; \
        return __VA_ARGS__(); \
    } else { \
        TORCH_CHECK(false, "Unsupported model type: ", (int)MODEL_TYPE); \
    } \
} ();

// The following code is adapted from https://ykiko.me/en/articles/680412313/, which converts enum values to string names.
template<auto value>
constexpr auto get_static_enum_name(){
    std::string_view name;
#if __GNUC__ || __clang__
    name = __PRETTY_FUNCTION__;
    std::size_t start = name.find('=') + 2;
    std::size_t end = name.size() - 1;
    name = std::string_view{ name.data() + start, end - start };
    start = name.find("::");
#elif _MSC_VER
    name = __FUNCSIG__;
    std::size_t start = name.find('<') + 1;
    std::size_t end = name.rfind(">(");
    name = std::string_view{ name.data() + start, end - start };
    start = name.rfind("::");
#endif
    return start == std::string_view::npos ? name : std::string_view {
            name.data() + start + 2, name.size() - start - 2
    };
}

template<typename T, std::size_t N = 0> 
static constexpr std::size_t get_enum_max(){
    constexpr T value = static_cast<T>(N);
    if constexpr (get_static_enum_name<value>().find(")") == std::string_view::npos)
        return get_enum_max<T, N + 1>();
    else
        return N;
}

zhanghj2's avatar
zhanghj2 committed
154
template<typename T>
155
static constexpr std::string get_dynamic_enum_name(T value){
zhanghj2's avatar
zhanghj2 committed
156
157
    static_assert(std::is_enum<T>::value, 
                "Template parameter T must be an enumeration type");
158
159
160
161
162
163
164
165
166
    constexpr std::size_t num = get_enum_max<T>();
    constexpr auto names = []<std::size_t... Is>(std::index_sequence<Is...>){
        return std::array<std::string_view, num>{ 
            get_static_enum_name<static_cast<T>(Is)>()... 
        };
    }(std::make_index_sequence<num>{});
    return (std::string)names[static_cast<std::size_t>(value)];
}

zhanghj2's avatar
zhanghj2 committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
template<typename T>
class SimpleSpan {
private:
    const T* data_;
    size_t size_;
    
public:
    constexpr SimpleSpan(const T* data, size_t size) : data_(data), size_(size) {}
    constexpr SimpleSpan(const T* begin, const T* end) : data_(begin), size_(end - begin) {}
    
    constexpr const T* data() const { return data_; }
    constexpr size_t size() const { return size_; }
    constexpr const T* begin() const { return data_; }
    constexpr const T* end() const { return data_ + size_; }
    constexpr const T& operator[](size_t index) const { return data_[index]; }
};


185
186
187
188
// A shortcut macro to declare supported features in an implementation class.
#define DECLARE_SUPPORTED_FEATURES(...) \
protected: \
    static constexpr FeatureT features[] = { __VA_ARGS__ }; \
zhanghj2's avatar
zhanghj2 committed
189
190
    constexpr inline SimpleSpan<const FeatureT> get_supported_features() const override { \
        return SimpleSpan<const FeatureT>(features, std::size(features)); \
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    }

/*
ImplBase - The base class for every implementation.

Every implementation should inherit from this class and implement the pure virtual functions, including:
- `run_`: The function that runs the implementation.
- `get_supported_features`: The function that returns the supported features of the implementation. You may use `DECLARE_SUPPORTED_FEATURES` to declare the supported features in a concise way.

The dispatcher will invoke `ImplBase::run()`, which checks if all required features are supported by the implementation, and then calls `run_`.
*/
template<
    typename RunArgT_,
    typename FeatureT_
>
class ImplBase {
protected:
    using RunArgT = RunArgT_;
    using FeatureT = FeatureT_;

    virtual inline void run_(const RunArgT &params, const std::vector<FeatureT> &required_features) = 0;

zhanghj2's avatar
zhanghj2 committed
213
    constexpr virtual inline SimpleSpan<const FeatureT> get_supported_features() const = 0;
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

    virtual ~ImplBase() = default;

public:
    inline bool check_if_all_features_are_supported(const std::vector<FeatureT> &required_features) {
        for (const auto &required_feature : required_features) {
            bool is_supported = false;
            for (const auto &supported_feature : get_supported_features()) {
                if (required_feature == supported_feature) {
                    is_supported = true;
                    break;
                }
            }
            if (!is_supported) {
                return false;
            }
        }
        return true;
    }

    inline void check_if_all_features_are_supported_and_abort(const std::vector<FeatureT> &required_features) {
        if (!check_if_all_features_are_supported(required_features)) {
            fprintf(stderr, "[FlashMLA] Error: The chosen implementation does not support all required features.\n");
            fprintf(stderr, "Required features:\n");
            for (const auto &f : required_features) {
                fprintf(stderr, "  - %3d: %s\n", static_cast<int>(f), get_dynamic_enum_name(f).c_str());
            }
            fprintf(stderr, "\n");
            fprintf(stderr, "Supported features:\n");
            for (const auto &supported_feature : get_supported_features()) {
                fprintf(stderr, "  - %3d: %s\n", static_cast<int>(supported_feature), get_dynamic_enum_name(supported_feature).c_str());
            }
            fprintf(stderr, "\n");
            fprintf(stderr, "Features that are required but not supported:\n");
            for (const auto &required_feature : required_features) {
                bool is_supported = false;
                for (const auto &supported_feature : get_supported_features()) {
                    if (required_feature == supported_feature) {
                        is_supported = true;
                        break;
                    }
                }
                if (!is_supported) {
                    fprintf(stderr, "  - %3d: %s\n", static_cast<int>(required_feature), get_dynamic_enum_name(required_feature).c_str());
                }
            }
            fprintf(stderr, "\n");
            Arch cur_gpu_arch = Arch();
            fprintf(stderr, "Current GPU: %s, SM %d.%d with %d SMs\n", cur_gpu_arch.device_prop->name, cur_gpu_arch.major, cur_gpu_arch.minor, cur_gpu_arch.num_sms);
            fprintf(stderr, "This means that the dispatcher has chosen an implementation that does not support all required features. Maybe there is a bug in the dispatcher, or you have requested an invalid combination of features.\n");
            TORCH_CHECK(false, "The chosen implementation does not support all required features. See message above for details.");
        }
    }

    inline void run(const RunArgT &params, const std::vector<FeatureT> &required_features) {
269
        check_if_all_features_are_supported_and_abort(required_features);
270
271
272
273
        run_(params, required_features);
    }
};