custom_all_reduce.cu 15 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
/*
 
 * Copyright (C) 2024-2025, The vLLM team.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <torch/all.h>

#include "custom_all_reduce.cuh"

// fake pointer type, must match fptr_t type in ops.h
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));

namespace aiter {

fptr_t init_custom_ar(torch::Tensor& meta,
                      torch::Tensor& rank_data,
                      const std::vector<torch::Tensor>& handles,
                      const std::vector<int64_t>& offsets,
                      int64_t rank,
                      bool fully_connected)
{
    int world_size = offsets.size();
    if(world_size > 8)
        throw std::invalid_argument("world size > 8 is not supported");
    if(world_size % 2 != 0)
        throw std::invalid_argument("Odd num gpus is not supported for now");
    if(world_size != handles.size())
        throw std::invalid_argument("handles length should equal to offsets length");
    if(rank < 0 || rank >= world_size)
        throw std::invalid_argument("invalid rank passed in");

    hipIpcMemHandle_t ipc_handles[8];
    for(int i = 0; i < world_size; i++)
    {
        hipIpcMemHandle_t* ipc_handle_ptr = (hipIpcMemHandle_t*)handles[i].data_ptr();
        std::memcpy(&ipc_handles[i], ipc_handle_ptr, sizeof(hipIpcMemHandle_t));
    }
    return (fptr_t) new aiter::CustomAllreduce(reinterpret_cast<aiter::Signal*>(meta.data_ptr()),
                                               rank_data.data_ptr(),
                                               rank_data.numel(),
                                               ipc_handles,
                                               offsets,
                                               rank,
                                               fully_connected);
}

/**
 * Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
 * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
 * because it allows transpose of contiguous slice (i.e. slicing the first
 * dimension). Currently, we require this because stride information is not
 * passed into the kernels and we treat input tensors as flat.
 *
 * Examples
 * A = torch.zeros(3, 3, 3)
 * 1. A: OK
 * 2. A[1:]: OK
 * 3. A.permute(2, 0, 1): OK
 * 4. A[1:].permute(2, 0, 1): OK
 * 5. A[None].expand(2, -1, -1, -1): Not OK
 * 6. A[:, 1:, 1:]: Not OK
 */
bool _is_weak_contiguous(torch::Tensor& t)
{
    return t.is_contiguous() || (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
                                 t.numel() * t.element_size());
}

void _all_reduce(
    fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, hipStream_t stream, bool open_fp8_quant)
{
    auto fa = reinterpret_cast<aiter::CustomAllreduce*>(_fa);
    TORCH_CHECK(_is_weak_contiguous(out));
    switch(out.scalar_type())
    {
    case at::ScalarType::Float: {
        fa->allreduce<float>(stream,
                             reinterpret_cast<float*>(inp.data_ptr()),
                             reinterpret_cast<float*>(out.data_ptr()),
                             out.numel());
        break;
    }
    case at::ScalarType::Half: {
        /*
         * By default, hidden_dim is a multiple of 128
         * Obvious effects can only be achieved when the data scale reaches a certain level
         * */
        if(open_fp8_quant && out.numel() >= 128 * 2048)
        {
            fa->runFp8QuantKernel<half>(stream,
                                        reinterpret_cast<half*>(inp.data_ptr()),
                                        reinterpret_cast<half*>(out.data_ptr()),
                                        out.numel());
        }
        else
        {
            fa->allreduce<half>(stream,
                                reinterpret_cast<half*>(inp.data_ptr()),
                                reinterpret_cast<half*>(out.data_ptr()),
                                out.numel());
        }
        break;
    }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
    case at::ScalarType::BFloat16: {
        fa->allreduce<__hip_bfloat16>(stream,
                                      reinterpret_cast<__hip_bfloat16*>(inp.data_ptr()),
                                      reinterpret_cast<__hip_bfloat16*>(out.data_ptr()),
                                      out.numel());
        break;
    }
#endif
    default:
        throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
    }
}

void all_reduce(fptr_t _fa,
                torch::Tensor& inp,
                torch::Tensor& out,
                bool open_fp8_quant,
                std::optional<torch::Tensor> reg_buffer)
{
    auto fa = reinterpret_cast<aiter::CustomAllreduce*>(_fa);
    const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
    auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
    TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
    TORCH_CHECK_EQ(inp.numel(), out.numel());

    if(reg_buffer.has_value())
    {
        auto input_size = inp.numel() * inp.element_size();
        TORCH_CHECK(input_size <= reg_buffer.value().numel() * reg_buffer.value().element_size(),
                    "registered buffer is too small to contain the input");
        HIP_CALL(hipMemcpyAsync(reg_buffer.value().data_ptr(),
                                inp.data_ptr(),
                                input_size,
                                hipMemcpyDeviceToDevice,
                                stream));
#ifdef DTK_ENV
        HIP_CALL(hipEventRecord(fa->event_, stream));
#endif
        _all_reduce(_fa, reg_buffer.value(), out, stream, open_fp8_quant);
    }
    else
    {
#ifdef DTK_ENV
        HIP_CALL(hipMemcpyAsync(fa->buffer_ptr_, out.data_ptr(), fa->buffer_size_, hipMemcpyDeviceToHost, stream));
#endif
        _all_reduce(_fa, inp, out, stream, open_fp8_quant);
    }
    

}

void _all_gather(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int size, hipStream_t stream)
{
    auto fa = reinterpret_cast<aiter::CustomAllreduce*>(_fa);
    TORCH_CHECK(_is_weak_contiguous(out));
    switch(out.scalar_type())
    {
    case at::ScalarType::Float: {
        fa->dispatchAllGather<float>(stream,
                                     reinterpret_cast<float*>(inp.data_ptr()),
                                     reinterpret_cast<float*>(out.data_ptr()),
                                     size);
        break;
    }
    case at::ScalarType::Half: {
        fa->dispatchAllGather<half>(stream,
                                    reinterpret_cast<half*>(inp.data_ptr()),
                                    reinterpret_cast<half*>(out.data_ptr()),
                                    size);
        break;
    }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
    case at::ScalarType::BFloat16: {
        fa->dispatchAllGather<__hip_bfloat16>(stream,
                                              reinterpret_cast<__hip_bfloat16*>(inp.data_ptr()),
                                              reinterpret_cast<__hip_bfloat16*>(out.data_ptr()),
                                              size);
        break;
    }
#endif
    default:
        throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
    }
}

void all_gather_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out)
{
    const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
    auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
    TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
    _all_gather(_fa, inp, out, inp.numel(), stream);
}

void all_gather_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out)
{
    const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
    auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();

    auto input_size = inp.numel() * inp.element_size();
    TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
    TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
                "registered buffer is too small to contain the input");
    HIP_CALL(hipMemcpyAsync(
        reg_buffer.data_ptr(), inp.data_ptr(), input_size, hipMemcpyDeviceToDevice, stream));
    _all_gather(_fa, reg_buffer, out, inp.numel(), stream);
}

void _fused_allreduce_rmsnorm(
    fptr_t _fa, torch::Tensor& inp, torch::Tensor& residual_inp, torch::Tensor& residual_out, torch::Tensor& out, torch::Tensor& w, float eps, int m, int n, hipStream_t stream)
{
    auto fa = reinterpret_cast<aiter::CustomAllreduce*>(_fa);
    TORCH_CHECK(_is_weak_contiguous(out));
    switch(out.scalar_type())
    {
    case at::ScalarType::Float: {
        fa->dispatchFusedAllReduceRMSNorm<float>(stream,
                             reinterpret_cast<float*>(inp.data_ptr()),
                             reinterpret_cast<float*>(residual_inp.data_ptr()),
                             reinterpret_cast<float*>(residual_out.data_ptr()),
                             reinterpret_cast<float*>(out.data_ptr()),
                             reinterpret_cast<float*>(w.data_ptr()),
                             eps, m, n);
        break;
    }
    case at::ScalarType::Half: {
        fa->dispatchFusedAllReduceRMSNorm<half>(stream,
                             reinterpret_cast<half*>(inp.data_ptr()),
                             reinterpret_cast<half*>(residual_inp.data_ptr()),
                             reinterpret_cast<half*>(residual_out.data_ptr()),
                             reinterpret_cast<half*>(out.data_ptr()),
                             reinterpret_cast<half*>(w.data_ptr()),
                             eps, m, n);
        break;
    }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
    case at::ScalarType::BFloat16: {
        fa->dispatchFusedAllReduceRMSNorm<__hip_bfloat16>(stream,
                             reinterpret_cast<__hip_bfloat16*>(inp.data_ptr()),
                             reinterpret_cast<__hip_bfloat16*>(residual_inp.data_ptr()),
                             reinterpret_cast<__hip_bfloat16*>(residual_out.data_ptr()),
                             reinterpret_cast<__hip_bfloat16*>(out.data_ptr()),
                             reinterpret_cast<__hip_bfloat16*>(w.data_ptr()),
                             eps, m, n);
        break;
    }
#endif
    default:
        throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
    }
}

void fused_allreduce_rmsnorm(fptr_t _fa,
                torch::Tensor& inp,
                torch::Tensor& res_inp,
                torch::Tensor& res_out,
                torch::Tensor& out,
                torch::Tensor& w,
                float eps,
                std::optional<torch::Tensor> reg_buffer)
{
    const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
    auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
    TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
    TORCH_CHECK_EQ(inp.scalar_type(), res_inp.scalar_type());
    TORCH_CHECK_EQ(inp.numel(), out.numel());
    TORCH_CHECK_EQ(inp.numel(), res_inp.numel());
    int n = w.numel();
    int m = inp.numel() / n;

    if(reg_buffer.has_value())
    {
        auto input_size = inp.numel() * inp.element_size();
        TORCH_CHECK(input_size <= reg_buffer.value().numel() * reg_buffer.value().element_size(),
                    "registered buffer is too small to contain the input");
        HIP_CALL(hipMemcpyAsync(reg_buffer.value().data_ptr(),
                                inp.data_ptr(),
                                input_size,
                                hipMemcpyDeviceToDevice,
                                stream));
        _fused_allreduce_rmsnorm(_fa, reg_buffer.value(), res_inp, res_out, out, w, eps, m, n, stream);
    }
    else
    {
      _fused_allreduce_rmsnorm(_fa, inp, res_inp, res_out, out, w, eps, m, n, stream);
    }
}

void dispose(fptr_t _fa)
{
    auto fa = reinterpret_cast<aiter::CustomAllreduce*>(_fa);
    delete fa;
}

int64_t meta_size() { return sizeof(aiter::Signal); }

void register_buffer(fptr_t _fa,
                     torch::Tensor& t,
                     const std::vector<torch::Tensor>& handles,
                     const std::vector<int64_t>& offsets)
{
    auto fa = reinterpret_cast<aiter::CustomAllreduce*>(_fa);
    fa->register_buffer(handles, offsets, t.data_ptr());
}

std::tuple<torch::Tensor, torch::Tensor> get_graph_buffer_ipc_meta(fptr_t _fa)
{
    auto fa                      = reinterpret_cast<aiter::CustomAllreduce*>(_fa);
    auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
    auto options                 = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
    auto handles = torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
    std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());

    torch::Tensor offset_tensor =
        torch::from_blob(offsets.data(), {static_cast<int64_t>(offsets.size())}, torch::kInt64)
            .clone();
    return {handles, offset_tensor};
}

void register_graph_buffers(fptr_t _fa,
                            const std::vector<torch::Tensor>& handles,
                            const std::vector<torch::Tensor>& offsets)
{
    auto fa = reinterpret_cast<aiter::CustomAllreduce*>(_fa);
    fa->register_graph_buffers(handles, offsets);
}

#ifdef USE_ROCM

void free_meta_buffer(void* buffer) { HIP_CALL(hipFree(buffer)); }

torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp)
{
    auto options     = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
    auto data_handle = torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
    HIP_CALL(hipIpcGetMemHandle((hipIpcMemHandle_t*)data_handle.data_ptr(), inp.data_ptr()));
    return data_handle;
}

torch::Tensor allocate_meta_buffer(int64_t size)
{
    auto device_index = c10::hip::current_device();
    at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
    void* buffer;
    hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
    auto stream               = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
    HIP_CALL(hipThreadExchangeStreamCaptureMode(&mode));
    HIP_CALL(hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
    HIP_CALL(hipMemsetAsync(buffer, 0, size, stream));
    HIP_CALL(hipStreamSynchronize(stream));
    HIP_CALL(hipThreadExchangeStreamCaptureMode(&mode));
    auto options = torch::TensorOptions().dtype(torch::kI8).device(torch::kCUDA, device_index);
    return torch::from_blob(buffer, {size}, free_meta_buffer, options);
}

#endif

} // namespace aiter