launch.cuh 8.45 KB
Newer Older
lijian6's avatar
lijian6 committed
1
#include "hip/hip_runtime.h"
Chenggang Zhao's avatar
Chenggang Zhao committed
2
3
4
#pragma once

#include "configs.cuh"
5
#include "exception.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
6

lijian6's avatar
lijian6 committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// ROCm helper functions and structures
namespace rocm::experimental {
typedef struct {
    dim3         num_sms;
    dim3         num_threads;
    unsigned int shared_mem_bytes;
    hipStream_t  stream;
} hipLaunchConfig_t;

// Compile time void** kernelArgs array fill with variadic arguments
template <typename T> void fill_kernel_args(void **f, size_t idx, T &&arg) {
    f[idx] = static_cast<void *>(std::addressof(arg));
}

template <typename Head, typename... Tail>
void fill_kernel_args(void **f, size_t idx, Head &&head, Tail &&...tail) {
    f[idx] = static_cast<void *>(std::addressof(head));
    fill_kernel_args(f, idx + 1, std::forward<Tail>(tail)...);
}
} // namespace rocm::experimental

Chenggang Zhao's avatar
Chenggang Zhao committed
28
#ifndef SETUP_LAUNCH_CONFIG
lijian6's avatar
lijian6 committed
29
30
31
32
33
34
35
// The code below is a workaround for ROCm. All the proposed overhead
// is to match current macro signatures and should be reworked once
// cudaLaunchKernelExt() hip alternative is live.
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream)                                          \
    rocm::experimental::hipLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream};

#endif // #ifndef SETUP_LAUNCH_CONFIG
Chenggang Zhao's avatar
Chenggang Zhao committed
36
37

#ifndef LAUNCH_KERNEL
lijian6's avatar
lijian6 committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
template <typename T, typename Kern, typename... Args>
inline void LAUNCH_KERNEL(T &&config, Kern &&kernel, Args &&...args) {
    constexpr size_t k_num_kernel_args = sizeof...(args);
    void            *kernel_args[k_num_kernel_args];
    rocm::experimental::fill_kernel_args(kernel_args, 0, std::forward<Args>(args)...);
    CUDA_CHECK(hipLaunchCooperativeKernel(std::forward<Kern>(kernel), config->num_sms,
                                                      config->num_threads, kernel_args,
                                                      config->shared_mem_bytes, config->stream));
}

template <typename T, typename Kern, typename... Args>
inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...args) {
    hipLaunchKernelGGL((*kernel), dim3(config->num_sms), dim3(config->num_threads), config->shared_mem_bytes, config->stream, 
        std::forward<Args>(args)...);
}
53

lijian6's avatar
lijian6 committed
54
#endif // #ifndef LAUNCH_KERNEL
Chenggang Zhao's avatar
Chenggang Zhao committed
55

lijian6's avatar
lijian6 committed
56
57
58
59
60
61
62
63
64
65
66
67
#define SWITCH_RANKS(case_macro)                                                                   \
    switch (num_ranks) {                                                                           \
    case 2:                                                                                        \
        case_macro(2);                                                                             \
    case 4:                                                                                        \
        case_macro(4);                                                                             \
    case 8:                                                                                        \
        case_macro(8);                                                                             \
    default:                                                                                       \
        EP_HOST_ASSERT(false and "Unsupported ranks");                                         \
    }                                                                                              \
    while (false)
Chenggang Zhao's avatar
Chenggang Zhao committed
68

lijian6's avatar
lijian6 committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#define SWITCH_RDMA_RANKS(case_macro)                                                              \
    switch (num_ranks / NUM_MAX_NVL_PEERS) {                                                       \
    case 2:                                                                                        \
        case_macro(2);                                                                             \
    case 3:                                                                                        \
        case_macro(3);                                                                             \
    case 4:                                                                                        \
        case_macro(4);                                                                             \
    case 8:                                                                                        \
        case_macro(8);                                                                             \
    case 16:                                                                                       \
        case_macro(16);                                                                            \
    case 18:                                                                                       \
        case_macro(18);                                                                            \
    case 20:                                                                                       \
        case_macro(20);                                                                            \
    default:                                                                                       \
        EP_HOST_ASSERT(false and "Unsupported RDMA ranks");                                       \
    }                                                                                              \
    while (false)
Chenggang Zhao's avatar
Chenggang Zhao committed
89

lijian6's avatar
lijian6 committed
90
91
92
93
94
95
96
97
98
99
100
101
#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro)                                                 \
    switch (num_ranks) {                                                                           \
    case 2:                                                                                        \
        case_macro(dtype, 2);                                                                      \
    case 4:                                                                                        \
        case_macro(dtype, 4);                                                                      \
    case 8:                                                                                        \
        case_macro(dtype, 8);                                                                      \
    default:                                                                                       \
        EP_HOST_ASSERT(false and "Unsupported ranks");                                            \
    }                                                                                              \
    while (false)
Chenggang Zhao's avatar
Chenggang Zhao committed
102

lijian6's avatar
lijian6 committed
103
104
105
106
107
108
109
110
111
112
#define SWITCH_TYPES(case_macro)                                                                   \
    switch (type) {                                                                                \
    case HIP_R_16BF:                                                                               \
        case_macro(hip_bfloat16);                                                                  \
    case HIP_R_32F:                                                                                \
        case_macro(float);                                                                         \
    default:                                                                                       \
        EP_HOST_ASSERT(false and "Unsupported type");                                             \
    }                                                                                              \
    while (false)
Chenggang Zhao's avatar
Chenggang Zhao committed
113

lijian6's avatar
lijian6 committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#define SWITCH_HIDDEN(case_macro)                                                                  \
    switch (hidden) {                                                                              \
    case 2560:                                                                                     \
        case_macro(2560);                                                                          \
    case 5120:                                                                                     \
        case_macro(5120);                                                                          \
    case 4096:                                                                                     \
        case_macro(4096);                                                                          \
    case 7168:                                                                                     \
        case_macro(7168);                                                                          \
    default:                                                                                       \
        EP_HOST_ASSERT(false and "Unsupported hidden");                                           \
    }                                                                                              \
    while (false)