token_sort.cu 6.86 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

#include <cassert>
#include "custom_cuda_layers.h"
#include "memory_access_utils.h"

namespace cg = cooperative_groups;

namespace td_sort {
constexpr int threads = 512;
constexpr int granularity = 16;
constexpr int mem_vals = granularity / sizeof(int32_t);
constexpr int max_buffer_size = (threads + 1) * mem_vals;

#ifdef __HIP_PLATFORM_HCC__
constexpr int warp_size = 64;
#else
constexpr int warp_size = 32;
#endif

constexpr int max_warps = threads / warp_size;
}  // namespace td_sort

template <int VALS_PER_THREAD>
__global__ void scan_sort(int32_t* data, int reserved_tokens, int original_tokens)
{
    cg::thread_block tb = cg::this_thread_block();
    cg::thread_block_tile<td_sort::warp_size> warp = cg::tiled_partition<td_sort::warp_size>(tb);

    __shared__ int32_t indices_buffer[td_sort::max_buffer_size];
    __shared__ int32_t intermediate_buffer[td_sort::max_warps];
    __shared__ int32_t sorted_indices_buffer[td_sort::max_buffer_size];

    for (int i = tb.thread_index().x * td_sort::mem_vals; i < original_tokens + 1;
         i += tb.group_dim().x * td_sort::mem_vals) {
        uint32_t zeros[td_sort::mem_vals] = {0, 0, 0, 0};
        mem_access::store_shared<td_sort::granularity>(indices_buffer + i, zeros);
    }

    int32_t local_vals[VALS_PER_THREAD];

    // We flatten layers/batch into a single indexing dimension
    int32_t* data_block = data + tb.group_index().x * reserved_tokens;

    // The next two loops really could be fused for a more logical code layout, but don't want to
    // move the barrier forward
#pragma unroll
    for (int i = 0; i < VALS_PER_THREAD; i++) {
        const int iter_idx = i * td_sort::threads + tb.thread_index().x;
        if (iter_idx < reserved_tokens) {
            mem_access::load_global<sizeof(int32_t)>(local_vals + i, data_block + iter_idx);
        } else {
            local_vals[i] = 0;
        }
    }

    tb.sync();

#pragma unroll
    for (int i = 0; i < VALS_PER_THREAD; i++) {
        const int iter_idx = i * td_sort::threads + tb.thread_index().x;
        if (iter_idx < reserved_tokens) {
            const int32_t one = 1;
            mem_access::store_shared<sizeof(int32_t)>(indices_buffer + local_vals[i], &one);
        }
    }

    tb.sync();

    int32_t local_input[td_sort::mem_vals];
    mem_access::load_shared<td_sort::granularity>(
        local_input, indices_buffer + tb.thread_index().x * td_sort::mem_vals);

    int32_t reduce_vals[td_sort::mem_vals];
    reduce_vals[0] = local_input[0];

#pragma unroll
    for (int i = 1; i < td_sort::mem_vals; i++) {
        reduce_vals[i] = local_input[i] + reduce_vals[i - 1];
    }

    int32_t step_1_val = reduce_vals[td_sort::mem_vals - 1];
    // Short span exclusive scan algorithm (less work efficient)
#pragma unroll
    for (int i = 1; i < td_sort::warp_size; i *= 2) {
        int32_t step_val = warp.shfl_up(step_1_val, i);
        step_1_val = (warp.thread_rank() < i) ? step_1_val : step_1_val + step_val;
    }

    if (warp.thread_rank() == td_sort::warp_size - 1) {
        mem_access::store_shared<sizeof(int32_t)>(intermediate_buffer + warp.meta_group_rank(),
                                                  &step_1_val);
    }

    tb.sync();

    if (warp.meta_group_rank() == 0) {
        int32_t step_2_val = 0;
        if (warp.thread_rank() < td_sort::max_warps) {
            mem_access::load_shared<sizeof(int32_t)>(&step_2_val,
                                                     intermediate_buffer + warp.thread_rank());
        }

#pragma unroll
        for (int i = 1; i < td_sort::warp_size; i *= 2) {
            int32_t step_val = warp.shfl_up(step_2_val, i);
            step_2_val = (warp.thread_rank() < i) ? step_2_val : step_2_val + step_val;
        }

        if (warp.thread_rank() < td_sort::max_warps) {
            mem_access::store_shared<sizeof(int32_t)>(intermediate_buffer + warp.thread_rank(),
                                                      &step_2_val);
        }
    }

    tb.sync();

    int step_2_val = 0;
    if (warp.meta_group_rank() > 0) {
        mem_access::load_shared<sizeof(int32_t)>(&step_2_val,
                                                 intermediate_buffer + warp.meta_group_rank() - 1);
    }

    const int thread_offset = reduce_vals[td_sort::mem_vals - 1];

#pragma unroll
    for (int i = 0; i < td_sort::mem_vals; i++) {
        reduce_vals[i] += step_1_val + step_2_val - thread_offset;
    }
    mem_access::store_shared<td_sort::granularity>(
        indices_buffer + tb.thread_index().x * td_sort::mem_vals, reduce_vals);

    if (tb.thread_index().x == 0) {
        indices_buffer[original_tokens] = original_tokens - indices_buffer[original_tokens];
    }
    tb.sync();

    for (int i = 0; i < VALS_PER_THREAD; i++) {
        const int iter_idx = i * td_sort::threads + tb.thread_index().x;
        if (iter_idx < reserved_tokens) {
            if (local_vals[i] == 0) {
                int zero = 0;
                mem_access::store_shared<sizeof(int32_t)>(sorted_indices_buffer, &zero);
            } else {
                int sorted_idx;
                mem_access::load_shared<sizeof(int32_t)>(&sorted_idx,
                                                         indices_buffer + local_vals[i] - 1);
                mem_access::store_shared<sizeof(int32_t)>(sorted_indices_buffer + sorted_idx,
                                                          local_vals + i);
            }
        }
    }

    tb.sync();

#pragma unroll
    for (int i = 0; i < VALS_PER_THREAD; i++) {
        const int iter_idx = i * td_sort::threads + tb.thread_index().x;
        if (iter_idx < reserved_tokens) {
            int32_t store_val;
            mem_access::load_shared<sizeof(int32_t)>(&store_val, sorted_indices_buffer + iter_idx);
            mem_access::store_global<sizeof(int32_t)>(data_block + iter_idx, &store_val);
        }
    }
}

void launch_token_sort(int32_t* indices,
                       int layers,
                       int batch_size,
                       int reserved_size,
                       int original_tokens,
                       cudaStream_t stream)
{
    // Each sort is completely independent, can flatten this dimension
    dim3 grid(layers * batch_size);
    dim3 block(td_sort::threads);

    const int vals_per_thread = (reserved_size + td_sort::threads - 1) / td_sort::threads;

    if (vals_per_thread == 1) {
        scan_sort<1><<<grid, block, 0, stream>>>(indices, reserved_size, original_tokens);
    } else if (vals_per_thread == 2) {
        scan_sort<2><<<grid, block, 0, stream>>>(indices, reserved_size, original_tokens);
    } else if (vals_per_thread == 3) {
        scan_sort<3><<<grid, block, 0, stream>>>(indices, reserved_size, original_tokens);
    } else if (vals_per_thread == 4) {
        scan_sort<4><<<grid, block, 0, stream>>>(indices, reserved_size, original_tokens);
    } else {
        assert(false);
    }
}