dispatch_utils.h 2.07 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#pragma once

#include "common.h"
#include "Tensor.h"
#include <cuda_fp16.h>

template<typename F>
inline auto dispatchFloat(Tensor::ScalarType scalarType, F &&func) {
    switch (scalarType) {
    case Tensor::BF16:
        return func.template operator()<__nv_bfloat16>();
    case Tensor::FP16:
        return func.template operator()<half>();
    case Tensor::FP32:
        return func.template operator()<float>();
    default:
        assert(false);
    }
}

template<typename F>
inline auto dispatch(Tensor::ScalarType scalarType, F &&func) {
    switch (scalarType) {
    case Tensor::BF16:
        return func.template operator()<__nv_bfloat16>();
    case Tensor::FP16:
        return func.template operator()<half>();
    case Tensor::FP32:
        return func.template operator()<float>();
    case Tensor::INT8:
        return func.template operator()<int8_t>();
    case Tensor::INT32:
        return func.template operator()<int32_t>();
    case Tensor::INT64:
        return func.template operator()<int64_t>();
    default:
        throw std::runtime_error("Unsupported scalar type");
    }
}

#pragma nv_diagnostic push
// warning #445-D: template parameter "scalar_t" is not used in declaring the parameter types of function template "lambda []()->auto::operator auto (*)()"
#pragma nv_diag_suppress 445    
template<typename T>
inline bool isTypeMatch(Tensor::ScalarType scalarType) {
    return dispatch(scalarType, []<typename scalar_t>() {
        return std::is_same_v<scalar_t, T>;
    });
}
#pragma nv_diagnostic pop

template<typename F, int ...N>
inline auto dispatchVal(int val, std::integer_sequence<int, N...>, F &&func) {
    auto call = [&]<int i>() {
        if (val == i) {
            func.template operator()<i>();
        }
    };
    (call.template operator()<N>(), ...);
}

template<typename F>
inline auto dispatchBool(bool val, F &&func) {
    if (val) {
        func.template operator()<true>();
    } else {
        func.template operator()<false>();
    }
}


#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) dispatchFloat(TYPE, [&]<typename scalar_t>() { __VA_ARGS__(); });