#include "infiniccl_cuda.h" #if defined(ENABLE_HYGON_API) #include "infiniccl_custom_all_reduce.cuh" #include #include #if defined(__HIP__) || defined(__HIPCC__) #include #if __has_include() #include #endif #endif #endif /* ENABLE_HYGON_API */ #include #include #include #include #include #include #include #include "../../utils.h" #define CHECK_NCCL(API__) CHECK_INTERNAL(API__, ncclSuccess) inline cudaStream_t getCudaStream(infinirtStream_t stream) { if (stream == nullptr) { return 0; } return static_cast(stream); } inline ncclDataType_t getNcclDtype(infiniDtype_t datatype) { switch (datatype) { case INFINI_DTYPE_F32: return ncclFloat; case INFINI_DTYPE_F16: return ncclHalf; case INFINI_DTYPE_BF16: return ncclBfloat16; default: std::abort(); return ncclHalf; } } inline ncclRedOp_t getNcclRedOp(infinicclReduceOp_t op) { switch (op) { case INFINICCL_SUM: return ncclSum; case INFINICCL_PROD: return ncclProd; case INFINICCL_MAX: return ncclMax; case INFINICCL_MIN: return ncclMin; case INFINICCL_AVG: return ncclAvg; default: std::abort(); return ncclSum; } } inline ncclComm_t getNcclComm(infinicclComm_t comm) { return static_cast(comm->comm); } static size_t elemSizeBytes(infiniDtype_t datatype) { switch (datatype) { case INFINI_DTYPE_F32: return 4; case INFINI_DTYPE_F16: case INFINI_DTYPE_BF16: return 2; default: return 0; } } // Same numeric value as 8 * 1024 * 1024; threshold for hybrid custom allreduce vs NCCL. // static constexpr size_t kCustomAllreduceMaxBytes = size_t(8192) * 1024; static constexpr size_t kCustomAllreduceMaxBytes = size_t(8192) * 64; #if defined(ENABLE_HYGON_API) // vLLM-style rank_data pool size (bytes), see custom_all_reduce.py torch.empty(8 * 1024 * 1024, uint8). static constexpr size_t kHygonRankDataBytes = 8ull * 1024 * 1024; // vLLM csrc/custom_all_reduce.cu allocate_shared_buffer_and_handle: on USE_ROCM the shared buffer // uses hipExtMallocWithFlags(..., hipDeviceMallocUncached) so signal visibility is correct (e.g. MI200). // rank_data stays plain cudaMalloc like torch.empty(device). #if defined(__HIP__) || defined(__HIPCC__) static cudaError_t hygonMallocUncachedShared(void **ptr, size_t nbytes) { hipError_t e = hipExtMallocWithFlags(ptr, nbytes, hipDeviceMallocUncached); return e == hipSuccess ? cudaSuccess : cudaErrorMemoryAllocation; } #endif static cudaError_t hygonMallocStagingShared(void **ptr, size_t nbytes) { // vLLM allocate_shared_buffer_and_handle uses hipDeviceMallocUncached for // ALL shared buffers on ROCm (not just signal). IPC mappings of uncached // memory are fine-grained → cross-device kernel reads see latest data. #if defined(__HIP__) || defined(__HIPCC__) return hygonMallocUncachedShared(ptr, nbytes); #else return cudaMalloc(ptr, nbytes); #endif } struct HygonArGroup { int ndevice; std::atomic cars_remaining_to_destroy; std::array device_ids{}; /** Per-rank 2stage scratch on device (peer-read via P2P). */ std::array scratch_base{}; std::array rank_data_base{}; std::array staging_base{}; /** One portable host block: ndevice × Signal (barrier only; no scratch tail). */ void *sig_host_base = nullptr; void freeAllDeviceAllocs() { if (sig_host_base != nullptr) { #if defined(__HIP__) || defined(__HIPCC__) hipError_t he = hipHostFree(sig_host_base); if (he != hipSuccess) { std::fprintf(stderr, "[infiniccl] hipHostFree(Signal) failed: %s\n", hipGetErrorString(he)); } #else cudaError_t ce = cudaFreeHost(sig_host_base); if (ce != cudaSuccess) { std::fprintf(stderr, "[infiniccl] cudaFreeHost(Signal) failed: %s\n", cudaGetErrorString(ce)); } #endif sig_host_base = nullptr; } for (int j = 0; j < ndevice; ++j) { INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[j])); if (scratch_base[j]) { INFINICCL_AR_CUDA_CHECK(cudaFree(scratch_base[j])); } if (rank_data_base[j]) { INFINICCL_AR_CUDA_CHECK(cudaFree(rank_data_base[j])); } if (staging_base[j]) { INFINICCL_AR_CUDA_CHECK(cudaFree(staging_base[j])); } scratch_base[j] = rank_data_base[j] = staging_base[j] = nullptr; } } }; static bool hygonCustomWorldSupported(int n) { return n == 2 || n == 4 || n == 6 || n == 8; } /** INFINICCL_CUSTOM_ALLREDUCE=0 或 off:不初始化自定义 allreduce,且 allReduce 中也不走自定义核(仍走 NCCL)。 */ static bool hygonCustomAllreduceDisabledByEnv() { const char *env = std::getenv("INFINICCL_CUSTOM_ALLREDUCE"); if (env == nullptr) { return false; } return std::strcmp(env, "0") == 0 || std::strcmp(env, "off") == 0; } /** * Hygon DCU / single-process InfiniLM: IPC is unusable; device-resident Signal * + P2P atomics deadlock on barrier. We use: * - **host-mapped Signal** (hipHostMallocPortable|Mapped + hipHostGetDevicePointer * per viewer GPU) so barrier flags are CPU-coherent across all cards (TP 2/4/6/8). * - **Per-rank device scratch** for 2stage kernels (RankSignals.scratch[]), uncached VRAM. * - **Staging** buffers unchanged (memcpy + kernel read). * - **P2P** enabled for peer staging/scratch access. * * Set HIP_VISIBLE_DEVICES to the TP ranks only to reduce uncached VRAM side effects * on other GPUs in the box. */ static void hygonTryInitCommGroupCustomAllreduce( infinicclComm_t *comms, int ndevice, const int *device_ids, infiniDevice_t device_type) { if (device_type != INFINI_DEVICE_HYGON || ndevice <= 1 || !hygonCustomWorldSupported(ndevice) || ndevice > 8) { return; } if (hygonCustomAllreduceDisabledByEnv()) { const char *env = std::getenv("INFINICCL_CUSTOM_ALLREDUCE"); std::fprintf(stderr, "[infiniccl] custom allreduce disabled by INFINICCL_CUSTOM_ALLREDUCE=%s\n", env != nullptr ? env : ""); return; } int total_visible = 0; if (cudaGetDeviceCount(&total_visible) == cudaSuccess && total_visible > ndevice) { std::fprintf(stderr, "[infiniccl] WARNING: %d GPUs visible but only %d used for custom allreduce.\n" " hipDeviceMallocUncached causes ~2%% VRAM overhead on ALL visible GPUs.\n" " Set HIP_VISIBLE_DEVICES to only the GPUs you need (e.g. HIP_VISIBLE_DEVICES=0,%d)\n" " to avoid unnecessary VRAM usage on other devices.\n", total_visible, ndevice, ndevice - 1); } HygonArGroup *grp = nullptr; std::array scratch_per_rank{}; std::array rank_base{}; std::array stg_base{}; std::array have_alloc{}; std::array, 8> sig_on_viewer{}; // --- Phase 1: P2P check and enable peer access between every pair --- for (int a = 0; a < ndevice; ++a) { for (int b = a + 1; b < ndevice; ++b) { int can_ab = 0, can_ba = 0; INFINICCL_AR_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_ab, device_ids[a], device_ids[b])); INFINICCL_AR_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_ba, device_ids[b], device_ids[a])); if (!can_ab || !can_ba) { std::fprintf(stderr, "[infiniccl] P2P not supported between device %d and %d, custom allreduce disabled\n", device_ids[a], device_ids[b]); return; } } } for (int a = 0; a < ndevice; ++a) { INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[a])); for (int b = 0; b < ndevice; ++b) { if (a == b) { continue; } cudaError_t pe = cudaDeviceEnablePeerAccess(device_ids[b], 0); if (pe != cudaSuccess && pe != cudaErrorPeerAccessAlreadyEnabled) { std::fprintf(stderr, "[infiniccl] cudaDeviceEnablePeerAccess(%d -> %d) failed: %s\n", device_ids[a], device_ids[b], cudaGetErrorString(pe)); return; } } } // --- Phase 2: host-mapped Signal (barrier) + per-rank 2stage scratch + rank_data + staging --- // DTK 等环境可能以 CUDA 前端编译(无 __HIP__),此时应使用 cudaHostAlloc/cudaHostGetDevicePointer, // 而不能调用 hipHost*(未包含 hip 头时会报 undeclared identifier)。 void *sig_host_base = nullptr; const size_t host_sig_bytes = sizeof(infiniccl_ar::Signal) * static_cast(ndevice); #if !(defined(__HIP__) || defined(__HIPCC__)) cudaError_t ce = cudaSuccess; #endif #if defined(__HIP__) || defined(__HIPCC__) hipError_t he = hipHostMalloc(&sig_host_base, host_sig_bytes, hipHostMallocPortable | hipHostMallocMapped); if (he != hipSuccess || sig_host_base == nullptr) { std::fprintf(stderr, "[infiniccl] hipHostMalloc(Signal) failed: %s\n", hipGetErrorString(he)); return; } #else ce = cudaHostAlloc(&sig_host_base, host_sig_bytes, cudaHostAllocPortable | cudaHostAllocMapped); if (ce != cudaSuccess || sig_host_base == nullptr) { std::fprintf(stderr, "[infiniccl] cudaHostAlloc(Signal) failed: %s\n", cudaGetErrorString(ce)); return; } #endif std::memset(sig_host_base, 0, host_sig_bytes); for (int vi = 0; vi < ndevice; ++vi) { INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[vi])); for (int j = 0; j < ndevice; ++j) { void *dp = nullptr; #if defined(__HIP__) || defined(__HIPCC__) he = hipHostGetDevicePointer( &dp, reinterpret_cast(sig_host_base) + j * sizeof(infiniccl_ar::Signal), 0); if (he != hipSuccess) { std::fprintf(stderr, "[infiniccl] hipHostGetDevicePointer failed: %s\n", hipGetErrorString(he)); hipHostFree(sig_host_base); return; } #else ce = cudaHostGetDevicePointer( &dp, reinterpret_cast(sig_host_base) + j * sizeof(infiniccl_ar::Signal), 0); if (ce != cudaSuccess) { std::fprintf(stderr, "[infiniccl] cudaHostGetDevicePointer failed: %s\n", cudaGetErrorString(ce)); cudaFreeHost(sig_host_base); return; } #endif sig_on_viewer[static_cast(vi)][static_cast(j)] = dp; } } for (int j = 0; j < ndevice; ++j) { INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[j])); void *sc = nullptr, *rd = nullptr, *st = nullptr; if (hygonMallocStagingShared(&sc, kCustomAllreduceMaxBytes) != cudaSuccess) { goto fail_alloc; } INFINICCL_AR_CUDA_CHECK(cudaMemset(sc, 0, kCustomAllreduceMaxBytes)); if (cudaMalloc(&rd, kHygonRankDataBytes) != cudaSuccess) { INFINICCL_AR_CUDA_CHECK(cudaFree(sc)); goto fail_alloc; } if (hygonMallocStagingShared(&st, kCustomAllreduceMaxBytes) != cudaSuccess) { INFINICCL_AR_CUDA_CHECK(cudaFree(sc)); INFINICCL_AR_CUDA_CHECK(cudaFree(rd)); goto fail_alloc; } scratch_per_rank[j] = sc; rank_base[j] = rd; stg_base[j] = st; have_alloc[j] = true; } grp = new HygonArGroup{}; grp->ndevice = ndevice; grp->cars_remaining_to_destroy.store(ndevice, std::memory_order_relaxed); grp->sig_host_base = sig_host_base; for (int j = 0; j < ndevice; ++j) { grp->device_ids[j] = device_ids[j]; grp->scratch_base[j] = scratch_per_rank[j]; grp->rank_data_base[j] = rank_base[j]; grp->staging_base[j] = stg_base[j]; } // --- Phase 3: create CustomAllreduce per rank (direct P2P pointers) --- for (int i = 0; i < ndevice; ++i) { INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[i])); infiniccl_ar::Signal *sig_ptrs[8]{}; void *stg_ptrs[8]{}; void *scratch_ptrs[8]{}; for (int j = 0; j < ndevice; ++j) { sig_ptrs[j] = reinterpret_cast(sig_on_viewer[static_cast(i)][static_cast(j)]); stg_ptrs[j] = stg_base[j]; scratch_ptrs[j] = scratch_per_rank[j]; } infiniccl_ar::CustomAllreduce *car = nullptr; try { car = new infiniccl_ar::CustomAllreduce( sig_ptrs, scratch_ptrs, rank_base[i], kHygonRankDataBytes, i, ndevice, true); car->register_buffer(stg_ptrs); } catch (...) { for (int k = 0; k < i; ++k) { if (comms[k]->custom_ar != nullptr) { INFINICCL_AR_CUDA_CHECK(cudaSetDevice(comms[k]->device_id)); delete static_cast(comms[k]->custom_ar); comms[k]->custom_ar = nullptr; comms[k]->custom_ar_reg_buf = nullptr; comms[k]->custom_ar_reg_sz = 0; comms[k]->hygon_ar_group = nullptr; comms[k]->hygon_custom_owned = false; } } grp->freeAllDeviceAllocs(); delete grp; return; } comms[i]->custom_ar = car; comms[i]->custom_ar_reg_buf = stg_base[i]; comms[i]->custom_ar_reg_sz = kCustomAllreduceMaxBytes; comms[i]->hygon_ar_group = grp; comms[i]->hygon_custom_owned = true; } std::fprintf(stderr, "[infiniccl] custom allreduce enabled (host-mapped Signal + per-rank scratch + P2P staging, TP 2/4/6/8): " "%d devices, threshold <= %zu bytes\n", ndevice, kCustomAllreduceMaxBytes); return; fail_alloc: if (sig_host_base != nullptr) { #if defined(__HIP__) || defined(__HIPCC__) hipHostFree(sig_host_base); #else cudaFreeHost(sig_host_base); #endif sig_host_base = nullptr; } for (int j = 0; j < ndevice; ++j) { if (!have_alloc[j]) { continue; } INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[j])); if (scratch_per_rank[j]) { INFINICCL_AR_CUDA_CHECK(cudaFree(scratch_per_rank[j])); } if (rank_base[j]) { INFINICCL_AR_CUDA_CHECK(cudaFree(rank_base[j])); } if (stg_base[j]) { INFINICCL_AR_CUDA_CHECK(cudaFree(stg_base[j])); } } } #endif // ENABLE_HYGON_API namespace infiniccl::cuda { infiniStatus_t commSetHygonCustomAllreduce( infinicclComm_t comm, void *custom_allreduce, void *reg_buffer, size_t reg_buffer_bytes) { #if defined(ENABLE_HYGON_API) if (comm == nullptr) { return INFINI_STATUS_NULL_POINTER; } if (comm->device_type != INFINI_DEVICE_HYGON) { return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } if (comm->hygon_custom_owned && comm->hygon_ar_group != nullptr) { return INFINI_STATUS_BAD_PARAM; } comm->custom_ar = custom_allreduce; comm->custom_ar_reg_buf = reg_buffer; comm->custom_ar_reg_sz = reg_buffer_bytes; return INFINI_STATUS_SUCCESS; #else (void)comm; (void)custom_allreduce; (void)reg_buffer; (void)reg_buffer_bytes; return INFINI_STATUS_NOT_IMPLEMENTED; #endif } infiniStatus_t commInitAll( infiniDevice_t device_type, infinicclComm_t *comms, int ndevice, const int *device_ids) { std::vector nccl_comms(ndevice); CHECK_NCCL(ncclCommInitAll(nccl_comms.data(), ndevice, (int const *)device_ids)); for (int i = 0; i < ndevice; i++) { comms[i] = new InfinicclComm{ device_type, device_ids[i], (void *)(nccl_comms[i]), nullptr, nullptr, 0, nullptr, false}; } #if defined(ENABLE_HYGON_API) hygonTryInitCommGroupCustomAllreduce(comms, ndevice, device_ids, device_type); #endif return INFINI_STATUS_SUCCESS; } infiniStatus_t commDestroy(infinicclComm_t comm) { #if defined(ENABLE_HYGON_API) if (comm->hygon_custom_owned && comm->custom_ar != nullptr) { HygonArGroup *g = static_cast(comm->hygon_ar_group); // Set device before delete: ~CustomAllreduce calls cudaIpcCloseMemHandle // which must run in the context of the device that opened the handles. INFINICCL_AR_CUDA_CHECK(cudaSetDevice(comm->device_id)); delete static_cast(comm->custom_ar); comm->custom_ar = nullptr; comm->custom_ar_reg_buf = nullptr; comm->custom_ar_reg_sz = 0; if (g != nullptr) { // fetch_sub 返回减之前的值;最后一次销毁时返回 1,此时原子量变为 0。 if (g->cars_remaining_to_destroy.fetch_sub(1, std::memory_order_acq_rel) == 1) { g->freeAllDeviceAllocs(); delete g; } comm->hygon_ar_group = nullptr; } comm->hygon_custom_owned = false; } #endif CHECK_NCCL(ncclCommDestroy(getNcclComm(comm))); delete comm; return INFINI_STATUS_SUCCESS; } #if defined(ENABLE_HYGON_API) namespace { bool customArTraceEnabled() { const char *v = std::getenv("INFINICCL_CUSTOM_ALLREDUCE_TRACE"); return v != nullptr && v[0] != '\0' && v[0] != '0'; } std::atomic g_custom_ar_trace_exec{0}; } // namespace #endif infiniStatus_t allReduce( void *sendbuf, void *recvbuf, size_t count, infiniDtype_t datatype, infinicclReduceOp_t op, infinicclComm_t comm, infinirtStream_t stream) { CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); cudaStream_t cuda_stream = getCudaStream(stream); #if defined(ENABLE_HYGON_API) const size_t elem_sz = elemSizeBytes(datatype); const size_t nbytes = count * elem_sz; infiniccl_ar::CustomAllreduce *custom = comm->device_type == INFINI_DEVICE_HYGON && comm->custom_ar ? static_cast(comm->custom_ar) : nullptr; bool try_custom = custom != nullptr && op == INFINICCL_SUM && nbytes > 0 && nbytes <= kCustomAllreduceMaxBytes && count <= static_cast(std::numeric_limits::max()); if (hygonCustomAllreduceDisabledByEnv()) { try_custom = false; } bool custom_ar_executed = false; // Opt-in diagnostic: set INFINICCL_CUSTOM_ALLREDUCE_DEBUG=1 to see which // path each size bucket takes (printed once per bucket). Useful for // verifying that decode path actually hits the custom kernel. { static bool debug = []() { const char *v = std::getenv("INFINICCL_CUSTOM_ALLREDUCE_DEBUG"); return v != nullptr && v[0] != '0' && v[0] != '\0'; }(); if (debug) { static bool p_null = false, p_big = false, p_ok = false; if (custom == nullptr && !p_null) { std::fprintf(stderr, "[infiniccl] custom_ar not available, all allreduce use NCCL\n"); p_null = true; } else if (custom != nullptr && nbytes > kCustomAllreduceMaxBytes && !p_big) { std::fprintf(stderr, "[infiniccl] large allreduce nbytes=%zu > %zu, use NCCL\n", nbytes, kCustomAllreduceMaxBytes); p_big = true; } else if (try_custom && !p_ok) { std::fprintf(stderr, "[infiniccl] small allreduce nbytes=%zu, use custom AR\n", nbytes); p_ok = true; } } } if (customArTraceEnabled()) { static std::atomic trace_banner{false}; if (!trace_banner.exchange(true, std::memory_order_relaxed)) { std::fprintf(stderr, "[infiniccl] INFINICCL_CUSTOM_ALLREDUCE_TRACE is on: will print up to 128 custom AR invocations " "and up to 48 NCCL fallbacks after try_custom (per process).\n"); } } if (try_custom) { void *input_ptr = sendbuf; if (comm->custom_ar_reg_buf != nullptr) { if (nbytes > comm->custom_ar_reg_sz) { return INFINI_STATUS_BAD_PARAM; } INFINICCL_AR_CUDA_CHECK(cudaMemcpyAsync( comm->custom_ar_reg_buf, sendbuf, nbytes, cudaMemcpyDeviceToDevice, cuda_stream)); input_ptr = comm->custom_ar_reg_buf; } const int numel = static_cast(count); try { switch (datatype) { case INFINI_DTYPE_F32: { constexpr int d = infiniccl_ar::packed_t::P::size; if (numel % d == 0) { custom->allreduce(cuda_stream, static_cast(input_ptr), static_cast(recvbuf), numel); custom_ar_executed = true; if (customArTraceEnabled()) { const int k = g_custom_ar_trace_exec.fetch_add(1, std::memory_order_relaxed); if (k < 128) { std::fprintf(stderr, "[infiniccl] custom AR exec #%d dev=%d nbytes=%zu count=%zu dtype=f32 " "staging=%d\n", k, comm->device_id, nbytes, count, comm->custom_ar_reg_buf != nullptr ? 1 : 0); } } return INFINI_STATUS_SUCCESS; } break; } case INFINI_DTYPE_F16: { constexpr int d = infiniccl_ar::packed_t::P::size; if (numel % d == 0) { custom->allreduce(cuda_stream, static_cast(input_ptr), static_cast(recvbuf), numel); custom_ar_executed = true; if (customArTraceEnabled()) { const int k = g_custom_ar_trace_exec.fetch_add(1, std::memory_order_relaxed); if (k < 128) { std::fprintf(stderr, "[infiniccl] custom AR exec #%d dev=%d nbytes=%zu count=%zu dtype=f16 " "staging=%d\n", k, comm->device_id, nbytes, count, comm->custom_ar_reg_buf != nullptr ? 1 : 0); } } return INFINI_STATUS_SUCCESS; } break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) || defined(__HIP__) || defined(__HIPCC__) || defined(ENABLE_HYGON_API)) case INFINI_DTYPE_BF16: { constexpr int d = infiniccl_ar::packed_t::P::size; if (numel % d == 0) { custom->allreduce(cuda_stream, static_cast(input_ptr), static_cast(recvbuf), numel); custom_ar_executed = true; if (customArTraceEnabled()) { const int k = g_custom_ar_trace_exec.fetch_add(1, std::memory_order_relaxed); if (k < 128) { std::fprintf(stderr, "[infiniccl] custom AR exec #%d dev=%d nbytes=%zu count=%zu dtype=bf16 " "staging=%d\n", k, comm->device_id, nbytes, count, comm->custom_ar_reg_buf != nullptr ? 1 : 0); } } return INFINI_STATUS_SUCCESS; } break; } #endif default: break; } } catch (const std::exception &) { // Unregistered buffer, unsupported world size, etc.: fall back to NCCL. } } if (customArTraceEnabled() && try_custom && !custom_ar_executed) { static std::atomic nfallback{0}; const int f = nfallback.fetch_add(1, std::memory_order_relaxed); if (f < 48) { std::fprintf(stderr, "[infiniccl] try_custom set but NCCL path dev=%d nbytes=%zu count=%zu dtype=%d " "(alignment / unregistered / exception)\n", comm->device_id, nbytes, count, static_cast(datatype)); } } #endif CHECK_NCCL(ncclAllReduce(sendbuf, recvbuf, count, getNcclDtype(datatype), getNcclRedOp(op), getNcclComm(comm), cuda_stream)); return INFINI_STATUS_SUCCESS; } } // namespace infiniccl::cuda #if defined(ENABLE_HYGON_API) namespace infiniccl_ar { template void CustomAllreduce::allreduce(cudaStream_t, nv_bfloat16 *, nv_bfloat16 *, int, int, int); } // namespace infiniccl_ar #endif