buffer.cuh 5.05 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
2
3
4
5
6
7
#pragma once

#include "configs.cuh"
#include "exception.cuh"

namespace deep_ep {

lijian6's avatar
lijian6 committed
8
template <typename dtype_t> struct Buffer {
Chenggang Zhao's avatar
Chenggang Zhao committed
9
private:
lijian6's avatar
lijian6 committed
10
    uint8_t *ptr;
Chenggang Zhao's avatar
Chenggang Zhao committed
11
12
13
14
15
16

public:
    int total_bytes;

    __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {}

lijian6's avatar
lijian6 committed
17
    __device__ __forceinline__ Buffer(void *&gbl_ptr, int num_elems, int offset = 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
18
        total_bytes = num_elems * sizeof(dtype_t);
lijian6's avatar
lijian6 committed
19
20
        ptr         = reinterpret_cast<uint8_t *>(gbl_ptr) + offset * sizeof(dtype_t);
        gbl_ptr     = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
Chenggang Zhao's avatar
Chenggang Zhao committed
21
22
    }

lijian6's avatar
lijian6 committed
23
24
    __device__ __forceinline__ Buffer advance_also(void *&gbl_ptr) {
        gbl_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
Chenggang Zhao's avatar
Chenggang Zhao committed
25
26
27
        return *this;
    }

lijian6's avatar
lijian6 committed
28
    __device__ __forceinline__ dtype_t *buffer() { return reinterpret_cast<dtype_t *>(ptr); }
Chenggang Zhao's avatar
Chenggang Zhao committed
29

lijian6's avatar
lijian6 committed
30
    __device__ __forceinline__ dtype_t &operator[](int idx) { return buffer()[idx]; }
Chenggang Zhao's avatar
Chenggang Zhao committed
31
32
};

lijian6's avatar
lijian6 committed
33
template <typename dtype_t, int kNumRanks = 1> struct AsymBuffer {
Chenggang Zhao's avatar
Chenggang Zhao committed
34
private:
lijian6's avatar
lijian6 committed
35
36
    uint8_t *ptrs[kNumRanks];
    int      num_bytes;
Chenggang Zhao's avatar
Chenggang Zhao committed
37
38
39
40

public:
    int total_bytes;

lijian6's avatar
lijian6 committed
41
    __device__ __forceinline__ AsymBuffer(void *&gbl_ptr, int num_elems, int num_ranks,
Chenggang Zhao's avatar
Chenggang Zhao committed
42
43
44
45
46
                                          int sm_id = 0, int num_sms = 1, int offset = 0) {
        EP_STATIC_ASSERT(kNumRanks == 1, "");
        num_bytes = num_elems * sizeof(dtype_t);

        int per_channel_bytes = num_bytes * num_ranks;
lijian6's avatar
lijian6 committed
47
48
49
50
        total_bytes           = per_channel_bytes * num_sms;
        ptrs[0] =
            reinterpret_cast<uint8_t *>(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset;
        gbl_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
Chenggang Zhao's avatar
Chenggang Zhao committed
51
52
    }

lijian6's avatar
lijian6 committed
53
    __device__ __forceinline__ AsymBuffer(void **gbl_ptrs, int num_elems, int num_ranks,
Chenggang Zhao's avatar
Chenggang Zhao committed
54
55
56
57
58
                                          int sm_id = 0, int num_sms = 1, int offset = 0) {
        EP_STATIC_ASSERT(kNumRanks > 1, "");
        num_bytes = num_elems * sizeof(dtype_t);

        int per_channel_bytes = num_bytes * num_ranks;
lijian6's avatar
lijian6 committed
59
60
61
62
63
        total_bytes           = per_channel_bytes * num_sms;
        for (int i = 0; i < kNumRanks; ++i) {
            ptrs[i] = reinterpret_cast<uint8_t *>(gbl_ptrs[i]) + per_channel_bytes * sm_id +
                      num_bytes * offset;
            gbl_ptrs[i] = reinterpret_cast<uint8_t *>(gbl_ptrs[i]) + total_bytes;
Chenggang Zhao's avatar
Chenggang Zhao committed
64
65
66
67
        }
    }

    __device__ __forceinline__ void advance(int shift) {
lijian6's avatar
lijian6 committed
68
69
#pragma unroll
        for (int i = 0; i < kNumRanks; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
70
71
72
            ptrs[i] = ptrs[i] + shift * sizeof(dtype_t);
    }

lijian6's avatar
lijian6 committed
73
74
    __device__ __forceinline__ AsymBuffer advance_also(void *&gbl_ptr) {
        gbl_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
Chenggang Zhao's avatar
Chenggang Zhao committed
75
76
77
        return *this;
    }

lijian6's avatar
lijian6 committed
78
79
80
81
    template <int kNumAlsoRanks>
    __device__ __forceinline__ AsymBuffer advance_also(void **gbl_ptrs) {
        for (int i = 0; i < kNumAlsoRanks; ++i)
            gbl_ptrs[i] = reinterpret_cast<uint8_t *>(gbl_ptrs[i]) + total_bytes;
Chenggang Zhao's avatar
Chenggang Zhao committed
82
83
84
        return *this;
    }

lijian6's avatar
lijian6 committed
85
86
87
88
    __device__ __forceinline__ dtype_t *buffer(int idx = 0) {
        EP_STATIC_ASSERT(kNumRanks == 1,
                                  "`buffer` is only available for single rank case");
        return reinterpret_cast<dtype_t *>(ptrs[0] + num_bytes * idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
89
90
    }

lijian6's avatar
lijian6 committed
91
    __device__ __forceinline__ dtype_t *buffer_by(int rank_idx, int idx = 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
92
        EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case");
lijian6's avatar
lijian6 committed
93
        return reinterpret_cast<dtype_t *>(ptrs[rank_idx] + num_bytes * idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
94
95
96
    }
};

lijian6's avatar
lijian6 committed
97
template <typename dtype_t, bool kDecoupled = true> struct SymBuffer {
Chenggang Zhao's avatar
Chenggang Zhao committed
98
99
private:
    // NOTES: for non-decoupled case, `recv_ptr` is not used
lijian6's avatar
lijian6 committed
100
101
102
    uint8_t *send_ptr;
    uint8_t *recv_ptr;
    int      num_bytes;
Chenggang Zhao's avatar
Chenggang Zhao committed
103
104
105
106

public:
    int total_bytes;

lijian6's avatar
lijian6 committed
107
    __device__ __forceinline__ SymBuffer(void *&gbl_ptr, int num_elems, int num_ranks,
Chenggang Zhao's avatar
Chenggang Zhao committed
108
109
110
111
                                         int sm_id = 0, int num_sms = 1) {
        num_bytes = num_elems * sizeof(dtype_t);

        int per_channel_bytes = num_bytes * num_ranks;
lijian6's avatar
lijian6 committed
112
113
114
115
        total_bytes           = per_channel_bytes * num_sms * (static_cast<int>(kDecoupled) + 1);
        send_ptr              = reinterpret_cast<uint8_t *>(gbl_ptr) + per_channel_bytes * sm_id;
        recv_ptr = reinterpret_cast<uint8_t *>(gbl_ptr) + per_channel_bytes * (sm_id + num_sms);
        gbl_ptr  = reinterpret_cast<uint8_t *>(gbl_ptr) + total_bytes;
Chenggang Zhao's avatar
Chenggang Zhao committed
116
117
    }

lijian6's avatar
lijian6 committed
118
119
120
121
    __device__ __forceinline__ dtype_t *send_buffer(int idx = 0) {
        EP_STATIC_ASSERT(kDecoupled,
                                  "`send_buffer` is only available for non-decoupled case");
        return reinterpret_cast<dtype_t *>(send_ptr + num_bytes * idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
122
123
    }

lijian6's avatar
lijian6 committed
124
125
126
127
    __device__ __forceinline__ dtype_t *recv_buffer(int idx = 0) {
        EP_STATIC_ASSERT(kDecoupled,
                                  "`recv_buffer` is only available for non-decoupled case");
        return reinterpret_cast<dtype_t *>(recv_ptr + num_bytes * idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
128
129
    }

lijian6's avatar
lijian6 committed
130
    __device__ __forceinline__ dtype_t *buffer(int idx = 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
131
        EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case");
lijian6's avatar
lijian6 committed
132
        return reinterpret_cast<dtype_t *>(send_ptr + num_bytes * idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
133
134
135
136
    }
};

} // namespace deep_ep