Commit 6a59259c authored by zhangyue's avatar zhangyue
Browse files

自定义allreduce初版

parent 71cac971
......@@ -15,6 +15,15 @@ struct InfinicclComm;
typedef struct InfinicclComm *infinicclComm_t;
/**
* Initialize NCCL communicators (one per device). On Hygon DCU builds (ENABLE_HYGON_API), when
* device_type is INFINI_DEVICE_HYGON and ndevice is 2/4/6/8, also allocates per-GPU shared buffers
* (vLLM-style staging + Signal + rank_data) and wires infiniccl_ar::CustomAllreduce automatically;
* otherwise custom path stays disabled until infinicclCommSetHygonCustomAllreduce is used.
*
* Hygon switch: INFINICCL_CUSTOM_ALLREDUCE=0 or off disables that wiring; infinicclAllReduce then
* uses NCCL only for the same process (see infinicclAllReduce).
*/
__INFINI_C __export infiniStatus_t infinicclCommInitAll(
infiniDevice_t device_type,
infinicclComm_t *comms,
......@@ -23,6 +32,29 @@ __INFINI_C __export infiniStatus_t infinicclCommInitAll(
__INFINI_C __export infiniStatus_t infinicclCommDestroy(infinicclComm_t comm);
/**
* Hygon DCU only: attach an optional custom allreduce (opaque infiniccl_ar::CustomAllreduce*).
* Other device types receive INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED.
* When set, infinicclAllReduce may use it for SUM on f32/f16/bf16 payloads up to 8192 * 1024 bytes
* (8 MiB); larger or unsupported cases use NCCL.
*
* If reg_buffer is non-null, sendbuf is copied to reg_buffer on the same stream before the custom
* kernel (vLLM-style): fixed IPC-registered buffer for CUDA graph or unregistered sendbuf.
* reg_buffer_bytes must be >= payload when reg_buffer is used. Pass custom_allreduce == nullptr to clear.
* Do not call this after commInitAll has already auto-wired Hygon custom allreduce (returns BAD_PARAM).
*/
__INFINI_C __export infiniStatus_t infinicclCommSetHygonCustomAllreduce(
infinicclComm_t comm,
void *custom_allreduce,
void *reg_buffer,
size_t reg_buffer_bytes);
/**
* Hygon: optional custom allreduce for small SUM payloads (see comm init / setHygon).
* Runtime switch: INFINICCL_CUSTOM_ALLREDUCE=0 or off forces NCCL even if custom objects were initialized.
* Diagnostics to stderr: INFINICCL_CUSTOM_ALLREDUCE_DEBUG=1 (coarse path hints);
* INFINICCL_CUSTOM_ALLREDUCE_TRACE=1 (first 128 custom kernel invocations and up to 48 NCCL fallbacks after try_custom, per OS process).
*/
__INFINI_C __export infiniStatus_t infinicclAllReduce(
void *sendbuf,
void *recvbuf,
......
......@@ -55,6 +55,7 @@ inline HcclReduceOp getHcclRedOp(infinicclReduceOp_t op) {
namespace infiniccl::ascend {
infiniStatus_t commInitAll(
infiniDevice_t device_type,
infinicclComm_t *comms,
int ndevice,
const int *device_ids) {
......@@ -67,7 +68,7 @@ infiniStatus_t commInitAll(
CHECK_HCCL(HcclCommInitAll(ndevice, (int32_t *)device_ids, hccl_comms.data()));
for (int i = 0; i < ndevice; i++) {
comms[i] = new InfinicclComm{INFINI_DEVICE_ASCEND, device_ids[i], (void *)(hccl_comms[i])};
comms[i] = new InfinicclComm{device_type, device_ids[i], (void *)(hccl_comms[i]), nullptr, nullptr, 0, nullptr, false};
}
return INFINI_STATUS_SUCCESS;
......
......@@ -53,6 +53,7 @@ inline cnclReduceOp_t getCnclRedOp(infinicclReduceOp_t op) {
namespace infiniccl::cambricon {
infiniStatus_t commInitAll(
infiniDevice_t device_type,
infinicclComm_t *comms,
int ndevice,
const int *device_ids) {
......@@ -70,7 +71,7 @@ infiniStatus_t commInitAll(
ndevice, nullptr));
for (int i = 0; i < ndevice; i++) {
comms[i] = new InfinicclComm{INFINI_DEVICE_CAMBRICON, device_ids[i], (void *)(cncl_comms[i])};
comms[i] = new InfinicclComm{device_type, device_ids[i], (void *)(cncl_comms[i]), nullptr, nullptr, 0, nullptr, false};
}
return INFINI_STATUS_SUCCESS;
......
#include "infiniccl_cuda.h"
#if defined(ENABLE_HYGON_API)
#include "infiniccl_custom_all_reduce.cuh"
#include <atomic>
#include <array>
#if defined(__HIP__) || defined(__HIPCC__)
#include <hip/hip_runtime_api.h>
#if __has_include(<hip/hip_ext.h>)
#include <hip/hip_ext.h>
#endif
#endif
#endif /* ENABLE_HYGON_API */
#include <cuda_runtime.h>
#include <iostream>
#include <cstddef>
#include <cstring>
#include <exception>
#include <limits>
#include <nccl.h>
#include <vector>
......@@ -52,9 +68,356 @@ inline ncclComm_t getNcclComm(infinicclComm_t comm) {
return static_cast<ncclComm_t>(comm->comm);
}
static size_t elemSizeBytes(infiniDtype_t datatype) {
switch (datatype) {
case INFINI_DTYPE_F32:
return 4;
case INFINI_DTYPE_F16:
case INFINI_DTYPE_BF16:
return 2;
default:
return 0;
}
}
// Same numeric value as 8 * 1024 * 1024; threshold for hybrid custom allreduce vs NCCL.
// static constexpr size_t kCustomAllreduceMaxBytes = size_t(8192) * 1024;
static constexpr size_t kCustomAllreduceMaxBytes = size_t(8192) * 64;
#if defined(ENABLE_HYGON_API)
// vLLM-style rank_data pool size (bytes), see custom_all_reduce.py torch.empty(8 * 1024 * 1024, uint8).
static constexpr size_t kHygonRankDataBytes = 8ull * 1024 * 1024;
// vLLM csrc/custom_all_reduce.cu allocate_shared_buffer_and_handle: on USE_ROCM the shared buffer
// uses hipExtMallocWithFlags(..., hipDeviceMallocUncached) so signal visibility is correct (e.g. MI200).
// rank_data stays plain cudaMalloc like torch.empty(device).
#if defined(__HIP__) || defined(__HIPCC__)
static cudaError_t hygonMallocUncachedShared(void **ptr, size_t nbytes) {
hipError_t e = hipExtMallocWithFlags(ptr, nbytes, hipDeviceMallocUncached);
return e == hipSuccess ? cudaSuccess : cudaErrorMemoryAllocation;
}
#endif
static cudaError_t hygonMallocStagingShared(void **ptr, size_t nbytes) {
// vLLM allocate_shared_buffer_and_handle uses hipDeviceMallocUncached for
// ALL shared buffers on ROCm (not just signal). IPC mappings of uncached
// memory are fine-grained → cross-device kernel reads see latest data.
#if defined(__HIP__) || defined(__HIPCC__)
return hygonMallocUncachedShared(ptr, nbytes);
#else
return cudaMalloc(ptr, nbytes);
#endif
}
struct HygonArGroup {
int ndevice;
std::atomic<int> cars_remaining_to_destroy;
std::array<int, 8> device_ids{};
/** Per-rank 2stage scratch on device (peer-read via P2P). */
std::array<void *, 8> scratch_base{};
std::array<void *, 8> rank_data_base{};
std::array<void *, 8> staging_base{};
/** One portable host block: ndevice × Signal (barrier only; no scratch tail). */
void *sig_host_base = nullptr;
void freeAllDeviceAllocs() {
if (sig_host_base != nullptr) {
#if defined(__HIP__) || defined(__HIPCC__)
hipError_t he = hipHostFree(sig_host_base);
if (he != hipSuccess) {
std::fprintf(stderr, "[infiniccl] hipHostFree(Signal) failed: %s\n", hipGetErrorString(he));
}
#else
cudaError_t ce = cudaFreeHost(sig_host_base);
if (ce != cudaSuccess) {
std::fprintf(stderr, "[infiniccl] cudaFreeHost(Signal) failed: %s\n", cudaGetErrorString(ce));
}
#endif
sig_host_base = nullptr;
}
for (int j = 0; j < ndevice; ++j) {
INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[j]));
if (scratch_base[j]) {
INFINICCL_AR_CUDA_CHECK(cudaFree(scratch_base[j]));
}
if (rank_data_base[j]) {
INFINICCL_AR_CUDA_CHECK(cudaFree(rank_data_base[j]));
}
if (staging_base[j]) {
INFINICCL_AR_CUDA_CHECK(cudaFree(staging_base[j]));
}
scratch_base[j] = rank_data_base[j] = staging_base[j] = nullptr;
}
}
};
static bool hygonCustomWorldSupported(int n) {
return n == 2 || n == 4 || n == 6 || n == 8;
}
/** INFINICCL_CUSTOM_ALLREDUCE=0 或 off:不初始化自定义 allreduce,且 allReduce 中也不走自定义核(仍走 NCCL)。 */
static bool hygonCustomAllreduceDisabledByEnv() {
const char *env = std::getenv("INFINICCL_CUSTOM_ALLREDUCE");
if (env == nullptr) {
return false;
}
return std::strcmp(env, "0") == 0 || std::strcmp(env, "off") == 0;
}
/**
* Hygon DCU / single-process InfiniLM: IPC is unusable; device-resident Signal
* + P2P atomics deadlock on barrier. We use:
* - **host-mapped Signal** (hipHostMallocPortable|Mapped + hipHostGetDevicePointer
* per viewer GPU) so barrier flags are CPU-coherent across all cards (TP 2/4/6/8).
* - **Per-rank device scratch** for 2stage kernels (RankSignals.scratch[]), uncached VRAM.
* - **Staging** buffers unchanged (memcpy + kernel read).
* - **P2P** enabled for peer staging/scratch access.
*
* Set HIP_VISIBLE_DEVICES to the TP ranks only to reduce uncached VRAM side effects
* on other GPUs in the box.
*/
static void hygonTryInitCommGroupCustomAllreduce(
infinicclComm_t *comms, int ndevice, const int *device_ids, infiniDevice_t device_type) {
if (device_type != INFINI_DEVICE_HYGON || ndevice <= 1 || !hygonCustomWorldSupported(ndevice) || ndevice > 8) {
return;
}
if (hygonCustomAllreduceDisabledByEnv()) {
const char *env = std::getenv("INFINICCL_CUSTOM_ALLREDUCE");
std::fprintf(stderr, "[infiniccl] custom allreduce disabled by INFINICCL_CUSTOM_ALLREDUCE=%s\n",
env != nullptr ? env : "");
return;
}
int total_visible = 0;
if (cudaGetDeviceCount(&total_visible) == cudaSuccess && total_visible > ndevice) {
std::fprintf(stderr,
"[infiniccl] WARNING: %d GPUs visible but only %d used for custom allreduce.\n"
" hipDeviceMallocUncached causes ~2%% VRAM overhead on ALL visible GPUs.\n"
" Set HIP_VISIBLE_DEVICES to only the GPUs you need (e.g. HIP_VISIBLE_DEVICES=0,%d)\n"
" to avoid unnecessary VRAM usage on other devices.\n",
total_visible, ndevice, ndevice - 1);
}
HygonArGroup *grp = nullptr;
std::array<void *, 8> scratch_per_rank{};
std::array<void *, 8> rank_base{};
std::array<void *, 8> stg_base{};
std::array<bool, 8> have_alloc{};
std::array<std::array<void *, 8>, 8> sig_on_viewer{};
// --- Phase 1: P2P check and enable peer access between every pair ---
for (int a = 0; a < ndevice; ++a) {
for (int b = a + 1; b < ndevice; ++b) {
int can_ab = 0, can_ba = 0;
INFINICCL_AR_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_ab, device_ids[a], device_ids[b]));
INFINICCL_AR_CUDA_CHECK(cudaDeviceCanAccessPeer(&can_ba, device_ids[b], device_ids[a]));
if (!can_ab || !can_ba) {
std::fprintf(stderr, "[infiniccl] P2P not supported between device %d and %d, custom allreduce disabled\n",
device_ids[a], device_ids[b]);
return;
}
}
}
for (int a = 0; a < ndevice; ++a) {
INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[a]));
for (int b = 0; b < ndevice; ++b) {
if (a == b) {
continue;
}
cudaError_t pe = cudaDeviceEnablePeerAccess(device_ids[b], 0);
if (pe != cudaSuccess && pe != cudaErrorPeerAccessAlreadyEnabled) {
std::fprintf(stderr, "[infiniccl] cudaDeviceEnablePeerAccess(%d -> %d) failed: %s\n",
device_ids[a], device_ids[b], cudaGetErrorString(pe));
return;
}
}
}
// --- Phase 2: host-mapped Signal (barrier) + per-rank 2stage scratch + rank_data + staging ---
// DTK 等环境可能以 CUDA 前端编译(无 __HIP__),此时应使用 cudaHostAlloc/cudaHostGetDevicePointer,
// 而不能调用 hipHost*(未包含 hip 头时会报 undeclared identifier)。
void *sig_host_base = nullptr;
const size_t host_sig_bytes = sizeof(infiniccl_ar::Signal) * static_cast<size_t>(ndevice);
#if !(defined(__HIP__) || defined(__HIPCC__))
cudaError_t ce = cudaSuccess;
#endif
#if defined(__HIP__) || defined(__HIPCC__)
hipError_t he = hipHostMalloc(&sig_host_base, host_sig_bytes, hipHostMallocPortable | hipHostMallocMapped);
if (he != hipSuccess || sig_host_base == nullptr) {
std::fprintf(stderr, "[infiniccl] hipHostMalloc(Signal) failed: %s\n", hipGetErrorString(he));
return;
}
#else
ce = cudaHostAlloc(&sig_host_base, host_sig_bytes, cudaHostAllocPortable | cudaHostAllocMapped);
if (ce != cudaSuccess || sig_host_base == nullptr) {
std::fprintf(stderr, "[infiniccl] cudaHostAlloc(Signal) failed: %s\n", cudaGetErrorString(ce));
return;
}
#endif
std::memset(sig_host_base, 0, host_sig_bytes);
for (int vi = 0; vi < ndevice; ++vi) {
INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[vi]));
for (int j = 0; j < ndevice; ++j) {
void *dp = nullptr;
#if defined(__HIP__) || defined(__HIPCC__)
he = hipHostGetDevicePointer(
&dp, reinterpret_cast<char *>(sig_host_base) + j * sizeof(infiniccl_ar::Signal), 0);
if (he != hipSuccess) {
std::fprintf(stderr, "[infiniccl] hipHostGetDevicePointer failed: %s\n", hipGetErrorString(he));
hipHostFree(sig_host_base);
return;
}
#else
ce = cudaHostGetDevicePointer(
&dp, reinterpret_cast<char *>(sig_host_base) + j * sizeof(infiniccl_ar::Signal), 0);
if (ce != cudaSuccess) {
std::fprintf(stderr, "[infiniccl] cudaHostGetDevicePointer failed: %s\n", cudaGetErrorString(ce));
cudaFreeHost(sig_host_base);
return;
}
#endif
sig_on_viewer[static_cast<size_t>(vi)][static_cast<size_t>(j)] = dp;
}
}
for (int j = 0; j < ndevice; ++j) {
INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[j]));
void *sc = nullptr, *rd = nullptr, *st = nullptr;
if (hygonMallocStagingShared(&sc, kCustomAllreduceMaxBytes) != cudaSuccess) {
goto fail_alloc;
}
INFINICCL_AR_CUDA_CHECK(cudaMemset(sc, 0, kCustomAllreduceMaxBytes));
if (cudaMalloc(&rd, kHygonRankDataBytes) != cudaSuccess) {
INFINICCL_AR_CUDA_CHECK(cudaFree(sc));
goto fail_alloc;
}
if (hygonMallocStagingShared(&st, kCustomAllreduceMaxBytes) != cudaSuccess) {
INFINICCL_AR_CUDA_CHECK(cudaFree(sc));
INFINICCL_AR_CUDA_CHECK(cudaFree(rd));
goto fail_alloc;
}
scratch_per_rank[j] = sc;
rank_base[j] = rd;
stg_base[j] = st;
have_alloc[j] = true;
}
grp = new HygonArGroup{};
grp->ndevice = ndevice;
grp->cars_remaining_to_destroy.store(ndevice, std::memory_order_relaxed);
grp->sig_host_base = sig_host_base;
for (int j = 0; j < ndevice; ++j) {
grp->device_ids[j] = device_ids[j];
grp->scratch_base[j] = scratch_per_rank[j];
grp->rank_data_base[j] = rank_base[j];
grp->staging_base[j] = stg_base[j];
}
// --- Phase 3: create CustomAllreduce per rank (direct P2P pointers) ---
for (int i = 0; i < ndevice; ++i) {
INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[i]));
infiniccl_ar::Signal *sig_ptrs[8]{};
void *stg_ptrs[8]{};
void *scratch_ptrs[8]{};
for (int j = 0; j < ndevice; ++j) {
sig_ptrs[j] = reinterpret_cast<infiniccl_ar::Signal *>(sig_on_viewer[static_cast<size_t>(i)][static_cast<size_t>(j)]);
stg_ptrs[j] = stg_base[j];
scratch_ptrs[j] = scratch_per_rank[j];
}
infiniccl_ar::CustomAllreduce *car = nullptr;
try {
car = new infiniccl_ar::CustomAllreduce(
sig_ptrs, scratch_ptrs, rank_base[i], kHygonRankDataBytes, i, ndevice, true);
car->register_buffer(stg_ptrs);
} catch (...) {
for (int k = 0; k < i; ++k) {
if (comms[k]->custom_ar != nullptr) {
INFINICCL_AR_CUDA_CHECK(cudaSetDevice(comms[k]->device_id));
delete static_cast<infiniccl_ar::CustomAllreduce *>(comms[k]->custom_ar);
comms[k]->custom_ar = nullptr;
comms[k]->custom_ar_reg_buf = nullptr;
comms[k]->custom_ar_reg_sz = 0;
comms[k]->hygon_ar_group = nullptr;
comms[k]->hygon_custom_owned = false;
}
}
grp->freeAllDeviceAllocs();
delete grp;
return;
}
comms[i]->custom_ar = car;
comms[i]->custom_ar_reg_buf = stg_base[i];
comms[i]->custom_ar_reg_sz = kCustomAllreduceMaxBytes;
comms[i]->hygon_ar_group = grp;
comms[i]->hygon_custom_owned = true;
}
std::fprintf(stderr,
"[infiniccl] custom allreduce enabled (host-mapped Signal + per-rank scratch + P2P staging, TP 2/4/6/8): "
"%d devices, threshold <= %zu bytes\n",
ndevice, kCustomAllreduceMaxBytes);
return;
fail_alloc:
if (sig_host_base != nullptr) {
#if defined(__HIP__) || defined(__HIPCC__)
hipHostFree(sig_host_base);
#else
cudaFreeHost(sig_host_base);
#endif
sig_host_base = nullptr;
}
for (int j = 0; j < ndevice; ++j) {
if (!have_alloc[j]) {
continue;
}
INFINICCL_AR_CUDA_CHECK(cudaSetDevice(device_ids[j]));
if (scratch_per_rank[j]) {
INFINICCL_AR_CUDA_CHECK(cudaFree(scratch_per_rank[j]));
}
if (rank_base[j]) {
INFINICCL_AR_CUDA_CHECK(cudaFree(rank_base[j]));
}
if (stg_base[j]) {
INFINICCL_AR_CUDA_CHECK(cudaFree(stg_base[j]));
}
}
}
#endif // ENABLE_HYGON_API
namespace infiniccl::cuda {
infiniStatus_t commSetHygonCustomAllreduce(
infinicclComm_t comm, void *custom_allreduce, void *reg_buffer, size_t reg_buffer_bytes) {
#if defined(ENABLE_HYGON_API)
if (comm == nullptr) {
return INFINI_STATUS_NULL_POINTER;
}
if (comm->device_type != INFINI_DEVICE_HYGON) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
if (comm->hygon_custom_owned && comm->hygon_ar_group != nullptr) {
return INFINI_STATUS_BAD_PARAM;
}
comm->custom_ar = custom_allreduce;
comm->custom_ar_reg_buf = reg_buffer;
comm->custom_ar_reg_sz = reg_buffer_bytes;
return INFINI_STATUS_SUCCESS;
#else
(void)comm;
(void)custom_allreduce;
(void)reg_buffer;
(void)reg_buffer_bytes;
return INFINI_STATUS_NOT_IMPLEMENTED;
#endif
}
infiniStatus_t commInitAll(
infiniDevice_t device_type,
infinicclComm_t *comms,
int ndevice,
const int *device_ids) {
......@@ -63,18 +426,57 @@ infiniStatus_t commInitAll(
CHECK_NCCL(ncclCommInitAll(nccl_comms.data(), ndevice, (int const *)device_ids));
for (int i = 0; i < ndevice; i++) {
comms[i] = new InfinicclComm{INFINI_DEVICE_NVIDIA, device_ids[i], (void *)(nccl_comms[i])};
comms[i] = new InfinicclComm{
device_type, device_ids[i], (void *)(nccl_comms[i]), nullptr, nullptr, 0, nullptr, false};
}
#if defined(ENABLE_HYGON_API)
hygonTryInitCommGroupCustomAllreduce(comms, ndevice, device_ids, device_type);
#endif
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t commDestroy(infinicclComm_t comm) {
#if defined(ENABLE_HYGON_API)
if (comm->hygon_custom_owned && comm->custom_ar != nullptr) {
HygonArGroup *g = static_cast<HygonArGroup *>(comm->hygon_ar_group);
// Set device before delete: ~CustomAllreduce calls cudaIpcCloseMemHandle
// which must run in the context of the device that opened the handles.
INFINICCL_AR_CUDA_CHECK(cudaSetDevice(comm->device_id));
delete static_cast<infiniccl_ar::CustomAllreduce *>(comm->custom_ar);
comm->custom_ar = nullptr;
comm->custom_ar_reg_buf = nullptr;
comm->custom_ar_reg_sz = 0;
if (g != nullptr) {
// fetch_sub 返回减之前的值;最后一次销毁时返回 1,此时原子量变为 0。
if (g->cars_remaining_to_destroy.fetch_sub(1, std::memory_order_acq_rel) == 1) {
g->freeAllDeviceAllocs();
delete g;
}
comm->hygon_ar_group = nullptr;
}
comm->hygon_custom_owned = false;
}
#endif
CHECK_NCCL(ncclCommDestroy(getNcclComm(comm)));
delete comm;
return INFINI_STATUS_SUCCESS;
}
#if defined(ENABLE_HYGON_API)
namespace {
bool customArTraceEnabled() {
const char *v = std::getenv("INFINICCL_CUSTOM_ALLREDUCE_TRACE");
return v != nullptr && v[0] != '\0' && v[0] != '0';
}
std::atomic<int> g_custom_ar_trace_exec{0};
} // namespace
#endif
infiniStatus_t allReduce(
void *sendbuf,
void *recvbuf,
......@@ -86,9 +488,154 @@ infiniStatus_t allReduce(
CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
cudaStream_t cuda_stream = getCudaStream(stream);
#if defined(ENABLE_HYGON_API)
const size_t elem_sz = elemSizeBytes(datatype);
const size_t nbytes = count * elem_sz;
infiniccl_ar::CustomAllreduce *custom =
comm->device_type == INFINI_DEVICE_HYGON && comm->custom_ar
? static_cast<infiniccl_ar::CustomAllreduce *>(comm->custom_ar)
: nullptr;
bool try_custom = custom != nullptr && op == INFINICCL_SUM && nbytes > 0 &&
nbytes <= kCustomAllreduceMaxBytes && count <= static_cast<size_t>(std::numeric_limits<int>::max());
if (hygonCustomAllreduceDisabledByEnv()) {
try_custom = false;
}
bool custom_ar_executed = false;
// Opt-in diagnostic: set INFINICCL_CUSTOM_ALLREDUCE_DEBUG=1 to see which
// path each size bucket takes (printed once per bucket). Useful for
// verifying that decode path actually hits the custom kernel.
{
static bool debug = []() {
const char *v = std::getenv("INFINICCL_CUSTOM_ALLREDUCE_DEBUG");
return v != nullptr && v[0] != '0' && v[0] != '\0';
}();
if (debug) {
static bool p_null = false, p_big = false, p_ok = false;
if (custom == nullptr && !p_null) {
std::fprintf(stderr, "[infiniccl] custom_ar not available, all allreduce use NCCL\n");
p_null = true;
} else if (custom != nullptr && nbytes > kCustomAllreduceMaxBytes && !p_big) {
std::fprintf(stderr, "[infiniccl] large allreduce nbytes=%zu > %zu, use NCCL\n",
nbytes, kCustomAllreduceMaxBytes);
p_big = true;
} else if (try_custom && !p_ok) {
std::fprintf(stderr, "[infiniccl] small allreduce nbytes=%zu, use custom AR\n", nbytes);
p_ok = true;
}
}
}
if (customArTraceEnabled()) {
static std::atomic<bool> trace_banner{false};
if (!trace_banner.exchange(true, std::memory_order_relaxed)) {
std::fprintf(stderr,
"[infiniccl] INFINICCL_CUSTOM_ALLREDUCE_TRACE is on: will print up to 128 custom AR invocations "
"and up to 48 NCCL fallbacks after try_custom (per process).\n");
}
}
if (try_custom) {
void *input_ptr = sendbuf;
if (comm->custom_ar_reg_buf != nullptr) {
if (nbytes > comm->custom_ar_reg_sz) {
return INFINI_STATUS_BAD_PARAM;
}
INFINICCL_AR_CUDA_CHECK(cudaMemcpyAsync(
comm->custom_ar_reg_buf, sendbuf, nbytes, cudaMemcpyDeviceToDevice, cuda_stream));
input_ptr = comm->custom_ar_reg_buf;
}
const int numel = static_cast<int>(count);
try {
switch (datatype) {
case INFINI_DTYPE_F32: {
constexpr int d = infiniccl_ar::packed_t<float>::P::size;
if (numel % d == 0) {
custom->allreduce<float>(cuda_stream, static_cast<float *>(input_ptr),
static_cast<float *>(recvbuf), numel);
custom_ar_executed = true;
if (customArTraceEnabled()) {
const int k = g_custom_ar_trace_exec.fetch_add(1, std::memory_order_relaxed);
if (k < 128) {
std::fprintf(stderr,
"[infiniccl] custom AR exec #%d dev=%d nbytes=%zu count=%zu dtype=f32 "
"staging=%d\n",
k, comm->device_id, nbytes, count, comm->custom_ar_reg_buf != nullptr ? 1 : 0);
}
}
return INFINI_STATUS_SUCCESS;
}
break;
}
case INFINI_DTYPE_F16: {
constexpr int d = infiniccl_ar::packed_t<half>::P::size;
if (numel % d == 0) {
custom->allreduce<half>(cuda_stream, static_cast<half *>(input_ptr),
static_cast<half *>(recvbuf), numel);
custom_ar_executed = true;
if (customArTraceEnabled()) {
const int k = g_custom_ar_trace_exec.fetch_add(1, std::memory_order_relaxed);
if (k < 128) {
std::fprintf(stderr,
"[infiniccl] custom AR exec #%d dev=%d nbytes=%zu count=%zu dtype=f16 "
"staging=%d\n",
k, comm->device_id, nbytes, count, comm->custom_ar_reg_buf != nullptr ? 1 : 0);
}
}
return INFINI_STATUS_SUCCESS;
}
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) || defined(__HIP__) || defined(__HIPCC__) || defined(ENABLE_HYGON_API))
case INFINI_DTYPE_BF16: {
constexpr int d = infiniccl_ar::packed_t<nv_bfloat16>::P::size;
if (numel % d == 0) {
custom->allreduce<nv_bfloat16>(cuda_stream, static_cast<nv_bfloat16 *>(input_ptr),
static_cast<nv_bfloat16 *>(recvbuf), numel);
custom_ar_executed = true;
if (customArTraceEnabled()) {
const int k = g_custom_ar_trace_exec.fetch_add(1, std::memory_order_relaxed);
if (k < 128) {
std::fprintf(stderr,
"[infiniccl] custom AR exec #%d dev=%d nbytes=%zu count=%zu dtype=bf16 "
"staging=%d\n",
k, comm->device_id, nbytes, count, comm->custom_ar_reg_buf != nullptr ? 1 : 0);
}
}
return INFINI_STATUS_SUCCESS;
}
break;
}
#endif
default:
break;
}
} catch (const std::exception &) {
// Unregistered buffer, unsupported world size, etc.: fall back to NCCL.
}
}
if (customArTraceEnabled() && try_custom && !custom_ar_executed) {
static std::atomic<int> nfallback{0};
const int f = nfallback.fetch_add(1, std::memory_order_relaxed);
if (f < 48) {
std::fprintf(stderr,
"[infiniccl] try_custom set but NCCL path dev=%d nbytes=%zu count=%zu dtype=%d "
"(alignment / unregistered / exception)\n",
comm->device_id, nbytes, count, static_cast<int>(datatype));
}
}
#endif
CHECK_NCCL(ncclAllReduce(sendbuf, recvbuf, count, getNcclDtype(datatype),
getNcclRedOp(op), getNcclComm(comm), getCudaStream(stream)));
getNcclRedOp(op), getNcclComm(comm), cuda_stream));
return INFINI_STATUS_SUCCESS;
}
} // namespace infiniccl::cuda
#if defined(ENABLE_HYGON_API)
namespace infiniccl_ar {
template void CustomAllreduce::allreduce<nv_bfloat16>(cudaStream_t, nv_bfloat16 *, nv_bfloat16 *, int, int, int);
} // namespace infiniccl_ar
#endif
// InfiniCore: adapted from sglang (vLLM-style) custom allreduce.\n// Adapted from https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cuh
#pragma once
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
// DTK nvcc 编 gfx 时 device pass 常不带 __HIP__;海光构建会定义 ENABLE_HYGON_API。
// 必须与 HIP 版 Signal / barrier / 2stage(scratch) 一致,否则与 host 侧分配的 Signal 布局不符 → VMFault。
#if defined(__HIP__) || defined(__HIPCC__) || defined(ENABLE_HYGON_API)
#define INFINICCL_AR_USE_HIP_SIGNAL_PATH 1
#endif
// 仅真 HIP 编译时拉 hip 头。DTK nvcc + ENABLE_HYGON_API 时 cuda 侧已带 hip* 类型 shim,
// 再 include <hip/hip_runtime_api.h> 会与 cuda_device_runtime_internal.h 里 typedef 冲突。
#if defined(__HIP__) || defined(__HIPCC__)
#include <hip/hip_runtime_api.h>
#if __has_include(<hip/hip_ext.h>)
#include <hip/hip_ext.h>
#endif
#endif
#include <array>
#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#define INFINICCL_AR_CUDA_CHECK(cmd) \
do { \
cudaError_t __e = (cmd); \
if (__e != cudaSuccess) { \
std::fprintf(stderr, "infiniccl_ar CUDA error %s at %s:%d\n", cudaGetErrorString(__e), __FILE__, __LINE__); \
std::abort(); \
} \
} while (0)
namespace infiniccl_ar {
constexpr int kMaxBlocks = 36;
#if defined(INFINICCL_AR_USE_HIP_SIGNAL_PATH)
constexpr int kDefaultBlockLimit = 16;
#else
constexpr int kDefaultBlockLimit = 36;
#endif
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using FlagType = uint32_t;
#if defined(INFINICCL_AR_USE_HIP_SIGNAL_PATH)
// ROCm/HIP (Hygon DCU): vLLM USE_ROCM path — device atomics + this layout.
// The CUDA-only PTX (st.volatile.global / ld.*.sys) path below is not valid for
// cross-device visibility here and can deadlock.
struct Signal {
alignas(128) FlagType start[kMaxBlocks][8];
alignas(128) FlagType end[kMaxBlocks][8];
alignas(128) FlagType _flag[kMaxBlocks];
};
#else
struct Signal {
alignas(128) FlagType self_counter[kMaxBlocks][8];
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
};
#endif
struct __align__(16) RankData {
// No __restrict__ on members: it breaks implicit copy-assignment used by std::vector<RankData>.
const void* ptrs[8];
};
struct __align__(16) RankSignals {
Signal* signals[8];
#if defined(INFINICCL_AR_USE_HIP_SIGNAL_PATH)
/** Per-rank 2stage scratch (device VRAM); InfiniCore Hygon: separate from host-mapped Signal. */
void* scratch[8];
#endif
};
// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
T data[sz];
using type = T;
static constexpr int size = sz;
};
// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
// the (P)acked type for load/store
using P = array_t<T, 16 / sizeof(T)>;
// the (A)ccumulator type for reduction
using A = array_t<float, 16 / sizeof(T)>;
};
#define DINLINE __device__ __forceinline__
// scalar cast functions
DINLINE float upcast_s(half val) {
return __half2float(val);
}
template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
return __float2half(val);
}
// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE half& assign_add(half& a, half b) {
a = __hadd(a, b);
return a;
}
DINLINE float& assign_add(float& a, float b) {
return a += b;
}
// Host: __CUDA_ARCH__ is undefined → include. CUDA sm >= 80: include.
// HIP/DTK: __HIP__ 等宏在部分 device 编译 pass 中不存在;海光 infiniccl-hygon 用 ENABLE_HYGON_API
// 强制生成 BF16 device 代码,否则会出现 cross_device_reduce_*<nv_bfloat16,*> 符号缺失。
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) || defined(__HIP__) || defined(__HIPCC__) || defined(ENABLE_HYGON_API))
DINLINE float upcast_s(nv_bfloat16 val) {
return __bfloat162float(val);
}
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val);
}
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b);
return a;
}
// #endif
template <typename T, int N>
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll
for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]);
}
return a;
}
template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
if constexpr (std::is_same<T, float>::value) {
return val;
} else {
array_t<float, N> out;
#pragma unroll
for (int i = 0; i < N; i++) {
out.data[i] = upcast_s(val.data[i]);
}
return out;
}
}
template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
if constexpr (std::is_same<typename O::type, float>::value) {
return val;
} else {
O out;
#pragma unroll
for (int i = 0; i < O::size; i++) {
out.data[i] = downcast_s<typename O::type>(val.data[i]);
}
return out;
}
}
#if !defined(INFINICCL_AR_USE_HIP_SIGNAL_PATH)
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#else
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#endif
}
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
FlagType flag;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#else
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" : "=r"(flag) : "l"(flag_addr));
#endif
return flag;
}
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
FlagType flag;
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
return flag;
}
// is_start: whether this is the very first synchronization barrier.
// need_fence: whether a memory fence is needed. If true, a release-acquire
// semantic is used to enforce memory access order before and after this
// barrier.
template <int ngpus, bool is_start, bool need_fence = false>
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, int rank) {
if constexpr (!is_start) __syncthreads();
static_assert(!(is_start && need_fence)); // Start barrier shouldn't need fence.
if (threadIdx.x < ngpus) {
// Increment the counter. Technically we only need one counter, but we use
// multiple per block to eliminate the need to share the counter via smem.
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
// Write the expected counter value to peer and wait for correct value from
// peer.
auto peer_counter_ptr = &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
auto self_counter_ptr = &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
if constexpr (need_fence) {
st_flag_release(peer_counter_ptr, val);
while (ld_flag_acquire(self_counter_ptr) != val)
;
} else {
st_flag_volatile(peer_counter_ptr, val);
while (ld_flag_volatile(self_counter_ptr) != val)
;
}
}
if constexpr (is_start || need_fence) __syncthreads();
}
#else
// ROCm/HIP barrier — matches vLLM's USE_ROCM path exactly.
// Requires hipDeviceMallocUncached memory shared via IPC handles
// (cudaIpcOpenMemHandle) for cross-device visibility.
template <int ngpus>
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
int rank) {
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag,
__ATOMIC_RELAXED);
while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED) < flag);
}
__syncthreads();
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}
template <int ngpus, bool final_sync = false>
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
__syncthreads();
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
flag);
}
if constexpr (!final_sync) __syncthreads();
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}
#endif
template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]);
#pragma unroll
for (int i = 1; i < ngpus; i++) {
packed_assign_add(tmp, upcast(ptrs[i][idx]));
}
return downcast<P>(tmp);
}
#if !defined(INFINICCL_AR_USE_HIP_SIGNAL_PATH)
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
}
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
}
#else
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
auto dp = *_dp;
barrier_at_start<ngpus>(sg, self_sg, rank);
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
}
barrier_at_end<ngpus, true>(sg, self_sg, rank);
}
#endif
template <typename P>
DINLINE P* get_tmp_buf(Signal* sg) {
return (P*)(((Signal*)sg) + 1);
}
#if !defined(INFINICCL_AR_USE_HIP_SIGNAL_PATH)
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P* ptrs[ngpus];
P* tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P*)result)[dst_idx] = tmps[i][idx];
}
}
}
}
#else
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P* ptrs[ngpus];
P* tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = reinterpret_cast<P*>(sg.scratch[target]);
}
auto tmp_out = tmps[0];
barrier_at_start<ngpus>(sg, self_sg, rank);
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
barrier_at_end<ngpus>(sg, self_sg, rank);
for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P*)result)[dst_idx] = tmps[i][idx];
}
}
}
}
#endif
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
class CustomAllreduce {
public:
int rank_;
int world_size_;
bool full_nvlink_;
RankSignals sg_;
// Stores an map from a pointer to its peer pointters from all ranks.
std::unordered_map<void*, RankData*> buffers_;
Signal* self_sg_;
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// For cuda graph to work, all kernel arguments must be fixed during graph
// capture time. However, the peer pointers are not known during graph capture
// time. Therefore, during capture, we increment the rank data pointer and use
// that as the argument to the kernel. The kernel arguments are stored in
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
// memory pointed to by the pointers in graph_unreg_buffers_ when
// the IPC handles are exchanged between ranks.
//
// The overall process looks like this:
// 1. Graph capture.
// 2. Each rank obtains the IPC handles for each addresses used during cuda
// graph capture using get_graph_buffer_ipc_meta.
// 3. (In Python) all gather the IPC handles.
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
// the rank data array at corresponding positions.
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char*> ipc_handles_;
/**
* @param scratch_bufs HIP/InfiniCore Hygon: per-rank device pointers to 2stage scratch (kCustomAllreduceMaxBytes each).
* CUDA: may be nullptr (scratch lives after Signal in device allocation).
* Note: this class does not own any device memory; buffers are passed in.
*/
CustomAllreduce(Signal** signals, void** scratch_bufs, void* rank_data, size_t rank_data_sz, int rank, int world_size,
bool full_nvlink = true)
: rank_(rank),
world_size_(world_size),
full_nvlink_(full_nvlink),
self_sg_(signals[rank]),
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) {
sg_.signals[i] = signals[i];
}
#if defined(INFINICCL_AR_USE_HIP_SIGNAL_PATH)
for (int i = 0; i < world_size_; i++) {
sg_.scratch[i] = scratch_bufs != nullptr ? scratch_bufs[i] : nullptr;
}
#else
(void)scratch_bufs;
#endif
}
char* open_ipc_handle(const void* ipc_handle) {
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
INFINICCL_AR_CUDA_CHECK(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size();
auto handle_sz = sizeof(cudaIpcMemHandle_t);
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i];
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr");
INFINICCL_AR_CUDA_CHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
}
return std::make_pair(handles, offsets);
}
void check_rank_data_capacity(size_t num = 1) {
if (d_rank_data_base_ + num > d_rank_data_end_)
throw std::runtime_error(
"Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}
/**
* Register already-shared IPC pointers.
*/
void register_buffer(void** ptrs) {
check_rank_data_capacity();
RankData data;
for (int i = 0; i < world_size_; i++) {
data.ptrs[i] = ptrs[i];
}
auto d_data = d_rank_data_base_++;
INFINICCL_AR_CUDA_CHECK(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
buffers_[ptrs[rank_]] = d_data;
}
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i];
auto& rd = rank_data[i];
for (int j = 0; j < world_size_; j++) {
if (j != rank_) {
char* handle = open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
rd.ptrs[j] = self_ptr;
}
}
}
INFINICCL_AR_CUDA_CHECK(
cudaMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice));
d_rank_data_base_ += num_buffers;
graph_unreg_buffers_.clear();
}
/**
* Performs allreduce, assuming input has already been registered.
*
* Block and grid default configs are results after careful grid search. Using
* 36 blocks give the best or close to the best runtime on the devices I
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
* take a small amount of SMs. Not quite sure the underlying reason, but my
* guess is that too many SMs will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(
cudaStream_t stream, T* input, T* output, int size, int threads = 512, int block_limit = kDefaultBlockLimit) {
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error(
"max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
RankData* ptrs;
auto it = buffers_.find(input);
if (it != buffers_.end()) {
// Pre-registered buffer (e.g. staging buffer in InfiniCore single-process P2P).
// RankData on device already has correct peer pointers — safe for both
// eager execution and CUDA graph capture/replay.
ptrs = it->second;
} else {
// Unregistered buffer — only valid during graph capture.
// Peer pointers must be filled later via register_graph_buffers().
cudaStreamCaptureStatus status;
INFINICCL_AR_CUDA_CHECK(cudaStreamIsCapturing(stream, &status));
if (status != cudaStreamCaptureStatusActive)
throw std::runtime_error(
"buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
}
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
// Check environment variable once
const char* env_algo = std::getenv("INFINICCL_CUSTOM_ALLREDUCE_ALGO");
bool force_1stage = false;
bool force_2stage = false;
if (env_algo != nullptr) {
if (std::strcmp(env_algo, "1stage") == 0 || std::strcmp(env_algo, "oneshot") == 0) {
force_1stage = true;
} else if (std::strcmp(env_algo, "2stage") == 0 || std::strcmp(env_algo, "twoshot") == 0) {
force_2stage = true;
} else {
throw std::runtime_error(
"Invalid INFINICCL_CUSTOM_ALLREDUCE_ALGO: " + std::string(env_algo) +
". Valid values: 1stage, oneshot, 2stage, twoshot");
}
}
#define KL(ngpus, name) name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
// TODO(hanzhi713): Threshold is different for A100 and H100.
// Add per device threshold.
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (force_1stage) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (force_2stage) { \
KL(ngpus, cross_device_reduce_2stage); \
} else { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full_nvlink_) { \
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} \
break; \
}
switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
}
#undef REDUCE_CASE
#undef KL
}
~CustomAllreduce() {
for (auto [_, ptr] : ipc_handles_) {
INFINICCL_AR_CUDA_CHECK(cudaIpcCloseMemHandle(ptr));
}
}
};
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
* template void sglang::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int);
*/
} // namespace infiniccl_ar
......@@ -7,6 +7,11 @@
#include "./metax/infiniccl_metax.h"
#include "./moore/infiniccl_moore.h"
namespace infiniccl::cuda {
infiniStatus_t commSetHygonCustomAllreduce(
infinicclComm_t comm, void *custom_allreduce, void *reg_buffer, size_t reg_buffer_bytes);
}
__INFINI_C infiniStatus_t infinicclCommInitAll(
infiniDevice_t device_type,
infinicclComm_t *comms,
......@@ -15,7 +20,7 @@ __INFINI_C infiniStatus_t infinicclCommInitAll(
#define COMM_INIT_ALL(CASE_, NAMESPACE_) \
case CASE_: \
return infiniccl::NAMESPACE_::commInitAll(comms, ndevice, device_ids)
return infiniccl::NAMESPACE_::commInitAll(device_type, comms, ndevice, device_ids)
switch (device_type) {
COMM_INIT_ALL(INFINI_DEVICE_NVIDIA, cuda);
......@@ -61,6 +66,23 @@ __INFINI_C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) {
#undef COMM_DESTROY
}
__INFINI_C infiniStatus_t infinicclCommSetHygonCustomAllreduce(
infinicclComm_t comm,
void *custom_allreduce,
void *reg_buffer,
size_t reg_buffer_bytes) {
if (comm == nullptr) {
return INFINI_STATUS_NULL_POINTER;
}
if (comm->device_type != INFINI_DEVICE_HYGON) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return infiniccl::cuda::commSetHygonCustomAllreduce(comm, custom_allreduce, reg_buffer, reg_buffer_bytes);
}
__INFINI_C infiniStatus_t infinicclAllReduce(
void *sendbuf,
void *recvbuf,
......
......@@ -3,15 +3,29 @@
#include "infiniccl.h"
#include <cstddef>
struct InfinicclComm {
infiniDevice_t device_type;
int device_id; // the actual device ID, not rank number
void *comm; // the actual communicator
/** Optional infiniccl_ar::CustomAllreduce* (Hygon DCU build only); nullptr disables hybrid path. */
void *custom_ar;
/** Optional staging buffer: sendbuf is copied here before custom AR (graph / unregistered send). */
void *custom_ar_reg_buf;
size_t custom_ar_reg_sz;
/**
* Hygon: when commInitAll auto-wires custom allreduce, all ranks share this group for teardown order
* (last destroy frees cudaMalloc bases). Opaque HygonArGroup* in cuda .cu.
*/
void *hygon_ar_group;
bool hygon_custom_owned;
};
#define INFINICCL_DEVICE_API(NAMSPACE, IMPL) \
namespace infiniccl::NAMSPACE { \
infiniStatus_t commInitAll( \
infiniDevice_t device_type, \
infinicclComm_t *comms, \
int ndevice, \
const int *device_ids) IMPL; \
......
......@@ -57,6 +57,7 @@ inline BKCLOp getBkclRedOp(infinicclReduceOp_t op) {
namespace infiniccl::kunlun {
infiniStatus_t commInitAll(
infiniDevice_t device_type,
infinicclComm_t *comms,
int ndevice,
const int *device_ids) {
......@@ -64,7 +65,7 @@ infiniStatus_t commInitAll(
CHECK_BKCL(bkcl_comm_init_all(bkcl_comms.data(), ndevice, device_ids));
for (int i = 0; i < ndevice; i++) {
comms[i] = new InfinicclComm{INFINI_DEVICE_KUNLUN, device_ids[i], (void *)(bkcl_comms[i])};
comms[i] = new InfinicclComm{device_type, device_ids[i], (void *)(bkcl_comms[i]), nullptr, nullptr, 0, nullptr, false};
}
return INFINI_STATUS_SUCCESS;
......
......@@ -61,6 +61,7 @@ inline hcclComm_t getHcclComm(infinicclComm_t comm) {
namespace infiniccl::metax {
infiniStatus_t commInitAll(
infiniDevice_t device_type,
infinicclComm_t *comms,
int ndevice,
const int *device_ids) {
......@@ -69,7 +70,7 @@ infiniStatus_t commInitAll(
CHECK_HCCL(hcclCommInitAll(hccl_comms.data(), ndevice, (int const *)device_ids));
for (int i = 0; i < ndevice; i++) {
comms[i] = new InfinicclComm{INFINI_DEVICE_METAX, device_ids[i], (void *)(hccl_comms[i])};
comms[i] = new InfinicclComm{device_type, device_ids[i], (void *)(hccl_comms[i]), nullptr, nullptr, 0, nullptr, false};
}
return INFINI_STATUS_SUCCESS;
......
......@@ -60,6 +60,7 @@ inline mcclComm_t getMcclComm(infinicclComm_t comm) {
namespace infiniccl::moore {
infiniStatus_t commInitAll(
infiniDevice_t device_type,
infinicclComm_t *comms,
int ndevice,
const int *device_ids) {
......@@ -68,7 +69,7 @@ infiniStatus_t commInitAll(
CHECK_MCCL(mcclCommInitAll(mccl_comms.data(), ndevice, (int const *)device_ids));
for (int i = 0; i < ndevice; i++) {
comms[i] = new InfinicclComm{INFINI_DEVICE_MOORE, device_ids[i], (void *)(mccl_comms[i])};
comms[i] = new InfinicclComm{device_type, device_ids[i], (void *)(mccl_comms[i]), nullptr, nullptr, 0, nullptr, false};
}
return INFINI_STATUS_SUCCESS;
......
// Hygon DCU decode attention backend.
//
// Overrides the ALLDEVICE flashattn registration so that Hygon uses the
// correct HIP stream guard (TorchStreamGuard) and calls mha_fwd_kvcache.
#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_HYGON_API) && !defined(ENABLE_NVIDIA_API)
#include "infinicore/ops/mha_kvcache.hpp"
#include "infinicore/adaptor/flash_attention_adaptor.hpp"
#include <stdexcept>
namespace infinicore::op::mha_kvcache_impl::hygon_paged {
struct PlannedMeta {
graph::GraphTensor out, q, k_cache, v_cache, seqlens_k, block_table;
std::optional<graph::GraphTensor> alibi_slopes;
float scale;
};
void *plan(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
return new PlannedMeta{
graph::GraphTensor(out),
graph::GraphTensor(q),
graph::GraphTensor(k_cache),
graph::GraphTensor(v_cache),
graph::GraphTensor(seqlens_k),
graph::GraphTensor(block_table),
alibi_slopes ? std::optional<graph::GraphTensor>(graph::GraphTensor(*alibi_slopes)) : std::nullopt,
scale};
}
void run(void *planned_meta) {
infinicore::adaptor::TorchStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out);
auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k));
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
auto alibi_slopes = p->alibi_slopes
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes))
: std::nullopt;
std::optional<const at::Tensor> k_new = std::nullopt;
std::optional<const at::Tensor> v_new = std::nullopt;
std::optional<const at::Tensor> rotary_cos = std::nullopt;
std::optional<const at::Tensor> rotary_sin = std::nullopt;
std::optional<const at::Tensor> cache_batch_idx = std::nullopt;
std::optional<const at::Tensor> leftpad_k = std::nullopt;
const bool use_dynamic_out = q.dim() == 4 && k_cache.dim() == 4
&& q.size(1) == 1 && q.size(2) > k_cache.size(2)
&& q.size(3) % 8 == 0 && !alibi_slopes.has_value();
auto out = use_dynamic_out ? std::optional<at::Tensor>(std::nullopt)
: std::optional<at::Tensor>(out_tensor);
auto result = flash::mha_fwd_kvcache(
q,
k_cache,
v_cache,
k_new,
v_new,
seqlens_k,
rotary_cos,
rotary_sin,
cache_batch_idx,
leftpad_k,
block_table,
alibi_slopes,
out,
p->scale,
true,
-1,
-1,
0.0f,
false,
0);
if (use_dynamic_out) {
out_tensor.copy_(result[0]);
}
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
// Register for Hygon device only, overriding the ALLDEVICE flashattn registration.
static bool registered = []() {
MhaKVCache::plan_dispatcher().registerDevice(Device::Type::HYGON, &plan, true);
MhaKVCache::run_dispatcher().registerDevice(Device::Type::HYGON, &run, true);
MhaKVCache::cleanup_dispatcher().registerDevice(Device::Type::HYGON, &cleanup, true);
return true;
}();
} // namespace infinicore::op::mha_kvcache_impl::hygon_paged
#endif // ENABLE_FLASH_ATTN && ENABLE_HYGON_API && !ENABLE_NVIDIA_API
// Hygon DCU prefill attention backend.
//
// Overrides the ALLDEVICE flashattn registration so that Hygon uses the
// correct HIP stream guard (TorchStreamGuard) and calls mha_varlen_fwd.
#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_HYGON_API) && !defined(ENABLE_NVIDIA_API)
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/adaptor/flash_attention_adaptor.hpp"
#include <stdexcept>
namespace infinicore::op::mha_varlen_impl::hygon_vllm {
struct PlannedMeta {
graph::GraphTensor out, q, k, v, cum_seqlens_q, cum_seqlens_k, block_table;
int max_seqlen_q, max_seqlen_k;
std::optional<graph::GraphTensor> alibi_slopes;
float scale;
};
void *plan(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_k,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
return new PlannedMeta{
graph::GraphTensor(out),
graph::GraphTensor(q),
graph::GraphTensor(k),
graph::GraphTensor(v),
graph::GraphTensor(cum_seqlens_q),
graph::GraphTensor(cum_seqlens_k),
graph::GraphTensor(block_table),
max_seqlen_q,
max_seqlen_k,
alibi_slopes ? std::optional<graph::GraphTensor>(graph::GraphTensor(*alibi_slopes)) : std::nullopt,
scale};
}
void run(void *planned_meta) {
infinicore::adaptor::TorchStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k = infinicore::adaptor::to_aten_tensor(p->k);
auto v = infinicore::adaptor::to_aten_tensor(p->v);
auto out = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->out));
auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q);
auto cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k);
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
auto device = q.device();
if (!cu_seqlens_q.is_cuda()) cu_seqlens_q = cu_seqlens_q.to(device);
if (!cu_seqlens_kv.is_cuda()) cu_seqlens_kv = cu_seqlens_kv.to(device);
if (block_table.has_value() && !block_table->is_cuda()) block_table = block_table->to(device);
std::optional<at::Tensor> seqused_k = std::nullopt;
std::optional<const at::Tensor> leftpad_k = std::nullopt;
auto alibi_slopes = p->alibi_slopes
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes))
: std::nullopt;
flash::mha_varlen_fwd(
q, k, v, out,
cu_seqlens_q, cu_seqlens_kv,
seqused_k, leftpad_k, block_table, alibi_slopes,
p->max_seqlen_q, p->max_seqlen_k,
0.0f, p->scale, false, true,
-1, -1, 0.0f, false,
std::nullopt);
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
// Register for Hygon device only, overriding the ALLDEVICE flashattn registration.
static bool registered = []() {
MultiheadAttentionVarlen::plan_dispatcher().registerDevice(Device::Type::HYGON, &plan, true);
MultiheadAttentionVarlen::run_dispatcher().registerDevice(Device::Type::HYGON, &run, true);
MultiheadAttentionVarlen::cleanup_dispatcher().registerDevice(Device::Type::HYGON, &cleanup, true);
return true;
}();
} // namespace infinicore::op::mha_varlen_impl::hygon_vllm
#endif // ENABLE_FLASH_ATTN && ENABLE_HYGON_API && !ENABLE_NVIDIA_API
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment