launch.cuh 8.85 KB
Newer Older
lijian6's avatar
lijian6 committed
1
#include "hip/hip_runtime.h"
Chenggang Zhao's avatar
Chenggang Zhao committed
2
3
4
#pragma once

#include "configs.cuh"
5
#include "exception.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
6

lijian6's avatar
lijian6 committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// ROCm helper functions and structures
namespace rocm::experimental {
typedef struct {
    dim3         num_sms;
    dim3         num_threads;
    unsigned int shared_mem_bytes;
    hipStream_t  stream;
} hipLaunchConfig_t;

// Compile time void** kernelArgs array fill with variadic arguments
template <typename T> void fill_kernel_args(void **f, size_t idx, T &&arg) {
    f[idx] = static_cast<void *>(std::addressof(arg));
}

template <typename Head, typename... Tail>
void fill_kernel_args(void **f, size_t idx, Head &&head, Tail &&...tail) {
    f[idx] = static_cast<void *>(std::addressof(head));
    fill_kernel_args(f, idx + 1, std::forward<Tail>(tail)...);
}
} // namespace rocm::experimental

Chenggang Zhao's avatar
Chenggang Zhao committed
28
#ifndef SETUP_LAUNCH_CONFIG
lijian6's avatar
lijian6 committed
29
30
31
32
33
34
35
// The code below is a workaround for ROCm. All the proposed overhead
// is to match current macro signatures and should be reworked once
// cudaLaunchKernelExt() hip alternative is live.
#define SETUP_LAUNCH_CONFIG(num_sms, num_threads, stream)                                          \
    rocm::experimental::hipLaunchConfig_t cfg = {(num_sms), (num_threads), 0, stream};

#endif // #ifndef SETUP_LAUNCH_CONFIG
Chenggang Zhao's avatar
Chenggang Zhao committed
36
37

#ifndef LAUNCH_KERNEL
lijian6's avatar
lijian6 committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
template <typename T, typename Kern, typename... Args>
inline void LAUNCH_KERNEL(T &&config, Kern &&kernel, Args &&...args) {
    constexpr size_t k_num_kernel_args = sizeof...(args);
    void            *kernel_args[k_num_kernel_args];
    rocm::experimental::fill_kernel_args(kernel_args, 0, std::forward<Args>(args)...);
    CUDA_CHECK(hipLaunchCooperativeKernel(std::forward<Kern>(kernel), config->num_sms,
                                                      config->num_threads, kernel_args,
                                                      config->shared_mem_bytes, config->stream));
}

template <typename T, typename Kern, typename... Args>
inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...args) {
    hipLaunchKernelGGL((*kernel), dim3(config->num_sms), dim3(config->num_threads), config->shared_mem_bytes, config->stream, 
        std::forward<Args>(args)...);
}
53

lijian6's avatar
lijian6 committed
54
#endif // #ifndef LAUNCH_KERNEL
Chenggang Zhao's avatar
Chenggang Zhao committed
55

lijian6's avatar
lijian6 committed
56
57
58
59
60
61
62
63
64
#define SWITCH_RANKS(case_macro)                                                                   \
    switch (num_ranks) {                                                                           \
    case 2:                                                                                        \
        case_macro(2);                                                                             \
    case 4:                                                                                        \
        case_macro(4);                                                                             \
    case 8:                                                                                        \
        case_macro(8);                                                                             \
    default:                                                                                       \
lishen's avatar
lishen committed
65
        EP_HOST_ASSERT(false and "Unsupported ranks");                                             \
lijian6's avatar
lijian6 committed
66
67
    }                                                                                              \
    while (false)
Chenggang Zhao's avatar
Chenggang Zhao committed
68

lijian6's avatar
lijian6 committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#define SWITCH_RDMA_RANKS(case_macro)                                                              \
    switch (num_ranks / NUM_MAX_NVL_PEERS) {                                                       \
    case 2:                                                                                        \
        case_macro(2);                                                                             \
    case 3:                                                                                        \
        case_macro(3);                                                                             \
    case 4:                                                                                        \
        case_macro(4);                                                                             \
    case 8:                                                                                        \
        case_macro(8);                                                                             \
    case 16:                                                                                       \
        case_macro(16);                                                                            \
    case 18:                                                                                       \
        case_macro(18);                                                                            \
    case 20:                                                                                       \
        case_macro(20);                                                                            \
    default:                                                                                       \
lishen's avatar
lishen committed
86
        EP_HOST_ASSERT(false and "Unsupported RDMA ranks");                                        \
lijian6's avatar
lijian6 committed
87
88
    }                                                                                              \
    while (false)
Chenggang Zhao's avatar
Chenggang Zhao committed
89

lijian6's avatar
lijian6 committed
90
91
92
93
94
95
96
97
98
#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro)                                                 \
    switch (num_ranks) {                                                                           \
    case 2:                                                                                        \
        case_macro(dtype, 2);                                                                      \
    case 4:                                                                                        \
        case_macro(dtype, 4);                                                                      \
    case 8:                                                                                        \
        case_macro(dtype, 8);                                                                      \
    default:                                                                                       \
lishen's avatar
lishen committed
99
        EP_HOST_ASSERT(false and "Unsupported ranks");                                             \
lijian6's avatar
lijian6 committed
100
101
    }                                                                                              \
    while (false)
Chenggang Zhao's avatar
Chenggang Zhao committed
102

lijian6's avatar
lijian6 committed
103
104
105
106
107
108
109
#define SWITCH_TYPES(case_macro)                                                                   \
    switch (type) {                                                                                \
    case HIP_R_16BF:                                                                               \
        case_macro(hip_bfloat16);                                                                  \
    case HIP_R_32F:                                                                                \
        case_macro(float);                                                                         \
    default:                                                                                       \
lishen's avatar
lishen committed
110
        EP_HOST_ASSERT(false and "Unsupported type");                                              \
lijian6's avatar
lijian6 committed
111
112
    }                                                                                              \
    while (false)
Chenggang Zhao's avatar
Chenggang Zhao committed
113

lijian6's avatar
lijian6 committed
114
115
#define SWITCH_HIDDEN(case_macro)                                                                  \
    switch (hidden) {                                                                              \
lijian's avatar
lijian committed
116
117
    case 2048:                                                                                     \
        case_macro(2048);                                                                          \
lijian6's avatar
lijian6 committed
118
119
120
121
122
123
124
125
    case 2560:                                                                                     \
        case_macro(2560);                                                                          \
    case 5120:                                                                                     \
        case_macro(5120);                                                                          \
    case 4096:                                                                                     \
        case_macro(4096);                                                                          \
    case 7168:                                                                                     \
        case_macro(7168);                                                                          \
lishen's avatar
lishen committed
126
127
    case 8192:                                                                                     \
        case_macro(8192);                                                                          \
lijian6's avatar
lijian6 committed
128
    default:                                                                                       \
lishen's avatar
lishen committed
129
        EP_HOST_ASSERT(false and "Unsupported hidden");                                            \
lijian6's avatar
lijian6 committed
130
131
    }                                                                                              \
    while (false)