Commit 3e6729e0 authored by wujl5's avatar wujl5
Browse files

deepseekv2-w4a8支持custom-rms-quant融合

parent 813f81fb
...@@ -59,6 +59,144 @@ bool _is_weak_contiguous(torch::Tensor& t) { ...@@ -59,6 +59,144 @@ bool _is_weak_contiguous(torch::Tensor& t) {
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer. * copied into _reg_buffer.
*/ */
void all_reduce_fuse_norm(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t hidden_size, torch::Tensor& residual, torch::Tensor& rms_weight,
double eps, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_is_weak_contiguous(out));
TORCH_CHECK(_is_weak_contiguous(inp));
TORCH_CHECK(_is_weak_contiguous(residual));
TORCH_CHECK(_is_weak_contiguous(rms_weight));
int token_num = inp.numel() / hidden_size;
auto input_size = inp.numel() * inp.element_size();
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
cudaMemcpyDeviceToDevice, stream));
} else {
reg_buffer = inp.data_ptr();
}
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce_fuse_norm<float>(stream, reinterpret_cast<float*>(reg_buffer),
reinterpret_cast<float*>(out.data_ptr()),out.numel(),
token_num, hidden_size, reinterpret_cast<float*>(residual.data_ptr()),
reinterpret_cast<float*>(rms_weight.data_ptr()), (float)eps);
break;
}
case at::ScalarType::Half: {
fa->allreduce_fuse_norm<half>(stream, reinterpret_cast<half*>(reg_buffer),
reinterpret_cast<half*>(out.data_ptr()),out.numel(),
token_num, hidden_size, reinterpret_cast<half*>(residual.data_ptr()),
reinterpret_cast<half*>(rms_weight.data_ptr()), (float)eps);
break;
}
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce_fuse_norm<nv_bfloat16>(stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()),out.numel(),
token_num, hidden_size, reinterpret_cast<nv_bfloat16*>(residual.data_ptr()),
reinterpret_cast<nv_bfloat16*>(rms_weight.data_ptr()), (float)eps);
break;
}
// #endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}
template<typename scalar_in_t, bool update_input>
void allreduce_fuse_norm_quant_dispath(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int hidden_size,torch::Tensor& rms_weight, double eps,
torch::Tensor& scales, torch::Tensor& norm_out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes,
std::optional<at::Tensor> residual) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(_is_weak_contiguous(inp));
int token_num = inp.numel() / hidden_size;
auto input_size = inp.numel() * inp.element_size();
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
cudaMemcpyDeviceToDevice, stream));
} else {
reg_buffer = inp.data_ptr();
}
auto wt_ptr = reinterpret_cast<std::uintptr_t>(rms_weight.data_ptr());
if (wt_ptr % 16 != 0) {
throw std::runtime_error(
"custom allreduce currently requires wt_ptr % 16 "
"of " +
std::to_string(wt_ptr % 16));
}
if (fa->fully_connected_) {
if (residual.has_value()) {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "fa->allreduce_fuse_norm_quant", [&] {
fa->allreduce_fuse_norm_quant<scalar_in_t, scalar_t, true, update_input>
(stream, reinterpret_cast<scalar_in_t*>(reg_buffer), out.data_ptr<scalar_t>(),
out.numel(), token_num, hidden_size, residual->data_ptr<scalar_in_t>(),
rms_weight.data_ptr<scalar_in_t>(),
norm_out.data_ptr<scalar_in_t>(),
eps, scales.data_ptr<float>());
});
} else {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "fa->allreduce_fuse_norm_quant", [&] {
fa->allreduce_fuse_norm_quant<scalar_in_t, scalar_t, false, update_input>
(stream, reinterpret_cast<scalar_in_t*>(reg_buffer), out.data_ptr<scalar_t>(),
out.numel(), token_num, hidden_size, nullptr,
rms_weight.data_ptr<scalar_in_t>(),
norm_out.data_ptr<scalar_in_t>(),
eps, scales.data_ptr<float>());
});
}
} else {
throw std::runtime_error(
"custom allreduce only supports fully_connected");
}
}
void all_reduce_fuse_norm_quant(fptr_t fa, torch::Tensor& inp, torch::Tensor& out,
int64_t hidden_size,torch::Tensor& rms_weight, double eps,
torch::Tensor& scales, torch::Tensor& norm_out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes,
std::optional<at::Tensor> residual, bool update_input) {
static c10::ScalarType kFp8Type = c10::ScalarType::Float8_e4m3fn;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(scales.dtype() == torch::kFloat32);
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(out.is_contiguous() && inp.is_contiguous());
VLLM_DISPATCH_FLOATING_TYPES(
inp.scalar_type(), "allreduce_fuse_norm_quant_dispath", [&] {
if (update_input)
allreduce_fuse_norm_quant_dispath<scalar_t, true>(
fa, inp, out, hidden_size, rms_weight, eps, scales, norm_out,
reg_buffer, reg_buffer_sz_bytes, residual);
else
allreduce_fuse_norm_quant_dispath<scalar_t, false>(
fa, inp, out, hidden_size, rms_weight, eps, scales, norm_out,
reg_buffer, reg_buffer_sz_bytes, residual);
});
}
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
......
This diff is collapsed.
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <limits> #include <limits>
#include <vector> #include <vector>
#include <random>
#include "cuda_profiler_api.h" #include "cuda_profiler_api.h"
#include "custom_all_reduce.cuh" #include "custom_all_reduce.cuh"
...@@ -117,16 +118,113 @@ __global__ void gen_data(curandState_t* state, T* data, double* ground_truth, ...@@ -117,16 +118,113 @@ __global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
ground_truth[idx] = sum; ground_truth[idx] = sum;
} }
} }
/*************************************************/
template <typename T,int reducesize=64>
__inline__ __device__ T WarpReduceSum_NEW(T val) {
#pragma unroll
for (int offset = reducesize/2; offset > 0; offset >>= 1) {
val += __shfl_down(val, offset);
}
return val;
}
template <typename T,int block_size=512>
__inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
constexpr int share_size=block_size/64;
val = WarpReduceSum_NEW<T>(val);
if constexpr(block_size==64)
{
return val;
}
else{
const int lid = threadIdx.x % 64;
const int wid = threadIdx.x / 64;
if (lid == 0&&wid<share_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0&&lid<share_size) {
val = WarpReduceSum_NEW<T,share_size>(shared[lid]);
}
return val;
}
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_add_rms_kernel_opt(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/64;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
using LoadT = typename vllm::packed_t<scalar_t>::P;
scalar_t intput_vec[Vec];
scalar_t residual_vec[Vec];
T_ACC trstd;
int64_t idx = i * tcol + j;
idx*=Vec;
if (j < tcol) {
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
*(LoadT*)residual_vec = *(LoadT*)(residual+idx);
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
residual_vec[ii]+=intput_vec[ii];
val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]);
}
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=rsqrtf(val/cols + eps);
__syncthreads();
trstd=s_rstd;
if (j < tcol) {
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(residual_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
*(LoadT*)(input+idx)=*(LoadT*)intput_vec;
}
}
template <typename scalar_t>
void fused_add_rms_norm_choose(cudaStream_t stream, scalar_t* self_data, scalar_t* other_data,
scalar_t*weight_data, double eps, int hidden_size, int num_tokens) {
if (hidden_size<=1024){
fused_add_rms_kernel_opt<scalar_t,float,8,128><<<num_tokens, 128, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=2048){
fused_add_rms_kernel_opt<scalar_t,float,8,256><<<num_tokens, 256, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
if(num_tokens>1200){
fused_add_rms_kernel_opt<scalar_t,float,8,512><<<num_tokens, 512, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_opt<scalar_t,float,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
}
else if(hidden_size<=8192){
fused_add_rms_kernel_opt<scalar_t,float,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_opt<scalar_t,float,16,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
}
/*****************************************************************/
template <typename T> template <typename T>
void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
int data_size, bool performance_test) { int data_size, bool performance_test, int hidden_dim) {
T* result; T* result_ori, *result_fuse;
cudaStream_t stream; cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); CUDACHECK(cudaMalloc(&result_ori, data_size * sizeof(T)));
CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T))); CUDACHECK(cudaMemset(result_ori, 0, data_size * sizeof(T)));
CUDACHECK(cudaMalloc(&result_fuse, data_size * sizeof(T)));
CUDACHECK(cudaMemset(result_fuse, 0, data_size * sizeof(T)));
cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t self_data_handle;
cudaIpcMemHandle_t data_handles[8]; cudaIpcMemHandle_t data_handles[8];
vllm::Signal* buffer; vllm::Signal* buffer;
...@@ -176,7 +274,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -176,7 +274,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
sizeof(vllm::Signal) + data_size * sizeof(T)); sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration // hack buffer registration
{ {
void* data[8]; void* data[8]; //gpu数据部分
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
data[i] = data[i] =
((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T); ((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T);
...@@ -196,7 +294,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -196,7 +294,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
cudaEvent_t start, stop; cudaEvent_t start, stop;
CUDACHECK(cudaEventCreate(&start)); CUDACHECK(cudaEventCreate(&start));
CUDACHECK(cudaEventCreate(&stop)); CUDACHECK(cudaEventCreate(&stop));
/*******************************/
int token_num = data_size / hidden_dim;
T* residual_h, *residual_d, *weight_h, *weight_d;
residual_h = (T*)malloc(data_size * sizeof(T));
std::random_device rd; // 用于获取随机数种子
std::mt19937 gen(7);
std::uniform_real_distribution<float> dis(-3.0f, 3.0f);
for (int i = 0; i < data_size; ++i)
residual_h[i] = static_cast<T>(dis(gen));
for (int i = 0; i < hidden_dim; ++i)
weight_h[i] = static_cast<T>(dis(gen));
cudaMalloc((void**)&residual_d, sizeof(T)*data_size);
cudaMalloc((void**)&weight_d, sizeof(T)*hidden_dim);
cudaMemcpyAsync(residual_d, residual_h, sizeof(T)*data_size, cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(weight_d, weight_h, sizeof(T)*hidden_dim, cudaMemcpyHostToDevice, stream);
float eps = 1.0f;
/*******************************/
ncclDataType_t ncclDtype; ncclDataType_t ncclDtype;
if (std::is_same<T, half>::value) { if (std::is_same<T, half>::value) {
ncclDtype = ncclFloat16; ncclDtype = ncclFloat16;
...@@ -211,16 +328,16 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -211,16 +328,16 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
if (performance_test) { if (performance_test) {
dummy_kernel<<<1, 1, 0, stream>>>(); dummy_kernel<<<1, 1, 0, stream>>>();
constexpr int warmup_iters = 5; constexpr int warmup_iters = 5;
constexpr int num_iters = 100; constexpr int num_iters = 10;
// warmup // warmup
for (int i = 0; i < warmup_iters; i++) { for (int i = 0; i < warmup_iters; i++) {
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, fa.allreduce<T>(stream, self_data, result_ori, data_size, threads, block_limit);
comm, stream)); fused_add_rms_norm_choose<T>(stream, result_ori, residual_d, weight_d, 1.0, hidden_dim, token_num);
} }
CUDACHECK(cudaEventRecord(start, stream)); CUDACHECK(cudaEventRecord(start, stream));
for (int i = 0; i < num_iters; i++) { for (int i = 0; i < num_iters; i++) {
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, fa.allreduce<T>(stream, self_data, result_ori, data_size, threads, block_limit);
comm, stream)); fused_add_rms_norm_choose<T>(stream, result_ori, residual_d, weight_d, 1.0, hidden_dim, token_num);
} }
CUDACHECK(cudaEventRecord(stop, stream)); CUDACHECK(cudaEventRecord(stop, stream));
CUDACHECK(cudaStreamSynchronize(stream)); CUDACHECK(cudaStreamSynchronize(stream));
...@@ -230,13 +347,15 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -230,13 +347,15 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
dummy_kernel<<<1, 1, 0, stream>>>(); dummy_kernel<<<1, 1, 0, stream>>>();
// warm up // warm up
for (int i = 0; i < warmup_iters; i++) { for (int i = 0; i < warmup_iters; i++) {
fa.allreduce<T>(stream, self_data, result, data_size, threads, fa.allreduce_fuse_norm<T>(stream, self_data, result_fuse, data_size, token_num,
block_limit); hidden_dim, residual_d, weight_d, eps,
threads, block_limit);
} }
CUDACHECK(cudaEventRecord(start, stream)); CUDACHECK(cudaEventRecord(start, stream));
for (int i = 0; i < num_iters; i++) { for (int i = 0; i < num_iters; i++) {
fa.allreduce<T>(stream, self_data, result, data_size, threads, fa.allreduce_fuse_norm<T>(stream, self_data, result_fuse, data_size, token_num,
block_limit); hidden_dim, residual_d, weight_d, eps,
threads, block_limit);
} }
CUDACHECK(cudaEventRecord(stop, stream)); CUDACHECK(cudaEventRecord(stop, stream));
CUDACHECK(cudaStreamSynchronize(stream)); CUDACHECK(cudaStreamSynchronize(stream));
...@@ -245,7 +364,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -245,7 +364,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
cudaEventElapsedTime(&duration_ms, start, stop); cudaEventElapsedTime(&duration_ms, start, stop);
if (myRank == 0) if (myRank == 0)
printf( printf(
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl " "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, allreduse_fuse_norm time:%.2fus, allreduce+norm "
"time:%.2fus\n", "time:%.2fus\n",
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit, myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters); duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
...@@ -255,8 +374,9 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -255,8 +374,9 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype, NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
ncclSum, comm, stream)); ncclSum, comm, stream));
fused_add_rms_norm_choose<T>(stream, self_data, residual_d, weight_d, 1.0, hidden_dim, token_num);
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result, convert_data<T><<<108, 1024, 0, stream>>>(result_ori, result_fuse, nccl_result,
my_result, data_size); my_result, data_size);
CUDACHECK(cudaStreamSynchronize(stream)); CUDACHECK(cudaStreamSynchronize(stream));
...@@ -279,13 +399,13 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -279,13 +399,13 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
<< " me: " << my_diffs / data_size << std::endl; << " me: " << my_diffs / data_size << std::endl;
} else { } else {
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
fa.allreduce<T>(stream, self_data, result, data_size, threads, fa.allreduce<T>(stream, self_data, result_ori, data_size, threads,
block_limit); block_limit);
CUDACHECK(cudaStreamSynchronize(stream)); CUDACHECK(cudaStreamSynchronize(stream));
NCCLCHECK(ncclAllReduce(self_data, self_data_copy, data_size, ncclDtype, NCCLCHECK(ncclAllReduce(self_data, self_data_copy, data_size, ncclDtype,
ncclSum, comm, stream)); ncclSum, comm, stream));
convert_data<T><<<108, 1024, 0, stream>>>( convert_data<T><<<108, 1024, 0, stream>>>(
self_data_copy, result, nccl_result, my_result, data_size); self_data_copy, result_ori, nccl_result, my_result, data_size);
CUDACHECK(cudaStreamSynchronize(stream)); CUDACHECK(cudaStreamSynchronize(stream));
for (unsigned long j = 0; j < data_size; j++) { for (unsigned long j = 0; j < data_size; j++) {
...@@ -312,7 +432,8 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, ...@@ -312,7 +432,8 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
// << " me: " << my_diffs / data_size << std::endl; // << " me: " << my_diffs / data_size << std::endl;
} }
CUDACHECK(cudaFree(result)); CUDACHECK(cudaFree(result_ori));
CUDACHECK(cudaFree(result_fuse));
CUDACHECK(cudaFree(self_data_copy)); CUDACHECK(cudaFree(self_data_copy));
CUDACHECK(cudaFree(rank_data)); CUDACHECK(cudaFree(rank_data));
CUDACHECK(cudaFree(buffer)); CUDACHECK(cudaFree(buffer));
...@@ -351,9 +472,7 @@ int main(int argc, char** argv) { ...@@ -351,9 +472,7 @@ int main(int argc, char** argv) {
const int block_limit = 36; const int block_limit = 36;
#endif #endif
// Scan through different sizes to test performance. // Scan through different sizes to test performance.
for (int sz = 512; sz <= (8 << 20); sz *= 2) { run<half>(myRank, nRanks, comm, 512, 36, 7168 * 80, performance_test, 7168);
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
}
cudaProfilerStop(); cudaProfilerStop();
MPICHECK(MPI_Finalize()); MPICHECK(MPI_Finalize());
......
...@@ -487,6 +487,17 @@ fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs, ...@@ -487,6 +487,17 @@ fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
bool fully_connected); bool fully_connected);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes); fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
void all_reduce_fuse_norm(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t hidden_size, torch::Tensor& residual, torch::Tensor& rms_weight,
double eps, fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
void all_reduce_fuse_norm_quant(fptr_t fa, torch::Tensor& inp, torch::Tensor& out,
int64_t hidden_size,torch::Tensor& rms_weight, double eps,
torch::Tensor& scales, torch::Tensor& norm_out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes,
std::optional<at::Tensor> residual, bool update_input);
void dispose(fptr_t _fa); void dispose(fptr_t _fa);
int64_t meta_size(); int64_t meta_size();
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs); void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
......
...@@ -933,6 +933,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { ...@@ -933,6 +933,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, " "all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()"); "int reg_buffer_sz_bytes) -> ()");
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce); custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
custom_ar.def(
"all_reduce_fuse_norm(int fa, Tensor inp, Tensor! out, int hidden_size, "
"Tensor residual, Tensor rms_weight, float eps, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
custom_ar.impl("all_reduce_fuse_norm", torch::kCUDA, &all_reduce_fuse_norm);
custom_ar.def(
"all_reduce_fuse_norm_quant(int fa, Tensor inp, Tensor! out, int hidden_size, "
"Tensor rms_weight, float eps, Tensor! scales, Tensor! norm_out, int reg_buffer, "
"int reg_buffer_sz_bytes, Tensor? residual, bool update_input) -> ()");
custom_ar.impl("all_reduce_fuse_norm_quant", torch::kCUDA, &all_reduce_fuse_norm_quant);
custom_ar.def("dispose", &dispose); custom_ar.def("dispose", &dispose);
custom_ar.def("meta_size", &meta_size); custom_ar.def("meta_size", &meta_size);
......
...@@ -2212,7 +2212,17 @@ def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, ...@@ -2212,7 +2212,17 @@ def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
reg_buffer_sz_bytes: int) -> None: reg_buffer_sz_bytes: int) -> None:
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
reg_buffer_sz_bytes) reg_buffer_sz_bytes)
def all_reduce_fuse_norm(fa: int, inp: torch.Tensor, out:torch.Tensor, hidden_size:int,
residual:torch.Tensor, rms_weight:torch.Tensor, eps:float,
reg_buffer: int, reg_buffer_sz_bytes: int) -> None:
torch.ops._C_custom_ar.all_reduce_fuse_norm(fa, inp, out, hidden_size, residual,
rms_weight, eps, reg_buffer, reg_buffer_sz_bytes)
def all_reduce_fuse_norm_quant(fa: int, inp: torch.Tensor, out:torch.Tensor, hidden_size:int,
rms_weight: torch.Tensor, eps: float, scales: torch.Tensor, norm_out:torch.Tensor,
reg_buffer: int, reg_buffer_sz_bytes: int, residual: torch.Tensor, update_input: bool) -> None:
torch.ops._C_custom_ar.all_reduce_fuse_norm_quant(fa, inp, out, hidden_size, rms_weight, eps, scales, norm_out,
reg_buffer, reg_buffer_sz_bytes, residual, update_input)
def dispose(fa: int) -> None: def dispose(fa: int) -> None:
torch.ops._C_custom_ar.dispose(fa) torch.ops._C_custom_ar.dispose(fa)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional, Union from typing import Any, Optional, Union, Tuple
import torch import torch
import torch.distributed import torch.distributed
from .parallel_state import get_tp_group from .parallel_state import get_tp_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group.""" """All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_) return get_tp_group().all_reduce(input_)
def tensor_model_parallel_all_reduce_crp_m32(input_: torch.Tensor,
pa_rms_weight: torch.Tensor,
pa_residual: torch.Tensor,
pa_rms_eps: float,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""All-reduce the input tensor across model parallel group."""
# allreduce fused rms and quant
return get_tp_group().all_reduce_crq_m32(input_=input_,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
def tensor_model_parallel_all_gather(input_: torch.Tensor, def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor: dim: int = -1) -> torch.Tensor:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Optional, Tuple
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -14,6 +14,7 @@ from .base_device_communicator import DeviceCommunicatorBase ...@@ -14,6 +14,7 @@ from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__) logger = init_logger(__name__)
from lmslim.quantize.quant_ops import lm_faster_rmsquant
class CudaCommunicator(DeviceCommunicatorBase): class CudaCommunicator(DeviceCommunicatorBase):
...@@ -117,6 +118,37 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -117,6 +118,37 @@ class CudaCommunicator(DeviceCommunicatorBase):
torch.distributed.all_reduce(out, group=self.device_group) torch.distributed.all_reduce(out, group=self.device_group)
return out return out
def all_reduce_rms_quant_m32(self, input_,
pa_rms_weight: torch.Tensor,
pa_residual: torch.Tensor,
pa_rms_eps: float,
pa_quant_dtype: torch.dtype,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, hidden_dim = input_.shape
ca_comm = self.ca_comm
assert ca_comm is not None and not ca_comm.disabled
assert envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and \
pa_rms_weight is not None and pa_residual is not None
if batch_size <= 16:
xq, xs, norm_out = ca_comm.custom_all_reduce_fuse_norm_quant(inp=input_,
rms_weight=pa_rms_weight,
residual=pa_residual,
eps=pa_rms_eps,
quant_type=pa_quant_dtype,
update_input=True)
input_ = norm_out
else:
input_ = self.all_reduce(input_)
xq, xs = lm_faster_rmsquant(input_,
rms_weight=pa_rms_weight,
residual=pa_residual,
epsilon=pa_rms_eps,
quant_dtype=pa_quant_dtype,
update_input=True)
assert input_ is not None
assert xq is not None and xs is not None
return input_, pa_residual, xq, xs
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size world_size = self.world_size
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
......
...@@ -275,6 +275,91 @@ class CustomAllreduce: ...@@ -275,6 +275,91 @@ class CustomAllreduce:
# latency) compared to the performance gain of using custom kernels # latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, registered=False) return self.all_reduce(input, registered=False)
def allreduce_fuse_norm(self,
inp: torch.Tensor,
hidden_size: int,
residual: torch.Tensor,
rms_weight: torch.Tensor,
eps: float,
*,
out: torch.Tensor = None,
registered: bool = False):
if out is None:
out = torch.empty_like(inp)
if registered:
ops.all_reduce_fuse_norm(self._ptr, inp, out,
hidden_size, residual, rms_weight, eps, 0, 0)
else:
ops.all_reduce_fuse_norm(self._ptr, inp, out,
hidden_size, residual, rms_weight, eps,
self.buffer_ptrs[self.rank], self.max_size)
return out
def custom_all_reduce_fuse_norm(self,
input: torch.Tensor,
hidden_size: int,
residual: torch.Tensor,
rms_weight: torch.Tensor,
eps: float) -> Optional[torch.Tensor]:
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.allreduce_fuse_norm(input, hidden_size,
residual, rms_weight, eps, registered=False)
else:
return torch.empty_like(input)
else:
return self.allreduce_fuse_norm(input, hidden_size,
residual, rms_weight, eps, registered=False)
def allreduce_fuse_norm_quant(self,
inp: torch.Tensor,
hidden_size: int,
rms_weight,
eps,
quant_dtype,
residual,
update_input: bool = True,
registered: bool = False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
xq = torch.empty_like(inp, dtype=quant_dtype)
norm_out = torch.empty_like(inp)
scales = torch.empty((inp.numel() // inp.shape[-1], 1),
device=inp.device,
dtype=torch.float32)
if registered:
ops.all_reduce_fuse_norm_quant(self._ptr, inp, xq,
hidden_size, rms_weight, eps, scales, norm_out, 0, 0, residual, update_input)
else:
ops.all_reduce_fuse_norm_quant(self._ptr, inp, xq,
hidden_size, rms_weight, eps, scales, norm_out,
self.buffer_ptrs[self.rank], self.max_size, residual, update_input)
return xq, scales, norm_out
def custom_all_reduce_fuse_norm_quant(self,
inp: torch.Tensor,
rms_weight: torch.Tensor,
eps,
quant_type,
residual: Optional[torch.Tensor] = None,
update_input = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_size = inp.shape[-1]
if self.disabled or not self.should_custom_ar(inp):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.allreduce_fuse_norm_quant(inp, hidden_size, rms_weight,
eps, quant_type, residual, update_input = update_input, registered=False)
else:
return torch.empty_like(inp, dtype=quant_type), \
torch.empty((inp.numel() // inp.shape[-1], 1), dtype=torch.float32, device=inp.device), \
torch.empty_like(inp)
else:
return self.allreduce_fuse_norm_quant(inp, hidden_size, rms_weight,
eps, quant_type, residual, update_input = update_input, registered=False)
def close(self): def close(self):
if not self.disabled and self._ptr: if not self.disabled and self._ptr:
if ops is not None: if ops is not None:
......
...@@ -30,7 +30,7 @@ from collections import namedtuple ...@@ -30,7 +30,7 @@ from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Tuple, Union
from unittest.mock import patch from unittest.mock import patch
import torch import torch
...@@ -114,6 +114,37 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: ...@@ -114,6 +114,37 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor) return torch.empty_like(tensor)
def all_reduce_rms_quant(input_: torch.Tensor, group_name: str,
pa_rms_weight: Optional[torch.Tensor] = None,
pa_residual: Optional[torch.Tensor] = None,
pa_rms_eps: Optional[float] = 1e-6,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place_m32(input_,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
def all_reduce_rms_quant_fake(input_: torch.Tensor, group_name: str,
pa_rms_weight: Optional[torch.Tensor] = None,
pa_residual: Optional[torch.Tensor] = None,
pa_rms_eps: Optional[float] = 1e-6,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
xq = torch.zeros_like(input_, dtype=pa_quant_dtype)
xs = torch.ones((input_.numel() // input_.shape[-1], 1),
device=input_.device,
dtype=torch.float32)
return input_, pa_residual, xq, xs
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor: group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found." assert group_name in _groups, f"Group {group_name} is not found."
...@@ -156,6 +187,14 @@ if supports_custom_op(): ...@@ -156,6 +187,14 @@ if supports_custom_op():
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op(
op_name="all_reduce_rms_quant",
op_func=all_reduce_rms_quant,
mutates_args=["input_", "pa_residual"],
fake_impl=all_reduce_rms_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op( direct_register_custom_op(
op_name="reduce_scatter", op_name="reduce_scatter",
op_func=reduce_scatter, op_func=reduce_scatter,
...@@ -358,9 +397,44 @@ class GroupCoordinator: ...@@ -358,9 +397,44 @@ class GroupCoordinator:
else: else:
return self._all_reduce_out_place(input_) return self._all_reduce_out_place(input_)
def all_reduce_crq_m32(self, input_: torch.Tensor,
pa_rms_weight: torch.Tensor,
pa_residual: torch.Tensor,
pa_rms_eps: float,
pa_quant_dtype: torch.dtype,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert self.world_size > 1
assert envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None
return torch.ops.vllm.all_reduce_rms_quant(input_,
group_name=self.unique_name,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
return self.device_communicator.all_reduce(input_) return self.device_communicator.all_reduce(input_)
def _all_reduce_out_place_m32(self, input_: torch.Tensor,
pa_rms_weight: torch.Tensor,
pa_residual: torch.Tensor,
pa_rms_eps: float,
pa_quant_dtype: torch.dtype,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None \
and pa_residual is not None
input_, pa_residual, xq, xs = self.device_communicator.all_reduce_rms_quant_m32(input_,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
return input_, pa_residual, xq, xs
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size world_size = self.world_size
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
......
...@@ -284,7 +284,13 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: ...@@ -284,7 +284,13 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cached version. cached version.
""" """
return copy.deepcopy(_compute_kwargs(cls)) return copy.deepcopy(_compute_kwargs(cls))
class EnvironmentConfigError(Exception):
pass
def check_incompatible_config(env1: bool, env2: bool):
if env1 is True and env2 is True:
_s = "USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and USE_FUSED_RMS_QUANT must not be enabled simultaneously!\n\n"
raise EnvironmentConfigError(_s)
@dataclass @dataclass
class EngineArgs: class EngineArgs:
...@@ -1230,6 +1236,7 @@ class EngineArgs: ...@@ -1230,6 +1236,7 @@ class EngineArgs:
num_lookahead_slots = num_lookahead_slots \ num_lookahead_slots = num_lookahead_slots \
if speculative_config is None \ if speculative_config is None \
else speculative_config.num_lookahead_slots else speculative_config.num_lookahead_slots
check_incompatible_config(envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT, envs.USE_FUSED_RMS_QUANT)
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
runner_type=model_config.runner_type, runner_type=model_config.runner_type,
......
...@@ -180,6 +180,7 @@ if TYPE_CHECKING: ...@@ -180,6 +180,7 @@ if TYPE_CHECKING:
VLLM_USE_PD_SPLIT: bool = False VLLM_USE_PD_SPLIT: bool = False
VLLM_USE_PP_SYNC: bool = False VLLM_USE_PP_SYNC: bool = False
VLLM_USE_LIGHTOP_FILL_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_FILL_MOE_ALIGN: bool = False
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1166,6 +1167,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1166,6 +1167,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_LIGHTOP_FILL_MOE_ALIGN": "VLLM_USE_LIGHTOP_FILL_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN", "False").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN", "False").lower() in
("true", "1")), ("true", "1")),
# vllm will use custom-allreduce rmsquant fused op
"USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT', '0').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import itertools import itertools
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union, Tuple
import vllm.envs as envs import vllm.envs as envs
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce,
tensor_model_parallel_all_reduce_crp_m32)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
...@@ -677,7 +678,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -677,7 +678,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self, input_, self, input_,
rms_weight: Optional[torch.Tensor] = None, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True update_hd: Optional[bool] = True,
xqxs: Optional[tuple] = None
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None input_quant_args = None
...@@ -706,7 +708,22 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -706,7 +708,22 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if not self.return_bias: if not self.return_bias:
return output return output
return output, new_residual, output_bias return output, new_residual, output_bias
else: # not USE_FUSED_RMS_QUANT elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=xqxs)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
else:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
...@@ -1495,8 +1512,56 @@ class RowParallelLinear(LinearBase): ...@@ -1495,8 +1512,56 @@ class RowParallelLinear(LinearBase):
def forward( def forward(
self, input_, self, input_,
use_fused_silu_mul_quant: Optional[bool] = False use_fused_silu_mul_quant: Optional[bool] = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: pa_rms_weight: Optional[torch.Tensor] = None,
pa_residual: Optional[torch.Tensor] = None,
pa_rms_eps: Optional[float] = 1e-6,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True
) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[Parameter]]
]:
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel)
packages_ = tensor_model_parallel_all_reduce_crp_m32(output_parallel,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
hs, resi, xq, xs = packages_
output = hs
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, resi, xq, xs, output_bias
else:
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
......
...@@ -162,7 +162,11 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -162,7 +162,11 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
assert len(input_quant_args) == 2 assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args x_q, x_scale = input_quant_args
elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None: elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
assert len(silu_quant_args) == 2
x_q, x_scale = silu_quant_args x_q, x_scale = silu_quant_args
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
else: else:
x_q, x_scale = per_token_quant_int8(x) x_q, x_scale = per_token_quant_int8(x)
...@@ -178,9 +182,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -178,9 +182,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
if m<=16: if m<=16:
m_=m m_=m
elif m<=64: elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数 m_ = ((m + 3) // 4) * 4 #取值到最近的4的倍数
elif m<=160: elif m<=160:
m_=(m + 7) & -8 m_ = ((m + 7) // 8) * 8
elif m<200: #256 elif m<200: #256
m_=160 m_=160
......
...@@ -29,7 +29,7 @@ import vllm.envs as envs ...@@ -29,7 +29,7 @@ import vllm.envs as envs
import typing import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import Any, Optional, Union from typing import Any, Optional, Union, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -96,8 +96,8 @@ class DeepseekV2MLP(nn.Module): ...@@ -96,8 +96,8 @@ class DeepseekV2MLP(nn.Module):
def forward(self, x, def forward(self, x,
rms_weight: Optional[torch.Tensor] = None, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = False update_hd: Optional[bool] = False,
): xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None):
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd) gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
if envs.USE_FUSED_SILU_MUL_QUANT: if envs.USE_FUSED_SILU_MUL_QUANT:
...@@ -107,6 +107,11 @@ class DeepseekV2MLP(nn.Module): ...@@ -107,6 +107,11 @@ class DeepseekV2MLP(nn.Module):
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
return x, new_resi return x, new_resi
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
gate_up, _ = self.gate_up_proj(x, xqxs=xqxs)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
else: else:
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
...@@ -200,20 +205,62 @@ class DeepseekV2MoE(nn.Module): ...@@ -200,20 +205,62 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None residual: Optional[torch.Tensor] = None,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> torch.Tensor: ) -> torch.Tensor:
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states, xqxs=xqxs)
router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output)
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)
else:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True) shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else: else:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -556,8 +603,16 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -556,8 +603,16 @@ class DeepseekV2MLAAttention(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None residual: Optional[torch.Tensor] = None,
) -> torch.Tensor: pa_rms_weight: Optional[torch.Tensor] = None,
pa_residual: Optional[torch.Tensor] = None,
pa_rms_eps: Optional[float] = 1e-6,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False) q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
...@@ -587,6 +642,40 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -587,6 +642,40 @@ class DeepseekV2MLAAttention(nn.Module):
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim)) self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0], new_residual return self.o_proj(attn_out)[0], new_residual
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None:
if self.q_lora_rank is not None:
q_c = self.q_a_proj(hidden_states)[0]
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
packages_ = self.o_proj(attn_out,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)[:4]
assert len(packages_) == 4
hs, resi, xq, xs = packages_
assert xq is not None and xs is not None
return hs, resi, xq, xs
else: else:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q_c = self.q_a_proj(hidden_states)[0] q_c = self.q_a_proj(hidden_states)[0]
...@@ -682,14 +771,15 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -682,14 +771,15 @@ class DeepseekV2DecoderLayer(nn.Module):
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.use_fused_rms_quant = envs.USE_FUSED_RMS_QUANT
self.use_fused_custom_all_reduce = envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
def forward( def forward_fused_rmsquant(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor]
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
if envs.USE_FUSED_RMS_QUANT:
# Fix residual FP16 overflow # Fix residual FP16 overflow
residual_fix_overflow = False residual_fix_overflow = False
...@@ -732,7 +822,51 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -732,7 +822,51 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states *= 1. / self.routed_scaling_factor hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, new_resi return hidden_states, new_resi
def forward_fused_CRQ(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
residual_fix_overflow = False
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
residual_fix_overflow = True
else: else:
hidden_states, resi_new = self.input_layernorm(
hidden_states, residual)
residual = resi_new
new_hs, new_resi, xq, xs = self.self_attn(
positions=positions,
hidden_states=hidden_states,
pa_rms_weight=self.post_attention_layernorm.weight.data,
pa_residual=residual,
pa_rms_eps=self.post_attention_layernorm.variance_epsilon,
pa_quant_dtype = torch.int8,
update_input=True
)
assert xq is not None and xs is not None
if new_hs.dtype == torch.float16: # overflow处理逻辑
new_hs *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
new_resi *= 1. / self.routed_scaling_factor
hidden_states = self.mlp(new_hs, xqxs=(xq, xs))
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, new_resi
def forward_default(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
# Fix residual FP16 overflow # Fix residual FP16 overflow
residual_fix_overflow = False residual_fix_overflow = False
...@@ -774,6 +908,26 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -774,6 +908,26 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
def choose_forward(self):
if self.use_fused_rms_quant:
return self.forward_fused_rmsquant
elif self.use_fused_custom_all_reduce:
return self.forward_fused_CRQ
else:
return self.forward_default
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
forward_func = self.choose_forward()
return forward_func(positions=positions, hidden_states=hidden_states, residual=residual )
@support_torch_compile @support_torch_compile
class DeepseekV2Model(nn.Module): class DeepseekV2Model(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment