fp16_switch.h 1.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h

// modified from static_switch.h 
// because MSVC cannot handle std::conditional with constexpr variable

#pragma once

/// @param COND       - a boolean expression to switch by
/// @param ...       - code to execute for true and false
///
/// Usage:
/// ```
/// FP16_SWITCH(flag, [&] {
///     some_function(...);
/// });
/// ```
#define FP16_SWITCH(COND, ...)                                           \
    [&] {                                                                            \
        if (COND) {                                                                  \
21
            using elem_type = __nv_bfloat16;   \
22
23
            return __VA_ARGS__();                                                    \
        } else {                                                                     \
24
            using elem_type = __half;   \
25
26
27
            return __VA_ARGS__();                                                    \
        }                                                                            \
    }()