Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
44ec8bed
Commit
44ec8bed
authored
Feb 03, 2026
by
lishen
Browse files
支持更复杂的量化,包括fp8/int8/ue8m0,且支持per-group/per-channel
parent
81e56124
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
283 additions
and
160 deletions
+283
-160
csrc/config.hpp
csrc/config.hpp
+1
-1
csrc/deep_ep.cu
csrc/deep_ep.cu
+28
-14
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-1
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-1
csrc/kernels/configs.cuh
csrc/kernels/configs.cuh
+1
-1
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+159
-91
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+30
-15
deep_ep/buffer.py
deep_ep/buffer.py
+38
-17
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+16
-9
tests/test_low_latency_new_int8.py
tests/test_low_latency_new_int8.py
+8
-10
No files found.
csrc/config.hpp
View file @
44ec8bed
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
)
{
int
num_ranks
,
int
num_experts
)
{
const
int
num_scales
=
hidden
/
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
;
const
int
num_scales
=
hidden
/
QUANTIZATION_
GROUPSIZE
;
// Dispatch and combine layout:
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even send buffer
...
...
csrc/deep_ep.cu
View file @
44ec8bed
...
@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
...
@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
quant_group_size
,
bool
fp8_round_scale
,
bool
async
,
bool
return_recv_hook
)
{
bool
async
,
bool
return_recv_hook
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
...
@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait
(
launch_stream
,
compute_stream
);
stream_wait
(
launch_stream
,
compute_stream
);
// Allocate packed tensors
// Allocate packed tensors
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
auto
packed_recv_x_dtype
=
torch
::
kBFloat16
;
x
.
options
().
dtype
(
use_int8
?
torch
::
kInt8
:
use_fp8
?
torch
::
kFloat8_e4m3fnuz
:
torch
::
kBFloat16
));
switch
(
quant_type
)
{
case
1
:
packed_recv_x_dtype
=
torch
::
kInt8
;
break
;
case
2
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fnuz
;
break
;
case
3
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fnuz
;
break
;
case
4
:
packed_recv_x_dtype
=
torch
::
kFloat8_e5m2fnuz
;
break
;
}
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
packed_recv_x_dtype
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
...
@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales
// Allocate column-majored scales
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
void
*
packed_recv_x_scales_ptr
=
nullptr
;
void
*
packed_recv_x_scales_ptr
=
nullptr
;
if
(
use_fp8
)
{
if
(
quant_type
>
0
)
{
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
// TODO: support unaligned cases
// TODO: support unaligned cases
EP_HOST_ASSERT
(
hidden
%
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
)
==
0
);
EP_HOST_ASSERT
(
hidden
%
(
QUANTIZATION_GROUPSIZE
*
4
)
==
0
);
EP_HOST_ASSERT
(
!
(
use_ue8m0
&&
use_int8
));
// 计算scale_col的大小
int
scales_col_size
=
1
;
// 默认为per-channel
if
(
quant_group_size
>
0
)
{
if
(
quant_type
==
3
)
{
// FP8_UE8M0比较特殊
scales_col_size
=
hidden
/
(
QUANTIZATION_GROUPSIZE
*
4
);
}
else
{
scales_col_size
=
hidden
/
QUANTIZATION_GROUPSIZE
;
}
}
if
(
use_ue8m0
)
{
// 设置packed_recv_x_scales
EP_HOST_ASSERT
(
round_scale
);
if
(
quant_type
==
3
)
{
// FP8_UE8M0比较特殊,需要单独处理
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
),
num_ranks
*
num_max_dispatch_tokens_per_rank
},
EP_HOST_ASSERT
(
fp8_round_scale
&&
quant_group_size
==
128
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
scales_col_size
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
}
else
if
(
use_int8
)
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
1
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
else
{
}
else
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
scales_col_size
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
}
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
...
@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta
.
first
,
next_clean_meta
.
second
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_int8
,
quant_type
,
quant_group_size
,
fp8_round_scale
,
workspace
,
num_device_sms
,
launch_stream
,
phases
);
workspace
,
num_device_sms
,
launch_stream
,
phases
);
};
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
...
csrc/deep_ep.hpp
View file @
44ec8bed
...
@@ -177,7 +177,7 @@ public:
...
@@ -177,7 +177,7 @@ public:
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
quant_group_size
,
bool
fp8_round_scale
,
bool
async
,
bool
return_recv_hook
);
bool
async
,
bool
return_recv_hook
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
...
...
csrc/kernels/api.cuh
View file @
44ec8bed
...
@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
group_size
,
bool
fp8_round_scale
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
void
combine
(
void
*
combined_x
,
...
...
csrc/kernels/configs.cuh
View file @
44ec8bed
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
#define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
128
#define QUANTIZATION_
GROUPSIZE
128
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
...
...
csrc/kernels/internode_ll.cu
View file @
44ec8bed
This diff is collapsed.
Click to expand it.
csrc/kernels/utils.cuh
View file @
44ec8bed
...
@@ -341,10 +341,14 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
...
@@ -341,10 +341,14 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
}
}
constexpr
float
kFP8Margin
=
1e-4
;
// 设置不同的量化方式的最大值与相反数
constexpr
float
kFP8Margin
=
0.0
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
constexpr
float
kInt8Amax
=
127.0
f
;
constexpr
float
kFinfoAmaxE5M2
=
57344.0
f
;
constexpr
float
kFinfoAmaxInvE5M2
=
1.0
f
/
kFinfoAmaxE5M2
;
constexpr
float
kFinfoAmaxInt8
=
127.0
f
;
constexpr
float
kFinfoAmaxInvInt8
=
1.0
f
/
127.0
f
;
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
// We can ensure `-126 <= x and x <= 127`
...
@@ -359,22 +363,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
...
@@ -359,22 +363,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return
exp_x
-
127
+
(
man_bits
!=
0
);
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
}
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
)
{
template
<
int
kQuantType
>
if
(
round_scale
)
{
__forceinline__
__device__
void
calculate_quant8bit_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
=
0
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
amax
=
fmaxf
(
amax
,
1e-6
f
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
if
constexpr
(
kQuantType
==
1
)
{
// 使用 INT8 对称量化
scale_inv
=
fast_pow2
(
exp_scale_inv
);
scale_inv
=
kFinfoAmaxInvInt8
*
amax
;
}
else
{
scale
=
kFinfoAmaxInt8
/
amax
;
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
}
else
if
constexpr
(
kQuantType
==
2
||
kQuantType
==
3
)
{
// 使用 FP8_E4M3 或 FP8_UE8M0 非对称量化
scale
=
kFinfoAmaxE4M3
/
amax
;
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
scale
=
kFinfoAmaxE4M3
/
amax
;
}
}
else
if
constexpr
(
kQuantType
==
4
)
{
// 使用 FP8_E5M2 对称量化
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE5M2
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE5M2
;
scale
=
kFinfoAmaxE5M2
/
amax
;
}
}
}
}
}
__forceinline__
__device__
void
calculate_int8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
)
{
scale
=
kInt8Amax
/
amax
;
scale_inv
=
amax
/
kInt8Amax
;
}
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
kIsUE8M0
)
{
if
constexpr
(
kIsUE8M0
)
{
...
...
deep_ep/buffer.py
View file @
44ec8bed
...
@@ -841,7 +841,7 @@ class Buffer:
...
@@ -841,7 +841,7 @@ class Buffer:
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
use_fp8
:
bool
=
True
,
round_scale
:
bool
=
False
,
use_ue8m0
:
bool
=
False
,
use_int8
:
bool
=
False
,
quant_type
:
int
=
1
,
quant_group_size
:
int
=
0
,
fp8_round_scale
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
"""
...
@@ -858,10 +858,20 @@ class Buffer:
...
@@ -858,10 +858,20 @@ class Buffer:
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
量化配置
round_scale: whether round the scaling factors into power of 2.
quant_type: int 量化类型枚举
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
0 -> None 不量化,保持原始精度
use_int8: whether to enable INT8 casting.
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
quant_group_size: int 量化分组大小
0 -> 逐token量化 (per-channel)
128-> 每 128 元素一组 (per-group) 量化
fp8_round_scale: bool 是否将 FP8 缩放因子取整为 2 的幂
true -> 缩放因子 = 2^k,硬件零开销
false -> 缩放因子 = 任意浮点,精度更高
异步配置
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
...
@@ -869,15 +879,25 @@ class Buffer:
...
@@ -869,15 +879,25 @@ class Buffer:
Returns:
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
- packed_recv_x:
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
存储接收到的 Token 数据,形状为
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`。
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
数据类型取决于 quant_type:
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
quant_type == 1 -> torch.int8
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
quant_type == 2 -> torch.float8_e4m3fnuz
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
quant_type == 3 -> torch.float8_e4m3fnuz (UE8M0 使用 E4M3 格式存储)
With `use_fp8=False`, the result would be a tensor shaped as
quant_type == 4 -> torch.float8_e5m2fnuz
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
其他 (非量化) -> torch.bfloat16
- packed_recv_x_scales (可选):
仅在 quant_type > 0 时存在,存储量化的 Scale 值。
形状为 `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, scales_col_size]`。
- 当 quant_type == 3 (UE8M0) 时:
scales_col_size = hidden // 512
数据类型为 torch.int (内部打包存储 4-bit scale)。
*注意:此模式强制要求 fp8_round_scale=True 且 group_size=128。
- 当 quant_type == 1, 2, 4 时:
scales_col_size = hidden // 128 (若使用 group_size) 或 1 (per-channel)。
数据类型为 torch.float32。
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
...
@@ -889,14 +909,15 @@ class Buffer:
...
@@ -889,14 +909,15 @@ class Buffer:
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_int8
,
quant_type
,
quant_group_size
,
fp8_round_scale
,
async_finish
,
return_recv_hook
)
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
)
packed_recv_src_info
,
packed_recv_layout_range
)
return
(
packed_recv_x
,
packed_recv_x_scales
)
if
use_fp8
else
packed_recv_x
,
packed_recv_count
,
handle
,
\
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
recv_x
=
(
packed_recv_x
,
packed_recv_x_scales
)
if
(
quant_type
>
0
)
else
packed_recv_x
return
recv_x
,
packed_recv_count
,
handle
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
...
...
tests/test_low_latency_new.py
View file @
44ec8bed
...
@@ -58,7 +58,8 @@ def test_main(num_tokens: int,
...
@@ -58,7 +58,8 @@ def test_main(num_tokens: int,
x_list
=
[
x
]
x_list
=
[
x
]
# # NOTES: the last one is for performance testing
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
# x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
# x_list = [x_rand]
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
...
@@ -80,14 +81,19 @@ def test_main(num_tokens: int,
...
@@ -80,14 +81,19 @@ def test_main(num_tokens: int,
hash_value
,
num_times
=
0
,
0
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
True
):
for
return_recv_hook
in
(
False
,
True
):
for
dispatch_use_fp8
in
(
False
,
True
):
for
quant_type
in
(
0
,
2
,
3
,
):
# 0: 不量化, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True)
for
round_scale
in
(
False
,
True
)
if
dispatch_use_fp8
else
(
False
,):
dispatch_use_fp8
=
quant_type
>
0
for
use_ue8m0
in
(
False
,
True
)
if
round_scale
else
(
False
,):
for
fp8_round_scale
in
(
False
,
True
)
if
dispatch_use_fp8
else
(
False
,
):
for
quant_group_size
in
(
128
,
):
# 跳过不支持的情况
if
quant_type
==
3
and
fp8_round_scale
==
False
:
continue
num_times
+=
1
num_times
+=
1
for
_
in
range
((
num_times
%
2
)
+
1
):
for
_
in
range
((
num_times
%
2
)
+
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_
fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
quant_type
=
quant_type
,
fp8
_
round_scale
=
fp8_
round_scale
,
quant_group_size
=
quant_group_size
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
...
@@ -115,13 +121,13 @@ def test_main(num_tokens: int,
...
@@ -115,13 +121,13 @@ def test_main(num_tokens: int,
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
if
round_scale
:
if
fp8_
round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
else
:
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
for
j
in
range
(
num_ranks
):
for
j
in
range
(
num_ranks
):
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
if
not
round_scale
:
if
not
fp8_
round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
if
dispatch_use_fp8
:
if
dispatch_use_fp8
:
...
@@ -147,7 +153,7 @@ def test_main(num_tokens: int,
...
@@ -147,7 +153,7 @@ def test_main(num_tokens: int,
if
do_check
:
if
do_check
:
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
# if not round_scale:
# if not
fp8_
round_scale:
assert
diff
<
(
9e-4
if
dispatch_use_fp8
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_fp8=
{
dispatch_use_fp8
}
, zero_copy=
{
zero_copy
}
'
assert
diff
<
(
9e-4
if
dispatch_use_fp8
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_fp8=
{
dispatch_use_fp8
}
, zero_copy=
{
zero_copy
}
'
hash_value
^=
hash_tensor
(
combined_x
)
hash_value
^=
hash_tensor
(
combined_x
)
...
@@ -162,7 +168,8 @@ def test_main(num_tokens: int,
...
@@ -162,7 +168,8 @@ def test_main(num_tokens: int,
def
test_func
(
return_recv_hook
:
bool
):
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
True
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
quant_type
=
2
,
quant_group_size
=
128
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_idx
,
...
...
tests/test_low_latency_new_int8.py
View file @
44ec8bed
...
@@ -54,16 +54,16 @@ def test_main(num_tokens: int,
...
@@ -54,16 +54,16 @@ def test_main(num_tokens: int,
do_check
=
True
do_check
=
True
hash_value
,
num_times
=
0
,
0
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
):
for
return_recv_hook
in
(
False
,
True
):
for
dispatch_use_fp8
in
(
True
,
):
for
quant_type
in
(
1
,
):
for
round_scale
in
(
False
,
):
for
fp8_round_scale
in
(
False
,
):
for
use_ue8m0
in
(
False
,
):
for
quant_group_size
in
(
0
,
):
dispatch_use_fp8
=
quant_type
>
0
num_times
+=
1
num_times
+=
1
use_int8
=
True
for
_
in
range
(
1
):
for
_
in
range
(
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
use_int8
=
use_int8
,
quant_type
=
quant_type
,
quant_group_size
=
quant_group_size
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
...
@@ -97,9 +97,7 @@ def test_main(num_tokens: int,
...
@@ -97,9 +97,7 @@ def test_main(num_tokens: int,
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
round_scale
:
if
quant_type
==
1
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
elif
use_int8
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.01
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.01
else
:
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
...
@@ -131,7 +129,7 @@ def test_main(num_tokens: int,
...
@@ -131,7 +129,7 @@ def test_main(num_tokens: int,
def
test_func
(
return_recv_hook
:
bool
):
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
True
,
round_scale
=
False
,
use_ue8m0
=
False
,
use_int8
=
True
,
quant_type
=
1
,
quant_group_size
=
0
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment