Commit ad74f612 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.9.2-dev

parents 7270e6d9 46006aee
...@@ -59,6 +59,144 @@ bool _is_weak_contiguous(torch::Tensor& t) { ...@@ -59,6 +59,144 @@ bool _is_weak_contiguous(torch::Tensor& t) {
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer. * copied into _reg_buffer.
*/ */
void all_reduce_fuse_norm(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t hidden_size, torch::Tensor& residual, torch::Tensor& rms_weight,
double eps, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_is_weak_contiguous(out));
TORCH_CHECK(_is_weak_contiguous(inp));
TORCH_CHECK(_is_weak_contiguous(residual));
TORCH_CHECK(_is_weak_contiguous(rms_weight));
int token_num = inp.numel() / hidden_size;
auto input_size = inp.numel() * inp.element_size();
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
cudaMemcpyDeviceToDevice, stream));
} else {
reg_buffer = inp.data_ptr();
}
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce_fuse_norm<float>(stream, reinterpret_cast<float*>(reg_buffer),
reinterpret_cast<float*>(out.data_ptr()),out.numel(),
token_num, hidden_size, reinterpret_cast<float*>(residual.data_ptr()),
reinterpret_cast<float*>(rms_weight.data_ptr()), (float)eps);
break;
}
case at::ScalarType::Half: {
fa->allreduce_fuse_norm<half>(stream, reinterpret_cast<half*>(reg_buffer),
reinterpret_cast<half*>(out.data_ptr()),out.numel(),
token_num, hidden_size, reinterpret_cast<half*>(residual.data_ptr()),
reinterpret_cast<half*>(rms_weight.data_ptr()), (float)eps);
break;
}
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce_fuse_norm<nv_bfloat16>(stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()),out.numel(),
token_num, hidden_size, reinterpret_cast<nv_bfloat16*>(residual.data_ptr()),
reinterpret_cast<nv_bfloat16*>(rms_weight.data_ptr()), (float)eps);
break;
}
// #endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}
template<typename scalar_in_t, bool update_input>
void allreduce_fuse_norm_quant_dispath(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int hidden_size,torch::Tensor& rms_weight, double eps,
torch::Tensor& scales, torch::Tensor& norm_out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes,
std::optional<at::Tensor> residual) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(_is_weak_contiguous(inp));
int token_num = inp.numel() / hidden_size;
auto input_size = inp.numel() * inp.element_size();
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
cudaMemcpyDeviceToDevice, stream));
} else {
reg_buffer = inp.data_ptr();
}
auto wt_ptr = reinterpret_cast<std::uintptr_t>(rms_weight.data_ptr());
if (wt_ptr % 16 != 0) {
throw std::runtime_error(
"custom allreduce currently requires wt_ptr % 16 "
"of " +
std::to_string(wt_ptr % 16));
}
if (fa->fully_connected_) {
if (residual.has_value()) {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "fa->allreduce_fuse_norm_quant", [&] {
fa->allreduce_fuse_norm_quant<scalar_in_t, scalar_t, true, update_input>
(stream, reinterpret_cast<scalar_in_t*>(reg_buffer), out.data_ptr<scalar_t>(),
out.numel(), token_num, hidden_size, residual->data_ptr<scalar_in_t>(),
rms_weight.data_ptr<scalar_in_t>(),
norm_out.data_ptr<scalar_in_t>(),
eps, scales.data_ptr<float>());
});
} else {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "fa->allreduce_fuse_norm_quant", [&] {
fa->allreduce_fuse_norm_quant<scalar_in_t, scalar_t, false, update_input>
(stream, reinterpret_cast<scalar_in_t*>(reg_buffer), out.data_ptr<scalar_t>(),
out.numel(), token_num, hidden_size, nullptr,
rms_weight.data_ptr<scalar_in_t>(),
norm_out.data_ptr<scalar_in_t>(),
eps, scales.data_ptr<float>());
});
}
} else {
throw std::runtime_error(
"custom allreduce only supports fully_connected");
}
}
void all_reduce_fuse_norm_quant(fptr_t fa, torch::Tensor& inp, torch::Tensor& out,
int64_t hidden_size,torch::Tensor& rms_weight, double eps,
torch::Tensor& scales, torch::Tensor& norm_out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes,
std::optional<at::Tensor> residual, bool update_input) {
static c10::ScalarType kFp8Type = c10::ScalarType::Float8_e4m3fn;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(scales.dtype() == torch::kFloat32);
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(out.is_contiguous() && inp.is_contiguous());
VLLM_DISPATCH_FLOATING_TYPES(
inp.scalar_type(), "allreduce_fuse_norm_quant_dispath", [&] {
if (update_input)
allreduce_fuse_norm_quant_dispath<scalar_t, true>(
fa, inp, out, hidden_size, rms_weight, eps, scales, norm_out,
reg_buffer, reg_buffer_sz_bytes, residual);
else
allreduce_fuse_norm_quant_dispath<scalar_t, false>(
fa, inp, out, hidden_size, rms_weight, eps, scales, norm_out,
reg_buffer, reg_buffer_sz_bytes, residual);
});
}
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
......
#pragma once #pragma once
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include <algorithm>
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include <cuda.h> #include <cuda.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <hip/hip_bf16.h>
#if defined(USE_ROCM) // #if defined(USE_ROCM)
typedef __hip_bfloat16 nv_bfloat16; typedef __hip_bfloat16 nv_bfloat16;
#endif // #endif
#include <iostream> #include <iostream>
#include <array> #include <array>
...@@ -15,7 +23,11 @@ typedef __hip_bfloat16 nv_bfloat16; ...@@ -15,7 +23,11 @@ typedef __hip_bfloat16 nv_bfloat16;
#include <map> #include <map>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace vllm { namespace vllm {
#define CUDACHECK(cmd) \ #define CUDACHECK(cmd) \
do { \ do { \
...@@ -28,7 +40,7 @@ namespace vllm { ...@@ -28,7 +40,7 @@ namespace vllm {
} while (0) } while (0)
// Maximal number of blocks in allreduce kernel. // Maximal number of blocks in allreduce kernel.
constexpr int kMaxBlocks = 36; constexpr int kMaxBlocks = 128;
// Default number of blocks in allreduce kernel. // Default number of blocks in allreduce kernel.
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -80,6 +92,7 @@ struct packed_t { ...@@ -80,6 +92,7 @@ struct packed_t {
using P = array_t<T, 16 / sizeof(T)>; using P = array_t<T, 16 / sizeof(T)>;
// the (A)ccumulator type for reduction // the (A)ccumulator type for reduction
using A = array_t<float, 16 / sizeof(T)>; using A = array_t<float, 16 / sizeof(T)>;
using F = array_t<int8_t, 16 / sizeof(T)>;
}; };
#define DINLINE __device__ __forceinline__ #define DINLINE __device__ __forceinline__
...@@ -124,6 +137,117 @@ DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) { ...@@ -124,6 +137,117 @@ DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
return a; return a;
} }
/**********************************************************/
template <typename P, uint32_t VEC_SIZE>
DINLINE P vec_add(const P& a, const P& b) {
P sum_tmp;
#pragma unroll
for (int i = 0; i < a.size; ++i)
sum_tmp.data[i] = static_cast<float>(a.data[i]) + static_cast<float>(b.data[i]);
return sum_tmp;
}
template <typename T, int reducesize=64>
__inline__ __device__ T WarpReduceSum(T val) {
#pragma unroll
for (int offset = reducesize / 2; offset > 0; offset >>= 1) {
val += WARP_SHFL_DOWN(val, offset);
}
return val;
}
template <typename T>
DINLINE T BlockReduce(T val, T* shared) {
const int lid = threadIdx.x % 64;
const int wid = threadIdx.x / 64;
const int block_size = blockDim.x;
const int shared_size = block_size / 64;
val = WarpReduceSum<T>(val);
if(block_size==64) return val;
if (lid == 0 && wid < shared_size) {
shared[wid] = val;
}
__syncthreads();
val = 0.f;
if (wid == 0 && lid < shared_size) {
val= shared[lid];
val = WarpReduceSum<T, 16>(val);
}
return val;
}
template <typename T, typename P, typename A>
DINLINE P fused_add_rms_norm(P const& residual, P const& gamma, int hidden_dim, float eps) {
static constexpr int VEC_SIZE = 16 / sizeof(T);
__shared__ float s_val;
float trstd;
P norm_out;
float acc = 0.0f;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float v = static_cast<float>(residual.data[i]);
acc += v * v;
}
__shared__ float r_sum[16];
acc = BlockReduce(acc, r_sum);
if (threadIdx.x == 0)
s_val = rsqrtf(acc / hidden_dim + eps);
__syncthreads();
trstd = s_val;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
norm_out.data[i] = static_cast<T>(static_cast<float>(residual.data[i]) * trstd * static_cast<float>(gamma.data[i]));
}
return norm_out;
}
static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
static constexpr auto i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
static constexpr auto i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
float dst = std::nearbyint(x);
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
return static_cast<int8_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
#endif
}
template <typename T, int reducesize=64>
__inline__ __device__ T WarpReduceMax(T val) {
#pragma unroll
for (int offset = reducesize / 2; offset > 0; offset >>= 1) {
val = fmaxf(val, WARP_SHFL_DOWN(val, offset));
}
return val;
}
template <typename T>
DINLINE T BlockReduceMax_ROW(T val, T* shared) {
const int lid = threadIdx.x % 64;
const int wid = threadIdx.x / 64;
const int block_size = blockDim.x;
const int shared_size = block_size / 64;
val = WarpReduceMax<T>(val);
if(block_size==64) return val;
if (lid == 0 && wid < shared_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0 && lid<shared_size) {
val= shared[lid];
val = WarpReduceMax<T, 16>(val);
}
return val;
}
template <typename T, int N> template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) { DINLINE array_t<float, N> upcast(array_t<T, N> val) {
if constexpr (std::is_same<T, float>::value) { if constexpr (std::is_same<T, float>::value) {
...@@ -132,7 +256,7 @@ DINLINE array_t<float, N> upcast(array_t<T, N> val) { ...@@ -132,7 +256,7 @@ DINLINE array_t<float, N> upcast(array_t<T, N> val) {
array_t<float, N> out; array_t<float, N> out;
#pragma unroll #pragma unroll
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
out.data[i] = upcast_s(val.data[i]); out.data[i] = static_cast<float>(val.data[i]);
} }
return out; return out;
} }
...@@ -146,13 +270,13 @@ DINLINE O downcast(array_t<float, O::size> val) { ...@@ -146,13 +270,13 @@ DINLINE O downcast(array_t<float, O::size> val) {
O out; O out;
#pragma unroll #pragma unroll
for (int i = 0; i < O::size; i++) { for (int i = 0; i < O::size; i++) {
out.data[i] = downcast_s<typename O::type>(val.data[i]); out.data[i] = static_cast<typename O::type>(val.data[i]);
} }
return out; return out;
} }
} }
#if !defined(USE_ROCM) #if 0
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
...@@ -243,18 +367,20 @@ DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) { ...@@ -243,18 +367,20 @@ DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
template <int ngpus> template <int ngpus>
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg, DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
int rank) { int rank) {
uint32_t flag = self_sg->_flag[blockIdx.x] + 1; uint32_t flag = self_sg->_flag[blockIdx.x] + 1; //当前线程块标记+1
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks. // simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write // Latency = 1 p2p write
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], // __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
// flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM); // flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// 将每个peer GPU对应线程块的本rank flag填入
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag,
__ATOMIC_RELAXED); __ATOMIC_RELAXED);
// wait until we got true from all ranks // wait until we got true from all ranks
// while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], // while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
// __ATOMIC_RELAXED, // __ATOMIC_RELAXED,
// __MEMORY_SCOPE_DEVICE) < flag); // __MEMORY_SCOPE_DEVICE) < flag);
//等待对应blockidx.x处理的数据的peer gpu到达
while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
__ATOMIC_RELAXED) < flag); __ATOMIC_RELAXED) < flag);
} }
...@@ -274,6 +400,7 @@ DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) { ...@@ -274,6 +400,7 @@ DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
// flag, // flag,
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, // final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
// __MEMORY_SCOPE_SYSTEM); // __MEMORY_SCOPE_SYSTEM);
// 告诉其他GPU 本block Reduce完毕
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag, __atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE); final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
// wait until we got true from all ranks // wait until we got true from all ranks
...@@ -281,6 +408,7 @@ DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) { ...@@ -281,6 +408,7 @@ DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
// __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], // __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, // final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
// __MEMORY_SCOPE_DEVICE) < flag); // __MEMORY_SCOPE_DEVICE) < flag);
// 当前block处理的 hs的其他GPU处理完毕
while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) < final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
flag); flag);
...@@ -290,6 +418,34 @@ DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) { ...@@ -290,6 +418,34 @@ DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
} }
template <int ngpus, bool final_sync = false>
DINLINE void barrier_at_end_fuse(const RankSignals& sg, Signal* self_sg, int rank) {
__syncthreads();
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
// flag,
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
// __MEMORY_SCOPE_SYSTEM);
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
// wait until we got true from all ranks
// while (
// __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
// __MEMORY_SCOPE_DEVICE) < flag);
while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
flag);
}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}
#endif #endif
template <typename P, int ngpus, typename A> template <typename P, int ngpus, typename A>
...@@ -325,6 +481,264 @@ DINLINE P* get_tmp_buf(Signal* sg) { ...@@ -325,6 +481,264 @@ DINLINE P* get_tmp_buf(Signal* sg) {
return (P*)(((Signal*)sg) + 1); return (P*)(((Signal*)sg) + 1);
} }
template <typename T, int ngpus>
__global__ void __launch_bounds__(1024, 1)
cross_device_reduce_2stage_fuse_norm(RankData* _dp, RankSignals sg, Signal* self_sg,
T* __restrict__ result, int rank, int size,
int hidden_dim, T* residual_in, T* rms_gamma,
float eps, std::array<int, ngpus> begin_tokens,
std::array<int, ngpus> token_num_per_ranks) {
static constexpr int VEC_SIZE = 16 / sizeof(T);
int H_D_word_num = hidden_dim / VEC_SIZE;
int token_id = blockIdx.x; // local token id
int access_id_in_token = threadIdx.x; // 当前token内数据部分
int token_stride = gridDim.x;
//
int access_id = token_id * H_D_word_num + access_id_in_token; // local token id * (token in size)
int access_stride = token_stride * H_D_word_num; // gridDim.x * (token in size)
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
const P* ptrs[ngpus];
P* tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; ++i) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
int start = begin_tokens[rank] * H_D_word_num;
int part = (begin_tokens[rank] + token_num_per_ranks[rank]) * H_D_word_num;
auto tmp_out = tmps[0]; // 当前rank的 (meta_data + sizeof(signal)) 偏移
barrier_at_start<ngpus>(sg, self_sg, rank);
#pragma unroll
for (int idx = access_id + start; idx < part; idx+=access_stride) {
tmp_out[idx] = packed_reduce<P, ngpus, A>(ptrs, idx);
#pragma unroll
for (int r = 0; r < ngpus; ++r)
tmps[r][idx] = tmp_out[idx]; //将当前GPU处理的数据--->其他GPU的对应问题
}
barrier_at_end<ngpus>(sg, self_sg, rank);
//debug --- 验证reduce结果
// for (int r = 0; r < ngpus; ++r) {
// int cm_access_id = access_id + begin_tokens[r] * H_D_word_num;
// int cm_token_id = token_id + begin_tokens[r];
// int cm_token_access = (begin_tokens[r] + token_num_per_ranks[r]) * H_D_word_num;
// for (int idx = cm_access_id; idx < cm_token_access; idx += access_stride)
// ((P*)result)[idx] = tmp_out[idx];
// }
P m_residual_val, m_gamm_val;
m_gamm_val = ((P*)rms_gamma)[access_id_in_token];
#pragma unroll
for (int r = 0; r < ngpus; ++r) {
int cm_access_id = access_id + begin_tokens[r] * H_D_word_num;
int cm_token_id = token_id + begin_tokens[r];
int cm_tot_access = (begin_tokens[r] + token_num_per_ranks[r]) * H_D_word_num;
for (int idx = cm_access_id; idx < cm_tot_access; idx += access_stride) {
P sum_val;
sum_val = tmp_out[idx];
m_residual_val =((P*)residual_in)[idx];
sum_val = vec_add<P, VEC_SIZE>(sum_val, m_residual_val);
sum_val = fused_add_rms_norm<T, P, A>(sum_val, m_gamm_val, hidden_dim, eps);
((P*)result)[idx] = sum_val;
}
}
}
template <typename T, typename T_out, int ngpus, bool isResidual=true, bool update_input=false>
__global__ void __launch_bounds__(1024, 1)
cross_device_reduce_1stage_norm_quant(RankData* _dp, RankSignals sg, Signal* self_sg,
T_out* __restrict__ result, int rank, int size,
int hidden_dim, T* residual_in, T* rms_gamma,
float* __restrict__ scales, float eps,
T* __restrict__ norm_res) {
// static constexpr int VEC_SIZE = 16 / sizeof(T);
static constexpr int VEC_SIZE = packed_t<T>::P::size;
int H_D_word_num = hidden_dim / VEC_SIZE;
int token_id = blockIdx.x;
int access_id_in_token = threadIdx.x;
int token_stride = gridDim.x;
int access_id = token_id * H_D_word_num + access_id_in_token;
int access_stride = token_stride * H_D_word_num;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
using F = typename packed_t<T>::F;
P m_residual_val, m_gamm_val;
m_gamm_val = reinterpret_cast<P*>(rms_gamma)[access_id_in_token];
auto dp = *_dp;
P sum_val;
barrier_at_start<ngpus>(sg, self_sg, rank);
sum_val = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], access_id);
barrier_at_end<ngpus, true>(sg, self_sg, rank);
if constexpr(isResidual) {
m_residual_val = reinterpret_cast<P*>(residual_in)[access_id];
sum_val = vec_add<P, VEC_SIZE>(m_residual_val, sum_val);
((P*)residual_in)[access_id] = sum_val;
}
__shared__ float s_val;
P norm_out;
float acc = 0.f;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float v = static_cast<float>(sum_val.data[i]);
acc += v * v;
}
__shared__ float r_sum[16];
acc = BlockReduce<float>(acc, r_sum);
if (threadIdx.x == 0)
s_val = rsqrt(acc / hidden_dim + eps);
__syncthreads();
float block_absmax_val_maybe = 0.f;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
norm_out.data[i] = static_cast<float>(sum_val.data[i]) * s_val * static_cast<float>(m_gamm_val.data[i]);
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabs(norm_out.data[i]));
}
block_absmax_val_maybe = BlockReduceMax_ROW(block_absmax_val_maybe,r_sum);
//
__shared__ float s_token_scale;
float scale = 0.0f;
if (threadIdx.x == 0) {
scale = block_absmax_val_maybe;
s_token_scale = scale;
}
__syncthreads();
float inv_s = (s_token_scale == 0.f) ? 0.f : 127.f / s_token_scale;
F out_vec;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i)
out_vec.data[i] = float_to_int8_rn(norm_out.data[i] * inv_s);
constexpr float qmax = 127.0f;
constexpr float min_scale = 1.19209e-07f;
((F*)result)[access_id] = out_vec;
if constexpr (update_input)
((P*)norm_res)[access_id] = norm_out;
if (threadIdx.x == 0)
scales[blockIdx.x] = fmaxf(scale/qmax, min_scale);
}
template <typename T, typename T_out, int ngpus, bool isResidual=true, bool update_input=false>
__global__ void __launch_bounds__(1024, 1)
cross_device_reduce_2stage_fuse_norm_quant(RankData* _dp, RankSignals sg, Signal* self_sg,
T_out* __restrict__ result, int rank, int size,
int hidden_dim, T* residual_in, T* rms_gamma,
float* __restrict__ scales, float eps,
T* __restrict__ norm_res,
std::array<int, ngpus> begin_tokens,
std::array<int, ngpus> token_num_per_ranks) {
static constexpr int VEC_SIZE = 16 / sizeof(T);
int H_D_word_num = hidden_dim / VEC_SIZE;
int token_id = blockIdx.x; // local token id
int access_id_in_token = threadIdx.x; // 当前token内数据部分
int token_stride = gridDim.x;
//
int access_id = token_id * H_D_word_num + access_id_in_token; // local token id * (token in size)
int access_stride = token_stride * H_D_word_num; // gridDim.x * (token in size)
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
using F = typename packed_t<T>::F;
const P* ptrs[ngpus];
P* tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; ++i) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
int start = begin_tokens[rank] * H_D_word_num;
int part = (begin_tokens[rank] + token_num_per_ranks[rank]) * H_D_word_num;
auto tmp_out = tmps[0]; // 当前rank的 (meta_data + sizeof(signal)) 偏移
auto input = ptrs[0];
barrier_at_start<ngpus>(sg, self_sg, rank);
#pragma unroll
for (int idx = access_id + start; idx < part; idx+=access_stride) {
tmp_out[idx] = packed_reduce<P, ngpus, A>(ptrs, idx);
#pragma unroll
for (int r = 0; r < ngpus; ++r)
tmps[r][idx] = tmp_out[idx]; //将当前GPU处理的数据--->其他GPU的对应问题
}
barrier_at_end<ngpus>(sg, self_sg, rank);
P m_residual_val, m_gamm_val;
m_gamm_val = reinterpret_cast<P*>(rms_gamma)[access_id_in_token];
#pragma unroll
for (int r = 0; r < ngpus; ++r) {
int cm_access_id = access_id + begin_tokens[r] * H_D_word_num;
int cm_token_id = token_id + begin_tokens[r];
int cm_tot_access = (begin_tokens[r] + token_num_per_ranks[r]) * H_D_word_num;
for (int idx = cm_access_id, tidx = cm_token_id; idx < cm_tot_access;
idx += access_stride, tidx += token_stride) {
P sum_val;
sum_val = tmp_out[idx];
if constexpr (isResidual) {
m_residual_val = reinterpret_cast<P*>(residual_in)[idx];
sum_val = vec_add<P, VEC_SIZE>(sum_val, m_residual_val);
((P*)residual_in)[idx] = sum_val;
}
__shared__ float s_val;
P norm_out;
float acc = 0.0f;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float v = static_cast<float>(sum_val.data[i]);
acc += v * v;
}
__shared__ float r_sum[16];
acc = BlockReduce(acc, r_sum);
if (threadIdx.x == 0)
s_val = rsqrtf(acc / hidden_dim + eps);
__syncthreads();
float block_absmax_val_maybe = 0.f;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
norm_out.data[i] = static_cast<T>(static_cast<float>(sum_val.data[i]) * s_val * static_cast<float>(m_gamm_val.data[i]));
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabs(norm_out.data[i]));
}
block_absmax_val_maybe = BlockReduceMax_ROW(block_absmax_val_maybe, r_sum);
__shared__ float s_token_scale;
float scale = 0.0f;
if (threadIdx.x == 0) {
scale = block_absmax_val_maybe;
s_token_scale = scale;
}
__syncthreads();
float inv_s = (s_token_scale == 0.f) ? 0.f : 127.f / s_token_scale;
F out_vec;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i)
out_vec.data[i] = float_to_int8_rn(norm_out.data[i] * inv_s);
constexpr float qmax = 127.0f;
constexpr float min_scale = 1.19209e-07f;
((F*)result)[idx] = out_vec;
if constexpr (update_input)
((P*)norm_res)[idx] = norm_out;
if (threadIdx.x == 0)
scales[tidx] = fmaxf(scale/qmax, min_scale);
}
}
}
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) __global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
...@@ -685,6 +1099,177 @@ class CustomAllreduce { ...@@ -685,6 +1099,177 @@ class CustomAllreduce {
* only take a small amount of SMs. Not quite sure the underlying reason, * only take a small amount of SMs. Not quite sure the underlying reason,
* but my guess is that too many SMs will cause contention on NVLink bus. * but my guess is that too many SMs will cause contention on NVLink bus.
*/ */
template <typename T>
void allreduce_fuse_norm(cudaStream_t stream, T* input, T* output, int size,
int token_num, int hidden_dim, T* residual, T* rms_weight,
double eps, int threads = 512, int block_limit = defaultBlockLimit) {
auto d = packed_t<T>::P::size;
if (hidden_dim % d != 0)
throw std::runtime_error(
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error("max supported block limit is " +
std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit));
RankData* ptrs;
cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) {
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
throw std::runtime_error(
"buffer address " +
std::to_string(reinterpret_cast<uint64_t>(input)) +
" is not registered!");
ptrs = it->second;
}
int block_num = token_num;
#define KL(ngpus, name) \
std::array<int, ngpus> begin_tokens, token_num_per_ranks; \
int remaining_token = token_num % ngpus; \
int token_num_per_rank = token_num / ngpus; \
block_num = token_num_per_rank; \
if (remaining_token) \
block_num++; \
for (int i = 0; i < ngpus; ++i) { \
begin_tokens[i] = i * token_num_per_rank + (remaining_token > i ? i : remaining_token); \
token_num_per_ranks[i] = token_num_per_rank + (remaining_token > i ? 1 : 0); \
} \
int thread_per_token = hidden_dim / d; \
int grid_size = std::min(kMaxBlocks, block_num); \
int threads_in_block = thread_per_token; \
name<T, ngpus><<<grid_size, threads_in_block, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size, hidden_dim, residual, \
rms_weight, eps, begin_tokens, token_num_per_ranks);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_2stage_fuse_norm); \
} else if (fully_connected_) { \
if ((world_size_ <= 4) || \
(world_size_ <= 8 )) { \
KL(ngpus, cross_device_reduce_2stage_fuse_norm); \
} else { \
KL(ngpus, cross_device_reduce_2stage_fuse_norm); \
} \
} \
break; \
}
switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual "
"num "
"gpus = " +
std::to_string(world_size_));
}
#undef REDUCE_CASE
#undef KL
}
template<typename scalar_in_t, typename scalar_out_t, bool isResidual=true, bool update_input=false>
void allreduce_fuse_norm_quant(cudaStream_t stream, scalar_in_t* input, scalar_out_t* output, int size,
int token_num, int hidden_dim, scalar_in_t* residual, scalar_in_t* rms_weight,
scalar_in_t* norm_out,
double eps, float* scales, int threads = 512, int block_limit = defaultBlockLimit) {
auto d = packed_t<scalar_in_t>::P::size;
if (hidden_dim % d != 0)
throw std::runtime_error(
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error("max supported block limit is " +
std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit));
RankData* ptrs;
cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) {
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
throw std::runtime_error(
"buffer address " +
std::to_string(reinterpret_cast<uint64_t>(input)) +
" is not registered!");
ptrs = it->second;
}
int block_num = token_num;
int thread_per_token = hidden_dim / d;
auto bytes = (size / d) * sizeof(typename packed_t<scalar_in_t>::P);
#define KL1(ngpus, name) \
std::array<int, ngpus> begin_tokens, token_num_per_ranks; \
int remaining_token = token_num % ngpus; \
int token_num_per_rank = token_num / ngpus; \
block_num = token_num_per_rank; \
if (remaining_token) \
block_num++; \
for (int i = 0; i < ngpus; ++i) { \
begin_tokens[i] = i * token_num_per_rank + (remaining_token > i ? i : remaining_token); \
token_num_per_ranks[i] = token_num_per_rank + (remaining_token > i ? 1 : 0); \
} \
int grid_size = std::min(kMaxBlocks, block_num); \
int threads_in_block = thread_per_token; \
name<scalar_in_t, scalar_out_t, ngpus, isResidual, update_input><<<block_num, threads_in_block, 0, stream>>>(ptrs, sg_, \
self_sg_, output, rank_, size, hidden_dim, residual, \
rms_weight, scales, eps, norm_out, begin_tokens, token_num_per_ranks);
#define KL(ngpus, name) \
name<scalar_in_t, scalar_out_t, ngpus, isResidual, update_input><<<block_num, thread_per_token, 0, stream>>>(ptrs, sg_, \
self_sg_, output, rank_, size, hidden_dim, residual, rms_weight, \
scales, eps, norm_out);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage_norm_quant); \
} else if (fully_connected_) { \
if ((world_size_ <= 4 && bytes < 1024 * 1024) || \
(world_size_ <= 8 && bytes < 512 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage_norm_quant); \
} else { \
KL1(ngpus, cross_device_reduce_2stage_fuse_norm_quant); \
} \
} \
break; \
}
switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual "
"num "
"gpus = " +
std::to_string(world_size_));
}
#undef REDUCE_CASE
#undef KL
}
template <typename T> template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, int size, void allreduce(cudaStream_t stream, T* input, T* output, int size,
int threads = 512, int block_limit = defaultBlockLimit) { int threads = 512, int block_limit = defaultBlockLimit) {
...@@ -766,4 +1351,4 @@ class CustomAllreduce { ...@@ -766,4 +1351,4 @@ class CustomAllreduce {
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *, * template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int); half *, int, int, int);
*/ */
} // namespace vllm } // namespace vllm
\ No newline at end of file
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <limits> #include <limits>
#include <vector> #include <vector>
#include <random>
#include "cuda_profiler_api.h" #include "cuda_profiler_api.h"
#include "custom_all_reduce.cuh" #include "custom_all_reduce.cuh"
...@@ -117,16 +118,113 @@ __global__ void gen_data(curandState_t* state, T* data, double* ground_truth, ...@@ -117,16 +118,113 @@ __global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
ground_truth[idx] = sum; ground_truth[idx] = sum;
} }
} }
/*************************************************/
template <typename T,int reducesize=64>
__inline__ __device__ T WarpReduceSum_NEW(T val) {
#pragma unroll
for (int offset = reducesize/2; offset > 0; offset >>= 1) {
val += __shfl_down(val, offset);
}
return val;
}
template <typename T,int block_size=512>
__inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
constexpr int share_size=block_size/64;
val = WarpReduceSum_NEW<T>(val);
if constexpr(block_size==64)
{
return val;
}
else{
const int lid = threadIdx.x % 64;
const int wid = threadIdx.x / 64;
if (lid == 0&&wid<share_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0&&lid<share_size) {
val = WarpReduceSum_NEW<T,share_size>(shared[lid]);
}
return val;
}
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_add_rms_kernel_opt(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/64;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
using LoadT = typename vllm::packed_t<scalar_t>::P;
scalar_t intput_vec[Vec];
scalar_t residual_vec[Vec];
T_ACC trstd;
int64_t idx = i * tcol + j;
idx*=Vec;
if (j < tcol) {
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
*(LoadT*)residual_vec = *(LoadT*)(residual+idx);
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
residual_vec[ii]+=intput_vec[ii];
val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]);
}
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=rsqrtf(val/cols + eps);
__syncthreads();
trstd=s_rstd;
if (j < tcol) {
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(residual_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
*(LoadT*)(input+idx)=*(LoadT*)intput_vec;
}
}
template <typename scalar_t>
void fused_add_rms_norm_choose(cudaStream_t stream, scalar_t* self_data, scalar_t* other_data,
scalar_t*weight_data, double eps, int hidden_size, int num_tokens) {
if (hidden_size<=1024){
fused_add_rms_kernel_opt<scalar_t,float,8,128><<<num_tokens, 128, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=2048){
fused_add_rms_kernel_opt<scalar_t,float,8,256><<<num_tokens, 256, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
if(num_tokens>1200){
fused_add_rms_kernel_opt<scalar_t,float,8,512><<<num_tokens, 512, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_opt<scalar_t,float,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
}
else if(hidden_size<=8192){
fused_add_rms_kernel_opt<scalar_t,float,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_opt<scalar_t,float,16,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
}
/*****************************************************************/
template <typename T> template <typename T>
void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
int data_size, bool performance_test) { int data_size, bool performance_test, int hidden_dim) {
T* result; T* result_ori, *result_fuse;
cudaStream_t stream; cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); CUDACHECK(cudaMalloc(&result_ori, data_size * sizeof(T)));
CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T))); CUDACHECK(cudaMemset(result_ori, 0, data_size * sizeof(T)));
CUDACHECK(cudaMalloc(&result_fuse, data_size * sizeof(T)));
CUDACHECK(cudaMemset(result_fuse, 0, data_size * sizeof(T)));
cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t self_data_handle;
cudaIpcMemHandle_t data_handles[8]; cudaIpcMemHandle_t data_handles[8];
vllm::Signal* buffer; vllm::Signal* buffer;
...@@ -176,7 +274,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -176,7 +274,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
sizeof(vllm::Signal) + data_size * sizeof(T)); sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration // hack buffer registration
{ {
void* data[8]; void* data[8]; //gpu数据部分
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
data[i] = data[i] =
((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T); ((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T);
...@@ -196,7 +294,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -196,7 +294,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
cudaEvent_t start, stop; cudaEvent_t start, stop;
CUDACHECK(cudaEventCreate(&start)); CUDACHECK(cudaEventCreate(&start));
CUDACHECK(cudaEventCreate(&stop)); CUDACHECK(cudaEventCreate(&stop));
/*******************************/
int token_num = data_size / hidden_dim;
T* residual_h, *residual_d, *weight_h, *weight_d;
residual_h = (T*)malloc(data_size * sizeof(T));
std::random_device rd; // 用于获取随机数种子
std::mt19937 gen(7);
std::uniform_real_distribution<float> dis(-3.0f, 3.0f);
for (int i = 0; i < data_size; ++i)
residual_h[i] = static_cast<T>(dis(gen));
for (int i = 0; i < hidden_dim; ++i)
weight_h[i] = static_cast<T>(dis(gen));
cudaMalloc((void**)&residual_d, sizeof(T)*data_size);
cudaMalloc((void**)&weight_d, sizeof(T)*hidden_dim);
cudaMemcpyAsync(residual_d, residual_h, sizeof(T)*data_size, cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(weight_d, weight_h, sizeof(T)*hidden_dim, cudaMemcpyHostToDevice, stream);
float eps = 1.0f;
/*******************************/
ncclDataType_t ncclDtype; ncclDataType_t ncclDtype;
if (std::is_same<T, half>::value) { if (std::is_same<T, half>::value) {
ncclDtype = ncclFloat16; ncclDtype = ncclFloat16;
...@@ -211,16 +328,16 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -211,16 +328,16 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
if (performance_test) { if (performance_test) {
dummy_kernel<<<1, 1, 0, stream>>>(); dummy_kernel<<<1, 1, 0, stream>>>();
constexpr int warmup_iters = 5; constexpr int warmup_iters = 5;
constexpr int num_iters = 100; constexpr int num_iters = 10;
// warmup // warmup
for (int i = 0; i < warmup_iters; i++) { for (int i = 0; i < warmup_iters; i++) {
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, fa.allreduce<T>(stream, self_data, result_ori, data_size, threads, block_limit);
comm, stream)); fused_add_rms_norm_choose<T>(stream, result_ori, residual_d, weight_d, 1.0, hidden_dim, token_num);
} }
CUDACHECK(cudaEventRecord(start, stream)); CUDACHECK(cudaEventRecord(start, stream));
for (int i = 0; i < num_iters; i++) { for (int i = 0; i < num_iters; i++) {
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, fa.allreduce<T>(stream, self_data, result_ori, data_size, threads, block_limit);
comm, stream)); fused_add_rms_norm_choose<T>(stream, result_ori, residual_d, weight_d, 1.0, hidden_dim, token_num);
} }
CUDACHECK(cudaEventRecord(stop, stream)); CUDACHECK(cudaEventRecord(stop, stream));
CUDACHECK(cudaStreamSynchronize(stream)); CUDACHECK(cudaStreamSynchronize(stream));
...@@ -230,13 +347,15 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -230,13 +347,15 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
dummy_kernel<<<1, 1, 0, stream>>>(); dummy_kernel<<<1, 1, 0, stream>>>();
// warm up // warm up
for (int i = 0; i < warmup_iters; i++) { for (int i = 0; i < warmup_iters; i++) {
fa.allreduce<T>(stream, self_data, result, data_size, threads, fa.allreduce_fuse_norm<T>(stream, self_data, result_fuse, data_size, token_num,
block_limit); hidden_dim, residual_d, weight_d, eps,
threads, block_limit);
} }
CUDACHECK(cudaEventRecord(start, stream)); CUDACHECK(cudaEventRecord(start, stream));
for (int i = 0; i < num_iters; i++) { for (int i = 0; i < num_iters; i++) {
fa.allreduce<T>(stream, self_data, result, data_size, threads, fa.allreduce_fuse_norm<T>(stream, self_data, result_fuse, data_size, token_num,
block_limit); hidden_dim, residual_d, weight_d, eps,
threads, block_limit);
} }
CUDACHECK(cudaEventRecord(stop, stream)); CUDACHECK(cudaEventRecord(stop, stream));
CUDACHECK(cudaStreamSynchronize(stream)); CUDACHECK(cudaStreamSynchronize(stream));
...@@ -245,7 +364,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -245,7 +364,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
cudaEventElapsedTime(&duration_ms, start, stop); cudaEventElapsedTime(&duration_ms, start, stop);
if (myRank == 0) if (myRank == 0)
printf( printf(
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl " "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, allreduse_fuse_norm time:%.2fus, allreduce+norm "
"time:%.2fus\n", "time:%.2fus\n",
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit, myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters); duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
...@@ -255,8 +374,9 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -255,8 +374,9 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype, NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
ncclSum, comm, stream)); ncclSum, comm, stream));
fused_add_rms_norm_choose<T>(stream, self_data, residual_d, weight_d, 1.0, hidden_dim, token_num);
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result, convert_data<T><<<108, 1024, 0, stream>>>(result_ori, result_fuse, nccl_result,
my_result, data_size); my_result, data_size);
CUDACHECK(cudaStreamSynchronize(stream)); CUDACHECK(cudaStreamSynchronize(stream));
...@@ -279,13 +399,13 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -279,13 +399,13 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
<< " me: " << my_diffs / data_size << std::endl; << " me: " << my_diffs / data_size << std::endl;
} else { } else {
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
fa.allreduce<T>(stream, self_data, result, data_size, threads, fa.allreduce<T>(stream, self_data, result_ori, data_size, threads,
block_limit); block_limit);
CUDACHECK(cudaStreamSynchronize(stream)); CUDACHECK(cudaStreamSynchronize(stream));
NCCLCHECK(ncclAllReduce(self_data, self_data_copy, data_size, ncclDtype, NCCLCHECK(ncclAllReduce(self_data, self_data_copy, data_size, ncclDtype,
ncclSum, comm, stream)); ncclSum, comm, stream));
convert_data<T><<<108, 1024, 0, stream>>>( convert_data<T><<<108, 1024, 0, stream>>>(
self_data_copy, result, nccl_result, my_result, data_size); self_data_copy, result_ori, nccl_result, my_result, data_size);
CUDACHECK(cudaStreamSynchronize(stream)); CUDACHECK(cudaStreamSynchronize(stream));
for (unsigned long j = 0; j < data_size; j++) { for (unsigned long j = 0; j < data_size; j++) {
...@@ -312,7 +432,8 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -312,7 +432,8 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
// << " me: " << my_diffs / data_size << std::endl; // << " me: " << my_diffs / data_size << std::endl;
} }
CUDACHECK(cudaFree(result)); CUDACHECK(cudaFree(result_ori));
CUDACHECK(cudaFree(result_fuse));
CUDACHECK(cudaFree(self_data_copy)); CUDACHECK(cudaFree(self_data_copy));
CUDACHECK(cudaFree(rank_data)); CUDACHECK(cudaFree(rank_data));
CUDACHECK(cudaFree(buffer)); CUDACHECK(cudaFree(buffer));
...@@ -351,9 +472,7 @@ int main(int argc, char** argv) { ...@@ -351,9 +472,7 @@ int main(int argc, char** argv) {
const int block_limit = 36; const int block_limit = 36;
#endif #endif
// Scan through different sizes to test performance. // Scan through different sizes to test performance.
for (int sz = 512; sz <= (8 << 20); sz *= 2) { run<half>(myRank, nRanks, comm, 512, 36, 7168 * 80, performance_test, 7168);
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
}
cudaProfilerStop(); cudaProfilerStop();
MPICHECK(MPI_Finalize()); MPICHECK(MPI_Finalize());
......
...@@ -487,6 +487,17 @@ fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs, ...@@ -487,6 +487,17 @@ fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
bool fully_connected); bool fully_connected);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes); fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
void all_reduce_fuse_norm(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t hidden_size, torch::Tensor& residual, torch::Tensor& rms_weight,
double eps, fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
void all_reduce_fuse_norm_quant(fptr_t fa, torch::Tensor& inp, torch::Tensor& out,
int64_t hidden_size,torch::Tensor& rms_weight, double eps,
torch::Tensor& scales, torch::Tensor& norm_out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes,
std::optional<at::Tensor> residual, bool update_input);
void dispose(fptr_t _fa); void dispose(fptr_t _fa);
int64_t meta_size(); int64_t meta_size();
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs); void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
......
...@@ -933,6 +933,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { ...@@ -933,6 +933,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, " "all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()"); "int reg_buffer_sz_bytes) -> ()");
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce); custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
custom_ar.def(
"all_reduce_fuse_norm(int fa, Tensor inp, Tensor! out, int hidden_size, "
"Tensor residual, Tensor rms_weight, float eps, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
custom_ar.impl("all_reduce_fuse_norm", torch::kCUDA, &all_reduce_fuse_norm);
custom_ar.def(
"all_reduce_fuse_norm_quant(int fa, Tensor inp, Tensor! out, int hidden_size, "
"Tensor rms_weight, float eps, Tensor! scales, Tensor! norm_out, int reg_buffer, "
"int reg_buffer_sz_bytes, Tensor? residual, bool update_input) -> ()");
custom_ar.impl("all_reduce_fuse_norm_quant", torch::kCUDA, &all_reduce_fuse_norm_quant);
custom_ar.def("dispose", &dispose); custom_ar.def("dispose", &dispose);
custom_ar.def("meta_size", &meta_size); custom_ar.def("meta_size", &meta_size);
......
...@@ -2212,7 +2212,17 @@ def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, ...@@ -2212,7 +2212,17 @@ def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
reg_buffer_sz_bytes: int) -> None: reg_buffer_sz_bytes: int) -> None:
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
reg_buffer_sz_bytes) reg_buffer_sz_bytes)
def all_reduce_fuse_norm(fa: int, inp: torch.Tensor, out:torch.Tensor, hidden_size:int,
residual:torch.Tensor, rms_weight:torch.Tensor, eps:float,
reg_buffer: int, reg_buffer_sz_bytes: int) -> None:
torch.ops._C_custom_ar.all_reduce_fuse_norm(fa, inp, out, hidden_size, residual,
rms_weight, eps, reg_buffer, reg_buffer_sz_bytes)
def all_reduce_fuse_norm_quant(fa: int, inp: torch.Tensor, out:torch.Tensor, hidden_size:int,
rms_weight: torch.Tensor, eps: float, scales: torch.Tensor, norm_out:torch.Tensor,
reg_buffer: int, reg_buffer_sz_bytes: int, residual: torch.Tensor, update_input: bool) -> None:
torch.ops._C_custom_ar.all_reduce_fuse_norm_quant(fa, inp, out, hidden_size, rms_weight, eps, scales, norm_out,
reg_buffer, reg_buffer_sz_bytes, residual, update_input)
def dispose(fa: int) -> None: def dispose(fa: int) -> None:
torch.ops._C_custom_ar.dispose(fa) torch.ops._C_custom_ar.dispose(fa)
......
...@@ -536,4 +536,4 @@ direct_register_custom_op( ...@@ -536,4 +536,4 @@ direct_register_custom_op(
mutates_args=["output"], mutates_args=["output"],
fake_impl=unified_attention_with_output_fake, fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional, Union from typing import Any, Optional, Union, Tuple
import torch import torch
import torch.distributed import torch.distributed
from .parallel_state import get_tp_group from .parallel_state import get_tp_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group.""" """All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_) return get_tp_group().all_reduce(input_)
def tensor_model_parallel_all_reduce_crp_m32(input_: torch.Tensor,
pa_rms_weight: torch.Tensor,
pa_residual: torch.Tensor,
pa_rms_eps: float,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""All-reduce the input tensor across model parallel group."""
# allreduce fused rms and quant
return get_tp_group().all_reduce_crq_m32(input_=input_,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
def tensor_model_parallel_all_gather(input_: torch.Tensor, def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor: dim: int = -1) -> torch.Tensor:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Optional, Tuple
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -14,6 +14,7 @@ from .base_device_communicator import DeviceCommunicatorBase ...@@ -14,6 +14,7 @@ from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__) logger = init_logger(__name__)
from lmslim.quantize.quant_ops import lm_faster_rmsquant
class CudaCommunicator(DeviceCommunicatorBase): class CudaCommunicator(DeviceCommunicatorBase):
...@@ -116,6 +117,37 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -116,6 +117,37 @@ class CudaCommunicator(DeviceCommunicatorBase):
out = input_.clone() out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group) torch.distributed.all_reduce(out, group=self.device_group)
return out return out
def all_reduce_rms_quant_m32(self, input_,
pa_rms_weight: torch.Tensor,
pa_residual: torch.Tensor,
pa_rms_eps: float,
pa_quant_dtype: torch.dtype,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, hidden_dim = input_.shape
ca_comm = self.ca_comm
assert ca_comm is not None and not ca_comm.disabled
assert envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and \
pa_rms_weight is not None and pa_residual is not None
if batch_size <= 16:
xq, xs, norm_out = ca_comm.custom_all_reduce_fuse_norm_quant(inp=input_,
rms_weight=pa_rms_weight,
residual=pa_residual,
eps=pa_rms_eps,
quant_type=pa_quant_dtype,
update_input=True)
input_ = norm_out
else:
input_ = self.all_reduce(input_)
xq, xs = lm_faster_rmsquant(input_,
rms_weight=pa_rms_weight,
residual=pa_residual,
epsilon=pa_rms_eps,
quant_dtype=pa_quant_dtype,
update_input=True)
assert input_ is not None
assert xq is not None and xs is not None
return input_, pa_residual, xq, xs
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size world_size = self.world_size
......
...@@ -275,6 +275,91 @@ class CustomAllreduce: ...@@ -275,6 +275,91 @@ class CustomAllreduce:
# latency) compared to the performance gain of using custom kernels # latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, registered=False) return self.all_reduce(input, registered=False)
def allreduce_fuse_norm(self,
inp: torch.Tensor,
hidden_size: int,
residual: torch.Tensor,
rms_weight: torch.Tensor,
eps: float,
*,
out: torch.Tensor = None,
registered: bool = False):
if out is None:
out = torch.empty_like(inp)
if registered:
ops.all_reduce_fuse_norm(self._ptr, inp, out,
hidden_size, residual, rms_weight, eps, 0, 0)
else:
ops.all_reduce_fuse_norm(self._ptr, inp, out,
hidden_size, residual, rms_weight, eps,
self.buffer_ptrs[self.rank], self.max_size)
return out
def custom_all_reduce_fuse_norm(self,
input: torch.Tensor,
hidden_size: int,
residual: torch.Tensor,
rms_weight: torch.Tensor,
eps: float) -> Optional[torch.Tensor]:
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.allreduce_fuse_norm(input, hidden_size,
residual, rms_weight, eps, registered=False)
else:
return torch.empty_like(input)
else:
return self.allreduce_fuse_norm(input, hidden_size,
residual, rms_weight, eps, registered=False)
def allreduce_fuse_norm_quant(self,
inp: torch.Tensor,
hidden_size: int,
rms_weight,
eps,
quant_dtype,
residual,
update_input: bool = True,
registered: bool = False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
xq = torch.empty_like(inp, dtype=quant_dtype)
norm_out = torch.empty_like(inp)
scales = torch.empty((inp.numel() // inp.shape[-1], 1),
device=inp.device,
dtype=torch.float32)
if registered:
ops.all_reduce_fuse_norm_quant(self._ptr, inp, xq,
hidden_size, rms_weight, eps, scales, norm_out, 0, 0, residual, update_input)
else:
ops.all_reduce_fuse_norm_quant(self._ptr, inp, xq,
hidden_size, rms_weight, eps, scales, norm_out,
self.buffer_ptrs[self.rank], self.max_size, residual, update_input)
return xq, scales, norm_out
def custom_all_reduce_fuse_norm_quant(self,
inp: torch.Tensor,
rms_weight: torch.Tensor,
eps,
quant_type,
residual: Optional[torch.Tensor] = None,
update_input = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_size = inp.shape[-1]
if self.disabled or not self.should_custom_ar(inp):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.allreduce_fuse_norm_quant(inp, hidden_size, rms_weight,
eps, quant_type, residual, update_input = update_input, registered=False)
else:
return torch.empty_like(inp, dtype=quant_type), \
torch.empty((inp.numel() // inp.shape[-1], 1), dtype=torch.float32, device=inp.device), \
torch.empty_like(inp)
else:
return self.allreduce_fuse_norm_quant(inp, hidden_size, rms_weight,
eps, quant_type, residual, update_input = update_input, registered=False)
def close(self): def close(self):
if not self.disabled and self._ptr: if not self.disabled and self._ptr:
if ops is not None: if ops is not None:
......
...@@ -30,7 +30,7 @@ from collections import namedtuple ...@@ -30,7 +30,7 @@ from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Tuple, Union
from unittest.mock import patch from unittest.mock import patch
import torch import torch
...@@ -114,6 +114,37 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: ...@@ -114,6 +114,37 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor) return torch.empty_like(tensor)
def all_reduce_rms_quant(input_: torch.Tensor, group_name: str,
pa_rms_weight: Optional[torch.Tensor] = None,
pa_residual: Optional[torch.Tensor] = None,
pa_rms_eps: Optional[float] = 1e-6,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place_m32(input_,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
def all_reduce_rms_quant_fake(input_: torch.Tensor, group_name: str,
pa_rms_weight: Optional[torch.Tensor] = None,
pa_residual: Optional[torch.Tensor] = None,
pa_rms_eps: Optional[float] = 1e-6,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
xq = torch.zeros_like(input_, dtype=pa_quant_dtype)
xs = torch.ones((input_.numel() // input_.shape[-1], 1),
device=input_.device,
dtype=torch.float32)
return input_, pa_residual, xq, xs
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor: group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found." assert group_name in _groups, f"Group {group_name} is not found."
...@@ -156,6 +187,14 @@ if supports_custom_op(): ...@@ -156,6 +187,14 @@ if supports_custom_op():
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op(
op_name="all_reduce_rms_quant",
op_func=all_reduce_rms_quant,
mutates_args=["input_", "pa_residual"],
fake_impl=all_reduce_rms_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op( direct_register_custom_op(
op_name="reduce_scatter", op_name="reduce_scatter",
op_func=reduce_scatter, op_func=reduce_scatter,
...@@ -358,9 +397,44 @@ class GroupCoordinator: ...@@ -358,9 +397,44 @@ class GroupCoordinator:
else: else:
return self._all_reduce_out_place(input_) return self._all_reduce_out_place(input_)
def all_reduce_crq_m32(self, input_: torch.Tensor,
pa_rms_weight: torch.Tensor,
pa_residual: torch.Tensor,
pa_rms_eps: float,
pa_quant_dtype: torch.dtype,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert self.world_size > 1
assert envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None
return torch.ops.vllm.all_reduce_rms_quant(input_,
group_name=self.unique_name,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
return self.device_communicator.all_reduce(input_) return self.device_communicator.all_reduce(input_)
def _all_reduce_out_place_m32(self, input_: torch.Tensor,
pa_rms_weight: torch.Tensor,
pa_residual: torch.Tensor,
pa_rms_eps: float,
pa_quant_dtype: torch.dtype,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None \
and pa_residual is not None
input_, pa_residual, xq, xs = self.device_communicator.all_reduce_rms_quant_m32(input_,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
return input_, pa_residual, xq, xs
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size world_size = self.world_size
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
...@@ -735,6 +809,8 @@ class GroupCoordinator: ...@@ -735,6 +809,8 @@ class GroupCoordinator:
torch.distributed.recv(tensor, torch.distributed.recv(tensor,
src=self.ranks[src], src=self.ranks[src],
group=group) group=group)
if envs.VLLM_USE_PP_SYNC:
torch.cuda.synchronize()
if use_all_gather: if use_all_gather:
# do the allgather # do the allgather
tensor = all_gather_group.all_gather( # type: ignore tensor = all_gather_group.all_gather( # type: ignore
......
...@@ -284,7 +284,13 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: ...@@ -284,7 +284,13 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cached version. cached version.
""" """
return copy.deepcopy(_compute_kwargs(cls)) return copy.deepcopy(_compute_kwargs(cls))
class EnvironmentConfigError(Exception):
pass
def check_incompatible_config(env1: bool, env2: bool):
if env1 is True and env2 is True:
_s = "USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and USE_FUSED_RMS_QUANT must not be enabled simultaneously!\n\n"
raise EnvironmentConfigError(_s)
@dataclass @dataclass
class EngineArgs: class EngineArgs:
...@@ -1230,7 +1236,8 @@ class EngineArgs: ...@@ -1230,7 +1236,8 @@ class EngineArgs:
num_lookahead_slots = num_lookahead_slots \ num_lookahead_slots = num_lookahead_slots \
if speculative_config is None \ if speculative_config is None \
else speculative_config.num_lookahead_slots else speculative_config.num_lookahead_slots
check_incompatible_config(envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT, envs.USE_FUSED_RMS_QUANT)
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
runner_type=model_config.runner_type, runner_type=model_config.runner_type,
max_num_batched_tokens=self.max_num_batched_tokens, max_num_batched_tokens=self.max_num_batched_tokens,
......
...@@ -178,6 +178,9 @@ if TYPE_CHECKING: ...@@ -178,6 +178,9 @@ if TYPE_CHECKING:
VLLM_P2P_BUF_TOKENS: int = 30000 VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False
VLLM_USE_PD_SPLIT: bool = False VLLM_USE_PD_SPLIT: bool = False
VLLM_USE_PP_SYNC: bool = False
VLLM_USE_LIGHTOP_FILL_MOE_ALIGN: bool = False
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1156,6 +1159,18 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1156,6 +1159,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_PD_SPLIT": "VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will sync to avoid pp vmfault
"VLLM_USE_PP_SYNC":
lambda: (os.environ.get("VLLM_USE_PP_SYNC", "False").lower() in
("true", "1")),
# vLLM will use lightop to fuse fill and moe align
"VLLM_USE_LIGHTOP_FILL_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN", "False").lower() in
("true", "1")),
# vllm will use custom-allreduce rmsquant fused op
"USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT', '0').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -216,7 +216,9 @@ def moe_align_block_size( ...@@ -216,7 +216,9 @@ def moe_align_block_size(
sorted_ids = torch.empty((max_num_tokens_padded, ), sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel()) if not envs.VLLM_USE_LIGHTOP_FILL_MOE_ALIGN:
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while # Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism. # mapping global expert ids to local expert ids in expert parallelism.
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import itertools import itertools
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union, Tuple
import vllm.envs as envs import vllm.envs as envs
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce,
tensor_model_parallel_all_reduce_crp_m32)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
...@@ -677,7 +678,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -677,7 +678,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self, input_, self, input_,
rms_weight: Optional[torch.Tensor] = None, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True update_hd: Optional[bool] = True,
xqxs: Optional[tuple] = None
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None input_quant_args = None
...@@ -706,7 +708,22 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -706,7 +708,22 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if not self.return_bias: if not self.return_bias:
return output return output
return output, new_residual, output_bias return output, new_residual, output_bias
else: # not USE_FUSED_RMS_QUANT elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=xqxs)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
else:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
...@@ -1495,46 +1512,94 @@ class RowParallelLinear(LinearBase): ...@@ -1495,46 +1512,94 @@ class RowParallelLinear(LinearBase):
def forward( def forward(
self, input_, self, input_,
use_fused_silu_mul_quant: Optional[bool] = False use_fused_silu_mul_quant: Optional[bool] = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: pa_rms_weight: Optional[torch.Tensor] = None,
if self.input_is_parallel: pa_residual: Optional[torch.Tensor] = None,
input_parallel = input_ pa_rms_eps: Optional[float] = 1e-6,
else: pa_quant_dtype: Optional[torch.dtype] = torch.int8,
tp_rank = get_tensor_model_parallel_rank() update_input: Optional[bool] = True
splitted_input = split_tensor_along_last_dim( ) -> Union[torch.Tensor,
input_, num_partitions=self.tp_size) tuple[torch.Tensor, Optional[Parameter]],
input_parallel = splitted_input[tp_rank].contiguous() tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[Parameter]]
]:
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply. # Matrix multiply.
assert self.quant_method is not None assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case) # bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if use_fused_silu_mul_quant:
xq, xs = lm_fuse_silu_mul_quant(input_parallel)
silu_quant_args = [xq, xs]
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_,
silu_quant_args=silu_quant_args)
else:
output_parallel = self.quant_method.apply(self, output_parallel = self.quant_method.apply(self,
input_parallel, input_parallel,
bias=bias_) bias=bias_)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel) output = self.tbo_all_reduce(output_parallel)
packages_ = tensor_model_parallel_all_reduce_crp_m32(output_parallel,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
hs, resi, xq, xs = packages_
output = hs
else: else:
output = tensor_model_parallel_all_reduce(output_parallel) output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, resi, xq, xs, output_bias
else: else:
output = output_parallel if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
output_bias = self.bias if self.skip_bias_add else None # Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if use_fused_silu_mul_quant:
xq, xs = lm_fuse_silu_mul_quant(input_parallel)
silu_quant_args = [xq, xs]
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_,
silu_quant_args=silu_quant_args)
else:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias: if not self.return_bias:
return output return output
return output, output_bias return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}" s = f"input_features={self.input_size_per_partition}"
......
...@@ -162,7 +162,11 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -162,7 +162,11 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
assert len(input_quant_args) == 2 assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args x_q, x_scale = input_quant_args
elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None: elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
assert len(silu_quant_args) == 2
x_q, x_scale = silu_quant_args x_q, x_scale = silu_quant_args
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
else: else:
x_q, x_scale = per_token_quant_int8(x) x_q, x_scale = per_token_quant_int8(x)
...@@ -178,9 +182,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -178,9 +182,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
if m<=16: if m<=16:
m_=m m_=m
elif m<=64: elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数 m_ = ((m + 3) // 4) * 4 #取值到最近的4的倍数
elif m<=160: elif m<=160:
m_=(m + 7) & -8 m_ = ((m + 7) // 8) * 8
elif m<200: #256 elif m<200: #256
m_=160 m_=160
......
...@@ -251,6 +251,8 @@ def get_model_architecture( ...@@ -251,6 +251,8 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if os.getenv('GEMM_PAD') != '1': if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
...@@ -264,6 +266,8 @@ def get_model_architecture( ...@@ -264,6 +266,8 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
# awq相关配置 # awq相关配置
try: try:
......
...@@ -29,7 +29,7 @@ import vllm.envs as envs ...@@ -29,7 +29,7 @@ import vllm.envs as envs
import typing import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import Any, Optional, Union from typing import Any, Optional, Union, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -96,8 +96,8 @@ class DeepseekV2MLP(nn.Module): ...@@ -96,8 +96,8 @@ class DeepseekV2MLP(nn.Module):
def forward(self, x, def forward(self, x,
rms_weight: Optional[torch.Tensor] = None, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = False update_hd: Optional[bool] = False,
): xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None):
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd) gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
if envs.USE_FUSED_SILU_MUL_QUANT: if envs.USE_FUSED_SILU_MUL_QUANT:
...@@ -107,6 +107,11 @@ class DeepseekV2MLP(nn.Module): ...@@ -107,6 +107,11 @@ class DeepseekV2MLP(nn.Module):
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
return x, new_resi return x, new_resi
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
gate_up, _ = self.gate_up_proj(x, xqxs=xqxs)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
else: else:
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
...@@ -200,57 +205,99 @@ class DeepseekV2MoE(nn.Module): ...@@ -200,57 +205,99 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None residual: Optional[torch.Tensor] = None,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
hidden_states = hidden_states.view(-1, hidden_dim) num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT: shared_output = self.shared_experts(hidden_states, xqxs=xqxs)
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else: router_logits, _ = self.gate(hidden_states)
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output)
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor router_logits=router_logits,
shared_output=shared_output)
else: else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else: else:
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = self.experts(hidden_states=hidden_states,
* (1. / self.routed_scaling_factor) router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)
else:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
if self.tp_size > 1: if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
if envs.VLLM_ENABLE_TBO: final_hidden_states = self.experts(
final_hidden_states = self.tbo_all_reduce(final_hidden_states) hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output)
else: else:
final_hidden_states = ( if hidden_states.dtype != torch.float16:
self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states = self.experts(
final_hidden_states)) hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
return final_hidden_states.view(num_tokens, hidden_dim), new_resi return final_hidden_states.view(num_tokens, hidden_dim), new_resi
else: else:
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
...@@ -556,8 +603,16 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -556,8 +603,16 @@ class DeepseekV2MLAAttention(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None residual: Optional[torch.Tensor] = None,
) -> torch.Tensor: pa_rms_weight: Optional[torch.Tensor] = None,
pa_residual: Optional[torch.Tensor] = None,
pa_rms_eps: Optional[float] = 1e-6,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False) q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
...@@ -587,6 +642,40 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -587,6 +642,40 @@ class DeepseekV2MLAAttention(nn.Module):
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim)) self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0], new_residual return self.o_proj(attn_out)[0], new_residual
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None:
if self.q_lora_rank is not None:
q_c = self.q_a_proj(hidden_states)[0]
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
packages_ = self.o_proj(attn_out,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)[:4]
assert len(packages_) == 4
hs, resi, xq, xs = packages_
assert xq is not None and xs is not None
return hs, resi, xq, xs
else: else:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q_c = self.q_a_proj(hidden_states)[0] q_c = self.q_a_proj(hidden_states)[0]
...@@ -682,97 +771,162 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -682,97 +771,162 @@ class DeepseekV2DecoderLayer(nn.Module):
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.use_fused_rms_quant = envs.USE_FUSED_RMS_QUANT
self.use_fused_custom_all_reduce = envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
def forward( def forward_fused_rmsquant(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor]
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
if envs.USE_FUSED_RMS_QUANT: # Fix residual FP16 overflow
# Fix residual FP16 overflow residual_fix_overflow = False
residual_fix_overflow = False
assert self.input_layernorm.has_weight is True
if residual is None:
residual = hidden_states
hidden_states, _ = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = None
)
residual_fix_overflow = True
else:
hidden_states, new_residual = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = residual
)
residual = new_residual
assert self.input_layernorm.has_weight is True if hidden_states.dtype == torch.float16:
if residual is None: # rmsnorm, and rmsnorm result would not affect by scale.
residual = hidden_states hidden_states *= 1. / self.routed_scaling_factor
hidden_states, _ = self.self_attn( if self.layer_idx == 0 or residual_fix_overflow:
positions = positions, # The residual is shared by all layers, we only scale it on
hidden_states = hidden_states, # first layer.
rms_weight = self.input_layernorm.weight.data, residual *= 1. / self.routed_scaling_factor
residual = None
) hidden_states, new_resi = self.mlp(hidden_states, self.post_attention_layernorm.weight.data, residual)
residual_fix_overflow = True
else: if isinstance(self.mlp,
hidden_states, new_residual = self.self_attn( DeepseekV2MLP) and hidden_states.dtype == torch.float16:
positions = positions, # Fix FP16 overflow
hidden_states = hidden_states, # Scaling the DeepseekV2MLP output, it is the input of
rms_weight = self.input_layernorm.weight.data, # input_layernorm of next decoder layer.
residual = residual # The scaling of DeepseekV2MOE output would be done in the forward
) # of DeepseekV2MOE
residual = new_residual hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, new_resi
def forward_fused_CRQ(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
residual_fix_overflow = False
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
residual_fix_overflow = True
else:
hidden_states, resi_new = self.input_layernorm(
hidden_states, residual)
residual = resi_new
new_hs, new_resi, xq, xs = self.self_attn(
positions=positions,
hidden_states=hidden_states,
pa_rms_weight=self.post_attention_layernorm.weight.data,
pa_residual=residual,
pa_rms_eps=self.post_attention_layernorm.variance_epsilon,
pa_quant_dtype = torch.int8,
update_input=True
)
assert xq is not None and xs is not None
if new_hs.dtype == torch.float16: # overflow处理逻辑
new_hs *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
new_resi *= 1. / self.routed_scaling_factor
hidden_states = self.mlp(new_hs, xqxs=(xq, xs))
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, new_resi
def forward_default(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow = False
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
residual_fix_overflow = True
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, residual
if hidden_states.dtype == torch.float16: def choose_forward(self):
# rmsnorm, and rmsnorm result would not affect by scale. if self.use_fused_rms_quant:
hidden_states *= 1. / self.routed_scaling_factor return self.forward_fused_rmsquant
if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
hidden_states, new_resi = self.mlp(hidden_states, self.post_attention_layernorm.weight.data, residual)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, new_resi
elif self.use_fused_custom_all_reduce:
return self.forward_fused_CRQ
else: else:
# Self Attention return self.forward_default
# Fix residual FP16 overflow
residual_fix_overflow = False
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
residual_fix_overflow = True
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if hidden_states.dtype == torch.float16: def forward(
# Fix FP16 overflow self,
# We scale both hidden_states and residual before positions: torch.Tensor,
# rmsnorm, and rmsnorm result would not affect by scale. hidden_states: torch.Tensor,
hidden_states *= 1. / self.routed_scaling_factor residual: Optional[torch.Tensor]
if self.layer_idx == 0 or residual_fix_overflow: ) -> Tuple[torch.Tensor, torch.Tensor]:
# The residual is shared by all layers, we only scale it on forward_func = self.choose_forward()
# first layer. return forward_func(positions=positions, hidden_states=hidden_states, residual=residual )
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, residual
@support_torch_compile @support_torch_compile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment