#pragma once /* * Copyright (C) 2024-2025, The vLLM team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "aiter_hip_common.h" #include "ck_tile/core.hpp" #include "communication_asm.h" #include "hip_float8.h" #include #include #include #include #include #include #include #include namespace aiter { constexpr int kMaxBlocks = 80; // note: we don't want to use atomics for signals because peer atomics are no // supported on PCIe links struct Signal { alignas(128) uint32_t start[kMaxBlocks][8]; alignas(128) uint32_t end[kMaxBlocks][8]; alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank }; #ifdef USE_ROCM struct __align__(16) RankData { const void *ptrs[8]; }; #else struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; #endif struct __align__(16) RankSignals { #ifndef USE_ROCM volatile #endif Signal *signals[8]; }; // like std::array, but aligned template struct __align__(alignof(T) * sz) array_t { T data[sz]; using type = T; static constexpr int size = sz; }; // use packed type to maximize memory efficiency // goal: generate ld.128 and st.128 instructions template struct packed_t { // the (P)acked type for load/store using P = array_t; // the (A)ccumulator type for reduction using A = array_t; }; #define DINLINE __device__ __forceinline__ // scalar cast functions DINLINE float upcast_s(half val) { return __half2float(val); } template DINLINE T downcast_s(float val); template <> DINLINE half downcast_s(float val) { return __float2half(val); } // scalar add functions // for some reason when compiling with Pytorch, the + operator for half and // bfloat is disabled so we call the intrinsics directly DINLINE half &assign_add(half &a, half b) { a = __hadd(a, b); return a; } DINLINE float &assign_add(float &a, float b) { return a += b; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) DINLINE float upcast_s(__hip_bfloat16 val) { return __bfloat162float(val); } template <> DINLINE __hip_bfloat16 downcast_s(float val) { return __float2bfloat16(val); } DINLINE __hip_bfloat16 &assign_add(__hip_bfloat16 &a, __hip_bfloat16 b) { a = __hadd(a, b); return a; } #endif template DINLINE array_t &packed_assign_add(array_t &a, array_t b) { #pragma unroll for (int i = 0; i < N; i++) { assign_add(a.data[i], b.data[i]); } return a; } template DINLINE array_t upcast(array_t val) { if constexpr (std::is_same::value) { return val; } else { array_t out; #pragma unroll for (int i = 0; i < N; i++) { out.data[i] = upcast_s(val.data[i]); } return out; } } template DINLINE O downcast(array_t val) { if constexpr (std::is_same::value) { return val; } // else if constexpr (std::is_same::value) // { // O out; // #pragma unroll // for (int i = 0; i < O::size; i++) // { // union fcvt { // uint32_t i32; // float f32; // } u; // u.f32 = val.data[i]; // out.data[i] = __builtin_bit_cast(__hip_bfloat16, uint16_t(u.i32 >> 16)); // } // return out; // } else { O out; #pragma unroll for (int i = 0; i < O::size; i++) { out.data[i] = downcast_s(val.data[i]); } return out; } } // This function is meant to be used as the first synchronization in the all // reduce kernel. Thus, it doesn't need to make any visibility guarantees for // prior memory accesses. Note: volatile writes will not be reordered against // other volatile writes. template DINLINE void start_sync(const RankSignals &sg, #ifndef USE_ROCM volatile #endif Signal *self_sg, int rank) { #ifdef USE_ROCM uint32_t flag = self_sg->_flag[blockIdx.x] + 1; if (threadIdx.x < ngpus) { // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM); // wait until we got true from all ranks while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) < flag) ; } __syncthreads(); // use one thread to update flag if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; #else if (threadIdx.x < ngpus) { // reset flag for next time self_sg->end[blockIdx.x][threadIdx.x] = 0; // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; // wait until we got true from all ranks while (!self_sg->start[blockIdx.x][threadIdx.x]) ; } __syncthreads(); #endif } // This function is meant to be used as the second or the final synchronization // barrier in the all reduce kernel. If it's the final synchronization barrier, // we don't need to make any visibility guarantees for prior memory accesses. template DINLINE void end_sync(const RankSignals &sg, #ifndef USE_ROCM volatile #endif Signal *self_sg, int rank) { #ifdef USE_ROCM __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of // testing. Might be the case that hardware provides stronger guarantee than // the memory model. uint32_t flag = self_sg->_flag[blockIdx.x] + 1; if (threadIdx.x < ngpus) { // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag, final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, __MEMORY_SCOPE_SYSTEM); // wait until we got true from all ranks while ( __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, __MEMORY_SCOPE_DEVICE) < flag) ; } __syncthreads(); // use one thread to update flag if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; #else __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of // testing. Might be the case that hardware provides stronger guarantee than // the memory model. if constexpr (!final_sync) __threadfence_system(); if (threadIdx.x < ngpus) { // reset flag for next time self_sg->start[blockIdx.x][threadIdx.x] = 0; // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; // wait until we got true from all ranks while (!self_sg->end[blockIdx.x][threadIdx.x]) ; } if constexpr (!final_sync) __syncthreads(); #endif } template DINLINE P packed_reduce(const P *ptrs[], int idx) { A tmp = upcast(ptrs[0][idx]); #pragma unroll for (int i = 1; i < ngpus; i++) { packed_assign_add(tmp, upcast(ptrs[i][idx])); } return downcast

(tmp); } template __global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage_naive(RankData *_dp, RankSignals sg, #ifndef USE_ROCM volatile #endif Signal *self_sg, T *__restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results auto dp = *_dp; start_sync(sg, self_sg, rank); // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { ((P *)result)[idx] = packed_reduce((const P **)&dp.ptrs[0], idx); } end_sync(sg, self_sg, rank); // // Step-2 consumes data written by peers in step-1, so we need // // visibility guarantees from this barrier. // end_sync(sg, self_sg, rank); } template #ifdef USE_ROCM DINLINE P *get_tmp_buf(Signal *sg) { #else DINLINE P *get_tmp_buf(volatile Signal *sg) { #endif return (P *)(((Signal *)sg) + 1); } template __global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage_naive(RankData *_dp, RankSignals sg, #ifndef USE_ROCM volatile #endif Signal *self_sg, T *__restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; using A = typename packed_t::A; int part = size / ngpus; int start = rank * part; int end = rank == ngpus - 1 ? size : start + part; int largest_part = part + size % ngpus; const P *ptrs[ngpus]; P *tmps[ngpus]; #pragma unroll for (int i = 0; i < ngpus; i++) { int target = (rank + i) % ngpus; ptrs[i] = (const P *)_dp->ptrs[target]; tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; start_sync(sg, self_sg, rank); // stage 1: reduce scatter for (int idx = start + tid; idx < end; idx += stride) { tmp_out[idx - start] = packed_reduce(ptrs, idx); } end_sync(sg, self_sg, rank); // stage 2: allgather. Note: it's important to match the tid between // the two stages, because visibility across devices is only guaranteed // between threads that have the same tid. If thread i computes the sum of // start + i in the first stage, then thread i also gathers start + i from all // ranks. for (int idx = tid; idx < largest_part; idx += stride) { #pragma unroll for (int i = 0; i < ngpus; i++) { int gather_from_rank = ((rank + i) % ngpus); if (gather_from_rank == ngpus - 1 || idx < part) { int dst_idx = gather_from_rank * part + idx; ((P *)result)[dst_idx] = tmps[i][idx]; } } } } #define THREAD_NUM 512 // Toggle whether fused allreduce+rmsnorm keeps per-element rms input in float // before the final cast to output dtype. #ifndef AITER_FUSED_AR_RMS_KEEP_RMS_INP_F32 #define AITER_FUSED_AR_RMS_KEEP_RMS_INP_F32 1 #endif template __global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData *_dp, RankSignals sg, #ifndef USE_ROCM volatile #endif Signal *self_sg, T *__restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; constexpr int pack_size = packed_t::P::size; constexpr int tnum_gpu = THREAD_NUM / ngpus; __shared__ T tmp_smem[tnum_gpu * ngpus * pack_size]; // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results auto dp = *_dp; // load one gpu data each wave int warp_id = threadIdx.x / tnum_gpu; int lane_id = threadIdx.x % tnum_gpu; start_sync(sg, self_sg, rank); // do the actual reduction for (int idx = blockIdx.x * tnum_gpu + lane_id; idx < size; idx += gridDim.x * tnum_gpu) { *(reinterpret_cast(&tmp_smem[0]) + threadIdx.x) = ((const P**)&dp.ptrs[0])[warp_id][idx]; __syncthreads(); if (warp_id == 0) { A add_reg; #pragma unroll for (int i = 0; i < pack_size; ++i) { add_reg.data[i] = ck_tile::type_convert(tmp_smem[threadIdx.x * pack_size + i]); } constexpr int smem_gpu_loop_stride = tnum_gpu * pack_size; #pragma unroll for (int i = 1; i < ngpus; ++i) { #pragma unroll for (int j = 0; j < pack_size; ++j) { add_reg.data[j] += ck_tile::type_convert(tmp_smem[smem_gpu_loop_stride * i + threadIdx.x * pack_size + j]); } } P write_reg; #pragma unroll for (int i = 0; i < pack_size; ++i) { write_reg.data[i] = ck_tile::type_convert(add_reg.data[i]); } ((P *)result)[idx] = write_reg; } __syncthreads(); } // maybe do not need device sync // end_sync(sg, self_sg, rank); } template __global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData *_dp, RankSignals sg, #ifndef USE_ROCM volatile #endif Signal *self_sg, T *__restrict__ result, int rank, int size) { constexpr int pack_size = packed_t::P::size; constexpr int tnum_gpu = THREAD_NUM / ngpus; using P = typename packed_t::P; using A = typename packed_t::A; __shared__ T tmp_smem[tnum_gpu * ngpus * pack_size]; int warp_id = threadIdx.x / tnum_gpu; int lane_id = threadIdx.x % tnum_gpu; int tid = blockIdx.x * tnum_gpu + lane_id; int stride = gridDim.x * tnum_gpu; int part = size / ngpus; int start = rank * part; int end = rank == ngpus - 1 ? size : start + part; int largest_part = part + size % ngpus; const P *ptrs[ngpus]; P *tmps[ngpus]; #pragma unroll for (int i = 0; i < ngpus; i++) { int target = (rank + i) % ngpus; ptrs[i] = (const P *)_dp->ptrs[target]; tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; start_sync(sg, self_sg, rank); // stage 1: reduce scatter for (int idx = start + tid; idx < end; idx += stride) { *(reinterpret_cast(&tmp_smem[0]) + threadIdx.x) = ptrs[warp_id][idx]; __syncthreads(); // cal add in first 64 threads if (warp_id == 0) { A add_reg; #pragma unroll for (int i = 0; i < pack_size; ++i) { add_reg.data[i] = ck_tile::type_convert(tmp_smem[pack_size * threadIdx.x + i]); } constexpr int smem_gpu_loop_stride = tnum_gpu * pack_size; #pragma unroll for (int i = 1; i < ngpus; ++i) { #pragma unroll for (int j = 0; j < pack_size; ++j) { add_reg.data[j] += ck_tile::type_convert(tmp_smem[i * smem_gpu_loop_stride + pack_size * threadIdx.x + j]); } } P write_reg; #pragma unroll for (int i = 0; i < pack_size; ++i) { write_reg.data[i] = ck_tile::type_convert(add_reg.data[i]); } tmp_out[idx - start] = write_reg; } __syncthreads(); } end_sync(sg, self_sg, rank); // stage 2: allgather. Note: it's important to match the tid between // the two stages, because visibility across devices is only guaranteed // between threads that have the same tid. If thread i computes the sum of // start + i in the first stage, then thread i also gathers start + i from all // ranks. for (int idx = tid; idx < largest_part; idx += stride) { int dst_idx = (warp_id + rank) % ngpus * part + idx; ((P *)result)[dst_idx] = tmps[warp_id][idx]; } } /* * naive allgather * for case: input(1345,) * */ template __global__ void __launch_bounds__(512, 1) allgather_naive( RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size ) { constexpr int tnum_gpu = THREAD_NUM / ngpus; int warp_id = threadIdx.x / tnum_gpu; int lane_id = threadIdx.x % tnum_gpu; int tid = blockIdx.x * tnum_gpu + lane_id; int stride = gridDim.x * tnum_gpu; const T* ptrs[ngpus]; #pragma unroll for (int i = 0; i < ngpus; ++i) { ptrs[i] = (const T*)_dp->ptrs[i]; } start_sync(sg, self_sg, rank); for (int idx = tid; idx < size; idx += stride) { int write_idx = warp_id * size + idx; result[write_idx] = ptrs[warp_id][idx]; } } template __global__ void __launch_bounds__(512, 1) allgather_vec( RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size ) { constexpr int tnum_gpu = THREAD_NUM / ngpus; using P = typename packed_t::P; int warp_id = threadIdx.x / tnum_gpu; int lane_id = threadIdx.x % tnum_gpu; int tid = blockIdx.x * tnum_gpu + lane_id; int stride = gridDim.x * tnum_gpu; const P* ptrs[ngpus]; #pragma unroll for (int i = 0; i < ngpus; ++i) { ptrs[i] = (const P*)_dp->ptrs[i]; } start_sync(sg, self_sg, rank); for (int idx = tid; idx < size; idx += stride) { int write_idx = warp_id * size + idx; *(reinterpret_cast(&result[0]) + write_idx) = ptrs[warp_id][idx]; } } // fp8 quant all-reduce code start template struct Fp16Filter { static const bool value = false; }; template <> struct Fp16Filter { static const bool value = true; }; template struct Bf16Filter { static const bool value = false; }; template <> struct Bf16Filter<__hip_bfloat16> { static const bool value = true; }; // dtypes only support half and bf16 now #define FP16_FILTER \ typename std::enable_if::value, void>::type* = nullptr #define BF16_FILTER \ typename std::enable_if::value, void>::type* = nullptr template