dispatch_utils.h 2.51 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
#pragma once

#include "common.h"
#include "Tensor.h"
fengzch-das's avatar
fengzch-das committed
5
#include <cuda_fp16.h>
Zhekai Zhang's avatar
Zhekai Zhang committed
6
7
8
9
10

template<typename F>
inline auto dispatchFloat(Tensor::ScalarType scalarType, F &&func) {
    switch (scalarType) {
    case Tensor::BF16:
fengzch-das's avatar
fengzch-das committed
11
        return func.template operator()<__nv_bfloat16>();
Zhekai Zhang's avatar
Zhekai Zhang committed
12
13
14
15
16
17
    case Tensor::FP16:
        return func.template operator()<half>();
    case Tensor::FP32:
        return func.template operator()<float>();
    default:
        assert(false);
18
19
20
21
22
23
24
25
        throw std::invalid_argument("scalarType is not a floating type");
    }
}

template<typename F>
inline auto dispatchFloat16(Tensor::ScalarType scalarType, F &&func) {
    switch (scalarType) {
    case Tensor::BF16:
fengzch-das's avatar
fengzch-das committed
26
        return func.template operator()<__nv_bfloat16>();
27
28
29
30
31
    case Tensor::FP16:
        return func.template operator()<half>();
    default:
        assert(false);
        throw std::invalid_argument("scalarType is not a float16 type");
Zhekai Zhang's avatar
Zhekai Zhang committed
32
33
34
35
36
37
38
    }
}

template<typename F>
inline auto dispatch(Tensor::ScalarType scalarType, F &&func) {
    switch (scalarType) {
    case Tensor::BF16:
fengzch-das's avatar
fengzch-das committed
39
        return func.template operator()<__nv_bfloat16>();
Zhekai Zhang's avatar
Zhekai Zhang committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    case Tensor::FP16:
        return func.template operator()<half>();
    case Tensor::FP32:
        return func.template operator()<float>();
    case Tensor::INT8:
        return func.template operator()<int8_t>();
    case Tensor::INT32:
        return func.template operator()<int32_t>();
    case Tensor::INT64:
        return func.template operator()<int64_t>();
    default:
        throw std::runtime_error("Unsupported scalar type");
    }
}

#pragma nv_diagnostic push
Muyang Li's avatar
Muyang Li committed
56
57
58
// warning #445-D: template parameter "scalar_t" is not used in declaring the parameter types of function template
// "lambda []()->auto::operator auto (*)()"
#pragma nv_diag_suppress 445
Zhekai Zhang's avatar
Zhekai Zhang committed
59
60
template<typename T>
inline bool isTypeMatch(Tensor::ScalarType scalarType) {
Muyang Li's avatar
Muyang Li committed
61
    return dispatch(scalarType, []<typename scalar_t>() { return std::is_same_v<scalar_t, T>; });
Zhekai Zhang's avatar
Zhekai Zhang committed
62
63
64
}
#pragma nv_diagnostic pop

Muyang Li's avatar
Muyang Li committed
65
template<typename F, int... N>
Zhekai Zhang's avatar
Zhekai Zhang committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
inline auto dispatchVal(int val, std::integer_sequence<int, N...>, F &&func) {
    auto call = [&]<int i>() {
        if (val == i) {
            func.template operator()<i>();
        }
    };
    (call.template operator()<N>(), ...);
}

template<typename F>
inline auto dispatchBool(bool val, F &&func) {
    if (val) {
        func.template operator()<true>();
    } else {
        func.template operator()<false>();
    }
}

Muyang Li's avatar
Muyang Li committed
84
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) dispatchFloat(TYPE, [&]<typename scalar_t>() { __VA_ARGS__(); });