Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3e6729e0
Commit
3e6729e0
authored
Nov 18, 2025
by
wujl5
Browse files
deepseekv2-w4a8支持custom-rms-quant融合
parent
813f81fb
Changes
15
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1518 additions
and
205 deletions
+1518
-205
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+138
-0
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+596
-11
csrc/custom_all_reduce_test.cu
csrc/custom_all_reduce_test.cu
+142
-23
csrc/ops.h
csrc/ops.h
+11
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+11
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+11
-1
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+15
-2
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+33
-1
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+85
-0
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+75
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+8
-1
vllm/envs.py
vllm/envs.py
+6
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+101
-36
vllm/model_executor/layers/quantization/slimquant_w4a8.py
vllm/model_executor/layers/quantization/slimquant_w4a8.py
+6
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+280
-126
No files found.
csrc/custom_all_reduce.cu
View file @
3e6729e0
...
@@ -59,6 +59,144 @@ bool _is_weak_contiguous(torch::Tensor& t) {
...
@@ -59,6 +59,144 @@ bool _is_weak_contiguous(torch::Tensor& t) {
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
* copied into _reg_buffer.
*/
*/
void
all_reduce_fuse_norm
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int64_t
hidden_size
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
rms_weight
,
double
eps
,
fptr_t
_reg_buffer
,
int64_t
reg_buffer_sz_bytes
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK_EQ
(
inp
.
scalar_type
(),
out
.
scalar_type
());
TORCH_CHECK_EQ
(
inp
.
numel
(),
out
.
numel
());
TORCH_CHECK
(
_is_weak_contiguous
(
out
));
TORCH_CHECK
(
_is_weak_contiguous
(
inp
));
TORCH_CHECK
(
_is_weak_contiguous
(
residual
));
TORCH_CHECK
(
_is_weak_contiguous
(
rms_weight
));
int
token_num
=
inp
.
numel
()
/
hidden_size
;
auto
input_size
=
inp
.
numel
()
*
inp
.
element_size
();
auto
reg_buffer
=
reinterpret_cast
<
void
*>
(
_reg_buffer
);
if
(
reg_buffer
)
{
TORCH_CHECK_LE
(
input_size
,
reg_buffer_sz_bytes
);
AT_CUDA_CHECK
(
cudaMemcpyAsync
(
reg_buffer
,
inp
.
data_ptr
(),
input_size
,
cudaMemcpyDeviceToDevice
,
stream
));
}
else
{
reg_buffer
=
inp
.
data_ptr
();
}
switch
(
out
.
scalar_type
())
{
case
at
::
ScalarType
::
Float
:
{
fa
->
allreduce_fuse_norm
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
reg_buffer
),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
out
.
numel
(),
token_num
,
hidden_size
,
reinterpret_cast
<
float
*>
(
residual
.
data_ptr
()),
reinterpret_cast
<
float
*>
(
rms_weight
.
data_ptr
()),
(
float
)
eps
);
break
;
}
case
at
::
ScalarType
::
Half
:
{
fa
->
allreduce_fuse_norm
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
reg_buffer
),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
(),
token_num
,
hidden_size
,
reinterpret_cast
<
half
*>
(
residual
.
data_ptr
()),
reinterpret_cast
<
half
*>
(
rms_weight
.
data_ptr
()),
(
float
)
eps
);
break
;
}
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
fa
->
allreduce_fuse_norm
<
nv_bfloat16
>
(
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
reg_buffer
),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
(),
token_num
,
hidden_size
,
reinterpret_cast
<
nv_bfloat16
*>
(
residual
.
data_ptr
()),
reinterpret_cast
<
nv_bfloat16
*>
(
rms_weight
.
data_ptr
()),
(
float
)
eps
);
break
;
}
// #endif
default:
throw
std
::
runtime_error
(
"custom allreduce only supports float32, float16 and bfloat16"
);
}
}
template
<
typename
scalar_in_t
,
bool
update_input
>
void
allreduce_fuse_norm_quant_dispath
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int
hidden_size
,
torch
::
Tensor
&
rms_weight
,
double
eps
,
torch
::
Tensor
&
scales
,
torch
::
Tensor
&
norm_out
,
fptr_t
_reg_buffer
,
int64_t
reg_buffer_sz_bytes
,
std
::
optional
<
at
::
Tensor
>
residual
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
_is_weak_contiguous
(
inp
));
int
token_num
=
inp
.
numel
()
/
hidden_size
;
auto
input_size
=
inp
.
numel
()
*
inp
.
element_size
();
auto
reg_buffer
=
reinterpret_cast
<
void
*>
(
_reg_buffer
);
if
(
reg_buffer
)
{
TORCH_CHECK_LE
(
input_size
,
reg_buffer_sz_bytes
);
AT_CUDA_CHECK
(
cudaMemcpyAsync
(
reg_buffer
,
inp
.
data_ptr
(),
input_size
,
cudaMemcpyDeviceToDevice
,
stream
));
}
else
{
reg_buffer
=
inp
.
data_ptr
();
}
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
rms_weight
.
data_ptr
());
if
(
wt_ptr
%
16
!=
0
)
{
throw
std
::
runtime_error
(
"custom allreduce currently requires wt_ptr % 16 "
"of "
+
std
::
to_string
(
wt_ptr
%
16
));
}
if
(
fa
->
fully_connected_
)
{
if
(
residual
.
has_value
())
{
VLLM_DISPATCH_QUANT_TYPES
(
out
.
scalar_type
(),
"fa->allreduce_fuse_norm_quant"
,
[
&
]
{
fa
->
allreduce_fuse_norm_quant
<
scalar_in_t
,
scalar_t
,
true
,
update_input
>
(
stream
,
reinterpret_cast
<
scalar_in_t
*>
(
reg_buffer
),
out
.
data_ptr
<
scalar_t
>
(),
out
.
numel
(),
token_num
,
hidden_size
,
residual
->
data_ptr
<
scalar_in_t
>
(),
rms_weight
.
data_ptr
<
scalar_in_t
>
(),
norm_out
.
data_ptr
<
scalar_in_t
>
(),
eps
,
scales
.
data_ptr
<
float
>
());
});
}
else
{
VLLM_DISPATCH_QUANT_TYPES
(
out
.
scalar_type
(),
"fa->allreduce_fuse_norm_quant"
,
[
&
]
{
fa
->
allreduce_fuse_norm_quant
<
scalar_in_t
,
scalar_t
,
false
,
update_input
>
(
stream
,
reinterpret_cast
<
scalar_in_t
*>
(
reg_buffer
),
out
.
data_ptr
<
scalar_t
>
(),
out
.
numel
(),
token_num
,
hidden_size
,
nullptr
,
rms_weight
.
data_ptr
<
scalar_in_t
>
(),
norm_out
.
data_ptr
<
scalar_in_t
>
(),
eps
,
scales
.
data_ptr
<
float
>
());
});
}
}
else
{
throw
std
::
runtime_error
(
"custom allreduce only supports fully_connected"
);
}
}
void
all_reduce_fuse_norm_quant
(
fptr_t
fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int64_t
hidden_size
,
torch
::
Tensor
&
rms_weight
,
double
eps
,
torch
::
Tensor
&
scales
,
torch
::
Tensor
&
norm_out
,
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
,
std
::
optional
<
at
::
Tensor
>
residual
,
bool
update_input
)
{
static
c10
::
ScalarType
kFp8Type
=
c10
::
ScalarType
::
Float8_e4m3fn
;
TORCH_CHECK
(
out
.
dtype
()
==
kFp8Type
||
out
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK_EQ
(
inp
.
numel
(),
out
.
numel
());
TORCH_CHECK
(
out
.
is_contiguous
()
&&
inp
.
is_contiguous
());
VLLM_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"allreduce_fuse_norm_quant_dispath"
,
[
&
]
{
if
(
update_input
)
allreduce_fuse_norm_quant_dispath
<
scalar_t
,
true
>
(
fa
,
inp
,
out
,
hidden_size
,
rms_weight
,
eps
,
scales
,
norm_out
,
reg_buffer
,
reg_buffer_sz_bytes
,
residual
);
else
allreduce_fuse_norm_quant_dispath
<
scalar_t
,
false
>
(
fa
,
inp
,
out
,
hidden_size
,
rms_weight
,
eps
,
scales
,
norm_out
,
reg_buffer
,
reg_buffer_sz_bytes
,
residual
);
});
}
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
fptr_t
_reg_buffer
,
int64_t
reg_buffer_sz_bytes
)
{
fptr_t
_reg_buffer
,
int64_t
reg_buffer_sz_bytes
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
...
...
csrc/custom_all_reduce.cuh
View file @
3e6729e0
This diff is collapsed.
Click to expand it.
csrc/custom_all_reduce_test.cu
View file @
3e6729e0
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include <limits>
#include <limits>
#include <vector>
#include <vector>
#include <random>
#include "cuda_profiler_api.h"
#include "cuda_profiler_api.h"
#include "custom_all_reduce.cuh"
#include "custom_all_reduce.cuh"
...
@@ -117,16 +118,113 @@ __global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
...
@@ -117,16 +118,113 @@ __global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
ground_truth
[
idx
]
=
sum
;
ground_truth
[
idx
]
=
sum
;
}
}
}
}
/*************************************************/
template
<
typename
T
,
int
reducesize
=
64
>
__inline__
__device__
T
WarpReduceSum_NEW
(
T
val
)
{
#pragma unroll
for
(
int
offset
=
reducesize
/
2
;
offset
>
0
;
offset
>>=
1
)
{
val
+=
__shfl_down
(
val
,
offset
);
}
return
val
;
}
template
<
typename
T
,
int
block_size
=
512
>
__inline__
__device__
T
BlockReduceSum_NEW
(
T
val
,
T
*
shared
)
{
constexpr
int
share_size
=
block_size
/
64
;
val
=
WarpReduceSum_NEW
<
T
>
(
val
);
if
constexpr
(
block_size
==
64
)
{
return
val
;
}
else
{
const
int
lid
=
threadIdx
.
x
%
64
;
const
int
wid
=
threadIdx
.
x
/
64
;
if
(
lid
==
0
&&
wid
<
share_size
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
&&
lid
<
share_size
)
{
val
=
WarpReduceSum_NEW
<
T
,
share_size
>
(
shared
[
lid
]);
}
return
val
;
}
}
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
__global__
void
fused_add_rms_kernel_opt
(
scalar_t
*
input
,
scalar_t
*
residual
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
{
constexpr
int
share_size
=
block_size
/
64
;
__shared__
T_ACC
val_shared
[
share_size
];
__shared__
T_ACC
s_rstd
;
T_ACC
val
=
0
;
int
i
=
blockIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
tcol
=
cols
/
Vec
;
using
LoadT
=
typename
vllm
::
packed_t
<
scalar_t
>::
P
;
scalar_t
intput_vec
[
Vec
];
scalar_t
residual_vec
[
Vec
];
T_ACC
trstd
;
int64_t
idx
=
i
*
tcol
+
j
;
idx
*=
Vec
;
if
(
j
<
tcol
)
{
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
*
(
LoadT
*
)
residual_vec
=
*
(
LoadT
*
)(
residual
+
idx
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
residual_vec
[
ii
]
+=
intput_vec
[
ii
];
val
+=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
residual_vec
[
ii
]);
}
}
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
if
(
j
==
0
)
s_rstd
=
rsqrtf
(
val
/
cols
+
eps
);
__syncthreads
();
trstd
=
s_rstd
;
if
(
j
<
tcol
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
int
jj
=
j
*
Vec
+
ii
;
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
}
*
(
LoadT
*
)(
residual
+
idx
)
=*
(
LoadT
*
)
residual_vec
;
*
(
LoadT
*
)(
input
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
}
template
<
typename
scalar_t
>
void
fused_add_rms_norm_choose
(
cudaStream_t
stream
,
scalar_t
*
self_data
,
scalar_t
*
other_data
,
scalar_t
*
weight_data
,
double
eps
,
int
hidden_size
,
int
num_tokens
)
{
if
(
hidden_size
<=
1024
){
fused_add_rms_kernel_opt
<
scalar_t
,
float
,
8
,
128
><<<
num_tokens
,
128
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
2048
){
fused_add_rms_kernel_opt
<
scalar_t
,
float
,
8
,
256
><<<
num_tokens
,
256
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
4096
){
if
(
num_tokens
>
1200
){
fused_add_rms_kernel_opt
<
scalar_t
,
float
,
8
,
512
><<<
num_tokens
,
512
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_add_rms_kernel_opt
<
scalar_t
,
float
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
if
(
hidden_size
<=
8192
){
fused_add_rms_kernel_opt
<
scalar_t
,
float
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_add_rms_kernel_opt
<
scalar_t
,
float
,
16
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
}
/*****************************************************************/
template
<
typename
T
>
template
<
typename
T
>
void
run
(
int
myRank
,
int
nRanks
,
ncclComm_t
&
comm
,
int
threads
,
int
block_limit
,
void
run
(
int
myRank
,
int
nRanks
,
ncclComm_t
&
comm
,
int
threads
,
int
block_limit
,
int
data_size
,
bool
performance_test
)
{
int
data_size
,
bool
performance_test
,
int
hidden_dim
)
{
T
*
result
;
T
*
result
_ori
,
*
result_fuse
;
cudaStream_t
stream
;
cudaStream_t
stream
;
CUDACHECK
(
cudaStreamCreateWithFlags
(
&
stream
,
cudaStreamNonBlocking
));
CUDACHECK
(
cudaStreamCreateWithFlags
(
&
stream
,
cudaStreamNonBlocking
));
CUDACHECK
(
cudaMalloc
(
&
result
,
data_size
*
sizeof
(
T
)));
CUDACHECK
(
cudaMalloc
(
&
result_ori
,
data_size
*
sizeof
(
T
)));
CUDACHECK
(
cudaMemset
(
result
,
0
,
data_size
*
sizeof
(
T
)));
CUDACHECK
(
cudaMemset
(
result_ori
,
0
,
data_size
*
sizeof
(
T
)));
CUDACHECK
(
cudaMalloc
(
&
result_fuse
,
data_size
*
sizeof
(
T
)));
CUDACHECK
(
cudaMemset
(
result_fuse
,
0
,
data_size
*
sizeof
(
T
)));
cudaIpcMemHandle_t
self_data_handle
;
cudaIpcMemHandle_t
self_data_handle
;
cudaIpcMemHandle_t
data_handles
[
8
];
cudaIpcMemHandle_t
data_handles
[
8
];
vllm
::
Signal
*
buffer
;
vllm
::
Signal
*
buffer
;
...
@@ -176,7 +274,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
...
@@ -176,7 +274,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
// hack buffer registration
// hack buffer registration
{
{
void
*
data
[
8
];
void
*
data
[
8
];
//gpu数据部分
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
data
[
i
]
=
data
[
i
]
=
((
char
*
)
ipc_ptrs
[
i
])
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
);
((
char
*
)
ipc_ptrs
[
i
])
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
);
...
@@ -196,7 +294,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
...
@@ -196,7 +294,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
cudaEvent_t
start
,
stop
;
cudaEvent_t
start
,
stop
;
CUDACHECK
(
cudaEventCreate
(
&
start
));
CUDACHECK
(
cudaEventCreate
(
&
start
));
CUDACHECK
(
cudaEventCreate
(
&
stop
));
CUDACHECK
(
cudaEventCreate
(
&
stop
));
/*******************************/
int
token_num
=
data_size
/
hidden_dim
;
T
*
residual_h
,
*
residual_d
,
*
weight_h
,
*
weight_d
;
residual_h
=
(
T
*
)
malloc
(
data_size
*
sizeof
(
T
));
std
::
random_device
rd
;
// 用于获取随机数种子
std
::
mt19937
gen
(
7
);
std
::
uniform_real_distribution
<
float
>
dis
(
-
3.0
f
,
3.0
f
);
for
(
int
i
=
0
;
i
<
data_size
;
++
i
)
residual_h
[
i
]
=
static_cast
<
T
>
(
dis
(
gen
));
for
(
int
i
=
0
;
i
<
hidden_dim
;
++
i
)
weight_h
[
i
]
=
static_cast
<
T
>
(
dis
(
gen
));
cudaMalloc
((
void
**
)
&
residual_d
,
sizeof
(
T
)
*
data_size
);
cudaMalloc
((
void
**
)
&
weight_d
,
sizeof
(
T
)
*
hidden_dim
);
cudaMemcpyAsync
(
residual_d
,
residual_h
,
sizeof
(
T
)
*
data_size
,
cudaMemcpyHostToDevice
,
stream
);
cudaMemcpyAsync
(
weight_d
,
weight_h
,
sizeof
(
T
)
*
hidden_dim
,
cudaMemcpyHostToDevice
,
stream
);
float
eps
=
1.0
f
;
/*******************************/
ncclDataType_t
ncclDtype
;
ncclDataType_t
ncclDtype
;
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
ncclDtype
=
ncclFloat16
;
ncclDtype
=
ncclFloat16
;
...
@@ -211,16 +328,16 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
...
@@ -211,16 +328,16 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
if
(
performance_test
)
{
if
(
performance_test
)
{
dummy_kernel
<<<
1
,
1
,
0
,
stream
>>>
();
dummy_kernel
<<<
1
,
1
,
0
,
stream
>>>
();
constexpr
int
warmup_iters
=
5
;
constexpr
int
warmup_iters
=
5
;
constexpr
int
num_iters
=
10
0
;
constexpr
int
num_iters
=
10
;
// warmup
// warmup
for
(
int
i
=
0
;
i
<
warmup_iters
;
i
++
)
{
for
(
int
i
=
0
;
i
<
warmup_iters
;
i
++
)
{
NCCLCHECK
(
ncclA
ll
R
educe
(
result
,
result
,
data_size
,
ncclDtype
,
ncclSum
,
fa
.
a
ll
r
educe
<
T
>
(
stream
,
self_data
,
result
_ori
,
data_size
,
threads
,
block_limit
);
comm
,
stream
)
);
fused_add_rms_norm_choose
<
T
>
(
stream
,
result_ori
,
residual_d
,
weight_d
,
1.0
,
hidden_dim
,
token_num
);
}
}
CUDACHECK
(
cudaEventRecord
(
start
,
stream
));
CUDACHECK
(
cudaEventRecord
(
start
,
stream
));
for
(
int
i
=
0
;
i
<
num_iters
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_iters
;
i
++
)
{
NCCLCHECK
(
ncclA
ll
R
educe
(
result
,
result
,
data_size
,
ncclDtype
,
ncclSum
,
fa
.
a
ll
r
educe
<
T
>
(
stream
,
self_data
,
result
_ori
,
data_size
,
threads
,
block_limit
);
comm
,
stream
)
);
fused_add_rms_norm_choose
<
T
>
(
stream
,
result_ori
,
residual_d
,
weight_d
,
1.0
,
hidden_dim
,
token_num
);
}
}
CUDACHECK
(
cudaEventRecord
(
stop
,
stream
));
CUDACHECK
(
cudaEventRecord
(
stop
,
stream
));
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
...
@@ -230,13 +347,15 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
...
@@ -230,13 +347,15 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
dummy_kernel
<<<
1
,
1
,
0
,
stream
>>>
();
dummy_kernel
<<<
1
,
1
,
0
,
stream
>>>
();
// warm up
// warm up
for
(
int
i
=
0
;
i
<
warmup_iters
;
i
++
)
{
for
(
int
i
=
0
;
i
<
warmup_iters
;
i
++
)
{
fa
.
allreduce
<
T
>
(
stream
,
self_data
,
result
,
data_size
,
threads
,
fa
.
allreduce_fuse_norm
<
T
>
(
stream
,
self_data
,
result_fuse
,
data_size
,
token_num
,
block_limit
);
hidden_dim
,
residual_d
,
weight_d
,
eps
,
threads
,
block_limit
);
}
}
CUDACHECK
(
cudaEventRecord
(
start
,
stream
));
CUDACHECK
(
cudaEventRecord
(
start
,
stream
));
for
(
int
i
=
0
;
i
<
num_iters
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_iters
;
i
++
)
{
fa
.
allreduce
<
T
>
(
stream
,
self_data
,
result
,
data_size
,
threads
,
fa
.
allreduce_fuse_norm
<
T
>
(
stream
,
self_data
,
result_fuse
,
data_size
,
token_num
,
block_limit
);
hidden_dim
,
residual_d
,
weight_d
,
eps
,
threads
,
block_limit
);
}
}
CUDACHECK
(
cudaEventRecord
(
stop
,
stream
));
CUDACHECK
(
cudaEventRecord
(
stop
,
stream
));
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
...
@@ -245,7 +364,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
...
@@ -245,7 +364,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
cudaEventElapsedTime
(
&
duration_ms
,
start
,
stop
);
cudaEventElapsedTime
(
&
duration_ms
,
start
,
stop
);
if
(
myRank
==
0
)
if
(
myRank
==
0
)
printf
(
printf
(
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d,
my time:%.2fus, nccl
"
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d,
allreduse_fuse_norm time:%.2fus, allreduce+norm
"
"time:%.2fus
\n
"
,
"time:%.2fus
\n
"
,
myRank
,
nRanks
,
data_size
*
sizeof
(
T
)
/
1024
,
threads
,
block_limit
,
myRank
,
nRanks
,
data_size
*
sizeof
(
T
)
/
1024
,
threads
,
block_limit
,
duration_ms
*
1e3
/
num_iters
,
allreduce_ms
*
1e3
/
num_iters
);
duration_ms
*
1e3
/
num_iters
,
allreduce_ms
*
1e3
/
num_iters
);
...
@@ -255,8 +374,9 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
...
@@ -255,8 +374,9 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
NCCLCHECK
(
ncclAllReduce
(
self_data_copy
,
self_data
,
data_size
,
ncclDtype
,
NCCLCHECK
(
ncclAllReduce
(
self_data_copy
,
self_data
,
data_size
,
ncclDtype
,
ncclSum
,
comm
,
stream
));
ncclSum
,
comm
,
stream
));
fused_add_rms_norm_choose
<
T
>
(
stream
,
self_data
,
residual_d
,
weight_d
,
1.0
,
hidden_dim
,
token_num
);
convert_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
self_data
,
result
,
nccl_result
,
convert_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
result_ori
,
result
_fuse
,
nccl_result
,
my_result
,
data_size
);
my_result
,
data_size
);
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
...
@@ -279,13 +399,13 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
...
@@ -279,13 +399,13 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
<<
" me: "
<<
my_diffs
/
data_size
<<
std
::
endl
;
<<
" me: "
<<
my_diffs
/
data_size
<<
std
::
endl
;
}
else
{
}
else
{
for
(
int
i
=
0
;
i
<
100
;
i
++
)
{
for
(
int
i
=
0
;
i
<
100
;
i
++
)
{
fa
.
allreduce
<
T
>
(
stream
,
self_data
,
result
,
data_size
,
threads
,
fa
.
allreduce
<
T
>
(
stream
,
self_data
,
result
_ori
,
data_size
,
threads
,
block_limit
);
block_limit
);
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
NCCLCHECK
(
ncclAllReduce
(
self_data
,
self_data_copy
,
data_size
,
ncclDtype
,
NCCLCHECK
(
ncclAllReduce
(
self_data
,
self_data_copy
,
data_size
,
ncclDtype
,
ncclSum
,
comm
,
stream
));
ncclSum
,
comm
,
stream
));
convert_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
convert_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
self_data_copy
,
result
,
nccl_result
,
my_result
,
data_size
);
self_data_copy
,
result
_ori
,
nccl_result
,
my_result
,
data_size
);
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
for
(
unsigned
long
j
=
0
;
j
<
data_size
;
j
++
)
{
for
(
unsigned
long
j
=
0
;
j
<
data_size
;
j
++
)
{
...
@@ -312,7 +432,8 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
...
@@ -312,7 +432,8 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
// << " me: " << my_diffs / data_size << std::endl;
// << " me: " << my_diffs / data_size << std::endl;
}
}
CUDACHECK
(
cudaFree
(
result
));
CUDACHECK
(
cudaFree
(
result_ori
));
CUDACHECK
(
cudaFree
(
result_fuse
));
CUDACHECK
(
cudaFree
(
self_data_copy
));
CUDACHECK
(
cudaFree
(
self_data_copy
));
CUDACHECK
(
cudaFree
(
rank_data
));
CUDACHECK
(
cudaFree
(
rank_data
));
CUDACHECK
(
cudaFree
(
buffer
));
CUDACHECK
(
cudaFree
(
buffer
));
...
@@ -351,9 +472,7 @@ int main(int argc, char** argv) {
...
@@ -351,9 +472,7 @@ int main(int argc, char** argv) {
const
int
block_limit
=
36
;
const
int
block_limit
=
36
;
#endif
#endif
// Scan through different sizes to test performance.
// Scan through different sizes to test performance.
for
(
int
sz
=
512
;
sz
<=
(
8
<<
20
);
sz
*=
2
)
{
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
36
,
7168
*
80
,
performance_test
,
7168
);
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
36
,
sz
+
8
*
47
,
performance_test
);
}
cudaProfilerStop
();
cudaProfilerStop
();
MPICHECK
(
MPI_Finalize
());
MPICHECK
(
MPI_Finalize
());
...
...
csrc/ops.h
View file @
3e6729e0
...
@@ -487,6 +487,17 @@ fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
...
@@ -487,6 +487,17 @@ fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
bool
fully_connected
);
bool
fully_connected
);
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
);
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
);
void
all_reduce_fuse_norm
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int64_t
hidden_size
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
rms_weight
,
double
eps
,
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
);
void
all_reduce_fuse_norm_quant
(
fptr_t
fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int64_t
hidden_size
,
torch
::
Tensor
&
rms_weight
,
double
eps
,
torch
::
Tensor
&
scales
,
torch
::
Tensor
&
norm_out
,
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
,
std
::
optional
<
at
::
Tensor
>
residual
,
bool
update_input
);
void
dispose
(
fptr_t
_fa
);
void
dispose
(
fptr_t
_fa
);
int64_t
meta_size
();
int64_t
meta_size
();
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
);
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
);
...
...
csrc/torch_bindings.cpp
View file @
3e6729e0
...
@@ -933,6 +933,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
...
@@ -933,6 +933,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()"
);
"int reg_buffer_sz_bytes) -> ()"
);
custom_ar
.
impl
(
"all_reduce"
,
torch
::
kCUDA
,
&
all_reduce
);
custom_ar
.
impl
(
"all_reduce"
,
torch
::
kCUDA
,
&
all_reduce
);
custom_ar
.
def
(
"all_reduce_fuse_norm(int fa, Tensor inp, Tensor! out, int hidden_size, "
"Tensor residual, Tensor rms_weight, float eps, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()"
);
custom_ar
.
impl
(
"all_reduce_fuse_norm"
,
torch
::
kCUDA
,
&
all_reduce_fuse_norm
);
custom_ar
.
def
(
"all_reduce_fuse_norm_quant(int fa, Tensor inp, Tensor! out, int hidden_size, "
"Tensor rms_weight, float eps, Tensor! scales, Tensor! norm_out, int reg_buffer, "
"int reg_buffer_sz_bytes, Tensor? residual, bool update_input) -> ()"
);
custom_ar
.
impl
(
"all_reduce_fuse_norm_quant"
,
torch
::
kCUDA
,
&
all_reduce_fuse_norm_quant
);
custom_ar
.
def
(
"dispose"
,
&
dispose
);
custom_ar
.
def
(
"dispose"
,
&
dispose
);
custom_ar
.
def
(
"meta_size"
,
&
meta_size
);
custom_ar
.
def
(
"meta_size"
,
&
meta_size
);
...
...
vllm/_custom_ops.py
View file @
3e6729e0
...
@@ -2212,7 +2212,17 @@ def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
...
@@ -2212,7 +2212,17 @@ def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
reg_buffer_sz_bytes
:
int
)
->
None
:
reg_buffer_sz_bytes
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
torch
.
ops
.
_C_custom_ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
reg_buffer_sz_bytes
)
def
all_reduce_fuse_norm
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
hidden_size
:
int
,
residual
:
torch
.
Tensor
,
rms_weight
:
torch
.
Tensor
,
eps
:
float
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce_fuse_norm
(
fa
,
inp
,
out
,
hidden_size
,
residual
,
rms_weight
,
eps
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
all_reduce_fuse_norm_quant
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
hidden_size
:
int
,
rms_weight
:
torch
.
Tensor
,
eps
:
float
,
scales
:
torch
.
Tensor
,
norm_out
:
torch
.
Tensor
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
,
residual
:
torch
.
Tensor
,
update_input
:
bool
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce_fuse_norm_quant
(
fa
,
inp
,
out
,
hidden_size
,
rms_weight
,
eps
,
scales
,
norm_out
,
reg_buffer
,
reg_buffer_sz_bytes
,
residual
,
update_input
)
def
dispose
(
fa
:
int
)
->
None
:
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
dispose
(
fa
)
torch
.
ops
.
_C_custom_ar
.
dispose
(
fa
)
...
...
vllm/distributed/communication_op.py
View file @
3e6729e0
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
,
Tuple
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
.parallel_state
import
get_tp_group
from
.parallel_state
import
get_tp_group
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""All-reduce the input tensor across model parallel group."""
"""All-reduce the input tensor across model parallel group."""
return
get_tp_group
().
all_reduce
(
input_
)
return
get_tp_group
().
all_reduce
(
input_
)
def
tensor_model_parallel_all_reduce_crp_m32
(
input_
:
torch
.
Tensor
,
pa_rms_weight
:
torch
.
Tensor
,
pa_residual
:
torch
.
Tensor
,
pa_rms_eps
:
float
,
pa_quant_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
int8
,
update_input
:
Optional
[
bool
]
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""All-reduce the input tensor across model parallel group."""
# allreduce fused rms and quant
return
get_tp_group
().
all_reduce_crq_m32
(
input_
=
input_
,
pa_rms_weight
=
pa_rms_weight
,
pa_residual
=
pa_residual
,
pa_rms_eps
=
pa_rms_eps
,
pa_quant_dtype
=
pa_quant_dtype
,
update_input
=
update_input
)
def
tensor_model_parallel_all_gather
(
input_
:
torch
.
Tensor
,
def
tensor_model_parallel_all_gather
(
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
...
...
vllm/distributed/device_communicators/cuda_communicator.py
View file @
3e6729e0
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
import
torch
import
torch
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
...
@@ -14,6 +14,7 @@ from .base_device_communicator import DeviceCommunicatorBase
...
@@ -14,6 +14,7 @@ from .base_device_communicator import DeviceCommunicatorBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
from
lmslim.quantize.quant_ops
import
lm_faster_rmsquant
class
CudaCommunicator
(
DeviceCommunicatorBase
):
class
CudaCommunicator
(
DeviceCommunicatorBase
):
...
@@ -117,6 +118,37 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -117,6 +118,37 @@ class CudaCommunicator(DeviceCommunicatorBase):
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
device_group
)
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
device_group
)
return
out
return
out
def
all_reduce_rms_quant_m32
(
self
,
input_
,
pa_rms_weight
:
torch
.
Tensor
,
pa_residual
:
torch
.
Tensor
,
pa_rms_eps
:
float
,
pa_quant_dtype
:
torch
.
dtype
,
update_input
:
Optional
[
bool
]
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
batch_size
,
hidden_dim
=
input_
.
shape
ca_comm
=
self
.
ca_comm
assert
ca_comm
is
not
None
and
not
ca_comm
.
disabled
assert
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
\
pa_rms_weight
is
not
None
and
pa_residual
is
not
None
if
batch_size
<=
16
:
xq
,
xs
,
norm_out
=
ca_comm
.
custom_all_reduce_fuse_norm_quant
(
inp
=
input_
,
rms_weight
=
pa_rms_weight
,
residual
=
pa_residual
,
eps
=
pa_rms_eps
,
quant_type
=
pa_quant_dtype
,
update_input
=
True
)
input_
=
norm_out
else
:
input_
=
self
.
all_reduce
(
input_
)
xq
,
xs
=
lm_faster_rmsquant
(
input_
,
rms_weight
=
pa_rms_weight
,
residual
=
pa_residual
,
epsilon
=
pa_rms_eps
,
quant_dtype
=
pa_quant_dtype
,
update_input
=
True
)
assert
input_
is
not
None
assert
xq
is
not
None
and
xs
is
not
None
return
input_
,
pa_residual
,
xq
,
xs
def
reduce_scatter
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
):
def
reduce_scatter
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
):
world_size
=
self
.
world_size
world_size
=
self
.
world_size
pynccl_comm
=
self
.
pynccl_comm
pynccl_comm
=
self
.
pynccl_comm
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
3e6729e0
...
@@ -275,6 +275,91 @@ class CustomAllreduce:
...
@@ -275,6 +275,91 @@ class CustomAllreduce:
# latency) compared to the performance gain of using custom kernels
# latency) compared to the performance gain of using custom kernels
return
self
.
all_reduce
(
input
,
registered
=
False
)
return
self
.
all_reduce
(
input
,
registered
=
False
)
def
allreduce_fuse_norm
(
self
,
inp
:
torch
.
Tensor
,
hidden_size
:
int
,
residual
:
torch
.
Tensor
,
rms_weight
:
torch
.
Tensor
,
eps
:
float
,
*
,
out
:
torch
.
Tensor
=
None
,
registered
:
bool
=
False
):
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
if
registered
:
ops
.
all_reduce_fuse_norm
(
self
.
_ptr
,
inp
,
out
,
hidden_size
,
residual
,
rms_weight
,
eps
,
0
,
0
)
else
:
ops
.
all_reduce_fuse_norm
(
self
.
_ptr
,
inp
,
out
,
hidden_size
,
residual
,
rms_weight
,
eps
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
return
out
def
custom_all_reduce_fuse_norm
(
self
,
input
:
torch
.
Tensor
,
hidden_size
:
int
,
residual
:
torch
.
Tensor
,
rms_weight
:
torch
.
Tensor
,
eps
:
float
)
->
Optional
[
torch
.
Tensor
]:
if
self
.
disabled
or
not
self
.
should_custom_ar
(
input
):
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
allreduce_fuse_norm
(
input
,
hidden_size
,
residual
,
rms_weight
,
eps
,
registered
=
False
)
else
:
return
torch
.
empty_like
(
input
)
else
:
return
self
.
allreduce_fuse_norm
(
input
,
hidden_size
,
residual
,
rms_weight
,
eps
,
registered
=
False
)
def
allreduce_fuse_norm_quant
(
self
,
inp
:
torch
.
Tensor
,
hidden_size
:
int
,
rms_weight
,
eps
,
quant_dtype
,
residual
,
update_input
:
bool
=
True
,
registered
:
bool
=
False
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
xq
=
torch
.
empty_like
(
inp
,
dtype
=
quant_dtype
)
norm_out
=
torch
.
empty_like
(
inp
)
scales
=
torch
.
empty
((
inp
.
numel
()
//
inp
.
shape
[
-
1
],
1
),
device
=
inp
.
device
,
dtype
=
torch
.
float32
)
if
registered
:
ops
.
all_reduce_fuse_norm_quant
(
self
.
_ptr
,
inp
,
xq
,
hidden_size
,
rms_weight
,
eps
,
scales
,
norm_out
,
0
,
0
,
residual
,
update_input
)
else
:
ops
.
all_reduce_fuse_norm_quant
(
self
.
_ptr
,
inp
,
xq
,
hidden_size
,
rms_weight
,
eps
,
scales
,
norm_out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
,
residual
,
update_input
)
return
xq
,
scales
,
norm_out
def
custom_all_reduce_fuse_norm_quant
(
self
,
inp
:
torch
.
Tensor
,
rms_weight
:
torch
.
Tensor
,
eps
,
quant_type
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_input
=
True
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_size
=
inp
.
shape
[
-
1
]
if
self
.
disabled
or
not
self
.
should_custom_ar
(
inp
):
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
allreduce_fuse_norm_quant
(
inp
,
hidden_size
,
rms_weight
,
eps
,
quant_type
,
residual
,
update_input
=
update_input
,
registered
=
False
)
else
:
return
torch
.
empty_like
(
inp
,
dtype
=
quant_type
),
\
torch
.
empty
((
inp
.
numel
()
//
inp
.
shape
[
-
1
],
1
),
dtype
=
torch
.
float32
,
device
=
inp
.
device
),
\
torch
.
empty_like
(
inp
)
else
:
return
self
.
allreduce_fuse_norm_quant
(
inp
,
hidden_size
,
rms_weight
,
eps
,
quant_type
,
residual
,
update_input
=
update_input
,
registered
=
False
)
def
close
(
self
):
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
if
not
self
.
disabled
and
self
.
_ptr
:
if
ops
is
not
None
:
if
ops
is
not
None
:
...
...
vllm/distributed/parallel_state.py
View file @
3e6729e0
...
@@ -30,7 +30,7 @@ from collections import namedtuple
...
@@ -30,7 +30,7 @@ from collections import namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
multiprocessing
import
shared_memory
from
multiprocessing
import
shared_memory
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
torch
import
torch
...
@@ -114,6 +114,37 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
...
@@ -114,6 +114,37 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return
torch
.
empty_like
(
tensor
)
return
torch
.
empty_like
(
tensor
)
def
all_reduce_rms_quant
(
input_
:
torch
.
Tensor
,
group_name
:
str
,
pa_rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_residual
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_rms_eps
:
Optional
[
float
]
=
1e-6
,
pa_quant_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
int8
,
update_input
:
Optional
[
bool
]
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
_all_reduce_out_place_m32
(
input_
,
pa_rms_weight
=
pa_rms_weight
,
pa_residual
=
pa_residual
,
pa_rms_eps
=
pa_rms_eps
,
pa_quant_dtype
=
pa_quant_dtype
,
update_input
=
update_input
)
def
all_reduce_rms_quant_fake
(
input_
:
torch
.
Tensor
,
group_name
:
str
,
pa_rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_residual
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_rms_eps
:
Optional
[
float
]
=
1e-6
,
pa_quant_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
int8
,
update_input
:
Optional
[
bool
]
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
xq
=
torch
.
zeros_like
(
input_
,
dtype
=
pa_quant_dtype
)
xs
=
torch
.
ones
((
input_
.
numel
()
//
input_
.
shape
[
-
1
],
1
),
device
=
input_
.
device
,
dtype
=
torch
.
float32
)
return
input_
,
pa_residual
,
xq
,
xs
def
reduce_scatter
(
tensor
:
torch
.
Tensor
,
dim
:
int
,
world_size
:
int
,
def
reduce_scatter
(
tensor
:
torch
.
Tensor
,
dim
:
int
,
world_size
:
int
,
group_name
:
str
)
->
torch
.
Tensor
:
group_name
:
str
)
->
torch
.
Tensor
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
...
@@ -156,6 +187,14 @@ if supports_custom_op():
...
@@ -156,6 +187,14 @@ if supports_custom_op():
dispatch_key
=
current_platform
.
dispatch_key
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
)
direct_register_custom_op
(
op_name
=
"all_reduce_rms_quant"
,
op_func
=
all_reduce_rms_quant
,
mutates_args
=
[
"input_"
,
"pa_residual"
],
fake_impl
=
all_reduce_rms_quant_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"reduce_scatter"
,
op_name
=
"reduce_scatter"
,
op_func
=
reduce_scatter
,
op_func
=
reduce_scatter
,
...
@@ -358,9 +397,44 @@ class GroupCoordinator:
...
@@ -358,9 +397,44 @@ class GroupCoordinator:
else
:
else
:
return
self
.
_all_reduce_out_place
(
input_
)
return
self
.
_all_reduce_out_place
(
input_
)
def
all_reduce_crq_m32
(
self
,
input_
:
torch
.
Tensor
,
pa_rms_weight
:
torch
.
Tensor
,
pa_residual
:
torch
.
Tensor
,
pa_rms_eps
:
float
,
pa_quant_dtype
:
torch
.
dtype
,
update_input
:
Optional
[
bool
]
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
world_size
>
1
assert
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
pa_rms_weight
is
not
None
and
pa_residual
is
not
None
return
torch
.
ops
.
vllm
.
all_reduce_rms_quant
(
input_
,
group_name
=
self
.
unique_name
,
pa_rms_weight
=
pa_rms_weight
,
pa_residual
=
pa_residual
,
pa_rms_eps
=
pa_rms_eps
,
pa_quant_dtype
=
pa_quant_dtype
,
update_input
=
update_input
)
def
_all_reduce_out_place
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_all_reduce_out_place
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
device_communicator
.
all_reduce
(
input_
)
return
self
.
device_communicator
.
all_reduce
(
input_
)
def
_all_reduce_out_place_m32
(
self
,
input_
:
torch
.
Tensor
,
pa_rms_weight
:
torch
.
Tensor
,
pa_residual
:
torch
.
Tensor
,
pa_rms_eps
:
float
,
pa_quant_dtype
:
torch
.
dtype
,
update_input
:
Optional
[
bool
]
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
pa_rms_weight
is
not
None
\
and
pa_residual
is
not
None
input_
,
pa_residual
,
xq
,
xs
=
self
.
device_communicator
.
all_reduce_rms_quant_m32
(
input_
,
pa_rms_weight
=
pa_rms_weight
,
pa_residual
=
pa_residual
,
pa_rms_eps
=
pa_rms_eps
,
pa_quant_dtype
=
pa_quant_dtype
,
update_input
=
update_input
)
return
input_
,
pa_residual
,
xq
,
xs
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
world_size
=
self
.
world_size
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
...
...
vllm/engine/arg_utils.py
View file @
3e6729e0
...
@@ -284,7 +284,13 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
...
@@ -284,7 +284,13 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cached version.
cached version.
"""
"""
return
copy
.
deepcopy
(
_compute_kwargs
(
cls
))
return
copy
.
deepcopy
(
_compute_kwargs
(
cls
))
class
EnvironmentConfigError
(
Exception
):
pass
def
check_incompatible_config
(
env1
:
bool
,
env2
:
bool
):
if
env1
is
True
and
env2
is
True
:
_s
=
"USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and USE_FUSED_RMS_QUANT must not be enabled simultaneously!
\n\n
"
raise
EnvironmentConfigError
(
_s
)
@
dataclass
@
dataclass
class
EngineArgs
:
class
EngineArgs
:
...
@@ -1230,6 +1236,7 @@ class EngineArgs:
...
@@ -1230,6 +1236,7 @@ class EngineArgs:
num_lookahead_slots
=
num_lookahead_slots
\
num_lookahead_slots
=
num_lookahead_slots
\
if
speculative_config
is
None
\
if
speculative_config
is
None
\
else
speculative_config
.
num_lookahead_slots
else
speculative_config
.
num_lookahead_slots
check_incompatible_config
(
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
,
envs
.
USE_FUSED_RMS_QUANT
)
scheduler_config
=
SchedulerConfig
(
scheduler_config
=
SchedulerConfig
(
runner_type
=
model_config
.
runner_type
,
runner_type
=
model_config
.
runner_type
,
...
...
vllm/envs.py
View file @
3e6729e0
...
@@ -180,6 +180,7 @@ if TYPE_CHECKING:
...
@@ -180,6 +180,7 @@ if TYPE_CHECKING:
VLLM_USE_PD_SPLIT
:
bool
=
False
VLLM_USE_PD_SPLIT
:
bool
=
False
VLLM_USE_PP_SYNC
:
bool
=
False
VLLM_USE_PP_SYNC
:
bool
=
False
VLLM_USE_LIGHTOP_FILL_MOE_ALIGN
:
bool
=
False
VLLM_USE_LIGHTOP_FILL_MOE_ALIGN
:
bool
=
False
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
...
@@ -1166,6 +1167,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1166,6 +1167,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"
:
"VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vllm will use custom-allreduce rmsquant fused op
"USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT"
:
lambda
:
(
os
.
getenv
(
'USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT'
,
'0'
).
lower
()
in
(
"true"
,
"1"
)),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/layers/linear.py
View file @
3e6729e0
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
itertools
import
itertools
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Any
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Literal
,
Optional
,
Union
,
Tuple
import
vllm.envs
as
envs
import
vllm.envs
as
envs
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce_crp_m32
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
...
@@ -677,7 +678,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -677,7 +678,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
,
input_
,
self
,
input_
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
update_hd
:
Optional
[
bool
]
=
True
,
xqxs
:
Optional
[
tuple
]
=
None
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
input_quant_args
=
None
input_quant_args
=
None
...
@@ -706,7 +708,22 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -706,7 +708,22 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if
not
self
.
return_bias
:
if
not
self
.
return_bias
:
return
output
return
output
return
output
,
new_residual
,
output_bias
return
output
,
new_residual
,
output_bias
else
:
# not USE_FUSED_RMS_QUANT
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
=
xqxs
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
else
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
...
@@ -1495,8 +1512,56 @@ class RowParallelLinear(LinearBase):
...
@@ -1495,8 +1512,56 @@ class RowParallelLinear(LinearBase):
def
forward
(
def
forward
(
self
,
input_
,
self
,
input_
,
use_fused_silu_mul_quant
:
Optional
[
bool
]
=
False
use_fused_silu_mul_quant
:
Optional
[
bool
]
=
False
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
pa_rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_residual
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_rms_eps
:
Optional
[
float
]
=
1e-6
,
pa_quant_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
int8
,
update_input
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]],
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
Parameter
]]
]:
if
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
pa_rms_weight
is
not
None
and
pa_residual
is
not
None
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
# Matrix multiply.
assert
self
.
quant_method
is
not
None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
output
=
self
.
tbo_all_reduce
(
output_parallel
)
packages_
=
tensor_model_parallel_all_reduce_crp_m32
(
output_parallel
,
pa_rms_weight
=
pa_rms_weight
,
pa_residual
=
pa_residual
,
pa_rms_eps
=
pa_rms_eps
,
pa_quant_dtype
=
pa_quant_dtype
,
update_input
=
update_input
)
hs
,
resi
,
xq
,
xs
=
packages_
output
=
hs
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
resi
,
xq
,
xs
,
output_bias
else
:
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8.py
View file @
3e6729e0
...
@@ -162,7 +162,11 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -162,7 +162,11 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
assert
len
(
input_quant_args
)
==
2
assert
len
(
input_quant_args
)
==
2
x_q
,
x_scale
=
input_quant_args
x_q
,
x_scale
=
input_quant_args
elif
envs
.
USE_FUSED_SILU_MUL_QUANT
and
silu_quant_args
is
not
None
:
elif
envs
.
USE_FUSED_SILU_MUL_QUANT
and
silu_quant_args
is
not
None
:
assert
len
(
silu_quant_args
)
==
2
x_q
,
x_scale
=
silu_quant_args
x_q
,
x_scale
=
silu_quant_args
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
input_quant_args
is
not
None
:
assert
len
(
input_quant_args
)
==
2
x_q
,
x_scale
=
input_quant_args
else
:
else
:
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
...
@@ -178,9 +182,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -178,9 +182,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
if
m
<=
16
:
if
m
<=
16
:
m_
=
m
m_
=
m
elif
m
<=
64
:
elif
m
<=
64
:
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
m_
=
(
(
m
+
3
)
//
4
)
*
4
#取值到最近的4的倍数
elif
m
<=
160
:
elif
m
<=
160
:
m_
=
(
m
+
7
)
&
-
8
m_
=
(
(
m
+
7
)
//
8
)
*
8
elif
m
<
200
:
#256
elif
m
<
200
:
#256
m_
=
160
m_
=
160
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
3e6729e0
...
@@ -29,7 +29,7 @@ import vllm.envs as envs
...
@@ -29,7 +29,7 @@ import vllm.envs as envs
import
typing
import
typing
from
collections.abc
import
Callable
,
Iterable
from
collections.abc
import
Callable
,
Iterable
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -96,8 +96,8 @@ class DeepseekV2MLP(nn.Module):
...
@@ -96,8 +96,8 @@ class DeepseekV2MLP(nn.Module):
def
forward
(
self
,
x
,
def
forward
(
self
,
x
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
False
update_hd
:
Optional
[
bool
]
=
False
,
):
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
):
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
gate_up
,
new_resi
,
_
=
self
.
gate_up_proj
(
x
,
rms_weight
,
residual
,
update_hd
=
update_hd
)
gate_up
,
new_resi
,
_
=
self
.
gate_up_proj
(
x
,
rms_weight
,
residual
,
update_hd
=
update_hd
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
...
@@ -107,6 +107,11 @@ class DeepseekV2MLP(nn.Module):
...
@@ -107,6 +107,11 @@ class DeepseekV2MLP(nn.Module):
x
,
_
=
self
.
down_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
,
new_resi
return
x
,
new_resi
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
xqxs
=
xqxs
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
else
:
else
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
act_fn
(
gate_up
)
...
@@ -200,20 +205,62 @@ class DeepseekV2MoE(nn.Module):
...
@@ -200,20 +205,62 @@ class DeepseekV2MoE(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
,
xqxs
=
xqxs
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
shared_output
=
shared_output
)
else
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
if
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
final_hidden_states
=
self
.
tbo_all_reduce
(
final_hidden_states
)
else
:
final_hidden_states
=
(
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
final_hidden_states
))
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
else
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
else
:
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -556,8 +603,16 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -556,8 +603,16 @@ class DeepseekV2MLAAttention(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
pa_rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_residual
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_rms_eps
:
Optional
[
float
]
=
1e-6
,
pa_quant_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
int8
,
update_input
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
],
]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
q_c
,
new_residual
,
_
,
input_quant_args
=
self
.
q_a_proj
(
hidden_states
,
rms_weight
=
rms_weight
,
residual
=
residual
,
update_hd
=
False
)
q_c
,
new_residual
,
_
,
input_quant_args
=
self
.
q_a_proj
(
hidden_states
,
rms_weight
=
rms_weight
,
residual
=
residual
,
update_hd
=
False
)
...
@@ -587,6 +642,40 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -587,6 +642,40 @@ class DeepseekV2MLAAttention(nn.Module):
output_shape
=
(
hidden_states
.
shape
[
0
],
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
self
.
num_local_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
],
new_residual
return
self
.
o_proj
(
attn_out
)[
0
],
new_residual
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
pa_rms_weight
is
not
None
and
pa_residual
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q_c
=
self
.
q_a_layernorm
(
q_c
)
q
=
self
.
q_b_proj
(
q_c
)[
0
]
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
packages_
=
self
.
o_proj
(
attn_out
,
pa_rms_weight
=
pa_rms_weight
,
pa_residual
=
pa_residual
,
pa_rms_eps
=
pa_rms_eps
,
pa_quant_dtype
=
pa_quant_dtype
,
update_input
=
update_input
)[:
4
]
assert
len
(
packages_
)
==
4
hs
,
resi
,
xq
,
xs
=
packages_
assert
xq
is
not
None
and
xs
is
not
None
return
hs
,
resi
,
xq
,
xs
else
:
else
:
if
self
.
q_lora_rank
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
...
@@ -682,14 +771,15 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -682,14 +771,15 @@ class DeepseekV2DecoderLayer(nn.Module):
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
use_fused_rms_quant
=
envs
.
USE_FUSED_RMS_QUANT
self
.
use_fused_custom_all_reduce
=
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
def
forward
(
def
forward
_fused_rmsquant
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
]
)
->
torch
.
Tensor
:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
envs
.
USE_FUSED_RMS_QUANT
:
# Fix residual FP16 overflow
# Fix residual FP16 overflow
residual_fix_overflow
=
False
residual_fix_overflow
=
False
...
@@ -732,7 +822,51 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -732,7 +822,51 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
*=
1.
/
self
.
routed_scaling_factor
hidden_states
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
new_resi
return
hidden_states
,
new_resi
def
forward_fused_CRQ
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
residual_fix_overflow
=
False
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
residual_fix_overflow
=
True
else
:
else
:
hidden_states
,
resi_new
=
self
.
input_layernorm
(
hidden_states
,
residual
)
residual
=
resi_new
new_hs
,
new_resi
,
xq
,
xs
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
pa_rms_weight
=
self
.
post_attention_layernorm
.
weight
.
data
,
pa_residual
=
residual
,
pa_rms_eps
=
self
.
post_attention_layernorm
.
variance_epsilon
,
pa_quant_dtype
=
torch
.
int8
,
update_input
=
True
)
assert
xq
is
not
None
and
xs
is
not
None
if
new_hs
.
dtype
==
torch
.
float16
:
# overflow处理逻辑
new_hs
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
new_resi
*=
1.
/
self
.
routed_scaling_factor
hidden_states
=
self
.
mlp
(
new_hs
,
xqxs
=
(
xq
,
xs
))
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
hidden_states
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
new_resi
def
forward_default
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
# Fix residual FP16 overflow
# Fix residual FP16 overflow
residual_fix_overflow
=
False
residual_fix_overflow
=
False
...
@@ -774,6 +908,26 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -774,6 +908,26 @@ class DeepseekV2DecoderLayer(nn.Module):
return
hidden_states
,
residual
return
hidden_states
,
residual
def
choose_forward
(
self
):
if
self
.
use_fused_rms_quant
:
return
self
.
forward_fused_rmsquant
elif
self
.
use_fused_custom_all_reduce
:
return
self
.
forward_fused_CRQ
else
:
return
self
.
forward_default
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
forward_func
=
self
.
choose_forward
()
return
forward_func
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
)
@
support_torch_compile
@
support_torch_compile
class
DeepseekV2Model
(
nn
.
Module
):
class
DeepseekV2Model
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment