infiniccl_cuda.cu 24.8 KB
Newer Older
1
2
#include "infiniccl_cuda.h"

zhangyue's avatar
zhangyue committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#if defined(ENABLE_HYGON_API)
#include "infiniccl_custom_all_reduce.cuh"

#include <atomic>
#include <array>
#if defined(__HIP__) || defined(__HIPCC__)
#include <hip/hip_runtime_api.h>
#if __has_include(<hip/hip_ext.h>)
#include <hip/hip_ext.h>
#endif
#endif
#endif /* ENABLE_HYGON_API */

16
#include <cuda_runtime.h>
zhangyue's avatar
zhangyue committed
17
18
19
20
#include <cstddef>
#include <cstring>
#include <exception>
#include <limits>
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <nccl.h>
#include <vector>

#include "../../utils.h"

#define CHECK_NCCL(API__) CHECK_INTERNAL(API__, ncclSuccess)

inline cudaStream_t getCudaStream(infinirtStream_t stream) {
    if (stream == nullptr) {
        return 0;
    }
    return static_cast<cudaStream_t>(stream);
}

inline ncclDataType_t getNcclDtype(infiniDtype_t datatype) {
    switch (datatype) {
    case INFINI_DTYPE_F32:
        return ncclFloat;
    case INFINI_DTYPE_F16:
        return ncclHalf;
PanZezhong1725's avatar
PanZezhong1725 committed
41
42
    case INFINI_DTYPE_BF16:
        return ncclBfloat16;
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    default:
        std::abort();
        return ncclHalf;
    }
}

inline ncclRedOp_t getNcclRedOp(infinicclReduceOp_t op) {
    switch (op) {
    case INFINICCL_SUM:
        return ncclSum;
    case INFINICCL_PROD:
        return ncclProd;
    case INFINICCL_MAX:
        return ncclMax;
    case INFINICCL_MIN:
        return ncclMin;
    case INFINICCL_AVG:
        return ncclAvg;
    default:
        std::abort();
        return ncclSum;
    }
}

inline ncclComm_t getNcclComm(infinicclComm_t comm) {
    return static_cast<ncclComm_t>(comm->comm);
}

zhangyue's avatar
zhangyue committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
static size_t elemSizeBytes(infiniDtype_t datatype) {
    switch (datatype) {
    case INFINI_DTYPE_F32:
        return 4;
    case INFINI_DTYPE_F16:
    case INFINI_DTYPE_BF16:
        return 2;
    default:
        return 0;
    }
}

// Same numeric value as 8 * 1024 * 1024; threshold for hybrid custom allreduce vs NCCL.
// static constexpr size_t kCustomAllreduceMaxBytes = size_t(8192) * 1024;
static constexpr size_t kCustomAllreduceMaxBytes = size_t(8192) * 64;

#if defined(ENABLE_HYGON_API)
// vLLM-style rank_data pool size (bytes), see custom_all_reduce.py torch.empty(8 * 1024 * 1024, uint8).
static constexpr size_t kHygonRankDataBytes = 8ull * 1024 * 1024;

// vLLM csrc/custom_all_reduce.cu allocate_shared_buffer_and_handle: on USE_ROCM the shared buffer
// uses hipExtMallocWithFlags(..., hipDeviceMallocUncached) so signal visibility is correct (e.g. MI200).
// rank_data stays plain cudaMalloc like torch.empty(device).
#if defined(__HIP__) || defined(__HIPCC__)
static cudaError_t hygonMallocUncachedShared(void **ptr, size_t nbytes) {
    hipError_t e = hipExtMallocWithFlags(ptr, nbytes, hipDeviceMallocUncached);
    return e == hipSuccess ? cudaSuccess : cudaErrorMemoryAllocation;
}
#endif

static cudaError_t hygonMallocStagingShared(void **ptr, size_t nbytes) {
    // vLLM allocate_shared_buffer_and_handle uses hipDeviceMallocUncached for
    // ALL shared buffers on ROCm (not just signal). IPC mappings of uncached
    // memory are fine-grained → cross-device kernel reads see latest data.
#if defined(__HIP__) || defined(__HIPCC__)
    return hygonMallocUncachedShared(ptr, nbytes);
#else
    return cudaMalloc(ptr, nbytes);
#endif
}

struct HygonArGroup {
    int ndevice;
    std::atomic<int> cars_remaining_to_destroy;
    std::array<int, 8> device_ids{};
    /** Per-rank 2stage scratch on device (peer-read via P2P). */
    std::array<void *, 8> scratch_base{};
    std::array<void *, 8> rank_data_base{};
    std::array<void *, 8> staging_base{};
    /** One portable host block: ndevice × Signal (barrier only; no scratch tail). */
    void *sig_host_base = nullptr;

    void freeAllDeviceAllocs() {
        if (sig_host_base != nullptr) {
#if defined(__HIP__) || defined(__HIPCC__)
            hipError_t he = hipHostFree(sig_host_base);
            if (he != hipSuccess) {
                std::fprintf(stderr, "[infiniccl] hipHostFree(Signal) failed: %s\n", hipGetErrorString(he));
            }
#else
            cudaError_t ce = cudaFreeHost(sig_host_base);
            if (ce != cudaSuccess) {
                std::fprintf(stderr, "[infiniccl] cudaFreeHost(Signal) failed: %s\n", cudaGetErrorString(ce));
            }
#endif
            sig_host_base = nullptr;
        }
        for (int j = 0; j < ndevice; ++j) {
            INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[j]));
            if (scratch_base[j]) {
                INFINICCL_AR_CUDA_CHECK(cudaFree(scratch_base[j]));
            }
            if (rank_data_base[j]) {
                INFINICCL_AR_CUDA_CHECK(cudaFree(rank_data_base[j]));
            }
            if (staging_base[j]) {
                INFINICCL_AR_CUDA_CHECK(cudaFree(staging_base[j]));
            }
            scratch_base[j] = rank_data_base[j] = staging_base[j] = nullptr;
        }
    }
};

static bool hygonCustomWorldSupported(int n) {
    return n == 2 || n == 4 || n == 6 || n == 8;
}

/** INFINICCL_CUSTOM_ALLREDUCE=0 或 off:不初始化自定义 allreduce,且 allReduce 中也不走自定义核(仍走 NCCL)。 */
static bool hygonCustomAllreduceDisabledByEnv() {
    const char *env = std::getenv("INFINICCL_CUSTOM_ALLREDUCE");
    if (env == nullptr) {
        return false;
    }
    return std::strcmp(env, "0") == 0 || std::strcmp(env, "off") == 0;
}

/**
 * Hygon DCU / single-process InfiniLM: IPC is unusable; device-resident Signal
 * + P2P atomics deadlock on barrier. We use:
 *  - **host-mapped Signal** (hipHostMallocPortable|Mapped + hipHostGetDevicePointer
 *    per viewer GPU) so barrier flags are CPU-coherent across all cards (TP 2/4/6/8).
 *  - **Per-rank device scratch** for 2stage kernels (RankSignals.scratch[]), uncached VRAM.
 *  - **Staging** buffers unchanged (memcpy + kernel read).
 *  - **P2P** enabled for peer staging/scratch access.
 *
 * Set HIP_VISIBLE_DEVICES to the TP ranks only to reduce uncached VRAM side effects
 * on other GPUs in the box.
 */
static void hygonTryInitCommGroupCustomAllreduce(
    infinicclComm_t *comms, int ndevice, const int *device_ids, infiniDevice_t device_type) {
    if (device_type != INFINI_DEVICE_HYGON || ndevice <= 1 || !hygonCustomWorldSupported(ndevice) || ndevice > 8) {
        return;
    }

    if (hygonCustomAllreduceDisabledByEnv()) {
        const char *env = std::getenv("INFINICCL_CUSTOM_ALLREDUCE");
        std::fprintf(stderr, "[infiniccl] custom allreduce disabled by INFINICCL_CUSTOM_ALLREDUCE=%s\n",
                     env != nullptr ? env : "");
        return;
    }

    int total_visible = 0;
    if (cudaGetDeviceCount(&total_visible) == cudaSuccess && total_visible > ndevice) {
        std::fprintf(stderr,
            "[infiniccl] WARNING: %d GPUs visible but only %d used for custom allreduce.\n"
            "  hipDeviceMallocUncached causes ~2%% VRAM overhead on ALL visible GPUs.\n"
            "  Set HIP_VISIBLE_DEVICES to only the GPUs you need (e.g. HIP_VISIBLE_DEVICES=0,%d)\n"
            "  to avoid unnecessary VRAM usage on other devices.\n",
            total_visible, ndevice, ndevice - 1);
    }

    HygonArGroup *grp = nullptr;

    std::array<void *, 8> scratch_per_rank{};
    std::array<void *, 8> rank_base{};
    std::array<void *, 8> stg_base{};
    std::array<bool, 8> have_alloc{};
    std::array<std::array<void *, 8>, 8> sig_on_viewer{};

    // --- Phase 1: P2P check and enable peer access between every pair ---
    for (int a = 0; a < ndevice; ++a) {
        for (int b = a + 1; b < ndevice; ++b) {
            int can_ab = 0, can_ba = 0;
            INFINICCL_AR_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_ab, device_ids[a], device_ids[b]));
            INFINICCL_AR_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_ba, device_ids[b], device_ids[a]));
            if (!can_ab || !can_ba) {
                std::fprintf(stderr, "[infiniccl] P2P not supported between device %d and %d, custom allreduce disabled\n",
                             device_ids[a], device_ids[b]);
                return;
            }
        }
    }

    for (int a = 0; a < ndevice; ++a) {
        INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[a]));
        for (int b = 0; b < ndevice; ++b) {
            if (a == b) {
                continue;
            }
            cudaError_t pe = cudaDeviceEnablePeerAccess(device_ids[b], 0);
            if (pe != cudaSuccess && pe != cudaErrorPeerAccessAlreadyEnabled) {
                std::fprintf(stderr, "[infiniccl] cudaDeviceEnablePeerAccess(%d -> %d) failed: %s\n",
                             device_ids[a], device_ids[b], cudaGetErrorString(pe));
                return;
            }
        }
    }

    // --- Phase 2: host-mapped Signal (barrier) + per-rank 2stage scratch + rank_data + staging ---
    // DTK 等环境可能以 CUDA 前端编译(无 __HIP__),此时应使用 cudaHostAlloc/cudaHostGetDevicePointer,
    // 而不能调用 hipHost*(未包含 hip 头时会报 undeclared identifier)。
    void *sig_host_base = nullptr;
    const size_t host_sig_bytes = sizeof(infiniccl_ar::Signal) * static_cast<size_t>(ndevice);
#if !(defined(__HIP__) || defined(__HIPCC__))
    cudaError_t ce = cudaSuccess;
#endif
#if defined(__HIP__) || defined(__HIPCC__)
    hipError_t he = hipHostMalloc(&sig_host_base, host_sig_bytes, hipHostMallocPortable | hipHostMallocMapped);
    if (he != hipSuccess || sig_host_base == nullptr) {
        std::fprintf(stderr, "[infiniccl] hipHostMalloc(Signal) failed: %s\n", hipGetErrorString(he));
        return;
    }
#else
    ce = cudaHostAlloc(&sig_host_base, host_sig_bytes, cudaHostAllocPortable | cudaHostAllocMapped);
    if (ce != cudaSuccess || sig_host_base == nullptr) {
        std::fprintf(stderr, "[infiniccl] cudaHostAlloc(Signal) failed: %s\n", cudaGetErrorString(ce));
        return;
    }
#endif
    std::memset(sig_host_base, 0, host_sig_bytes);
    for (int vi = 0; vi < ndevice; ++vi) {
        INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[vi]));
        for (int j = 0; j < ndevice; ++j) {
            void *dp = nullptr;
#if defined(__HIP__) || defined(__HIPCC__)
            he = hipHostGetDevicePointer(
                &dp, reinterpret_cast<char *>(sig_host_base) + j * sizeof(infiniccl_ar::Signal), 0);
            if (he != hipSuccess) {
                std::fprintf(stderr, "[infiniccl] hipHostGetDevicePointer failed: %s\n", hipGetErrorString(he));
                hipHostFree(sig_host_base);
                return;
            }
#else
            ce = cudaHostGetDevicePointer(
                &dp, reinterpret_cast<char *>(sig_host_base) + j * sizeof(infiniccl_ar::Signal), 0);
            if (ce != cudaSuccess) {
                std::fprintf(stderr, "[infiniccl] cudaHostGetDevicePointer failed: %s\n", cudaGetErrorString(ce));
                cudaFreeHost(sig_host_base);
                return;
            }
#endif
            sig_on_viewer[static_cast<size_t>(vi)][static_cast<size_t>(j)] = dp;
        }
    }

    for (int j = 0; j < ndevice; ++j) {
        INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[j]));
        void *sc = nullptr, *rd = nullptr, *st = nullptr;
        if (hygonMallocStagingShared(&sc, kCustomAllreduceMaxBytes) != cudaSuccess) {
            goto fail_alloc;
        }
        INFINICCL_AR_CUDA_CHECK(cudaMemset(sc, 0, kCustomAllreduceMaxBytes));
        if (cudaMalloc(&rd, kHygonRankDataBytes) != cudaSuccess) {
            INFINICCL_AR_CUDA_CHECK(cudaFree(sc));
            goto fail_alloc;
        }
        if (hygonMallocStagingShared(&st, kCustomAllreduceMaxBytes) != cudaSuccess) {
            INFINICCL_AR_CUDA_CHECK(cudaFree(sc));
            INFINICCL_AR_CUDA_CHECK(cudaFree(rd));
            goto fail_alloc;
        }
        scratch_per_rank[j] = sc;
        rank_base[j] = rd;
        stg_base[j] = st;
        have_alloc[j] = true;
    }

    grp = new HygonArGroup{};
    grp->ndevice = ndevice;
    grp->cars_remaining_to_destroy.store(ndevice, std::memory_order_relaxed);
    grp->sig_host_base = sig_host_base;
    for (int j = 0; j < ndevice; ++j) {
        grp->device_ids[j] = device_ids[j];
        grp->scratch_base[j] = scratch_per_rank[j];
        grp->rank_data_base[j] = rank_base[j];
        grp->staging_base[j] = stg_base[j];
    }

    // --- Phase 3: create CustomAllreduce per rank (direct P2P pointers) ---
    for (int i = 0; i < ndevice; ++i) {
        INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[i]));
        infiniccl_ar::Signal *sig_ptrs[8]{};
        void *stg_ptrs[8]{};
        void *scratch_ptrs[8]{};
        for (int j = 0; j < ndevice; ++j) {
            sig_ptrs[j] = reinterpret_cast<infiniccl_ar::Signal *>(sig_on_viewer[static_cast<size_t>(i)][static_cast<size_t>(j)]);
            stg_ptrs[j] = stg_base[j];
            scratch_ptrs[j] = scratch_per_rank[j];
        }

        infiniccl_ar::CustomAllreduce *car = nullptr;
        try {
            car = new infiniccl_ar::CustomAllreduce(
                sig_ptrs, scratch_ptrs, rank_base[i], kHygonRankDataBytes, i, ndevice, true);
            car->register_buffer(stg_ptrs);
        } catch (...) {
            for (int k = 0; k < i; ++k) {
                if (comms[k]->custom_ar != nullptr) {
                    INFINICCL_AR_CUDA_CHECK(cudaSetDevice(comms[k]->device_id));
                    delete static_cast<infiniccl_ar::CustomAllreduce *>(comms[k]->custom_ar);
                    comms[k]->custom_ar = nullptr;
                    comms[k]->custom_ar_reg_buf = nullptr;
                    comms[k]->custom_ar_reg_sz = 0;
                    comms[k]->hygon_ar_group = nullptr;
                    comms[k]->hygon_custom_owned = false;
                }
            }
            grp->freeAllDeviceAllocs();
            delete grp;
            return;
        }

        comms[i]->custom_ar = car;
        comms[i]->custom_ar_reg_buf = stg_base[i];
        comms[i]->custom_ar_reg_sz = kCustomAllreduceMaxBytes;
        comms[i]->hygon_ar_group = grp;
        comms[i]->hygon_custom_owned = true;
    }
    std::fprintf(stderr,
                 "[infiniccl] custom allreduce enabled (host-mapped Signal + per-rank scratch + P2P staging, TP 2/4/6/8): "
                 "%d devices, threshold <= %zu bytes\n",
                 ndevice, kCustomAllreduceMaxBytes);
    return;

fail_alloc:
    if (sig_host_base != nullptr) {
#if defined(__HIP__) || defined(__HIPCC__)
        hipHostFree(sig_host_base);
#else
        cudaFreeHost(sig_host_base);
#endif
        sig_host_base = nullptr;
    }
    for (int j = 0; j < ndevice; ++j) {
        if (!have_alloc[j]) {
            continue;
        }
        INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[j]));
        if (scratch_per_rank[j]) {
            INFINICCL_AR_CUDA_CHECK(cudaFree(scratch_per_rank[j]));
        }
        if (rank_base[j]) {
            INFINICCL_AR_CUDA_CHECK(cudaFree(rank_base[j]));
        }
        if (stg_base[j]) {
            INFINICCL_AR_CUDA_CHECK(cudaFree(stg_base[j]));
        }
    }
}
#endif // ENABLE_HYGON_API

392
393
namespace infiniccl::cuda {

zhangyue's avatar
zhangyue committed
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
infiniStatus_t commSetHygonCustomAllreduce(
    infinicclComm_t comm, void *custom_allreduce, void *reg_buffer, size_t reg_buffer_bytes) {
#if defined(ENABLE_HYGON_API)
    if (comm == nullptr) {
        return INFINI_STATUS_NULL_POINTER;
    }
    if (comm->device_type != INFINI_DEVICE_HYGON) {
        return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
    }
    if (comm->hygon_custom_owned && comm->hygon_ar_group != nullptr) {
        return INFINI_STATUS_BAD_PARAM;
    }
    comm->custom_ar = custom_allreduce;
    comm->custom_ar_reg_buf = reg_buffer;
    comm->custom_ar_reg_sz = reg_buffer_bytes;
    return INFINI_STATUS_SUCCESS;
#else
    (void)comm;
    (void)custom_allreduce;
    (void)reg_buffer;
    (void)reg_buffer_bytes;
    return INFINI_STATUS_NOT_IMPLEMENTED;
#endif
}

419
infiniStatus_t commInitAll(
zhangyue's avatar
zhangyue committed
420
    infiniDevice_t device_type,
421
422
423
424
425
426
427
428
    infinicclComm_t *comms,
    int ndevice,
    const int *device_ids) {

    std::vector<ncclComm_t> nccl_comms(ndevice);
    CHECK_NCCL(ncclCommInitAll(nccl_comms.data(), ndevice, (int const *)device_ids));

    for (int i = 0; i < ndevice; i++) {
zhangyue's avatar
zhangyue committed
429
430
        comms[i] = new InfinicclComm{
            device_type, device_ids[i], (void *)(nccl_comms[i]), nullptr, nullptr, 0, nullptr, false};
431
432
    }

zhangyue's avatar
zhangyue committed
433
434
435
436
#if defined(ENABLE_HYGON_API)
    hygonTryInitCommGroupCustomAllreduce(comms, ndevice, device_ids, device_type);
#endif

437
438
439
440
    return INFINI_STATUS_SUCCESS;
}

infiniStatus_t commDestroy(infinicclComm_t comm) {
zhangyue's avatar
zhangyue committed
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
#if defined(ENABLE_HYGON_API)
    if (comm->hygon_custom_owned && comm->custom_ar != nullptr) {
        HygonArGroup *g = static_cast<HygonArGroup *>(comm->hygon_ar_group);
        // Set device before delete: ~CustomAllreduce calls cudaIpcCloseMemHandle
        // which must run in the context of the device that opened the handles.
        INFINICCL_AR_CUDA_CHECK(cudaSetDevice(comm->device_id));
        delete static_cast<infiniccl_ar::CustomAllreduce *>(comm->custom_ar);
        comm->custom_ar = nullptr;
        comm->custom_ar_reg_buf = nullptr;
        comm->custom_ar_reg_sz = 0;
        if (g != nullptr) {
            // fetch_sub 返回减之前的值;最后一次销毁时返回 1,此时原子量变为 0。
            if (g->cars_remaining_to_destroy.fetch_sub(1, std::memory_order_acq_rel) == 1) {
                g->freeAllDeviceAllocs();
                delete g;
            }
            comm->hygon_ar_group = nullptr;
        }
        comm->hygon_custom_owned = false;
    }
#endif
462
463
464
465
466
    CHECK_NCCL(ncclCommDestroy(getNcclComm(comm)));
    delete comm;
    return INFINI_STATUS_SUCCESS;
}

zhangyue's avatar
zhangyue committed
467
468
469
470
471
472
473
474
475
476
477
478
479
#if defined(ENABLE_HYGON_API)
namespace {

bool customArTraceEnabled() {
    const char *v = std::getenv("INFINICCL_CUSTOM_ALLREDUCE_TRACE");
    return v != nullptr && v[0] != '\0' && v[0] != '0';
}

std::atomic<int> g_custom_ar_trace_exec{0};

} // namespace
#endif

480
481
482
483
484
485
486
487
488
infiniStatus_t allReduce(
    void *sendbuf,
    void *recvbuf,
    size_t count,
    infiniDtype_t datatype,
    infinicclReduceOp_t op,
    infinicclComm_t comm,
    infinirtStream_t stream) {

PanZezhong1725's avatar
PanZezhong1725 committed
489
    CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
490

zhangyue's avatar
zhangyue committed
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
    cudaStream_t cuda_stream = getCudaStream(stream);

#if defined(ENABLE_HYGON_API)
    const size_t elem_sz = elemSizeBytes(datatype);
    const size_t nbytes = count * elem_sz;
    infiniccl_ar::CustomAllreduce *custom =
        comm->device_type == INFINI_DEVICE_HYGON && comm->custom_ar
            ? static_cast<infiniccl_ar::CustomAllreduce *>(comm->custom_ar)
            : nullptr;

    bool try_custom = custom != nullptr && op == INFINICCL_SUM && nbytes > 0 &&
                      nbytes <= kCustomAllreduceMaxBytes && count <= static_cast<size_t>(std::numeric_limits<int>::max());
    if (hygonCustomAllreduceDisabledByEnv()) {
        try_custom = false;
    }
    bool custom_ar_executed = false;

    // Opt-in diagnostic: set INFINICCL_CUSTOM_ALLREDUCE_DEBUG=1 to see which
    // path each size bucket takes (printed once per bucket). Useful for
    // verifying that decode path actually hits the custom kernel.
    {
        static bool debug = []() {
            const char *v = std::getenv("INFINICCL_CUSTOM_ALLREDUCE_DEBUG");
            return v != nullptr && v[0] != '0' && v[0] != '\0';
        }();
        if (debug) {
            static bool p_null = false, p_big = false, p_ok = false;
            if (custom == nullptr && !p_null) {
                std::fprintf(stderr, "[infiniccl] custom_ar not available, all allreduce use NCCL\n");
                p_null = true;
            } else if (custom != nullptr && nbytes > kCustomAllreduceMaxBytes && !p_big) {
                std::fprintf(stderr, "[infiniccl] large allreduce nbytes=%zu > %zu, use NCCL\n",
                             nbytes, kCustomAllreduceMaxBytes);
                p_big = true;
            } else if (try_custom && !p_ok) {
                std::fprintf(stderr, "[infiniccl] small allreduce nbytes=%zu, use custom AR\n", nbytes);
                p_ok = true;
            }
        }
    }
    if (customArTraceEnabled()) {
        static std::atomic<bool> trace_banner{false};
        if (!trace_banner.exchange(true, std::memory_order_relaxed)) {
            std::fprintf(stderr,
                         "[infiniccl] INFINICCL_CUSTOM_ALLREDUCE_TRACE is on: will print up to 128 custom AR invocations "
                         "and up to 48 NCCL fallbacks after try_custom (per process).\n");
        }
    }

    if (try_custom) {
        void *input_ptr = sendbuf;
        if (comm->custom_ar_reg_buf != nullptr) {
            if (nbytes > comm->custom_ar_reg_sz) {
                return INFINI_STATUS_BAD_PARAM;
            }
            INFINICCL_AR_CUDA_CHECK(cudaMemcpyAsync(
                comm->custom_ar_reg_buf, sendbuf, nbytes, cudaMemcpyDeviceToDevice, cuda_stream));
            input_ptr = comm->custom_ar_reg_buf;
        }
        const int numel = static_cast<int>(count);
        try {
            switch (datatype) {
            case INFINI_DTYPE_F32: {
                constexpr int d = infiniccl_ar::packed_t<float>::P::size;
                if (numel % d == 0) {
                    custom->allreduce<float>(cuda_stream, static_cast<float *>(input_ptr),
                                             static_cast<float *>(recvbuf), numel);
                    custom_ar_executed = true;
                    if (customArTraceEnabled()) {
                        const int k = g_custom_ar_trace_exec.fetch_add(1, std::memory_order_relaxed);
                        if (k < 128) {
                            std::fprintf(stderr,
                                         "[infiniccl] custom AR exec #%d dev=%d nbytes=%zu count=%zu dtype=f32 "
                                         "staging=%d\n",
                                         k, comm->device_id, nbytes, count, comm->custom_ar_reg_buf != nullptr ? 1 : 0);
                        }
                    }
                    return INFINI_STATUS_SUCCESS;
                }
                break;
            }
            case INFINI_DTYPE_F16: {
                constexpr int d = infiniccl_ar::packed_t<half>::P::size;
                if (numel % d == 0) {
                    custom->allreduce<half>(cuda_stream, static_cast<half *>(input_ptr),
                                              static_cast<half *>(recvbuf), numel);
                    custom_ar_executed = true;
                    if (customArTraceEnabled()) {
                        const int k = g_custom_ar_trace_exec.fetch_add(1, std::memory_order_relaxed);
                        if (k < 128) {
                            std::fprintf(stderr,
                                         "[infiniccl] custom AR exec #%d dev=%d nbytes=%zu count=%zu dtype=f16 "
                                         "staging=%d\n",
                                         k, comm->device_id, nbytes, count, comm->custom_ar_reg_buf != nullptr ? 1 : 0);
                        }
                    }
                    return INFINI_STATUS_SUCCESS;
                }
                break;
            }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) || defined(__HIP__) || defined(__HIPCC__) || defined(ENABLE_HYGON_API))
            case INFINI_DTYPE_BF16: {
                constexpr int d = infiniccl_ar::packed_t<nv_bfloat16>::P::size;
                if (numel % d == 0) {
                    custom->allreduce<nv_bfloat16>(cuda_stream, static_cast<nv_bfloat16 *>(input_ptr),
                                                   static_cast<nv_bfloat16 *>(recvbuf), numel);
                    custom_ar_executed = true;
                    if (customArTraceEnabled()) {
                        const int k = g_custom_ar_trace_exec.fetch_add(1, std::memory_order_relaxed);
                        if (k < 128) {
                            std::fprintf(stderr,
                                         "[infiniccl] custom AR exec #%d dev=%d nbytes=%zu count=%zu dtype=bf16 "
                                         "staging=%d\n",
                                         k, comm->device_id, nbytes, count, comm->custom_ar_reg_buf != nullptr ? 1 : 0);
                        }
                    }
                    return INFINI_STATUS_SUCCESS;
                }
                break;
            }
#endif
            default:
                break;
            }
        } catch (const std::exception &) {
            // Unregistered buffer, unsupported world size, etc.: fall back to NCCL.
        }
    }
    if (customArTraceEnabled() && try_custom && !custom_ar_executed) {
        static std::atomic<int> nfallback{0};
        const int f = nfallback.fetch_add(1, std::memory_order_relaxed);
        if (f < 48) {
            std::fprintf(stderr,
                         "[infiniccl] try_custom set but NCCL path dev=%d nbytes=%zu count=%zu dtype=%d "
                         "(alignment / unregistered / exception)\n",
                         comm->device_id, nbytes, count, static_cast<int>(datatype));
        }
    }
#endif
630
    CHECK_NCCL(ncclAllReduce(sendbuf, recvbuf, count, getNcclDtype(datatype),
zhangyue's avatar
zhangyue committed
631
                             getNcclRedOp(op), getNcclComm(comm), cuda_stream));
632
633
634
635

    return INFINI_STATUS_SUCCESS;
}
} // namespace infiniccl::cuda
zhangyue's avatar
zhangyue committed
636
637
638
639
640
641

#if defined(ENABLE_HYGON_API)
namespace infiniccl_ar {
template void CustomAllreduce::allreduce<nv_bfloat16>(cudaStream_t, nv_bfloat16 *, nv_bfloat16 *, int, int, int);
} // namespace infiniccl_ar
#endif