#pragma once /* * Copyright (C) Advanced Micro Devices, Inc. All rights reserved. * Copyright (C) 2024-2026, The vLLM team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "aiter_hip_common.h" #include "hip_float8.h" #include "opus/opus.hpp" #include #include #include #include #include #include #include #include #include namespace aiter { constexpr int kMaxBlocks = 80; // note: we don't want to use atomics for signals because peer atomics are no // supported on PCIe links struct Signal { alignas(128) uint32_t start[kMaxBlocks][8]; alignas(128) uint32_t end[kMaxBlocks][8]; alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank }; #ifdef USE_ROCM struct __align__(16) RankData { const void* ptrs[8]; }; #else struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; #endif struct __align__(16) RankSignals { #ifndef USE_ROCM volatile #endif Signal* signals[8]; }; #define DINLINE __device__ __forceinline__ // scalar cast functions template DINLINE opus::fp32_t upcast_s(inp_dtype val) { return opus::cast(val); } template <> DINLINE opus::fp32_t upcast_s(opus::fp32_t val) { return val; } template DINLINE out_dtype downcast_s(opus::fp32_t val) { return opus::cast(val); } template <> DINLINE opus::fp32_t downcast_s(opus::fp32_t val) { return val; } // scalar add functions // for some reason when compiling with Pytorch, the + operator for half and // bfloat is disabled so we call the intrinsics directly template DINLINE opus::vector_t& packed_assign_add(opus::vector_t& a, opus::vector_t b) { if constexpr(std::is_same::value) { a += b; } else { #pragma unroll for(int i = 0; i < N; i++) { a[i] = downcast_s(upcast_s(a[i]) + upcast_s(b[i])); } } return a; } // not support fp8 pack convert template , bool> = true> DINLINE auto upcast(V val) -> opus::vector_t::size()> { using T = typename opus::vector_traits::dtype; constexpr int N = opus::vector_traits::size(); if constexpr(std::is_same::value) { return val; } else { opus::vector_t out; #pragma unroll for(int i = 0; i < N; i++) { out[i] = upcast_s(val[i]); } return out; } } template , bool> = true> DINLINE O downcast(V val) { using T = typename opus::vector_traits::dtype; constexpr int N = opus::vector_traits::size(); if constexpr(std::is_same::value) { return val; } else { O out; #pragma unroll for(int i = 0; i < N; i++) { out[i] = downcast_s(val[i]); } return out; } } // This function is meant to be used as the first synchronization in the all // reduce kernel. Thus, it doesn't need to make any visibility guarantees for // prior memory accesses. Note: volatile writes will not be reordered against // other volatile writes. template DINLINE void start_sync(const RankSignals& sg, #ifndef USE_ROCM volatile #endif Signal* self_sg, int rank) { #ifdef USE_ROCM uint32_t flag = self_sg->_flag[blockIdx.x] + 1; if(threadIdx.x < ngpus) { // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM); // wait until we got true from all ranks while(__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) < flag) ; } __syncthreads(); // use one thread to update flag if(threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; #else if(threadIdx.x < ngpus) { // reset flag for next time self_sg->end[blockIdx.x][threadIdx.x] = 0; // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; // wait until we got true from all ranks while(!self_sg->start[blockIdx.x][threadIdx.x]) ; } __syncthreads(); #endif } // This function is meant to be used as the second or the final synchronization // barrier in the all reduce kernel. If it's the final synchronization barrier, // we don't need to make any visibility guarantees for prior memory accesses. template DINLINE void end_sync(const RankSignals& sg, #ifndef USE_ROCM volatile #endif Signal* self_sg, int rank) { #ifdef USE_ROCM __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of // testing. Might be the case that hardware provides stronger guarantee than // the memory model. uint32_t flag = self_sg->_flag[blockIdx.x] + 1; if(threadIdx.x < ngpus) { // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag, final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, __MEMORY_SCOPE_SYSTEM); // wait until we got true from all ranks while(__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, __MEMORY_SCOPE_DEVICE) < flag) ; } __syncthreads(); // use one thread to update flag if(threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; #else __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of // testing. Might be the case that hardware provides stronger guarantee than // the memory model. if constexpr(!final_sync) __threadfence_system(); if(threadIdx.x < ngpus) { // reset flag for next time self_sg->start[blockIdx.x][threadIdx.x] = 0; // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; // wait until we got true from all ranks while(!self_sg->end[blockIdx.x][threadIdx.x]) ; } if constexpr(!final_sync) __syncthreads(); #endif } template DINLINE P packed_reduce(const P* ptrs[], int idx) { A tmp = upcast(ptrs[0][idx]); #pragma unroll for(int i = 1; i < ngpus; i++) { packed_assign_add::dtype, opus::vector_traits::size()>( tmp, upcast(ptrs[i][idx])); } return downcast

(tmp); } template __global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage_naive(RankData* _input_dp, RankData* _output_dp, RankSignals sg, #ifndef USE_ROCM volatile #endif Signal* self_sg, T* __restrict__ result, int rank, int size) { constexpr int pack_size = 16 / sizeof(T); using P = typename opus::vector_t; using A = typename opus::vector_t; // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results auto dp = *_input_dp; start_sync(sg, self_sg, rank); // do the actual reduction for(int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } end_sync(sg, self_sg, rank); } template #ifdef USE_ROCM DINLINE P* get_tmp_buf(Signal* sg) { #else DINLINE P* get_tmp_buf(volatile Signal* sg) { #endif return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage_naive(RankData* _input_dp, RankData* _output_dp, RankSignals sg, #ifndef USE_ROCM volatile #endif Signal* self_sg, T* __restrict__ result, int rank, int size) { constexpr int pack_size = 16 / sizeof(T); int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename opus::vector_t; using A = typename opus::vector_t; int part = size / ngpus; int start = rank * part; int end = rank == ngpus - 1 ? size : start + part; int largest_part = part + size % ngpus; const P* ptrs[ngpus]; P* tmps[ngpus]; #pragma unroll for(int i = 0; i < ngpus; i++) { int target = (rank + i) % ngpus; ptrs[i] = (const P*)_input_dp->ptrs[target]; tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; start_sync(sg, self_sg, rank); // stage 1: reduce scatter for(int idx = start + tid; idx < end; idx += stride) { tmp_out[idx - start] = packed_reduce(ptrs, idx); } end_sync(sg, self_sg, rank); // stage 2: allgather. Note: it's important to match the tid between // the two stages, because visibility across devices is only guaranteed // between threads that have the same tid. If thread i computes the sum of // start + i in the first stage, then thread i also gathers start + i from all // ranks. for(int idx = tid; idx < largest_part; idx += stride) { #pragma unroll for(int i = 0; i < ngpus; i++) { int gather_from_rank = ((rank + i) % ngpus); if(gather_from_rank == ngpus - 1 || idx < part) { int dst_idx = gather_from_rank * part + idx; ((P*)result)[dst_idx] = tmps[i][idx]; } } } } #define THREAD_NUM 512 template __global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _input_dp, RankData* _output_dp, RankSignals sg, #ifndef USE_ROCM volatile #endif Signal* self_sg, T* __restrict__ result, int rank, int size) { constexpr int pack_size = 16 / sizeof(T); using P = typename opus::vector_t; using A = typename opus::vector_t; constexpr int tnum_gpu = THREAD_NUM / ngpus; // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results auto dp = *_input_dp; int warp_id = threadIdx.x / tnum_gpu; int lane_id = threadIdx.x % tnum_gpu; // --- double buffer: tmp_smem[0] and tmp_smem[1] --- __shared__ P tmp_smem[2][tnum_gpu * ngpus]; const int step = gridDim.x * tnum_gpu; const int start = blockIdx.x * tnum_gpu + lane_id; start_sync(sg, self_sg, rank); // --- compute uniform iteration count (to keep barriers well-formed) --- const int first = blockIdx.x * tnum_gpu; int iters = 0; { int rem = size - first; iters = rem > 0 ? (rem + step - 1) / step : 0; } // ------------------------------- // fill buffer 0 // ------------------------------- int buf = 0; int idx0 = start; if(idx0 < size) { P val = ((const P**)&dp.ptrs[0])[warp_id][idx0]; tmp_smem[buf][warp_id * tnum_gpu + lane_id] = val; } __syncthreads(); for(int it = 0; it < iters; ++it) { const int cur_idx = idx0 + it * step; const int next_idx = cur_idx + step; const int next_buf = buf ^ 1; // ======================================================= // 1. Warp 0 REDUCES current buffer // ======================================================= if(warp_id == 0 && cur_idx < size) { // GPU 0 contribution P v0 = tmp_smem[buf][0 * tnum_gpu + lane_id]; A acc; #pragma unroll for(int j = 0; j < pack_size; ++j) acc[j] = upcast_s(v0[j]); // GPUs 1..(ngpus-1) #pragma unroll for(int g = 1; g < ngpus; ++g) { P vg = tmp_smem[buf][g * tnum_gpu + lane_id]; #pragma unroll for(int j = 0; j < pack_size; ++j) acc[j] += upcast_s(vg[j]); } // store result P out; #pragma unroll for(int j = 0; j < pack_size; ++j) out[j] = downcast_s(acc[j]); ((P*)result)[cur_idx] = out; } // ======================================================= // 2. ALL warps prefetch NEXT buffer // (including warp 0; safe to issue after reduction) // ======================================================= if(next_idx < size) { P nxt = ((const P**)&dp.ptrs[0])[warp_id][next_idx]; tmp_smem[next_buf][warp_id * tnum_gpu + lane_id] = nxt; } __syncthreads(); buf = next_buf; } } template __global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _input_dp, RankData* _output_dp, RankSignals sg, #ifndef USE_ROCM volatile #endif Signal* self_sg, T* __restrict__ result, int rank, int size) { constexpr int pack_size = 16 / sizeof(T); constexpr int tnum_gpu = THREAD_NUM / ngpus; using P = typename opus::vector_t; using A = typename opus::vector_t; int warp_id = threadIdx.x / tnum_gpu; int lane_id = threadIdx.x % tnum_gpu; int tid = blockIdx.x * tnum_gpu + lane_id; int stride = gridDim.x * tnum_gpu; int part = size / ngpus; int start = rank * part; int end = rank == ngpus - 1 ? size : start + part; int largest_part = part + size % ngpus; __shared__ T tmp_smem[tnum_gpu * ngpus * pack_size]; const P* ptrs[ngpus]; P* tmps[ngpus]; #pragma unroll for(int i = 0; i < ngpus; i++) { int target = (rank + i) % ngpus; ptrs[i] = (const P*)_input_dp->ptrs[target]; tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; start_sync(sg, self_sg, rank); // stage 1: reduce scatter for(int idx = start + tid; idx < end; idx += stride) { *(reinterpret_cast(&tmp_smem[0]) + threadIdx.x) = ptrs[warp_id][idx]; __syncthreads(); // cal add in first 64 threads if(warp_id == 0) { A add_reg; #pragma unroll for(int i = 0; i < pack_size; ++i) { add_reg[i] = upcast_s(tmp_smem[pack_size * threadIdx.x + i]); } constexpr int smem_gpu_loop_stride = tnum_gpu * pack_size; #pragma unroll for(int i = 1; i < ngpus; ++i) { #pragma unroll for(int j = 0; j < pack_size; ++j) { add_reg[j] += upcast_s(tmp_smem[i * smem_gpu_loop_stride + pack_size * threadIdx.x + j]); } } P write_reg; #pragma unroll for(int i = 0; i < pack_size; ++i) { write_reg[i] = downcast_s(add_reg[i]); } tmp_out[idx - start] = write_reg; } __syncthreads(); } end_sync(sg, self_sg, rank); // stage 2: allgather. Note: it's important to match the tid between // the two stages, because visibility across devices is only guaranteed // between threads that have the same tid. If thread i computes the sum of // start + i in the first stage, then thread i also gathers start + i from all // ranks. for(int idx = tid; idx < largest_part; idx += stride) { int dst_idx = (warp_id + rank) % ngpus * part + idx; ((P*)result)[dst_idx] = tmps[warp_id][idx]; } } template __global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage_write_mode(RankData* _input_dp, RankData* _output_dp, RankSignals sg, #ifndef USE_ROCM volatile #endif Signal* self_sg, T* __restrict__ result, int rank, int size) { constexpr int pack_size = 16 / sizeof(T); constexpr int tnum_gpu = THREAD_NUM / ngpus; using P = typename opus::vector_t; using A = typename opus::vector_t; __shared__ T tmp_smem[tnum_gpu * ngpus * pack_size]; __shared__ T res_smem[tnum_gpu * pack_size]; int warp_id = threadIdx.x / tnum_gpu; int lane_id = threadIdx.x % tnum_gpu; int tid = blockIdx.x * tnum_gpu + lane_id; int stride = gridDim.x * tnum_gpu; int part = size / ngpus; P* output_ptrs[ngpus]; P* tmps[ngpus]; #pragma unroll for(int i = 0; i < ngpus; i++) { tmps[i] = get_tmp_buf

(sg.signals[i]); } if(is_broadcast_reg_outptr) { #pragma unroll for(int i = 0; i < ngpus; i++) { output_ptrs[i] = (P*)_output_dp->ptrs[i]; } } const P* input_ptr = (const P*)_input_dp->ptrs[rank]; auto tmp_out = tmps[rank]; int stage3_offset = size; // stage1: write local rank data to remote rank int start = warp_id * part; int end = warp_id == ngpus - 1 ? size : start + part; for(int idx = start + tid; idx < end; idx += stride) { tmps[warp_id][rank * part + idx - start] = input_ptr[idx]; } end_sync(sg, self_sg, rank); // stage 2: reduce scatter & write result to remote rank end = rank != ngpus - 1 ? part : size - part * (ngpus - 1); for(int idx = tid; idx < end; idx += stride) { *(reinterpret_cast(&tmp_smem[0]) + threadIdx.x) = tmp_out[warp_id * part + idx]; __syncthreads(); // cal add in first 64 threads if(warp_id == 0) { A add_reg; #pragma unroll for(int i = 0; i < pack_size; ++i) { add_reg[i] = upcast_s(tmp_smem[pack_size * threadIdx.x + i]); } constexpr int smem_gpu_loop_stride = tnum_gpu * pack_size; #pragma unroll for(int i = 1; i < ngpus; ++i) { #pragma unroll for(int j = 0; j < pack_size; ++j) { add_reg[j] += upcast_s(tmp_smem[i * smem_gpu_loop_stride + pack_size * threadIdx.x + j]); } } P write_reg; #pragma unroll for(int i = 0; i < pack_size; ++i) { write_reg[i] = downcast_s(add_reg[i]); } *(reinterpret_cast(&res_smem[0]) + lane_id) = write_reg; } __syncthreads(); // send data to remote rank if(is_broadcast_reg_outptr) { P temp_val = *(reinterpret_cast(&res_smem[0]) + lane_id); auto src_addr = (reinterpret_cast(&temp_val)); auto dst_addr = (reinterpret_cast(&output_ptrs[warp_id][rank * part + idx])); __builtin_nontemporal_store(*src_addr, dst_addr); __builtin_nontemporal_store(*(src_addr + 1), dst_addr + 1); __builtin_nontemporal_store(*(src_addr + 2), dst_addr + 2); __builtin_nontemporal_store(*(src_addr + 3), dst_addr + 3); } else { tmps[warp_id][rank * part + idx + stage3_offset] = *(reinterpret_cast(&res_smem[0]) + lane_id); } } end_sync(sg, self_sg, rank); if(!is_broadcast_reg_outptr) { // stage 3: get the output from tmp_buffer end = warp_id == ngpus - 1 ? size : start + part; for(int idx = start + tid; idx < end; idx += stride) { ((P*)result)[idx] = tmp_out[idx + stage3_offset]; } } } /* * naive allgather * for case: input(1345,) * */ template __global__ void __launch_bounds__(512, 1) allgather_naive( RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) { constexpr int tnum_gpu = THREAD_NUM / ngpus; int warp_id = threadIdx.x / tnum_gpu; int lane_id = threadIdx.x % tnum_gpu; int tid = blockIdx.x * tnum_gpu + lane_id; int stride = gridDim.x * tnum_gpu; const T* ptrs[ngpus]; #pragma unroll for(int i = 0; i < ngpus; ++i) { ptrs[i] = (const T*)_dp->ptrs[i]; } start_sync(sg, self_sg, rank); for(int idx = tid; idx < size; idx += stride) { int write_idx = warp_id * size + idx; result[write_idx] = ptrs[warp_id][idx]; } } template __global__ void __launch_bounds__(512, 1) allgather_vec( RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) { constexpr int tnum_gpu = THREAD_NUM / ngpus; constexpr int pack_size = 16 / sizeof(T); using P = typename opus::vector_t; int warp_id = threadIdx.x / tnum_gpu; int lane_id = threadIdx.x % tnum_gpu; int tid = blockIdx.x * tnum_gpu + lane_id; int stride = gridDim.x * tnum_gpu; const P* ptrs[ngpus]; #pragma unroll for(int i = 0; i < ngpus; ++i) { ptrs[i] = (const P*)_dp->ptrs[i]; } start_sync(sg, self_sg, rank); for(int idx = tid; idx < size; idx += stride) { int write_idx = warp_id * size + idx; *(reinterpret_cast(&result[0]) + write_idx) = ptrs[warp_id][idx]; } } template __global__ void __launch_bounds__(512, 1) allgather_lastdim(RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size, int last_dim_size) { constexpr int tnum_gpu = THREAD_NUM / ngpus; constexpr int pack_size = 16 / sizeof(T); using P = typename opus::vector_t; int warp_id = threadIdx.x / tnum_gpu; int lane_id = threadIdx.x % tnum_gpu; int tid = blockIdx.x * tnum_gpu + lane_id; int stride = gridDim.x * tnum_gpu; last_dim_size /= pack_size; const P* ptrs[ngpus]; #pragma unroll for(int i = 0; i < ngpus; ++i) { ptrs[i] = (const P*)_dp->ptrs[i]; } start_sync(sg, self_sg, rank); for(int idx = tid; idx < size; idx += stride) { int y = idx / last_dim_size; int x = idx % last_dim_size; int write_idx = (ngpus * y + warp_id) * last_dim_size + x; *(reinterpret_cast(&result[0]) + write_idx) = ptrs[warp_id][idx]; } } /* * reduce_scatter, at first dim * range = size / (pack_size * ngpu) * for case: * input:(ngpus * n) -> output:(n) * input:(ngpus * m, n, ...) -> output(m, n, ...) * cond: size % (pack_size * ngpus) == 0 * */ template __global__ void __launch_bounds__(512, 1) reduce_scatter_first_dim( RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int range) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; constexpr int pack_size = 16 / sizeof(T); using P = typename opus::vector_t; using A = typename opus::vector_t; const P* ptrs[ngpus]; #pragma unroll for(int i = 0; i < ngpus; i++) { int target = (rank + i) % ngpus; ptrs[i] = (const P*)_dp->ptrs[target]; } start_sync(sg, self_sg, rank); for(int idx = tid; idx < range; idx += stride) { int load_index = rank * range + idx; int store_index = idx; *(reinterpret_cast(result) + store_index) = packed_reduce(ptrs, load_index); } } // fp8 quant all-reduce code start template struct Fp16Filter { static const bool value = false; }; template <> struct Fp16Filter { static const bool value = true; }; template struct Bf16Filter { static const bool value = false; }; template <> struct Bf16Filter { static const bool value = true; }; // dtypes only support half and bf16 now #define FP16_FILTER typename std::enable_if::value, void>::type* = nullptr #define BF16_FILTER typename std::enable_if::value, void>::type* = nullptr template