launch.cuh 8.42 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
2
3
#pragma once

#include "configs.cuh"
4
#include "exception.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
5

lijian6's avatar
lijian6 committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// ROCm helper functions and structures
namespace rocm::experimental {
typedef struct {
    dim3         num_sms;
    dim3         num_threads;
    unsigned int shared_mem_bytes;
    hipStream_t  stream;
} hipLaunchConfig_t;

// Compile time void** kernelArgs array fill with variadic arguments
template <typename T> void fill_kernel_args(void **f, size_t idx, T &&arg) {
    f[idx] = static_cast<void *>(std::addressof(arg));
}

template <typename Head, typename... Tail>
void fill_kernel_args(void **f, size_t idx, Head &&head, Tail &&...tail) {
    f[idx] = static_cast<void *>(std::addressof(head));
    fill_kernel_args(f, idx + 1, std::forward<Tail>(tail)...);
}
} // namespace rocm::experimental

Chenggang Zhao's avatar
Chenggang Zhao committed
27
#ifndef SETUP_LAUNCH_CONFIG
lijian6's avatar
lijian6 committed
28
29
30
31
32
33
34
// The code below is a workaround for ROCm. All the proposed overhead
// is to match current macro signatures and should be reworked once
// cudaLaunchKernelExt() hip alternative is live.
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream)                                          \
    rocm::experimental::hipLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream};

#endif // #ifndef SETUP_LAUNCH_CONFIG
Chenggang Zhao's avatar
Chenggang Zhao committed
35
36

#ifndef LAUNCH_KERNEL
lijian6's avatar
lijian6 committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
template <typename T, typename Kern, typename... Args>
inline void LAUNCH_KERNEL(T &&config, Kern &&kernel, Args &&...args) {
    constexpr size_t k_num_kernel_args = sizeof...(args);
    void            *kernel_args[k_num_kernel_args];
    rocm::experimental::fill_kernel_args(kernel_args, 0, std::forward<Args>(args)...);
    CUDA_CHECK(hipLaunchCooperativeKernel(std::forward<Kern>(kernel), config->num_sms,
                                                      config->num_threads, kernel_args,
                                                      config->shared_mem_bytes, config->stream));
}

template <typename T, typename Kern, typename... Args>
inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...args) {
    hipLaunchKernelGGL((*kernel), dim3(config->num_sms), dim3(config->num_threads), config->shared_mem_bytes, config->stream, 
        std::forward<Args>(args)...);
}
52

lijian6's avatar
lijian6 committed
53
#endif // #ifndef LAUNCH_KERNEL
Chenggang Zhao's avatar
Chenggang Zhao committed
54

lijian6's avatar
lijian6 committed
55
56
57
58
59
60
61
62
63
64
65
66
#define SWITCH_RANKS(case_macro)                                                                   \
    switch (num_ranks) {                                                                           \
    case 2:                                                                                        \
        case_macro(2);                                                                             \
    case 4:                                                                                        \
        case_macro(4);                                                                             \
    case 8:                                                                                        \
        case_macro(8);                                                                             \
    default:                                                                                       \
        EP_HOST_ASSERT(false and "Unsupported ranks");                                         \
    }                                                                                              \
    while (false)
Chenggang Zhao's avatar
Chenggang Zhao committed
67

lijian6's avatar
lijian6 committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#define SWITCH_RDMA_RANKS(case_macro)                                                              \
    switch (num_ranks / NUM_MAX_NVL_PEERS) {                                                       \
    case 2:                                                                                        \
        case_macro(2);                                                                             \
    case 3:                                                                                        \
        case_macro(3);                                                                             \
    case 4:                                                                                        \
        case_macro(4);                                                                             \
    case 8:                                                                                        \
        case_macro(8);                                                                             \
    case 16:                                                                                       \
        case_macro(16);                                                                            \
    case 18:                                                                                       \
        case_macro(18);                                                                            \
    case 20:                                                                                       \
        case_macro(20);                                                                            \
    default:                                                                                       \
        EP_HOST_ASSERT(false and "Unsupported RDMA ranks");                                       \
    }                                                                                              \
    while (false)
Chenggang Zhao's avatar
Chenggang Zhao committed
88

lijian6's avatar
lijian6 committed
89
90
91
92
93
94
95
96
97
98
99
100
#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro)                                                 \
    switch (num_ranks) {                                                                           \
    case 2:                                                                                        \
        case_macro(dtype, 2);                                                                      \
    case 4:                                                                                        \
        case_macro(dtype, 4);                                                                      \
    case 8:                                                                                        \
        case_macro(dtype, 8);                                                                      \
    default:                                                                                       \
        EP_HOST_ASSERT(false and "Unsupported ranks");                                            \
    }                                                                                              \
    while (false)
Chenggang Zhao's avatar
Chenggang Zhao committed
101

lijian6's avatar
lijian6 committed
102
103
104
105
106
107
108
109
110
111
#define SWITCH_TYPES(case_macro)                                                                   \
    switch (type) {                                                                                \
    case HIP_R_16BF:                                                                               \
        case_macro(hip_bfloat16);                                                                  \
    case HIP_R_32F:                                                                                \
        case_macro(float);                                                                         \
    default:                                                                                       \
        EP_HOST_ASSERT(false and "Unsupported type");                                             \
    }                                                                                              \
    while (false)
Chenggang Zhao's avatar
Chenggang Zhao committed
112

lijian6's avatar
lijian6 committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#define SWITCH_HIDDEN(case_macro)                                                                  \
    switch (hidden) {                                                                              \
    case 2560:                                                                                     \
        case_macro(2560);                                                                          \
    case 5120:                                                                                     \
        case_macro(5120);                                                                          \
    case 4096:                                                                                     \
        case_macro(4096);                                                                          \
    case 7168:                                                                                     \
        case_macro(7168);                                                                          \
    default:                                                                                       \
        EP_HOST_ASSERT(false and "Unsupported hidden");                                           \
    }                                                                                              \
    while (false)