gather_scatter.cu 8.21 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

#include "custom_cuda_layers.h"
#include "memory_access_utils.h"

namespace cg = cooperative_groups;

namespace td_data {
constexpr int granularity = 16;
}

template <typename T>
__global__ void gather_tokens_impl(T* retained_tokens,
                                   const T* activations,
                                   int32_t* gather_indices,
                                   int32_t sampled_tokens,
                                   int32_t channels,
                                   int32_t read_batch_stride,
                                   int32_t read_seq_stride,
                                   int32_t write_batch_stride,
                                   int32_t write_seq_stride)
{
    constexpr int mem_vals_t = td_data::granularity / sizeof(T);

    cg::thread_block tb = cg::this_thread_block();

    const int gather_idx = gather_indices[tb.group_index().x * sampled_tokens + tb.group_index().y];

    const int read_offset = read_batch_stride * tb.group_index().x + read_seq_stride * gather_idx;
    const int write_offset =
        write_batch_stride * tb.group_index().x + write_seq_stride * tb.group_index().y;

    for (int i = tb.thread_index().x * mem_vals_t; i < channels; i += blockDim.x * mem_vals_t) {
        T local_data[mem_vals_t];
        mem_access::load_global<td_data::granularity>(local_data, activations + read_offset + i);
        mem_access::store_global<td_data::granularity>(retained_tokens + write_offset + i,
                                                       local_data);
    }
}

template <typename T>
void launch_gather_tokens(T* retained_tokens,
                          T* activations,
                          int32_t* gather_indices,
                          int32_t batch_size,
                          int32_t sampled_tokens,
                          int32_t channels,
                          int32_t read_batch_stride,
                          int32_t read_seq_stride,
                          int32_t write_batch_stride,
                          int32_t write_seq_stride,
                          cudaStream_t stream)
{
    constexpr int mem_vals_t = td_data::granularity / sizeof(T);

    const int load_steps = (channels + mem_vals_t - 1) / mem_vals_t;
    const int threads = (load_steps >= 1024) ? 1024 : load_steps;

    dim3 block(threads);
    dim3 grid(batch_size, sampled_tokens);

    gather_tokens_impl<T><<<grid, block, 0, stream>>>(retained_tokens,
                                                      activations,
                                                      gather_indices,
                                                      sampled_tokens,
                                                      channels,
                                                      read_batch_stride,
                                                      read_seq_stride,
                                                      write_batch_stride,
                                                      write_seq_stride);
}

template void launch_gather_tokens<float>(float*,
                                          float*,
                                          int32_t*,
                                          int32_t,
                                          int32_t,
                                          int32_t,
                                          int32_t,
                                          int32_t,
                                          int32_t,
                                          int32_t,
                                          cudaStream_t);

template void launch_gather_tokens<__half>(__half*,
                                           __half*,
                                           int32_t*,
                                           int32_t,
                                           int32_t,
                                           int32_t,
                                           int32_t,
                                           int32_t,
                                           int32_t,
                                           int32_t,
                                           cudaStream_t);

template <typename T>
__global__ void scatter_tokens_impl(T* all_activations,
                                    const T* layer_activations,
                                    int32_t* gather_indices,
                                    int32_t retained_tokens,
                                    int32_t channels,
                                    int32_t read_batch_stride,
                                    int32_t read_seq_stride,
                                    int32_t write_batch_stride,
                                    int32_t write_seq_stride)
{
    constexpr int mem_vals_t = td_data::granularity / sizeof(T);

    cg::thread_block tb = cg::this_thread_block();

    const int gather_idx =
        gather_indices[tb.group_index().x * retained_tokens + tb.group_index().y];

    const int read_offset =
        read_batch_stride * tb.group_index().x + read_seq_stride * tb.group_index().y;
    const int write_offset =
        write_batch_stride * tb.group_index().x + write_seq_stride * gather_idx;

    for (int i = tb.thread_index().x * mem_vals_t; i < channels; i += mem_vals_t * blockDim.x) {
        T local_data[mem_vals_t];
        mem_access::load_global<td_data::granularity>(local_data,
                                                      layer_activations + read_offset + i);
        mem_access::store_global<td_data::granularity>(all_activations + write_offset + i,
                                                       local_data);
    }
}

template <typename T>
void launch_scatter_tokens(T* all_activations,
                           T* layer_activations,
                           int32_t* gather_indices,
                           int32_t batch_size,
                           int32_t sampled_tokens,
                           int32_t channels,
                           int32_t read_batch_stride,
                           int32_t read_seq_stride,
                           int32_t write_batch_stride,
                           int32_t write_seq_stride,
                           cudaStream_t stream)
{
    constexpr int mem_vals_t = td_data::granularity / sizeof(T);

    const int load_steps = (channels + mem_vals_t - 1) / mem_vals_t;
    const int threads = (load_steps >= 1024) ? 1024 : load_steps;

    dim3 block(threads);
    dim3 grid(batch_size, sampled_tokens);

    scatter_tokens_impl<T><<<grid, block, 0, stream>>>(all_activations,
                                                       layer_activations,
                                                       gather_indices,
                                                       sampled_tokens,
                                                       channels,
                                                       read_batch_stride,
                                                       read_seq_stride,
                                                       write_batch_stride,
                                                       write_seq_stride);
}

template void launch_scatter_tokens<float>(float*,
                                           float*,
                                           int32_t*,
                                           int32_t,
                                           int32_t,
                                           int32_t,
                                           int32_t,
                                           int32_t,
                                           int32_t,
                                           int32_t,
                                           cudaStream_t);

template void launch_scatter_tokens<__half>(__half*,
                                            __half*,
                                            int32_t*,
                                            int32_t,
                                            int32_t,
                                            int32_t,
                                            int32_t,
                                            int32_t,
                                            int32_t,
                                            int32_t,
                                            cudaStream_t);