Commit ad74f612 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.9.2-dev

parents 7270e6d9 46006aee
......@@ -59,6 +59,144 @@ bool _is_weak_contiguous(torch::Tensor& t) {
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce_fuse_norm(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t hidden_size, torch::Tensor& residual, torch::Tensor& rms_weight,
double eps, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_is_weak_contiguous(out));
TORCH_CHECK(_is_weak_contiguous(inp));
TORCH_CHECK(_is_weak_contiguous(residual));
TORCH_CHECK(_is_weak_contiguous(rms_weight));
int token_num = inp.numel() / hidden_size;
auto input_size = inp.numel() * inp.element_size();
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
cudaMemcpyDeviceToDevice, stream));
} else {
reg_buffer = inp.data_ptr();
}
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce_fuse_norm<float>(stream, reinterpret_cast<float*>(reg_buffer),
reinterpret_cast<float*>(out.data_ptr()),out.numel(),
token_num, hidden_size, reinterpret_cast<float*>(residual.data_ptr()),
reinterpret_cast<float*>(rms_weight.data_ptr()), (float)eps);
break;
}
case at::ScalarType::Half: {
fa->allreduce_fuse_norm<half>(stream, reinterpret_cast<half*>(reg_buffer),
reinterpret_cast<half*>(out.data_ptr()),out.numel(),
token_num, hidden_size, reinterpret_cast<half*>(residual.data_ptr()),
reinterpret_cast<half*>(rms_weight.data_ptr()), (float)eps);
break;
}
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce_fuse_norm<nv_bfloat16>(stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()),out.numel(),
token_num, hidden_size, reinterpret_cast<nv_bfloat16*>(residual.data_ptr()),
reinterpret_cast<nv_bfloat16*>(rms_weight.data_ptr()), (float)eps);
break;
}
// #endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}
template<typename scalar_in_t, bool update_input>
void allreduce_fuse_norm_quant_dispath(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int hidden_size,torch::Tensor& rms_weight, double eps,
torch::Tensor& scales, torch::Tensor& norm_out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes,
std::optional<at::Tensor> residual) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(_is_weak_contiguous(inp));
int token_num = inp.numel() / hidden_size;
auto input_size = inp.numel() * inp.element_size();
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
cudaMemcpyDeviceToDevice, stream));
} else {
reg_buffer = inp.data_ptr();
}
auto wt_ptr = reinterpret_cast<std::uintptr_t>(rms_weight.data_ptr());
if (wt_ptr % 16 != 0) {
throw std::runtime_error(
"custom allreduce currently requires wt_ptr % 16 "
"of " +
std::to_string(wt_ptr % 16));
}
if (fa->fully_connected_) {
if (residual.has_value()) {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "fa->allreduce_fuse_norm_quant", [&] {
fa->allreduce_fuse_norm_quant<scalar_in_t, scalar_t, true, update_input>
(stream, reinterpret_cast<scalar_in_t*>(reg_buffer), out.data_ptr<scalar_t>(),
out.numel(), token_num, hidden_size, residual->data_ptr<scalar_in_t>(),
rms_weight.data_ptr<scalar_in_t>(),
norm_out.data_ptr<scalar_in_t>(),
eps, scales.data_ptr<float>());
});
} else {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "fa->allreduce_fuse_norm_quant", [&] {
fa->allreduce_fuse_norm_quant<scalar_in_t, scalar_t, false, update_input>
(stream, reinterpret_cast<scalar_in_t*>(reg_buffer), out.data_ptr<scalar_t>(),
out.numel(), token_num, hidden_size, nullptr,
rms_weight.data_ptr<scalar_in_t>(),
norm_out.data_ptr<scalar_in_t>(),
eps, scales.data_ptr<float>());
});
}
} else {
throw std::runtime_error(
"custom allreduce only supports fully_connected");
}
}
void all_reduce_fuse_norm_quant(fptr_t fa, torch::Tensor& inp, torch::Tensor& out,
int64_t hidden_size,torch::Tensor& rms_weight, double eps,
torch::Tensor& scales, torch::Tensor& norm_out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes,
std::optional<at::Tensor> residual, bool update_input) {
static c10::ScalarType kFp8Type = c10::ScalarType::Float8_e4m3fn;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(scales.dtype() == torch::kFloat32);
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(out.is_contiguous() && inp.is_contiguous());
VLLM_DISPATCH_FLOATING_TYPES(
inp.scalar_type(), "allreduce_fuse_norm_quant_dispath", [&] {
if (update_input)
allreduce_fuse_norm_quant_dispath<scalar_t, true>(
fa, inp, out, hidden_size, rms_weight, eps, scales, norm_out,
reg_buffer, reg_buffer_sz_bytes, residual);
else
allreduce_fuse_norm_quant_dispath<scalar_t, false>(
fa, inp, out, hidden_size, rms_weight, eps, scales, norm_out,
reg_buffer, reg_buffer_sz_bytes, residual);
});
}
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
......
This diff is collapsed.
......@@ -18,6 +18,7 @@
#include <limits>
#include <vector>
#include <random>
#include "cuda_profiler_api.h"
#include "custom_all_reduce.cuh"
......@@ -117,16 +118,113 @@ __global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
ground_truth[idx] = sum;
}
}
/*************************************************/
template <typename T,int reducesize=64>
__inline__ __device__ T WarpReduceSum_NEW(T val) {
#pragma unroll
for (int offset = reducesize/2; offset > 0; offset >>= 1) {
val += __shfl_down(val, offset);
}
return val;
}
template <typename T,int block_size=512>
__inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
constexpr int share_size=block_size/64;
val = WarpReduceSum_NEW<T>(val);
if constexpr(block_size==64)
{
return val;
}
else{
const int lid = threadIdx.x % 64;
const int wid = threadIdx.x / 64;
if (lid == 0&&wid<share_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0&&lid<share_size) {
val = WarpReduceSum_NEW<T,share_size>(shared[lid]);
}
return val;
}
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_add_rms_kernel_opt(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/64;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
using LoadT = typename vllm::packed_t<scalar_t>::P;
scalar_t intput_vec[Vec];
scalar_t residual_vec[Vec];
T_ACC trstd;
int64_t idx = i * tcol + j;
idx*=Vec;
if (j < tcol) {
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
*(LoadT*)residual_vec = *(LoadT*)(residual+idx);
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
residual_vec[ii]+=intput_vec[ii];
val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]);
}
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=rsqrtf(val/cols + eps);
__syncthreads();
trstd=s_rstd;
if (j < tcol) {
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(residual_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
*(LoadT*)(input+idx)=*(LoadT*)intput_vec;
}
}
template <typename scalar_t>
void fused_add_rms_norm_choose(cudaStream_t stream, scalar_t* self_data, scalar_t* other_data,
scalar_t*weight_data, double eps, int hidden_size, int num_tokens) {
if (hidden_size<=1024){
fused_add_rms_kernel_opt<scalar_t,float,8,128><<<num_tokens, 128, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=2048){
fused_add_rms_kernel_opt<scalar_t,float,8,256><<<num_tokens, 256, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
if(num_tokens>1200){
fused_add_rms_kernel_opt<scalar_t,float,8,512><<<num_tokens, 512, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_opt<scalar_t,float,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
}
else if(hidden_size<=8192){
fused_add_rms_kernel_opt<scalar_t,float,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_opt<scalar_t,float,16,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
}
/*****************************************************************/
template <typename T>
void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
int data_size, bool performance_test) {
T* result;
int data_size, bool performance_test, int hidden_dim) {
T* result_ori, *result_fuse;
cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));
CUDACHECK(cudaMalloc(&result_ori, data_size * sizeof(T)));
CUDACHECK(cudaMemset(result_ori, 0, data_size * sizeof(T)));
CUDACHECK(cudaMalloc(&result_fuse, data_size * sizeof(T)));
CUDACHECK(cudaMemset(result_fuse, 0, data_size * sizeof(T)));
cudaIpcMemHandle_t self_data_handle;
cudaIpcMemHandle_t data_handles[8];
vllm::Signal* buffer;
......@@ -176,7 +274,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration
{
void* data[8];
void* data[8]; //gpu数据部分
for (int i = 0; i < nRanks; i++) {
data[i] =
((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T);
......@@ -196,7 +294,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
cudaEvent_t start, stop;
CUDACHECK(cudaEventCreate(&start));
CUDACHECK(cudaEventCreate(&stop));
/*******************************/
int token_num = data_size / hidden_dim;
T* residual_h, *residual_d, *weight_h, *weight_d;
residual_h = (T*)malloc(data_size * sizeof(T));
std::random_device rd; // 用于获取随机数种子
std::mt19937 gen(7);
std::uniform_real_distribution<float> dis(-3.0f, 3.0f);
for (int i = 0; i < data_size; ++i)
residual_h[i] = static_cast<T>(dis(gen));
for (int i = 0; i < hidden_dim; ++i)
weight_h[i] = static_cast<T>(dis(gen));
cudaMalloc((void**)&residual_d, sizeof(T)*data_size);
cudaMalloc((void**)&weight_d, sizeof(T)*hidden_dim);
cudaMemcpyAsync(residual_d, residual_h, sizeof(T)*data_size, cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(weight_d, weight_h, sizeof(T)*hidden_dim, cudaMemcpyHostToDevice, stream);
float eps = 1.0f;
/*******************************/
ncclDataType_t ncclDtype;
if (std::is_same<T, half>::value) {
ncclDtype = ncclFloat16;
......@@ -211,16 +328,16 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
if (performance_test) {
dummy_kernel<<<1, 1, 0, stream>>>();
constexpr int warmup_iters = 5;
constexpr int num_iters = 100;
constexpr int num_iters = 10;
// warmup
for (int i = 0; i < warmup_iters; i++) {
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
comm, stream));
fa.allreduce<T>(stream, self_data, result_ori, data_size, threads, block_limit);
fused_add_rms_norm_choose<T>(stream, result_ori, residual_d, weight_d, 1.0, hidden_dim, token_num);
}
CUDACHECK(cudaEventRecord(start, stream));
for (int i = 0; i < num_iters; i++) {
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
comm, stream));
fa.allreduce<T>(stream, self_data, result_ori, data_size, threads, block_limit);
fused_add_rms_norm_choose<T>(stream, result_ori, residual_d, weight_d, 1.0, hidden_dim, token_num);
}
CUDACHECK(cudaEventRecord(stop, stream));
CUDACHECK(cudaStreamSynchronize(stream));
......@@ -230,13 +347,15 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
dummy_kernel<<<1, 1, 0, stream>>>();
// warm up
for (int i = 0; i < warmup_iters; i++) {
fa.allreduce<T>(stream, self_data, result, data_size, threads,
block_limit);
fa.allreduce_fuse_norm<T>(stream, self_data, result_fuse, data_size, token_num,
hidden_dim, residual_d, weight_d, eps,
threads, block_limit);
}
CUDACHECK(cudaEventRecord(start, stream));
for (int i = 0; i < num_iters; i++) {
fa.allreduce<T>(stream, self_data, result, data_size, threads,
block_limit);
fa.allreduce_fuse_norm<T>(stream, self_data, result_fuse, data_size, token_num,
hidden_dim, residual_d, weight_d, eps,
threads, block_limit);
}
CUDACHECK(cudaEventRecord(stop, stream));
CUDACHECK(cudaStreamSynchronize(stream));
......@@ -245,7 +364,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
cudaEventElapsedTime(&duration_ms, start, stop);
if (myRank == 0)
printf(
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, allreduse_fuse_norm time:%.2fus, allreduce+norm "
"time:%.2fus\n",
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
......@@ -255,8 +374,9 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
ncclSum, comm, stream));
fused_add_rms_norm_choose<T>(stream, self_data, residual_d, weight_d, 1.0, hidden_dim, token_num);
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
convert_data<T><<<108, 1024, 0, stream>>>(result_ori, result_fuse, nccl_result,
my_result, data_size);
CUDACHECK(cudaStreamSynchronize(stream));
......@@ -279,13 +399,13 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
<< " me: " << my_diffs / data_size << std::endl;
} else {
for (int i = 0; i < 100; i++) {
fa.allreduce<T>(stream, self_data, result, data_size, threads,
fa.allreduce<T>(stream, self_data, result_ori, data_size, threads,
block_limit);
CUDACHECK(cudaStreamSynchronize(stream));
NCCLCHECK(ncclAllReduce(self_data, self_data_copy, data_size, ncclDtype,
ncclSum, comm, stream));
convert_data<T><<<108, 1024, 0, stream>>>(
self_data_copy, result, nccl_result, my_result, data_size);
self_data_copy, result_ori, nccl_result, my_result, data_size);
CUDACHECK(cudaStreamSynchronize(stream));
for (unsigned long j = 0; j < data_size; j++) {
......@@ -312,7 +432,8 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
// << " me: " << my_diffs / data_size << std::endl;
}
CUDACHECK(cudaFree(result));
CUDACHECK(cudaFree(result_ori));
CUDACHECK(cudaFree(result_fuse));
CUDACHECK(cudaFree(self_data_copy));
CUDACHECK(cudaFree(rank_data));
CUDACHECK(cudaFree(buffer));
......@@ -351,9 +472,7 @@ int main(int argc, char** argv) {
const int block_limit = 36;
#endif
// Scan through different sizes to test performance.
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
}
run<half>(myRank, nRanks, comm, 512, 36, 7168 * 80, performance_test, 7168);
cudaProfilerStop();
MPICHECK(MPI_Finalize());
......
......@@ -487,6 +487,17 @@ fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
bool fully_connected);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
void all_reduce_fuse_norm(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t hidden_size, torch::Tensor& residual, torch::Tensor& rms_weight,
double eps, fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
void all_reduce_fuse_norm_quant(fptr_t fa, torch::Tensor& inp, torch::Tensor& out,
int64_t hidden_size,torch::Tensor& rms_weight, double eps,
torch::Tensor& scales, torch::Tensor& norm_out,
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes,
std::optional<at::Tensor> residual, bool update_input);
void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
......
......@@ -933,6 +933,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
custom_ar.def(
"all_reduce_fuse_norm(int fa, Tensor inp, Tensor! out, int hidden_size, "
"Tensor residual, Tensor rms_weight, float eps, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
custom_ar.impl("all_reduce_fuse_norm", torch::kCUDA, &all_reduce_fuse_norm);
custom_ar.def(
"all_reduce_fuse_norm_quant(int fa, Tensor inp, Tensor! out, int hidden_size, "
"Tensor rms_weight, float eps, Tensor! scales, Tensor! norm_out, int reg_buffer, "
"int reg_buffer_sz_bytes, Tensor? residual, bool update_input) -> ()");
custom_ar.impl("all_reduce_fuse_norm_quant", torch::kCUDA, &all_reduce_fuse_norm_quant);
custom_ar.def("dispose", &dispose);
custom_ar.def("meta_size", &meta_size);
......
......@@ -2212,7 +2212,17 @@ def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
reg_buffer_sz_bytes: int) -> None:
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
reg_buffer_sz_bytes)
def all_reduce_fuse_norm(fa: int, inp: torch.Tensor, out:torch.Tensor, hidden_size:int,
residual:torch.Tensor, rms_weight:torch.Tensor, eps:float,
reg_buffer: int, reg_buffer_sz_bytes: int) -> None:
torch.ops._C_custom_ar.all_reduce_fuse_norm(fa, inp, out, hidden_size, residual,
rms_weight, eps, reg_buffer, reg_buffer_sz_bytes)
def all_reduce_fuse_norm_quant(fa: int, inp: torch.Tensor, out:torch.Tensor, hidden_size:int,
rms_weight: torch.Tensor, eps: float, scales: torch.Tensor, norm_out:torch.Tensor,
reg_buffer: int, reg_buffer_sz_bytes: int, residual: torch.Tensor, update_input: bool) -> None:
torch.ops._C_custom_ar.all_reduce_fuse_norm_quant(fa, inp, out, hidden_size, rms_weight, eps, scales, norm_out,
reg_buffer, reg_buffer_sz_bytes, residual, update_input)
def dispose(fa: int) -> None:
torch.ops._C_custom_ar.dispose(fa)
......
......@@ -536,4 +536,4 @@ direct_register_custom_op(
mutates_args=["output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional, Union
from typing import Any, Optional, Union, Tuple
import torch
import torch.distributed
from .parallel_state import get_tp_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)
def tensor_model_parallel_all_reduce_crp_m32(input_: torch.Tensor,
pa_rms_weight: torch.Tensor,
pa_residual: torch.Tensor,
pa_rms_eps: float,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""All-reduce the input tensor across model parallel group."""
# allreduce fused rms and quant
return get_tp_group().all_reduce_crq_m32(input_=input_,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Optional, Tuple
import torch
from torch.distributed import ProcessGroup
......@@ -14,6 +14,7 @@ from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
from lmslim.quantize.quant_ops import lm_faster_rmsquant
class CudaCommunicator(DeviceCommunicatorBase):
......@@ -116,6 +117,37 @@ class CudaCommunicator(DeviceCommunicatorBase):
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
def all_reduce_rms_quant_m32(self, input_,
pa_rms_weight: torch.Tensor,
pa_residual: torch.Tensor,
pa_rms_eps: float,
pa_quant_dtype: torch.dtype,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, hidden_dim = input_.shape
ca_comm = self.ca_comm
assert ca_comm is not None and not ca_comm.disabled
assert envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and \
pa_rms_weight is not None and pa_residual is not None
if batch_size <= 16:
xq, xs, norm_out = ca_comm.custom_all_reduce_fuse_norm_quant(inp=input_,
rms_weight=pa_rms_weight,
residual=pa_residual,
eps=pa_rms_eps,
quant_type=pa_quant_dtype,
update_input=True)
input_ = norm_out
else:
input_ = self.all_reduce(input_)
xq, xs = lm_faster_rmsquant(input_,
rms_weight=pa_rms_weight,
residual=pa_residual,
epsilon=pa_rms_eps,
quant_dtype=pa_quant_dtype,
update_input=True)
assert input_ is not None
assert xq is not None and xs is not None
return input_, pa_residual, xq, xs
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
......
......@@ -275,6 +275,91 @@ class CustomAllreduce:
# latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, registered=False)
def allreduce_fuse_norm(self,
inp: torch.Tensor,
hidden_size: int,
residual: torch.Tensor,
rms_weight: torch.Tensor,
eps: float,
*,
out: torch.Tensor = None,
registered: bool = False):
if out is None:
out = torch.empty_like(inp)
if registered:
ops.all_reduce_fuse_norm(self._ptr, inp, out,
hidden_size, residual, rms_weight, eps, 0, 0)
else:
ops.all_reduce_fuse_norm(self._ptr, inp, out,
hidden_size, residual, rms_weight, eps,
self.buffer_ptrs[self.rank], self.max_size)
return out
def custom_all_reduce_fuse_norm(self,
input: torch.Tensor,
hidden_size: int,
residual: torch.Tensor,
rms_weight: torch.Tensor,
eps: float) -> Optional[torch.Tensor]:
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.allreduce_fuse_norm(input, hidden_size,
residual, rms_weight, eps, registered=False)
else:
return torch.empty_like(input)
else:
return self.allreduce_fuse_norm(input, hidden_size,
residual, rms_weight, eps, registered=False)
def allreduce_fuse_norm_quant(self,
inp: torch.Tensor,
hidden_size: int,
rms_weight,
eps,
quant_dtype,
residual,
update_input: bool = True,
registered: bool = False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
xq = torch.empty_like(inp, dtype=quant_dtype)
norm_out = torch.empty_like(inp)
scales = torch.empty((inp.numel() // inp.shape[-1], 1),
device=inp.device,
dtype=torch.float32)
if registered:
ops.all_reduce_fuse_norm_quant(self._ptr, inp, xq,
hidden_size, rms_weight, eps, scales, norm_out, 0, 0, residual, update_input)
else:
ops.all_reduce_fuse_norm_quant(self._ptr, inp, xq,
hidden_size, rms_weight, eps, scales, norm_out,
self.buffer_ptrs[self.rank], self.max_size, residual, update_input)
return xq, scales, norm_out
def custom_all_reduce_fuse_norm_quant(self,
inp: torch.Tensor,
rms_weight: torch.Tensor,
eps,
quant_type,
residual: Optional[torch.Tensor] = None,
update_input = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_size = inp.shape[-1]
if self.disabled or not self.should_custom_ar(inp):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.allreduce_fuse_norm_quant(inp, hidden_size, rms_weight,
eps, quant_type, residual, update_input = update_input, registered=False)
else:
return torch.empty_like(inp, dtype=quant_type), \
torch.empty((inp.numel() // inp.shape[-1], 1), dtype=torch.float32, device=inp.device), \
torch.empty_like(inp)
else:
return self.allreduce_fuse_norm_quant(inp, hidden_size, rms_weight,
eps, quant_type, residual, update_input = update_input, registered=False)
def close(self):
if not self.disabled and self._ptr:
if ops is not None:
......
......@@ -30,7 +30,7 @@ from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Tuple, Union
from unittest.mock import patch
import torch
......@@ -114,6 +114,37 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)
def all_reduce_rms_quant(input_: torch.Tensor, group_name: str,
pa_rms_weight: Optional[torch.Tensor] = None,
pa_residual: Optional[torch.Tensor] = None,
pa_rms_eps: Optional[float] = 1e-6,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place_m32(input_,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
def all_reduce_rms_quant_fake(input_: torch.Tensor, group_name: str,
pa_rms_weight: Optional[torch.Tensor] = None,
pa_residual: Optional[torch.Tensor] = None,
pa_rms_eps: Optional[float] = 1e-6,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
xq = torch.zeros_like(input_, dtype=pa_quant_dtype)
xs = torch.ones((input_.numel() // input_.shape[-1], 1),
device=input_.device,
dtype=torch.float32)
return input_, pa_residual, xq, xs
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
......@@ -156,6 +187,14 @@ if supports_custom_op():
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="all_reduce_rms_quant",
op_func=all_reduce_rms_quant,
mutates_args=["input_", "pa_residual"],
fake_impl=all_reduce_rms_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="reduce_scatter",
op_func=reduce_scatter,
......@@ -358,9 +397,44 @@ class GroupCoordinator:
else:
return self._all_reduce_out_place(input_)
def all_reduce_crq_m32(self, input_: torch.Tensor,
pa_rms_weight: torch.Tensor,
pa_residual: torch.Tensor,
pa_rms_eps: float,
pa_quant_dtype: torch.dtype,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert self.world_size > 1
assert envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None
return torch.ops.vllm.all_reduce_rms_quant(input_,
group_name=self.unique_name,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
return self.device_communicator.all_reduce(input_)
def _all_reduce_out_place_m32(self, input_: torch.Tensor,
pa_rms_weight: torch.Tensor,
pa_residual: torch.Tensor,
pa_rms_eps: float,
pa_quant_dtype: torch.dtype,
update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None \
and pa_residual is not None
input_, pa_residual, xq, xs = self.device_communicator.all_reduce_rms_quant_m32(input_,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
return input_, pa_residual, xq, xs
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
......@@ -735,6 +809,8 @@ class GroupCoordinator:
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
if envs.VLLM_USE_PP_SYNC:
torch.cuda.synchronize()
if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather( # type: ignore
......
......@@ -284,7 +284,13 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cached version.
"""
return copy.deepcopy(_compute_kwargs(cls))
class EnvironmentConfigError(Exception):
pass
def check_incompatible_config(env1: bool, env2: bool):
if env1 is True and env2 is True:
_s = "USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and USE_FUSED_RMS_QUANT must not be enabled simultaneously!\n\n"
raise EnvironmentConfigError(_s)
@dataclass
class EngineArgs:
......@@ -1230,7 +1236,8 @@ class EngineArgs:
num_lookahead_slots = num_lookahead_slots \
if speculative_config is None \
else speculative_config.num_lookahead_slots
check_incompatible_config(envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT, envs.USE_FUSED_RMS_QUANT)
scheduler_config = SchedulerConfig(
runner_type=model_config.runner_type,
max_num_batched_tokens=self.max_num_batched_tokens,
......
......@@ -178,6 +178,9 @@ if TYPE_CHECKING:
VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False
VLLM_USE_PD_SPLIT: bool = False
VLLM_USE_PP_SYNC: bool = False
VLLM_USE_LIGHTOP_FILL_MOE_ALIGN: bool = False
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1156,6 +1159,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
("true", "1")),
# vLLM will sync to avoid pp vmfault
"VLLM_USE_PP_SYNC":
lambda: (os.environ.get("VLLM_USE_PP_SYNC", "False").lower() in
("true", "1")),
# vLLM will use lightop to fuse fill and moe align
"VLLM_USE_LIGHTOP_FILL_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN", "False").lower() in
("true", "1")),
# vllm will use custom-allreduce rmsquant fused op
"USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT', '0').lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -216,7 +216,9 @@ def moe_align_block_size(
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
if not envs.VLLM_USE_LIGHTOP_FILL_MOE_ALIGN:
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
......
......@@ -3,7 +3,7 @@
import itertools
from abc import abstractmethod
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Optional, Union, Tuple
import vllm.envs as envs
import torch
import torch.nn as nn
......@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
tensor_model_parallel_all_reduce,
tensor_model_parallel_all_reduce_crp_m32)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
......@@ -677,7 +678,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self, input_,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
update_hd: Optional[bool] = True,
xqxs: Optional[tuple] = None
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
......@@ -706,7 +708,22 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if not self.return_bias:
return output
return output, new_residual, output_bias
else: # not USE_FUSED_RMS_QUANT
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=xqxs)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
else:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
......@@ -1495,46 +1512,94 @@ class RowParallelLinear(LinearBase):
def forward(
self, input_,
use_fused_silu_mul_quant: Optional[bool] = False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
use_fused_silu_mul_quant: Optional[bool] = False,
pa_rms_weight: Optional[torch.Tensor] = None,
pa_residual: Optional[torch.Tensor] = None,
pa_rms_eps: Optional[float] = 1e-6,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True
) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[Parameter]]
]:
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if use_fused_silu_mul_quant:
xq, xs = lm_fuse_silu_mul_quant(input_parallel)
silu_quant_args = [xq, xs]
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_,
silu_quant_args=silu_quant_args)
else:
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel)
if self.reduce_results and self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel)
packages_ = tensor_model_parallel_all_reduce_crp_m32(output_parallel,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
pa_rms_eps=pa_rms_eps,
pa_quant_dtype=pa_quant_dtype,
update_input=update_input)
hs, resi, xq, xs = packages_
output = hs
else:
output = tensor_model_parallel_all_reduce(output_parallel)
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, resi, xq, xs, output_bias
else:
output = output_parallel
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
output_bias = self.bias if self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if use_fused_silu_mul_quant:
xq, xs = lm_fuse_silu_mul_quant(input_parallel)
silu_quant_args = [xq, xs]
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_,
silu_quant_args=silu_quant_args)
else:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
if not self.return_bias:
return output
return output, output_bias
def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
......
......@@ -162,7 +162,11 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
assert len(silu_quant_args) == 2
x_q, x_scale = silu_quant_args
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
else:
x_q, x_scale = per_token_quant_int8(x)
......@@ -178,9 +182,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
if m<=16:
m_=m
elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数
m_ = ((m + 3) // 4) * 4 #取值到最近的4的倍数
elif m<=160:
m_=(m + 7) & -8
m_ = ((m + 7) // 8) * 8
elif m<200: #256
m_=160
......
......@@ -251,6 +251,8 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
......@@ -264,6 +266,8 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
# awq相关配置
try:
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment