Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3e6729e0
Commit
3e6729e0
authored
Nov 18, 2025
by
wujl5
Browse files
deepseekv2-w4a8支持custom-rms-quant融合
parent
813f81fb
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1518 additions
and
205 deletions
+1518
-205
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+138
-0
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+596
-11
csrc/custom_all_reduce_test.cu
csrc/custom_all_reduce_test.cu
+142
-23
csrc/ops.h
csrc/ops.h
+11
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+11
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+11
-1
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+15
-2
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+33
-1
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+85
-0
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+75
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+8
-1
vllm/envs.py
vllm/envs.py
+6
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+101
-36
vllm/model_executor/layers/quantization/slimquant_w4a8.py
vllm/model_executor/layers/quantization/slimquant_w4a8.py
+6
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+280
-126
No files found.
csrc/custom_all_reduce.cu
View file @
3e6729e0
...
...
@@ -59,6 +59,144 @@ bool _is_weak_contiguous(torch::Tensor& t) {
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void
all_reduce_fuse_norm
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int64_t
hidden_size
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
rms_weight
,
double
eps
,
fptr_t
_reg_buffer
,
int64_t
reg_buffer_sz_bytes
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK_EQ
(
inp
.
scalar_type
(),
out
.
scalar_type
());
TORCH_CHECK_EQ
(
inp
.
numel
(),
out
.
numel
());
TORCH_CHECK
(
_is_weak_contiguous
(
out
));
TORCH_CHECK
(
_is_weak_contiguous
(
inp
));
TORCH_CHECK
(
_is_weak_contiguous
(
residual
));
TORCH_CHECK
(
_is_weak_contiguous
(
rms_weight
));
int
token_num
=
inp
.
numel
()
/
hidden_size
;
auto
input_size
=
inp
.
numel
()
*
inp
.
element_size
();
auto
reg_buffer
=
reinterpret_cast
<
void
*>
(
_reg_buffer
);
if
(
reg_buffer
)
{
TORCH_CHECK_LE
(
input_size
,
reg_buffer_sz_bytes
);
AT_CUDA_CHECK
(
cudaMemcpyAsync
(
reg_buffer
,
inp
.
data_ptr
(),
input_size
,
cudaMemcpyDeviceToDevice
,
stream
));
}
else
{
reg_buffer
=
inp
.
data_ptr
();
}
switch
(
out
.
scalar_type
())
{
case
at
::
ScalarType
::
Float
:
{
fa
->
allreduce_fuse_norm
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
reg_buffer
),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
out
.
numel
(),
token_num
,
hidden_size
,
reinterpret_cast
<
float
*>
(
residual
.
data_ptr
()),
reinterpret_cast
<
float
*>
(
rms_weight
.
data_ptr
()),
(
float
)
eps
);
break
;
}
case
at
::
ScalarType
::
Half
:
{
fa
->
allreduce_fuse_norm
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
reg_buffer
),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
(),
token_num
,
hidden_size
,
reinterpret_cast
<
half
*>
(
residual
.
data_ptr
()),
reinterpret_cast
<
half
*>
(
rms_weight
.
data_ptr
()),
(
float
)
eps
);
break
;
}
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
fa
->
allreduce_fuse_norm
<
nv_bfloat16
>
(
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
reg_buffer
),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
(),
token_num
,
hidden_size
,
reinterpret_cast
<
nv_bfloat16
*>
(
residual
.
data_ptr
()),
reinterpret_cast
<
nv_bfloat16
*>
(
rms_weight
.
data_ptr
()),
(
float
)
eps
);
break
;
}
// #endif
default:
throw
std
::
runtime_error
(
"custom allreduce only supports float32, float16 and bfloat16"
);
}
}
template
<
typename
scalar_in_t
,
bool
update_input
>
void
allreduce_fuse_norm_quant_dispath
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int
hidden_size
,
torch
::
Tensor
&
rms_weight
,
double
eps
,
torch
::
Tensor
&
scales
,
torch
::
Tensor
&
norm_out
,
fptr_t
_reg_buffer
,
int64_t
reg_buffer_sz_bytes
,
std
::
optional
<
at
::
Tensor
>
residual
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
_is_weak_contiguous
(
inp
));
int
token_num
=
inp
.
numel
()
/
hidden_size
;
auto
input_size
=
inp
.
numel
()
*
inp
.
element_size
();
auto
reg_buffer
=
reinterpret_cast
<
void
*>
(
_reg_buffer
);
if
(
reg_buffer
)
{
TORCH_CHECK_LE
(
input_size
,
reg_buffer_sz_bytes
);
AT_CUDA_CHECK
(
cudaMemcpyAsync
(
reg_buffer
,
inp
.
data_ptr
(),
input_size
,
cudaMemcpyDeviceToDevice
,
stream
));
}
else
{
reg_buffer
=
inp
.
data_ptr
();
}
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
rms_weight
.
data_ptr
());
if
(
wt_ptr
%
16
!=
0
)
{
throw
std
::
runtime_error
(
"custom allreduce currently requires wt_ptr % 16 "
"of "
+
std
::
to_string
(
wt_ptr
%
16
));
}
if
(
fa
->
fully_connected_
)
{
if
(
residual
.
has_value
())
{
VLLM_DISPATCH_QUANT_TYPES
(
out
.
scalar_type
(),
"fa->allreduce_fuse_norm_quant"
,
[
&
]
{
fa
->
allreduce_fuse_norm_quant
<
scalar_in_t
,
scalar_t
,
true
,
update_input
>
(
stream
,
reinterpret_cast
<
scalar_in_t
*>
(
reg_buffer
),
out
.
data_ptr
<
scalar_t
>
(),
out
.
numel
(),
token_num
,
hidden_size
,
residual
->
data_ptr
<
scalar_in_t
>
(),
rms_weight
.
data_ptr
<
scalar_in_t
>
(),
norm_out
.
data_ptr
<
scalar_in_t
>
(),
eps
,
scales
.
data_ptr
<
float
>
());
});
}
else
{
VLLM_DISPATCH_QUANT_TYPES
(
out
.
scalar_type
(),
"fa->allreduce_fuse_norm_quant"
,
[
&
]
{
fa
->
allreduce_fuse_norm_quant
<
scalar_in_t
,
scalar_t
,
false
,
update_input
>
(
stream
,
reinterpret_cast
<
scalar_in_t
*>
(
reg_buffer
),
out
.
data_ptr
<
scalar_t
>
(),
out
.
numel
(),
token_num
,
hidden_size
,
nullptr
,
rms_weight
.
data_ptr
<
scalar_in_t
>
(),
norm_out
.
data_ptr
<
scalar_in_t
>
(),
eps
,
scales
.
data_ptr
<
float
>
());
});
}
}
else
{
throw
std
::
runtime_error
(
"custom allreduce only supports fully_connected"
);
}
}
void
all_reduce_fuse_norm_quant
(
fptr_t
fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int64_t
hidden_size
,
torch
::
Tensor
&
rms_weight
,
double
eps
,
torch
::
Tensor
&
scales
,
torch
::
Tensor
&
norm_out
,
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
,
std
::
optional
<
at
::
Tensor
>
residual
,
bool
update_input
)
{
static
c10
::
ScalarType
kFp8Type
=
c10
::
ScalarType
::
Float8_e4m3fn
;
TORCH_CHECK
(
out
.
dtype
()
==
kFp8Type
||
out
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK_EQ
(
inp
.
numel
(),
out
.
numel
());
TORCH_CHECK
(
out
.
is_contiguous
()
&&
inp
.
is_contiguous
());
VLLM_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"allreduce_fuse_norm_quant_dispath"
,
[
&
]
{
if
(
update_input
)
allreduce_fuse_norm_quant_dispath
<
scalar_t
,
true
>
(
fa
,
inp
,
out
,
hidden_size
,
rms_weight
,
eps
,
scales
,
norm_out
,
reg_buffer
,
reg_buffer_sz_bytes
,
residual
);
else
allreduce_fuse_norm_quant_dispath
<
scalar_t
,
false
>
(
fa
,
inp
,
out
,
hidden_size
,
rms_weight
,
eps
,
scales
,
norm_out
,
reg_buffer
,
reg_buffer_sz_bytes
,
residual
);
});
}
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
fptr_t
_reg_buffer
,
int64_t
reg_buffer_sz_bytes
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
...
...
csrc/custom_all_reduce.cuh
View file @
3e6729e0
#pragma once
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include <algorithm>
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#if defined(USE_ROCM)
#include <hip/hip_bf16.h>
//
#if defined(USE_ROCM)
typedef
__hip_bfloat16
nv_bfloat16
;
#endif
//
#endif
#include <iostream>
#include <array>
...
...
@@ -15,7 +23,11 @@ typedef __hip_bfloat16 nv_bfloat16;
#include <map>
#include <unordered_map>
#include <vector>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace
vllm
{
#define CUDACHECK(cmd) \
do { \
...
...
@@ -28,7 +40,7 @@ namespace vllm {
} while (0)
// Maximal number of blocks in allreduce kernel.
constexpr
int
kMaxBlocks
=
36
;
constexpr
int
kMaxBlocks
=
128
;
// Default number of blocks in allreduce kernel.
#ifndef USE_ROCM
...
...
@@ -80,6 +92,7 @@ struct packed_t {
using
P
=
array_t
<
T
,
16
/
sizeof
(
T
)
>
;
// the (A)ccumulator type for reduction
using
A
=
array_t
<
float
,
16
/
sizeof
(
T
)
>
;
using
F
=
array_t
<
int8_t
,
16
/
sizeof
(
T
)
>
;
};
#define DINLINE __device__ __forceinline__
...
...
@@ -124,6 +137,117 @@ DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
return
a
;
}
/**********************************************************/
template
<
typename
P
,
uint32_t
VEC_SIZE
>
DINLINE
P
vec_add
(
const
P
&
a
,
const
P
&
b
)
{
P
sum_tmp
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a
.
size
;
++
i
)
sum_tmp
.
data
[
i
]
=
static_cast
<
float
>
(
a
.
data
[
i
])
+
static_cast
<
float
>
(
b
.
data
[
i
]);
return
sum_tmp
;
}
template
<
typename
T
,
int
reducesize
=
64
>
__inline__
__device__
T
WarpReduceSum
(
T
val
)
{
#pragma unroll
for
(
int
offset
=
reducesize
/
2
;
offset
>
0
;
offset
>>=
1
)
{
val
+=
WARP_SHFL_DOWN
(
val
,
offset
);
}
return
val
;
}
template
<
typename
T
>
DINLINE
T
BlockReduce
(
T
val
,
T
*
shared
)
{
const
int
lid
=
threadIdx
.
x
%
64
;
const
int
wid
=
threadIdx
.
x
/
64
;
const
int
block_size
=
blockDim
.
x
;
const
int
shared_size
=
block_size
/
64
;
val
=
WarpReduceSum
<
T
>
(
val
);
if
(
block_size
==
64
)
return
val
;
if
(
lid
==
0
&&
wid
<
shared_size
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
val
=
0.
f
;
if
(
wid
==
0
&&
lid
<
shared_size
)
{
val
=
shared
[
lid
];
val
=
WarpReduceSum
<
T
,
16
>
(
val
);
}
return
val
;
}
template
<
typename
T
,
typename
P
,
typename
A
>
DINLINE
P
fused_add_rms_norm
(
P
const
&
residual
,
P
const
&
gamma
,
int
hidden_dim
,
float
eps
)
{
static
constexpr
int
VEC_SIZE
=
16
/
sizeof
(
T
);
__shared__
float
s_val
;
float
trstd
;
P
norm_out
;
float
acc
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
float
v
=
static_cast
<
float
>
(
residual
.
data
[
i
]);
acc
+=
v
*
v
;
}
__shared__
float
r_sum
[
16
];
acc
=
BlockReduce
(
acc
,
r_sum
);
if
(
threadIdx
.
x
==
0
)
s_val
=
rsqrtf
(
acc
/
hidden_dim
+
eps
);
__syncthreads
();
trstd
=
s_val
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
norm_out
.
data
[
i
]
=
static_cast
<
T
>
(
static_cast
<
float
>
(
residual
.
data
[
i
])
*
trstd
*
static_cast
<
float
>
(
gamma
.
data
[
i
]));
}
return
norm_out
;
}
static
inline
__device__
int8_t
float_to_int8_rn
(
float
x
)
{
#ifdef USE_ROCM
static
constexpr
auto
i8_min
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
min
());
static
constexpr
auto
i8_max
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
float
dst
=
std
::
nearbyint
(
x
);
dst
=
(
dst
<
i8_min
)
?
i8_min
:
(
dst
>
i8_max
)
?
i8_max
:
dst
;
return
static_cast
<
int8_t
>
(
dst
);
#else
// CUDA path
uint32_t
dst
;
asm
volatile
(
"cvt.rni.sat.s8.f32 %0, %1;"
:
"=r"
(
dst
)
:
"f"
(
x
));
return
reinterpret_cast
<
const
int8_t
&>
(
dst
);
#endif
}
template
<
typename
T
,
int
reducesize
=
64
>
__inline__
__device__
T
WarpReduceMax
(
T
val
)
{
#pragma unroll
for
(
int
offset
=
reducesize
/
2
;
offset
>
0
;
offset
>>=
1
)
{
val
=
fmaxf
(
val
,
WARP_SHFL_DOWN
(
val
,
offset
));
}
return
val
;
}
template
<
typename
T
>
DINLINE
T
BlockReduceMax_ROW
(
T
val
,
T
*
shared
)
{
const
int
lid
=
threadIdx
.
x
%
64
;
const
int
wid
=
threadIdx
.
x
/
64
;
const
int
block_size
=
blockDim
.
x
;
const
int
shared_size
=
block_size
/
64
;
val
=
WarpReduceMax
<
T
>
(
val
);
if
(
block_size
==
64
)
return
val
;
if
(
lid
==
0
&&
wid
<
shared_size
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
&&
lid
<
shared_size
)
{
val
=
shared
[
lid
];
val
=
WarpReduceMax
<
T
,
16
>
(
val
);
}
return
val
;
}
template
<
typename
T
,
int
N
>
DINLINE
array_t
<
float
,
N
>
upcast
(
array_t
<
T
,
N
>
val
)
{
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
)
{
...
...
@@ -132,7 +256,7 @@ DINLINE array_t<float, N> upcast(array_t<T, N> val) {
array_t
<
float
,
N
>
out
;
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
out
.
data
[
i
]
=
upcast_s
(
val
.
data
[
i
]);
out
.
data
[
i
]
=
static_cast
<
float
>
(
val
.
data
[
i
]);
}
return
out
;
}
...
...
@@ -146,13 +270,13 @@ DINLINE O downcast(array_t<float, O::size> val) {
O
out
;
#pragma unroll
for
(
int
i
=
0
;
i
<
O
::
size
;
i
++
)
{
out
.
data
[
i
]
=
down
cast
_s
<
typename
O
::
type
>
(
val
.
data
[
i
]);
out
.
data
[
i
]
=
static_
cast
<
typename
O
::
type
>
(
val
.
data
[
i
]);
}
return
out
;
}
}
#if
!defined(USE_ROCM)
#if
0
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
...
...
@@ -243,18 +367,20 @@ DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
template
<
int
ngpus
>
DINLINE
void
barrier_at_start
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
//当前线程块标记+1
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
// flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// 将每个peer GPU对应线程块的本rank flag填入
__atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
);
// wait until we got true from all ranks
// while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
// __ATOMIC_RELAXED,
// __MEMORY_SCOPE_DEVICE) < flag);
//等待对应blockidx.x处理的数据的peer gpu到达
while
(
__atomic_load_n
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
)
<
flag
);
}
...
...
@@ -274,6 +400,7 @@ DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
// flag,
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
// __MEMORY_SCOPE_SYSTEM);
// 告诉其他GPU 本block Reduce完毕
__atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
);
// wait until we got true from all ranks
...
...
@@ -281,6 +408,7 @@ DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
// __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
// __MEMORY_SCOPE_DEVICE) < flag);
// 当前block处理的 hs的其他GPU处理完毕
while
(
__atomic_load_n
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
)
<
flag
);
...
...
@@ -290,6 +418,34 @@ DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
barrier_at_end_fuse
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
__syncthreads
();
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
// flag,
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
// __MEMORY_SCOPE_SYSTEM);
__atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
);
// wait until we got true from all ranks
// while (
// __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
// __MEMORY_SCOPE_DEVICE) < flag);
while
(
__atomic_load_n
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
)
<
flag
);
}
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
#endif
template
<
typename
P
,
int
ngpus
,
typename
A
>
...
...
@@ -325,6 +481,264 @@ DINLINE P* get_tmp_buf(Signal* sg) {
return
(
P
*
)(((
Signal
*
)
sg
)
+
1
);
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
1024
,
1
)
cross_device_reduce_2stage_fuse_norm
(
RankData
*
_dp
,
RankSignals
sg
,
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
,
int
hidden_dim
,
T
*
residual_in
,
T
*
rms_gamma
,
float
eps
,
std
::
array
<
int
,
ngpus
>
begin_tokens
,
std
::
array
<
int
,
ngpus
>
token_num_per_ranks
)
{
static
constexpr
int
VEC_SIZE
=
16
/
sizeof
(
T
);
int
H_D_word_num
=
hidden_dim
/
VEC_SIZE
;
int
token_id
=
blockIdx
.
x
;
// local token id
int
access_id_in_token
=
threadIdx
.
x
;
// 当前token内数据部分
int
token_stride
=
gridDim
.
x
;
//
int
access_id
=
token_id
*
H_D_word_num
+
access_id_in_token
;
// local token id * (token in size)
int
access_stride
=
token_stride
*
H_D_word_num
;
// gridDim.x * (token in size)
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
const
P
*
ptrs
[
ngpus
];
P
*
tmps
[
ngpus
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
++
i
)
{
int
target
=
(
rank
+
i
)
%
ngpus
;
ptrs
[
i
]
=
(
const
P
*
)
_dp
->
ptrs
[
target
];
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
int
start
=
begin_tokens
[
rank
]
*
H_D_word_num
;
int
part
=
(
begin_tokens
[
rank
]
+
token_num_per_ranks
[
rank
])
*
H_D_word_num
;
auto
tmp_out
=
tmps
[
0
];
// 当前rank的 (meta_data + sizeof(signal)) 偏移
barrier_at_start
<
ngpus
>
(
sg
,
self_sg
,
rank
);
#pragma unroll
for
(
int
idx
=
access_id
+
start
;
idx
<
part
;
idx
+=
access_stride
)
{
tmp_out
[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
#pragma unroll
for
(
int
r
=
0
;
r
<
ngpus
;
++
r
)
tmps
[
r
][
idx
]
=
tmp_out
[
idx
];
//将当前GPU处理的数据--->其他GPU的对应问题
}
barrier_at_end
<
ngpus
>
(
sg
,
self_sg
,
rank
);
//debug --- 验证reduce结果
// for (int r = 0; r < ngpus; ++r) {
// int cm_access_id = access_id + begin_tokens[r] * H_D_word_num;
// int cm_token_id = token_id + begin_tokens[r];
// int cm_token_access = (begin_tokens[r] + token_num_per_ranks[r]) * H_D_word_num;
// for (int idx = cm_access_id; idx < cm_token_access; idx += access_stride)
// ((P*)result)[idx] = tmp_out[idx];
// }
P
m_residual_val
,
m_gamm_val
;
m_gamm_val
=
((
P
*
)
rms_gamma
)[
access_id_in_token
];
#pragma unroll
for
(
int
r
=
0
;
r
<
ngpus
;
++
r
)
{
int
cm_access_id
=
access_id
+
begin_tokens
[
r
]
*
H_D_word_num
;
int
cm_token_id
=
token_id
+
begin_tokens
[
r
];
int
cm_tot_access
=
(
begin_tokens
[
r
]
+
token_num_per_ranks
[
r
])
*
H_D_word_num
;
for
(
int
idx
=
cm_access_id
;
idx
<
cm_tot_access
;
idx
+=
access_stride
)
{
P
sum_val
;
sum_val
=
tmp_out
[
idx
];
m_residual_val
=
((
P
*
)
residual_in
)[
idx
];
sum_val
=
vec_add
<
P
,
VEC_SIZE
>
(
sum_val
,
m_residual_val
);
sum_val
=
fused_add_rms_norm
<
T
,
P
,
A
>
(
sum_val
,
m_gamm_val
,
hidden_dim
,
eps
);
((
P
*
)
result
)[
idx
]
=
sum_val
;
}
}
}
template
<
typename
T
,
typename
T_out
,
int
ngpus
,
bool
isResidual
=
true
,
bool
update_input
=
false
>
__global__
void
__launch_bounds__
(
1024
,
1
)
cross_device_reduce_1stage_norm_quant
(
RankData
*
_dp
,
RankSignals
sg
,
Signal
*
self_sg
,
T_out
*
__restrict__
result
,
int
rank
,
int
size
,
int
hidden_dim
,
T
*
residual_in
,
T
*
rms_gamma
,
float
*
__restrict__
scales
,
float
eps
,
T
*
__restrict__
norm_res
)
{
// static constexpr int VEC_SIZE = 16 / sizeof(T);
static
constexpr
int
VEC_SIZE
=
packed_t
<
T
>::
P
::
size
;
int
H_D_word_num
=
hidden_dim
/
VEC_SIZE
;
int
token_id
=
blockIdx
.
x
;
int
access_id_in_token
=
threadIdx
.
x
;
int
token_stride
=
gridDim
.
x
;
int
access_id
=
token_id
*
H_D_word_num
+
access_id_in_token
;
int
access_stride
=
token_stride
*
H_D_word_num
;
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
using
F
=
typename
packed_t
<
T
>::
F
;
P
m_residual_val
,
m_gamm_val
;
m_gamm_val
=
reinterpret_cast
<
P
*>
(
rms_gamma
)[
access_id_in_token
];
auto
dp
=
*
_dp
;
P
sum_val
;
barrier_at_start
<
ngpus
>
(
sg
,
self_sg
,
rank
);
sum_val
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
access_id
);
barrier_at_end
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
if
constexpr
(
isResidual
)
{
m_residual_val
=
reinterpret_cast
<
P
*>
(
residual_in
)[
access_id
];
sum_val
=
vec_add
<
P
,
VEC_SIZE
>
(
m_residual_val
,
sum_val
);
((
P
*
)
residual_in
)[
access_id
]
=
sum_val
;
}
__shared__
float
s_val
;
P
norm_out
;
float
acc
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
float
v
=
static_cast
<
float
>
(
sum_val
.
data
[
i
]);
acc
+=
v
*
v
;
}
__shared__
float
r_sum
[
16
];
acc
=
BlockReduce
<
float
>
(
acc
,
r_sum
);
if
(
threadIdx
.
x
==
0
)
s_val
=
rsqrt
(
acc
/
hidden_dim
+
eps
);
__syncthreads
();
float
block_absmax_val_maybe
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
norm_out
.
data
[
i
]
=
static_cast
<
float
>
(
sum_val
.
data
[
i
])
*
s_val
*
static_cast
<
float
>
(
m_gamm_val
.
data
[
i
]);
block_absmax_val_maybe
=
fmaxf
(
block_absmax_val_maybe
,
fabs
(
norm_out
.
data
[
i
]));
}
block_absmax_val_maybe
=
BlockReduceMax_ROW
(
block_absmax_val_maybe
,
r_sum
);
//
__shared__
float
s_token_scale
;
float
scale
=
0.0
f
;
if
(
threadIdx
.
x
==
0
)
{
scale
=
block_absmax_val_maybe
;
s_token_scale
=
scale
;
}
__syncthreads
();
float
inv_s
=
(
s_token_scale
==
0.
f
)
?
0.
f
:
127.
f
/
s_token_scale
;
F
out_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
out_vec
.
data
[
i
]
=
float_to_int8_rn
(
norm_out
.
data
[
i
]
*
inv_s
);
constexpr
float
qmax
=
127.0
f
;
constexpr
float
min_scale
=
1.19209e-07
f
;
((
F
*
)
result
)[
access_id
]
=
out_vec
;
if
constexpr
(
update_input
)
((
P
*
)
norm_res
)[
access_id
]
=
norm_out
;
if
(
threadIdx
.
x
==
0
)
scales
[
blockIdx
.
x
]
=
fmaxf
(
scale
/
qmax
,
min_scale
);
}
template
<
typename
T
,
typename
T_out
,
int
ngpus
,
bool
isResidual
=
true
,
bool
update_input
=
false
>
__global__
void
__launch_bounds__
(
1024
,
1
)
cross_device_reduce_2stage_fuse_norm_quant
(
RankData
*
_dp
,
RankSignals
sg
,
Signal
*
self_sg
,
T_out
*
__restrict__
result
,
int
rank
,
int
size
,
int
hidden_dim
,
T
*
residual_in
,
T
*
rms_gamma
,
float
*
__restrict__
scales
,
float
eps
,
T
*
__restrict__
norm_res
,
std
::
array
<
int
,
ngpus
>
begin_tokens
,
std
::
array
<
int
,
ngpus
>
token_num_per_ranks
)
{
static
constexpr
int
VEC_SIZE
=
16
/
sizeof
(
T
);
int
H_D_word_num
=
hidden_dim
/
VEC_SIZE
;
int
token_id
=
blockIdx
.
x
;
// local token id
int
access_id_in_token
=
threadIdx
.
x
;
// 当前token内数据部分
int
token_stride
=
gridDim
.
x
;
//
int
access_id
=
token_id
*
H_D_word_num
+
access_id_in_token
;
// local token id * (token in size)
int
access_stride
=
token_stride
*
H_D_word_num
;
// gridDim.x * (token in size)
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
using
F
=
typename
packed_t
<
T
>::
F
;
const
P
*
ptrs
[
ngpus
];
P
*
tmps
[
ngpus
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
++
i
)
{
int
target
=
(
rank
+
i
)
%
ngpus
;
ptrs
[
i
]
=
(
const
P
*
)
_dp
->
ptrs
[
target
];
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
int
start
=
begin_tokens
[
rank
]
*
H_D_word_num
;
int
part
=
(
begin_tokens
[
rank
]
+
token_num_per_ranks
[
rank
])
*
H_D_word_num
;
auto
tmp_out
=
tmps
[
0
];
// 当前rank的 (meta_data + sizeof(signal)) 偏移
auto
input
=
ptrs
[
0
];
barrier_at_start
<
ngpus
>
(
sg
,
self_sg
,
rank
);
#pragma unroll
for
(
int
idx
=
access_id
+
start
;
idx
<
part
;
idx
+=
access_stride
)
{
tmp_out
[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
#pragma unroll
for
(
int
r
=
0
;
r
<
ngpus
;
++
r
)
tmps
[
r
][
idx
]
=
tmp_out
[
idx
];
//将当前GPU处理的数据--->其他GPU的对应问题
}
barrier_at_end
<
ngpus
>
(
sg
,
self_sg
,
rank
);
P
m_residual_val
,
m_gamm_val
;
m_gamm_val
=
reinterpret_cast
<
P
*>
(
rms_gamma
)[
access_id_in_token
];
#pragma unroll
for
(
int
r
=
0
;
r
<
ngpus
;
++
r
)
{
int
cm_access_id
=
access_id
+
begin_tokens
[
r
]
*
H_D_word_num
;
int
cm_token_id
=
token_id
+
begin_tokens
[
r
];
int
cm_tot_access
=
(
begin_tokens
[
r
]
+
token_num_per_ranks
[
r
])
*
H_D_word_num
;
for
(
int
idx
=
cm_access_id
,
tidx
=
cm_token_id
;
idx
<
cm_tot_access
;
idx
+=
access_stride
,
tidx
+=
token_stride
)
{
P
sum_val
;
sum_val
=
tmp_out
[
idx
];
if
constexpr
(
isResidual
)
{
m_residual_val
=
reinterpret_cast
<
P
*>
(
residual_in
)[
idx
];
sum_val
=
vec_add
<
P
,
VEC_SIZE
>
(
sum_val
,
m_residual_val
);
((
P
*
)
residual_in
)[
idx
]
=
sum_val
;
}
__shared__
float
s_val
;
P
norm_out
;
float
acc
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
float
v
=
static_cast
<
float
>
(
sum_val
.
data
[
i
]);
acc
+=
v
*
v
;
}
__shared__
float
r_sum
[
16
];
acc
=
BlockReduce
(
acc
,
r_sum
);
if
(
threadIdx
.
x
==
0
)
s_val
=
rsqrtf
(
acc
/
hidden_dim
+
eps
);
__syncthreads
();
float
block_absmax_val_maybe
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
norm_out
.
data
[
i
]
=
static_cast
<
T
>
(
static_cast
<
float
>
(
sum_val
.
data
[
i
])
*
s_val
*
static_cast
<
float
>
(
m_gamm_val
.
data
[
i
]));
block_absmax_val_maybe
=
fmaxf
(
block_absmax_val_maybe
,
fabs
(
norm_out
.
data
[
i
]));
}
block_absmax_val_maybe
=
BlockReduceMax_ROW
(
block_absmax_val_maybe
,
r_sum
);
__shared__
float
s_token_scale
;
float
scale
=
0.0
f
;
if
(
threadIdx
.
x
==
0
)
{
scale
=
block_absmax_val_maybe
;
s_token_scale
=
scale
;
}
__syncthreads
();
float
inv_s
=
(
s_token_scale
==
0.
f
)
?
0.
f
:
127.
f
/
s_token_scale
;
F
out_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
out_vec
.
data
[
i
]
=
float_to_int8_rn
(
norm_out
.
data
[
i
]
*
inv_s
);
constexpr
float
qmax
=
127.0
f
;
constexpr
float
min_scale
=
1.19209e-07
f
;
((
F
*
)
result
)[
idx
]
=
out_vec
;
if
constexpr
(
update_input
)
((
P
*
)
norm_res
)[
idx
]
=
norm_out
;
if
(
threadIdx
.
x
==
0
)
scales
[
tidx
]
=
fmaxf
(
scale
/
qmax
,
min_scale
);
}
}
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
Signal
*
self_sg
,
...
...
@@ -685,6 +1099,177 @@ class CustomAllreduce {
* only take a small amount of SMs. Not quite sure the underlying reason,
* but my guess is that too many SMs will cause contention on NVLink bus.
*/
template
<
typename
T
>
void
allreduce_fuse_norm
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
int
token_num
,
int
hidden_dim
,
T
*
residual
,
T
*
rms_weight
,
double
eps
,
int
threads
=
512
,
int
block_limit
=
defaultBlockLimit
)
{
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
hidden_dim
%
d
!=
0
)
throw
std
::
runtime_error
(
"custom allreduce currently requires input length to be multiple "
"of "
+
std
::
to_string
(
d
));
if
(
block_limit
>
kMaxBlocks
)
throw
std
::
runtime_error
(
"max supported block limit is "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
RankData
*
ptrs
;
cudaStreamCaptureStatus
status
;
CUDACHECK
(
cudaStreamIsCapturing
(
stream
,
&
status
));
if
(
status
==
cudaStreamCaptureStatusActive
)
{
ptrs
=
d_rank_data_base_
+
graph_unreg_buffers_
.
size
();
graph_unreg_buffers_
.
push_back
(
input
);
}
else
{
auto
it
=
buffers_
.
find
(
input
);
if
(
it
==
buffers_
.
end
())
throw
std
::
runtime_error
(
"buffer address "
+
std
::
to_string
(
reinterpret_cast
<
uint64_t
>
(
input
))
+
" is not registered!"
);
ptrs
=
it
->
second
;
}
int
block_num
=
token_num
;
#define KL(ngpus, name) \
std::array<int, ngpus> begin_tokens, token_num_per_ranks; \
int remaining_token = token_num % ngpus; \
int token_num_per_rank = token_num / ngpus; \
block_num = token_num_per_rank; \
if (remaining_token) \
block_num++; \
for (int i = 0; i < ngpus; ++i) { \
begin_tokens[i] = i * token_num_per_rank + (remaining_token > i ? i : remaining_token); \
token_num_per_ranks[i] = token_num_per_rank + (remaining_token > i ? 1 : 0); \
} \
int thread_per_token = hidden_dim / d; \
int grid_size = std::min(kMaxBlocks, block_num); \
int threads_in_block = thread_per_token; \
name<T, ngpus><<<grid_size, threads_in_block, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size, hidden_dim, residual, \
rms_weight, eps, begin_tokens, token_num_per_ranks);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_2stage_fuse_norm); \
} else if (fully_connected_) { \
if ((world_size_ <= 4) || \
(world_size_ <= 8 )) { \
KL(ngpus, cross_device_reduce_2stage_fuse_norm); \
} else { \
KL(ngpus, cross_device_reduce_2stage_fuse_norm); \
} \
} \
break; \
}
switch
(
world_size_
)
{
REDUCE_CASE
(
2
)
REDUCE_CASE
(
4
)
REDUCE_CASE
(
6
)
REDUCE_CASE
(
8
)
default:
throw
std
::
runtime_error
(
"custom allreduce only supports num gpus in (2,4,6,8). Actual "
"num "
"gpus = "
+
std
::
to_string
(
world_size_
));
}
#undef REDUCE_CASE
#undef KL
}
template
<
typename
scalar_in_t
,
typename
scalar_out_t
,
bool
isResidual
=
true
,
bool
update_input
=
false
>
void
allreduce_fuse_norm_quant
(
cudaStream_t
stream
,
scalar_in_t
*
input
,
scalar_out_t
*
output
,
int
size
,
int
token_num
,
int
hidden_dim
,
scalar_in_t
*
residual
,
scalar_in_t
*
rms_weight
,
scalar_in_t
*
norm_out
,
double
eps
,
float
*
scales
,
int
threads
=
512
,
int
block_limit
=
defaultBlockLimit
)
{
auto
d
=
packed_t
<
scalar_in_t
>::
P
::
size
;
if
(
hidden_dim
%
d
!=
0
)
throw
std
::
runtime_error
(
"custom allreduce currently requires input length to be multiple "
"of "
+
std
::
to_string
(
d
));
if
(
block_limit
>
kMaxBlocks
)
throw
std
::
runtime_error
(
"max supported block limit is "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
RankData
*
ptrs
;
cudaStreamCaptureStatus
status
;
CUDACHECK
(
cudaStreamIsCapturing
(
stream
,
&
status
));
if
(
status
==
cudaStreamCaptureStatusActive
)
{
ptrs
=
d_rank_data_base_
+
graph_unreg_buffers_
.
size
();
graph_unreg_buffers_
.
push_back
(
input
);
}
else
{
auto
it
=
buffers_
.
find
(
input
);
if
(
it
==
buffers_
.
end
())
throw
std
::
runtime_error
(
"buffer address "
+
std
::
to_string
(
reinterpret_cast
<
uint64_t
>
(
input
))
+
" is not registered!"
);
ptrs
=
it
->
second
;
}
int
block_num
=
token_num
;
int
thread_per_token
=
hidden_dim
/
d
;
auto
bytes
=
(
size
/
d
)
*
sizeof
(
typename
packed_t
<
scalar_in_t
>::
P
);
#define KL1(ngpus, name) \
std::array<int, ngpus> begin_tokens, token_num_per_ranks; \
int remaining_token = token_num % ngpus; \
int token_num_per_rank = token_num / ngpus; \
block_num = token_num_per_rank; \
if (remaining_token) \
block_num++; \
for (int i = 0; i < ngpus; ++i) { \
begin_tokens[i] = i * token_num_per_rank + (remaining_token > i ? i : remaining_token); \
token_num_per_ranks[i] = token_num_per_rank + (remaining_token > i ? 1 : 0); \
} \
int
grid_size
=
std
::
min
(
kMaxBlocks
,
block_num
);
\
int
threads_in_block
=
thread_per_token
;
\
name
<
scalar_in_t
,
scalar_out_t
,
ngpus
,
isResidual
,
update_input
><<<
block_num
,
threads_in_block
,
0
,
stream
>>>
(
ptrs
,
sg_
,
\
self_sg_
,
output
,
rank_
,
size
,
hidden_dim
,
residual
,
\
rms_weight
,
scales
,
eps
,
norm_out
,
begin_tokens
,
token_num_per_ranks
);
#define KL(ngpus, name) \
name<scalar_in_t, scalar_out_t, ngpus, isResidual, update_input><<<block_num, thread_per_token, 0, stream>>>(ptrs, sg_, \
self_sg_, output, rank_, size, hidden_dim, residual, rms_weight, \
scales, eps, norm_out);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage_norm_quant); \
} else if (fully_connected_) { \
if ((world_size_ <= 4 && bytes < 1024 * 1024) || \
(world_size_ <= 8 && bytes < 512 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage_norm_quant); \
} else { \
KL1(ngpus, cross_device_reduce_2stage_fuse_norm_quant); \
} \
} \
break; \
}
switch
(
world_size_
)
{
REDUCE_CASE
(
2
)
REDUCE_CASE
(
4
)
REDUCE_CASE
(
6
)
REDUCE_CASE
(
8
)
default:
throw
std
::
runtime_error
(
"custom allreduce only supports num gpus in (2,4,6,8). Actual "
"num "
"gpus = "
+
std
::
to_string
(
world_size_
));
}
#undef REDUCE_CASE
#undef KL
}
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
int
threads
=
512
,
int
block_limit
=
defaultBlockLimit
)
{
...
...
@@ -766,4 +1351,4 @@ class CustomAllreduce {
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int);
*/
}
// namespace vllm
\ No newline at end of file
}
// namespace vllm
csrc/custom_all_reduce_test.cu
View file @
3e6729e0
...
...
@@ -18,6 +18,7 @@
#include <limits>
#include <vector>
#include <random>
#include "cuda_profiler_api.h"
#include "custom_all_reduce.cuh"
...
...
@@ -117,16 +118,113 @@ __global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
ground_truth
[
idx
]
=
sum
;
}
}
/*************************************************/
template
<
typename
T
,
int
reducesize
=
64
>
__inline__
__device__
T
WarpReduceSum_NEW
(
T
val
)
{
#pragma unroll
for
(
int
offset
=
reducesize
/
2
;
offset
>
0
;
offset
>>=
1
)
{
val
+=
__shfl_down
(
val
,
offset
);
}
return
val
;
}
template
<
typename
T
,
int
block_size
=
512
>
__inline__
__device__
T
BlockReduceSum_NEW
(
T
val
,
T
*
shared
)
{
constexpr
int
share_size
=
block_size
/
64
;
val
=
WarpReduceSum_NEW
<
T
>
(
val
);
if
constexpr
(
block_size
==
64
)
{
return
val
;
}
else
{
const
int
lid
=
threadIdx
.
x
%
64
;
const
int
wid
=
threadIdx
.
x
/
64
;
if
(
lid
==
0
&&
wid
<
share_size
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
&&
lid
<
share_size
)
{
val
=
WarpReduceSum_NEW
<
T
,
share_size
>
(
shared
[
lid
]);
}
return
val
;
}
}
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
__global__
void
fused_add_rms_kernel_opt
(
scalar_t
*
input
,
scalar_t
*
residual
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
{
constexpr
int
share_size
=
block_size
/
64
;
__shared__
T_ACC
val_shared
[
share_size
];
__shared__
T_ACC
s_rstd
;
T_ACC
val
=
0
;
int
i
=
blockIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
tcol
=
cols
/
Vec
;
using
LoadT
=
typename
vllm
::
packed_t
<
scalar_t
>::
P
;
scalar_t
intput_vec
[
Vec
];
scalar_t
residual_vec
[
Vec
];
T_ACC
trstd
;
int64_t
idx
=
i
*
tcol
+
j
;
idx
*=
Vec
;
if
(
j
<
tcol
)
{
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
*
(
LoadT
*
)
residual_vec
=
*
(
LoadT
*
)(
residual
+
idx
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
residual_vec
[
ii
]
+=
intput_vec
[
ii
];
val
+=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
residual_vec
[
ii
]);
}
}
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
if
(
j
==
0
)
s_rstd
=
rsqrtf
(
val
/
cols
+
eps
);
__syncthreads
();
trstd
=
s_rstd
;
if
(
j
<
tcol
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
int
jj
=
j
*
Vec
+
ii
;
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
}
*
(
LoadT
*
)(
residual
+
idx
)
=*
(
LoadT
*
)
residual_vec
;
*
(
LoadT
*
)(
input
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
}
template
<
typename
scalar_t
>
void
fused_add_rms_norm_choose
(
cudaStream_t
stream
,
scalar_t
*
self_data
,
scalar_t
*
other_data
,
scalar_t
*
weight_data
,
double
eps
,
int
hidden_size
,
int
num_tokens
)
{
if
(
hidden_size
<=
1024
){
fused_add_rms_kernel_opt
<
scalar_t
,
float
,
8
,
128
><<<
num_tokens
,
128
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
2048
){
fused_add_rms_kernel_opt
<
scalar_t
,
float
,
8
,
256
><<<
num_tokens
,
256
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
4096
){
if
(
num_tokens
>
1200
){
fused_add_rms_kernel_opt
<
scalar_t
,
float
,
8
,
512
><<<
num_tokens
,
512
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_add_rms_kernel_opt
<
scalar_t
,
float
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
if
(
hidden_size
<=
8192
){
fused_add_rms_kernel_opt
<
scalar_t
,
float
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_add_rms_kernel_opt
<
scalar_t
,
float
,
16
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
}
/*****************************************************************/
template
<
typename
T
>
void
run
(
int
myRank
,
int
nRanks
,
ncclComm_t
&
comm
,
int
threads
,
int
block_limit
,
int
data_size
,
bool
performance_test
)
{
T
*
result
;
int
data_size
,
bool
performance_test
,
int
hidden_dim
)
{
T
*
result
_ori
,
*
result_fuse
;
cudaStream_t
stream
;
CUDACHECK
(
cudaStreamCreateWithFlags
(
&
stream
,
cudaStreamNonBlocking
));
CUDACHECK
(
cudaMalloc
(
&
result
,
data_size
*
sizeof
(
T
)));
CUDACHECK
(
cudaMemset
(
result
,
0
,
data_size
*
sizeof
(
T
)));
CUDACHECK
(
cudaMalloc
(
&
result_ori
,
data_size
*
sizeof
(
T
)));
CUDACHECK
(
cudaMemset
(
result_ori
,
0
,
data_size
*
sizeof
(
T
)));
CUDACHECK
(
cudaMalloc
(
&
result_fuse
,
data_size
*
sizeof
(
T
)));
CUDACHECK
(
cudaMemset
(
result_fuse
,
0
,
data_size
*
sizeof
(
T
)));
cudaIpcMemHandle_t
self_data_handle
;
cudaIpcMemHandle_t
data_handles
[
8
];
vllm
::
Signal
*
buffer
;
...
...
@@ -176,7 +274,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
// hack buffer registration
{
void
*
data
[
8
];
void
*
data
[
8
];
//gpu数据部分
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
data
[
i
]
=
((
char
*
)
ipc_ptrs
[
i
])
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
);
...
...
@@ -196,7 +294,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
cudaEvent_t
start
,
stop
;
CUDACHECK
(
cudaEventCreate
(
&
start
));
CUDACHECK
(
cudaEventCreate
(
&
stop
));
/*******************************/
int
token_num
=
data_size
/
hidden_dim
;
T
*
residual_h
,
*
residual_d
,
*
weight_h
,
*
weight_d
;
residual_h
=
(
T
*
)
malloc
(
data_size
*
sizeof
(
T
));
std
::
random_device
rd
;
// 用于获取随机数种子
std
::
mt19937
gen
(
7
);
std
::
uniform_real_distribution
<
float
>
dis
(
-
3.0
f
,
3.0
f
);
for
(
int
i
=
0
;
i
<
data_size
;
++
i
)
residual_h
[
i
]
=
static_cast
<
T
>
(
dis
(
gen
));
for
(
int
i
=
0
;
i
<
hidden_dim
;
++
i
)
weight_h
[
i
]
=
static_cast
<
T
>
(
dis
(
gen
));
cudaMalloc
((
void
**
)
&
residual_d
,
sizeof
(
T
)
*
data_size
);
cudaMalloc
((
void
**
)
&
weight_d
,
sizeof
(
T
)
*
hidden_dim
);
cudaMemcpyAsync
(
residual_d
,
residual_h
,
sizeof
(
T
)
*
data_size
,
cudaMemcpyHostToDevice
,
stream
);
cudaMemcpyAsync
(
weight_d
,
weight_h
,
sizeof
(
T
)
*
hidden_dim
,
cudaMemcpyHostToDevice
,
stream
);
float
eps
=
1.0
f
;
/*******************************/
ncclDataType_t
ncclDtype
;
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
ncclDtype
=
ncclFloat16
;
...
...
@@ -211,16 +328,16 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
if
(
performance_test
)
{
dummy_kernel
<<<
1
,
1
,
0
,
stream
>>>
();
constexpr
int
warmup_iters
=
5
;
constexpr
int
num_iters
=
10
0
;
constexpr
int
num_iters
=
10
;
// warmup
for
(
int
i
=
0
;
i
<
warmup_iters
;
i
++
)
{
NCCLCHECK
(
ncclA
ll
R
educe
(
result
,
result
,
data_size
,
ncclDtype
,
ncclSum
,
comm
,
stream
)
);
fa
.
a
ll
r
educe
<
T
>
(
stream
,
self_data
,
result
_ori
,
data_size
,
threads
,
block_limit
);
fused_add_rms_norm_choose
<
T
>
(
stream
,
result_ori
,
residual_d
,
weight_d
,
1.0
,
hidden_dim
,
token_num
);
}
CUDACHECK
(
cudaEventRecord
(
start
,
stream
));
for
(
int
i
=
0
;
i
<
num_iters
;
i
++
)
{
NCCLCHECK
(
ncclA
ll
R
educe
(
result
,
result
,
data_size
,
ncclDtype
,
ncclSum
,
comm
,
stream
)
);
fa
.
a
ll
r
educe
<
T
>
(
stream
,
self_data
,
result
_ori
,
data_size
,
threads
,
block_limit
);
fused_add_rms_norm_choose
<
T
>
(
stream
,
result_ori
,
residual_d
,
weight_d
,
1.0
,
hidden_dim
,
token_num
);
}
CUDACHECK
(
cudaEventRecord
(
stop
,
stream
));
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
...
...
@@ -230,13 +347,15 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
dummy_kernel
<<<
1
,
1
,
0
,
stream
>>>
();
// warm up
for
(
int
i
=
0
;
i
<
warmup_iters
;
i
++
)
{
fa
.
allreduce
<
T
>
(
stream
,
self_data
,
result
,
data_size
,
threads
,
block_limit
);
fa
.
allreduce_fuse_norm
<
T
>
(
stream
,
self_data
,
result_fuse
,
data_size
,
token_num
,
hidden_dim
,
residual_d
,
weight_d
,
eps
,
threads
,
block_limit
);
}
CUDACHECK
(
cudaEventRecord
(
start
,
stream
));
for
(
int
i
=
0
;
i
<
num_iters
;
i
++
)
{
fa
.
allreduce
<
T
>
(
stream
,
self_data
,
result
,
data_size
,
threads
,
block_limit
);
fa
.
allreduce_fuse_norm
<
T
>
(
stream
,
self_data
,
result_fuse
,
data_size
,
token_num
,
hidden_dim
,
residual_d
,
weight_d
,
eps
,
threads
,
block_limit
);
}
CUDACHECK
(
cudaEventRecord
(
stop
,
stream
));
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
...
...
@@ -245,7 +364,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
cudaEventElapsedTime
(
&
duration_ms
,
start
,
stop
);
if
(
myRank
==
0
)
printf
(
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d,
my time:%.2fus, nccl
"
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d,
allreduse_fuse_norm time:%.2fus, allreduce+norm
"
"time:%.2fus
\n
"
,
myRank
,
nRanks
,
data_size
*
sizeof
(
T
)
/
1024
,
threads
,
block_limit
,
duration_ms
*
1e3
/
num_iters
,
allreduce_ms
*
1e3
/
num_iters
);
...
...
@@ -255,8 +374,9 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
NCCLCHECK
(
ncclAllReduce
(
self_data_copy
,
self_data
,
data_size
,
ncclDtype
,
ncclSum
,
comm
,
stream
));
fused_add_rms_norm_choose
<
T
>
(
stream
,
self_data
,
residual_d
,
weight_d
,
1.0
,
hidden_dim
,
token_num
);
convert_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
self_data
,
result
,
nccl_result
,
convert_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
result_ori
,
result
_fuse
,
nccl_result
,
my_result
,
data_size
);
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
...
...
@@ -279,13 +399,13 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
<<
" me: "
<<
my_diffs
/
data_size
<<
std
::
endl
;
}
else
{
for
(
int
i
=
0
;
i
<
100
;
i
++
)
{
fa
.
allreduce
<
T
>
(
stream
,
self_data
,
result
,
data_size
,
threads
,
fa
.
allreduce
<
T
>
(
stream
,
self_data
,
result
_ori
,
data_size
,
threads
,
block_limit
);
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
NCCLCHECK
(
ncclAllReduce
(
self_data
,
self_data_copy
,
data_size
,
ncclDtype
,
ncclSum
,
comm
,
stream
));
convert_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
self_data_copy
,
result
,
nccl_result
,
my_result
,
data_size
);
self_data_copy
,
result
_ori
,
nccl_result
,
my_result
,
data_size
);
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
for
(
unsigned
long
j
=
0
;
j
<
data_size
;
j
++
)
{
...
...
@@ -312,7 +432,8 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
// << " me: " << my_diffs / data_size << std::endl;
}
CUDACHECK
(
cudaFree
(
result
));
CUDACHECK
(
cudaFree
(
result_ori
));
CUDACHECK
(
cudaFree
(
result_fuse
));
CUDACHECK
(
cudaFree
(
self_data_copy
));
CUDACHECK
(
cudaFree
(
rank_data
));
CUDACHECK
(
cudaFree
(
buffer
));
...
...
@@ -351,9 +472,7 @@ int main(int argc, char** argv) {
const
int
block_limit
=
36
;
#endif
// Scan through different sizes to test performance.
for
(
int
sz
=
512
;
sz
<=
(
8
<<
20
);
sz
*=
2
)
{
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
36
,
sz
+
8
*
47
,
performance_test
);
}
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
36
,
7168
*
80
,
performance_test
,
7168
);
cudaProfilerStop
();
MPICHECK
(
MPI_Finalize
());
...
...
csrc/ops.h
View file @
3e6729e0
...
...
@@ -487,6 +487,17 @@ fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
bool
fully_connected
);
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
);
void
all_reduce_fuse_norm
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int64_t
hidden_size
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
rms_weight
,
double
eps
,
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
);
void
all_reduce_fuse_norm_quant
(
fptr_t
fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
int64_t
hidden_size
,
torch
::
Tensor
&
rms_weight
,
double
eps
,
torch
::
Tensor
&
scales
,
torch
::
Tensor
&
norm_out
,
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
,
std
::
optional
<
at
::
Tensor
>
residual
,
bool
update_input
);
void
dispose
(
fptr_t
_fa
);
int64_t
meta_size
();
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
);
...
...
csrc/torch_bindings.cpp
View file @
3e6729e0
...
...
@@ -933,6 +933,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()"
);
custom_ar
.
impl
(
"all_reduce"
,
torch
::
kCUDA
,
&
all_reduce
);
custom_ar
.
def
(
"all_reduce_fuse_norm(int fa, Tensor inp, Tensor! out, int hidden_size, "
"Tensor residual, Tensor rms_weight, float eps, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()"
);
custom_ar
.
impl
(
"all_reduce_fuse_norm"
,
torch
::
kCUDA
,
&
all_reduce_fuse_norm
);
custom_ar
.
def
(
"all_reduce_fuse_norm_quant(int fa, Tensor inp, Tensor! out, int hidden_size, "
"Tensor rms_weight, float eps, Tensor! scales, Tensor! norm_out, int reg_buffer, "
"int reg_buffer_sz_bytes, Tensor? residual, bool update_input) -> ()"
);
custom_ar
.
impl
(
"all_reduce_fuse_norm_quant"
,
torch
::
kCUDA
,
&
all_reduce_fuse_norm_quant
);
custom_ar
.
def
(
"dispose"
,
&
dispose
);
custom_ar
.
def
(
"meta_size"
,
&
meta_size
);
...
...
vllm/_custom_ops.py
View file @
3e6729e0
...
...
@@ -2212,7 +2212,17 @@ def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
reg_buffer_sz_bytes
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
all_reduce_fuse_norm
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
hidden_size
:
int
,
residual
:
torch
.
Tensor
,
rms_weight
:
torch
.
Tensor
,
eps
:
float
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce_fuse_norm
(
fa
,
inp
,
out
,
hidden_size
,
residual
,
rms_weight
,
eps
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
all_reduce_fuse_norm_quant
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
hidden_size
:
int
,
rms_weight
:
torch
.
Tensor
,
eps
:
float
,
scales
:
torch
.
Tensor
,
norm_out
:
torch
.
Tensor
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
,
residual
:
torch
.
Tensor
,
update_input
:
bool
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce_fuse_norm_quant
(
fa
,
inp
,
out
,
hidden_size
,
rms_weight
,
eps
,
scales
,
norm_out
,
reg_buffer
,
reg_buffer_sz_bytes
,
residual
,
update_input
)
def
dispose
(
fa
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
dispose
(
fa
)
...
...
vllm/distributed/communication_op.py
View file @
3e6729e0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
,
Tuple
import
torch
import
torch.distributed
from
.parallel_state
import
get_tp_group
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""All-reduce the input tensor across model parallel group."""
return
get_tp_group
().
all_reduce
(
input_
)
def
tensor_model_parallel_all_reduce_crp_m32
(
input_
:
torch
.
Tensor
,
pa_rms_weight
:
torch
.
Tensor
,
pa_residual
:
torch
.
Tensor
,
pa_rms_eps
:
float
,
pa_quant_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
int8
,
update_input
:
Optional
[
bool
]
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""All-reduce the input tensor across model parallel group."""
# allreduce fused rms and quant
return
get_tp_group
().
all_reduce_crq_m32
(
input_
=
input_
,
pa_rms_weight
=
pa_rms_weight
,
pa_residual
=
pa_residual
,
pa_rms_eps
=
pa_rms_eps
,
pa_quant_dtype
=
pa_quant_dtype
,
update_input
=
update_input
)
def
tensor_model_parallel_all_gather
(
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
...
...
vllm/distributed/device_communicators/cuda_communicator.py
View file @
3e6729e0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
import
torch
from
torch.distributed
import
ProcessGroup
...
...
@@ -14,6 +14,7 @@ from .base_device_communicator import DeviceCommunicatorBase
logger
=
init_logger
(
__name__
)
from
lmslim.quantize.quant_ops
import
lm_faster_rmsquant
class
CudaCommunicator
(
DeviceCommunicatorBase
):
...
...
@@ -116,6 +117,37 @@ class CudaCommunicator(DeviceCommunicatorBase):
out
=
input_
.
clone
()
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
device_group
)
return
out
def
all_reduce_rms_quant_m32
(
self
,
input_
,
pa_rms_weight
:
torch
.
Tensor
,
pa_residual
:
torch
.
Tensor
,
pa_rms_eps
:
float
,
pa_quant_dtype
:
torch
.
dtype
,
update_input
:
Optional
[
bool
]
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
batch_size
,
hidden_dim
=
input_
.
shape
ca_comm
=
self
.
ca_comm
assert
ca_comm
is
not
None
and
not
ca_comm
.
disabled
assert
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
\
pa_rms_weight
is
not
None
and
pa_residual
is
not
None
if
batch_size
<=
16
:
xq
,
xs
,
norm_out
=
ca_comm
.
custom_all_reduce_fuse_norm_quant
(
inp
=
input_
,
rms_weight
=
pa_rms_weight
,
residual
=
pa_residual
,
eps
=
pa_rms_eps
,
quant_type
=
pa_quant_dtype
,
update_input
=
True
)
input_
=
norm_out
else
:
input_
=
self
.
all_reduce
(
input_
)
xq
,
xs
=
lm_faster_rmsquant
(
input_
,
rms_weight
=
pa_rms_weight
,
residual
=
pa_residual
,
epsilon
=
pa_rms_eps
,
quant_dtype
=
pa_quant_dtype
,
update_input
=
True
)
assert
input_
is
not
None
assert
xq
is
not
None
and
xs
is
not
None
return
input_
,
pa_residual
,
xq
,
xs
def
reduce_scatter
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
):
world_size
=
self
.
world_size
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
3e6729e0
...
...
@@ -275,6 +275,91 @@ class CustomAllreduce:
# latency) compared to the performance gain of using custom kernels
return
self
.
all_reduce
(
input
,
registered
=
False
)
def
allreduce_fuse_norm
(
self
,
inp
:
torch
.
Tensor
,
hidden_size
:
int
,
residual
:
torch
.
Tensor
,
rms_weight
:
torch
.
Tensor
,
eps
:
float
,
*
,
out
:
torch
.
Tensor
=
None
,
registered
:
bool
=
False
):
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
if
registered
:
ops
.
all_reduce_fuse_norm
(
self
.
_ptr
,
inp
,
out
,
hidden_size
,
residual
,
rms_weight
,
eps
,
0
,
0
)
else
:
ops
.
all_reduce_fuse_norm
(
self
.
_ptr
,
inp
,
out
,
hidden_size
,
residual
,
rms_weight
,
eps
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
return
out
def
custom_all_reduce_fuse_norm
(
self
,
input
:
torch
.
Tensor
,
hidden_size
:
int
,
residual
:
torch
.
Tensor
,
rms_weight
:
torch
.
Tensor
,
eps
:
float
)
->
Optional
[
torch
.
Tensor
]:
if
self
.
disabled
or
not
self
.
should_custom_ar
(
input
):
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
allreduce_fuse_norm
(
input
,
hidden_size
,
residual
,
rms_weight
,
eps
,
registered
=
False
)
else
:
return
torch
.
empty_like
(
input
)
else
:
return
self
.
allreduce_fuse_norm
(
input
,
hidden_size
,
residual
,
rms_weight
,
eps
,
registered
=
False
)
def
allreduce_fuse_norm_quant
(
self
,
inp
:
torch
.
Tensor
,
hidden_size
:
int
,
rms_weight
,
eps
,
quant_dtype
,
residual
,
update_input
:
bool
=
True
,
registered
:
bool
=
False
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
xq
=
torch
.
empty_like
(
inp
,
dtype
=
quant_dtype
)
norm_out
=
torch
.
empty_like
(
inp
)
scales
=
torch
.
empty
((
inp
.
numel
()
//
inp
.
shape
[
-
1
],
1
),
device
=
inp
.
device
,
dtype
=
torch
.
float32
)
if
registered
:
ops
.
all_reduce_fuse_norm_quant
(
self
.
_ptr
,
inp
,
xq
,
hidden_size
,
rms_weight
,
eps
,
scales
,
norm_out
,
0
,
0
,
residual
,
update_input
)
else
:
ops
.
all_reduce_fuse_norm_quant
(
self
.
_ptr
,
inp
,
xq
,
hidden_size
,
rms_weight
,
eps
,
scales
,
norm_out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
,
residual
,
update_input
)
return
xq
,
scales
,
norm_out
def
custom_all_reduce_fuse_norm_quant
(
self
,
inp
:
torch
.
Tensor
,
rms_weight
:
torch
.
Tensor
,
eps
,
quant_type
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_input
=
True
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_size
=
inp
.
shape
[
-
1
]
if
self
.
disabled
or
not
self
.
should_custom_ar
(
inp
):
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
allreduce_fuse_norm_quant
(
inp
,
hidden_size
,
rms_weight
,
eps
,
quant_type
,
residual
,
update_input
=
update_input
,
registered
=
False
)
else
:
return
torch
.
empty_like
(
inp
,
dtype
=
quant_type
),
\
torch
.
empty
((
inp
.
numel
()
//
inp
.
shape
[
-
1
],
1
),
dtype
=
torch
.
float32
,
device
=
inp
.
device
),
\
torch
.
empty_like
(
inp
)
else
:
return
self
.
allreduce_fuse_norm_quant
(
inp
,
hidden_size
,
rms_weight
,
eps
,
quant_type
,
residual
,
update_input
=
update_input
,
registered
=
False
)
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
if
ops
is
not
None
:
...
...
vllm/distributed/parallel_state.py
View file @
3e6729e0
...
...
@@ -30,7 +30,7 @@ from collections import namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
multiprocessing
import
shared_memory
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
from
unittest.mock
import
patch
import
torch
...
...
@@ -114,6 +114,37 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return
torch
.
empty_like
(
tensor
)
def
all_reduce_rms_quant
(
input_
:
torch
.
Tensor
,
group_name
:
str
,
pa_rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_residual
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_rms_eps
:
Optional
[
float
]
=
1e-6
,
pa_quant_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
int8
,
update_input
:
Optional
[
bool
]
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
_all_reduce_out_place_m32
(
input_
,
pa_rms_weight
=
pa_rms_weight
,
pa_residual
=
pa_residual
,
pa_rms_eps
=
pa_rms_eps
,
pa_quant_dtype
=
pa_quant_dtype
,
update_input
=
update_input
)
def
all_reduce_rms_quant_fake
(
input_
:
torch
.
Tensor
,
group_name
:
str
,
pa_rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_residual
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_rms_eps
:
Optional
[
float
]
=
1e-6
,
pa_quant_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
int8
,
update_input
:
Optional
[
bool
]
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
xq
=
torch
.
zeros_like
(
input_
,
dtype
=
pa_quant_dtype
)
xs
=
torch
.
ones
((
input_
.
numel
()
//
input_
.
shape
[
-
1
],
1
),
device
=
input_
.
device
,
dtype
=
torch
.
float32
)
return
input_
,
pa_residual
,
xq
,
xs
def
reduce_scatter
(
tensor
:
torch
.
Tensor
,
dim
:
int
,
world_size
:
int
,
group_name
:
str
)
->
torch
.
Tensor
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
...
...
@@ -156,6 +187,14 @@ if supports_custom_op():
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"all_reduce_rms_quant"
,
op_func
=
all_reduce_rms_quant
,
mutates_args
=
[
"input_"
,
"pa_residual"
],
fake_impl
=
all_reduce_rms_quant_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"reduce_scatter"
,
op_func
=
reduce_scatter
,
...
...
@@ -358,9 +397,44 @@ class GroupCoordinator:
else
:
return
self
.
_all_reduce_out_place
(
input_
)
def
all_reduce_crq_m32
(
self
,
input_
:
torch
.
Tensor
,
pa_rms_weight
:
torch
.
Tensor
,
pa_residual
:
torch
.
Tensor
,
pa_rms_eps
:
float
,
pa_quant_dtype
:
torch
.
dtype
,
update_input
:
Optional
[
bool
]
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
world_size
>
1
assert
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
pa_rms_weight
is
not
None
and
pa_residual
is
not
None
return
torch
.
ops
.
vllm
.
all_reduce_rms_quant
(
input_
,
group_name
=
self
.
unique_name
,
pa_rms_weight
=
pa_rms_weight
,
pa_residual
=
pa_residual
,
pa_rms_eps
=
pa_rms_eps
,
pa_quant_dtype
=
pa_quant_dtype
,
update_input
=
update_input
)
def
_all_reduce_out_place
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
device_communicator
.
all_reduce
(
input_
)
def
_all_reduce_out_place_m32
(
self
,
input_
:
torch
.
Tensor
,
pa_rms_weight
:
torch
.
Tensor
,
pa_residual
:
torch
.
Tensor
,
pa_rms_eps
:
float
,
pa_quant_dtype
:
torch
.
dtype
,
update_input
:
Optional
[
bool
]
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
pa_rms_weight
is
not
None
\
and
pa_residual
is
not
None
input_
,
pa_residual
,
xq
,
xs
=
self
.
device_communicator
.
all_reduce_rms_quant_m32
(
input_
,
pa_rms_weight
=
pa_rms_weight
,
pa_residual
=
pa_residual
,
pa_rms_eps
=
pa_rms_eps
,
pa_quant_dtype
=
pa_quant_dtype
,
update_input
=
update_input
)
return
input_
,
pa_residual
,
xq
,
xs
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
# Bypass the function if we are using only 1 GPU.
...
...
vllm/engine/arg_utils.py
View file @
3e6729e0
...
...
@@ -284,7 +284,13 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cached version.
"""
return
copy
.
deepcopy
(
_compute_kwargs
(
cls
))
class
EnvironmentConfigError
(
Exception
):
pass
def
check_incompatible_config
(
env1
:
bool
,
env2
:
bool
):
if
env1
is
True
and
env2
is
True
:
_s
=
"USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and USE_FUSED_RMS_QUANT must not be enabled simultaneously!
\n\n
"
raise
EnvironmentConfigError
(
_s
)
@
dataclass
class
EngineArgs
:
...
...
@@ -1230,7 +1236,8 @@ class EngineArgs:
num_lookahead_slots
=
num_lookahead_slots
\
if
speculative_config
is
None
\
else
speculative_config
.
num_lookahead_slots
check_incompatible_config
(
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
,
envs
.
USE_FUSED_RMS_QUANT
)
scheduler_config
=
SchedulerConfig
(
runner_type
=
model_config
.
runner_type
,
max_num_batched_tokens
=
self
.
max_num_batched_tokens
,
...
...
vllm/envs.py
View file @
3e6729e0
...
...
@@ -180,6 +180,7 @@ if TYPE_CHECKING:
VLLM_USE_PD_SPLIT
:
bool
=
False
VLLM_USE_PP_SYNC
:
bool
=
False
VLLM_USE_LIGHTOP_FILL_MOE_ALIGN
:
bool
=
False
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
:
bool
=
False
def
get_default_cache_root
():
return
os
.
getenv
(
...
...
@@ -1161,11 +1162,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will sync to avoid pp vmfault
"VLLM_USE_PP_SYNC"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_PP_SYNC"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use lightop to fuse fill and moe align
"VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vllm will use custom-allreduce rmsquant fused op
"USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT"
:
lambda
:
(
os
.
getenv
(
'USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT'
,
'0'
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/layers/linear.py
View file @
3e6729e0
...
...
@@ -3,7 +3,7 @@
import
itertools
from
abc
import
abstractmethod
from
typing
import
Any
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Literal
,
Optional
,
Union
,
Tuple
import
vllm.envs
as
envs
import
torch
import
torch.nn
as
nn
...
...
@@ -14,7 +14,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce_crp_m32
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
...
...
@@ -677,7 +678,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
,
input_
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
update_hd
:
Optional
[
bool
]
=
True
,
xqxs
:
Optional
[
tuple
]
=
None
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
input_quant_args
=
None
...
...
@@ -706,7 +708,22 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if
not
self
.
return_bias
:
return
output
return
output
,
new_residual
,
output_bias
else
:
# not USE_FUSED_RMS_QUANT
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
=
xqxs
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
else
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
...
...
@@ -1495,46 +1512,94 @@ class RowParallelLinear(LinearBase):
def
forward
(
self
,
input_
,
use_fused_silu_mul_quant
:
Optional
[
bool
]
=
False
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
use_fused_silu_mul_quant
:
Optional
[
bool
]
=
False
,
pa_rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_residual
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_rms_eps
:
Optional
[
float
]
=
1e-6
,
pa_quant_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
int8
,
update_input
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]],
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
Parameter
]]
]:
if
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
pa_rms_weight
is
not
None
and
pa_residual
is
not
None
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
# Matrix multiply.
assert
self
.
quant_method
is
not
None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
if
use_fused_silu_mul_quant
:
xq
,
xs
=
lm_fuse_silu_mul_quant
(
input_parallel
)
silu_quant_args
=
[
xq
,
xs
]
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
,
silu_quant_args
=
silu_quant_args
)
else
:
# Matrix multiply.
assert
self
.
quant_method
is
not
None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
output
=
self
.
tbo_all_reduce
(
output_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
output
=
self
.
tbo_all_reduce
(
output_parallel
)
packages_
=
tensor_model_parallel_all_reduce_crp_m32
(
output_parallel
,
pa_rms_weight
=
pa_rms_weight
,
pa_residual
=
pa_residual
,
pa_rms_eps
=
pa_rms_eps
,
pa_quant_dtype
=
pa_quant_dtype
,
update_input
=
update_input
)
hs
,
resi
,
xq
,
xs
=
packages_
output
=
hs
else
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
resi
,
xq
,
xs
,
output_bias
else
:
output
=
output_parallel
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
# Matrix multiply.
assert
self
.
quant_method
is
not
None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
if
use_fused_silu_mul_quant
:
xq
,
xs
=
lm_fuse_silu_mul_quant
(
input_parallel
)
silu_quant_args
=
[
xq
,
xs
]
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
,
silu_quant_args
=
silu_quant_args
)
else
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
output
=
self
.
tbo_all_reduce
(
output_parallel
)
else
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
s
=
f
"input_features=
{
self
.
input_size_per_partition
}
"
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8.py
View file @
3e6729e0
...
...
@@ -162,7 +162,11 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
assert
len
(
input_quant_args
)
==
2
x_q
,
x_scale
=
input_quant_args
elif
envs
.
USE_FUSED_SILU_MUL_QUANT
and
silu_quant_args
is
not
None
:
assert
len
(
silu_quant_args
)
==
2
x_q
,
x_scale
=
silu_quant_args
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
input_quant_args
is
not
None
:
assert
len
(
input_quant_args
)
==
2
x_q
,
x_scale
=
input_quant_args
else
:
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
...
...
@@ -178,9 +182,9 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
if
m
<=
16
:
m_
=
m
elif
m
<=
64
:
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
m_
=
(
(
m
+
3
)
//
4
)
*
4
#取值到最近的4的倍数
elif
m
<=
160
:
m_
=
(
m
+
7
)
&
-
8
m_
=
(
(
m
+
7
)
//
8
)
*
8
elif
m
<
200
:
#256
m_
=
160
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
3e6729e0
...
...
@@ -29,7 +29,7 @@ import vllm.envs as envs
import
typing
from
collections.abc
import
Callable
,
Iterable
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
,
Tuple
import
torch
from
torch
import
nn
...
...
@@ -96,8 +96,8 @@ class DeepseekV2MLP(nn.Module):
def
forward
(
self
,
x
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
False
):
update_hd
:
Optional
[
bool
]
=
False
,
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
):
if
envs
.
USE_FUSED_RMS_QUANT
:
gate_up
,
new_resi
,
_
=
self
.
gate_up_proj
(
x
,
rms_weight
,
residual
,
update_hd
=
update_hd
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
...
...
@@ -107,6 +107,11 @@ class DeepseekV2MLP(nn.Module):
x
,
_
=
self
.
down_proj
(
x
)
return
x
,
new_resi
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
xqxs
=
xqxs
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
else
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
...
...
@@ -200,57 +205,99 @@ class DeepseekV2MoE(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
n_shared_experts
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
,
xqxs
=
xqxs
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
shared_output
=
shared_output
)
else
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
router_logits
=
router_logits
,
shared_output
=
shared_output
)
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
if
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
final_hidden_states
=
self
.
tbo_all_reduce
(
final_hidden_states
)
else
:
final_hidden_states
=
(
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
final_hidden_states
))
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
else
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
n_shared_experts
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
final_hidden_states
=
self
.
tbo_all_reduce
(
final_hidden_states
)
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
shared_output
=
shared_output
)
else
:
final_hidden_states
=
(
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
final_hidden_states
))
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
if
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
final_hidden_states
=
self
.
tbo_all_reduce
(
final_hidden_states
)
else
:
final_hidden_states
=
(
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
final_hidden_states
))
if
envs
.
USE_FUSED_RMS_QUANT
:
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
),
new_resi
else
:
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
if
envs
.
USE_FUSED_RMS_QUANT
:
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
),
new_resi
else
:
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
...
...
@@ -556,8 +603,16 @@ class DeepseekV2MLAAttention(nn.Module):
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_residual
:
Optional
[
torch
.
Tensor
]
=
None
,
pa_rms_eps
:
Optional
[
float
]
=
1e-6
,
pa_quant_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
int8
,
update_input
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
],
]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
q_c
,
new_residual
,
_
,
input_quant_args
=
self
.
q_a_proj
(
hidden_states
,
rms_weight
=
rms_weight
,
residual
=
residual
,
update_hd
=
False
)
...
...
@@ -587,6 +642,40 @@ class DeepseekV2MLAAttention(nn.Module):
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
return
self
.
o_proj
(
attn_out
)[
0
],
new_residual
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
pa_rms_weight
is
not
None
and
pa_residual
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q_c
=
self
.
q_a_layernorm
(
q_c
)
q
=
self
.
q_b_proj
(
q_c
)[
0
]
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
]
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
k_pe
=
k_pe
.
unsqueeze
(
1
)
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
self
.
rotary_emb
(
positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
k_pe
,
output_shape
=
(
hidden_states
.
shape
[
0
],
self
.
num_local_heads
*
self
.
v_head_dim
))
packages_
=
self
.
o_proj
(
attn_out
,
pa_rms_weight
=
pa_rms_weight
,
pa_residual
=
pa_residual
,
pa_rms_eps
=
pa_rms_eps
,
pa_quant_dtype
=
pa_quant_dtype
,
update_input
=
update_input
)[:
4
]
assert
len
(
packages_
)
==
4
hs
,
resi
,
xq
,
xs
=
packages_
assert
xq
is
not
None
and
xs
is
not
None
return
hs
,
resi
,
xq
,
xs
else
:
if
self
.
q_lora_rank
is
not
None
:
q_c
=
self
.
q_a_proj
(
hidden_states
)[
0
]
...
...
@@ -682,97 +771,162 @@ class DeepseekV2DecoderLayer(nn.Module):
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
use_fused_rms_quant
=
envs
.
USE_FUSED_RMS_QUANT
self
.
use_fused_custom_all_reduce
=
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
def
forward
(
def
forward
_fused_rmsquant
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
envs
.
USE_FUSED_RMS_QUANT
:
# Fix residual FP16 overflow
residual_fix_overflow
=
False
residual
:
Optional
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Fix residual FP16 overflow
residual_fix_overflow
=
False
assert
self
.
input_layernorm
.
has_weight
is
True
if
residual
is
None
:
residual
=
hidden_states
hidden_states
,
_
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
rms_weight
=
self
.
input_layernorm
.
weight
.
data
,
residual
=
None
)
residual_fix_overflow
=
True
else
:
hidden_states
,
new_residual
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
rms_weight
=
self
.
input_layernorm
.
weight
.
data
,
residual
=
residual
)
residual
=
new_residual
assert
self
.
input_layernorm
.
has_weight
is
True
if
residual
is
None
:
residual
=
hidden_states
hidden_states
,
_
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
rms_weight
=
self
.
input_layernorm
.
weight
.
data
,
residual
=
None
)
residual_fix_overflow
=
True
else
:
hidden_states
,
new_residual
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
rms_weight
=
self
.
input_layernorm
.
weight
.
data
,
residual
=
residual
)
residual
=
new_residual
if
hidden_states
.
dtype
==
torch
.
float16
:
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
hidden_states
,
new_resi
=
self
.
mlp
(
hidden_states
,
self
.
post_attention_layernorm
.
weight
.
data
,
residual
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
new_resi
def
forward_fused_CRQ
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
residual_fix_overflow
=
False
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
residual_fix_overflow
=
True
else
:
hidden_states
,
resi_new
=
self
.
input_layernorm
(
hidden_states
,
residual
)
residual
=
resi_new
new_hs
,
new_resi
,
xq
,
xs
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
pa_rms_weight
=
self
.
post_attention_layernorm
.
weight
.
data
,
pa_residual
=
residual
,
pa_rms_eps
=
self
.
post_attention_layernorm
.
variance_epsilon
,
pa_quant_dtype
=
torch
.
int8
,
update_input
=
True
)
assert
xq
is
not
None
and
xs
is
not
None
if
new_hs
.
dtype
==
torch
.
float16
:
# overflow处理逻辑
new_hs
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
new_resi
*=
1.
/
self
.
routed_scaling_factor
hidden_states
=
self
.
mlp
(
new_hs
,
xqxs
=
(
xq
,
xs
))
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
hidden_states
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
new_resi
def
forward_default
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow
=
False
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
residual_fix_overflow
=
True
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
if
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
residual
if
hidden_states
.
dtype
==
torch
.
float16
:
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
hidden_states
,
new_resi
=
self
.
mlp
(
hidden_states
,
self
.
post_attention_layernorm
.
weight
.
data
,
residual
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
new_resi
def
choose_forward
(
self
):
if
self
.
use_fused_rms_quant
:
return
self
.
forward_fused_rmsquant
elif
self
.
use_fused_custom_all_reduce
:
return
self
.
forward_fused_CRQ
else
:
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow
=
False
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
residual_fix_overflow
=
True
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
return
self
.
forward_default
if
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
residual
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
forward_func
=
self
.
choose_forward
()
return
forward_func
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
)
@
support_torch_compile
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment