Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
44ec8bed
Commit
44ec8bed
authored
Feb 03, 2026
by
lishen
Browse files
支持更复杂的量化,包括fp8/int8/ue8m0,且支持per-group/per-channel
parent
81e56124
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
283 additions
and
160 deletions
+283
-160
csrc/config.hpp
csrc/config.hpp
+1
-1
csrc/deep_ep.cu
csrc/deep_ep.cu
+28
-14
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-1
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-1
csrc/kernels/configs.cuh
csrc/kernels/configs.cuh
+1
-1
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+159
-91
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+30
-15
deep_ep/buffer.py
deep_ep/buffer.py
+38
-17
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+16
-9
tests/test_low_latency_new_int8.py
tests/test_low_latency_new_int8.py
+8
-10
No files found.
csrc/config.hpp
View file @
44ec8bed
...
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
)
{
const
int
num_scales
=
hidden
/
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
;
const
int
num_scales
=
hidden
/
QUANTIZATION_
GROUPSIZE
;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
...
...
csrc/deep_ep.cu
View file @
44ec8bed
...
...
@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
quant_group_size
,
bool
fp8_round_scale
,
bool
async
,
bool
return_recv_hook
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
...
...
@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait
(
launch_stream
,
compute_stream
);
// Allocate packed tensors
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
use_int8
?
torch
::
kInt8
:
use_fp8
?
torch
::
kFloat8_e4m3fnuz
:
torch
::
kBFloat16
));
auto
packed_recv_x_dtype
=
torch
::
kBFloat16
;
switch
(
quant_type
)
{
case
1
:
packed_recv_x_dtype
=
torch
::
kInt8
;
break
;
case
2
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fnuz
;
break
;
case
3
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fnuz
;
break
;
case
4
:
packed_recv_x_dtype
=
torch
::
kFloat8_e5m2fnuz
;
break
;
}
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
packed_recv_x_dtype
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
...
...
@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
void
*
packed_recv_x_scales_ptr
=
nullptr
;
if
(
use_fp8
)
{
if
(
quant_type
>
0
)
{
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
// TODO: support unaligned cases
EP_HOST_ASSERT
(
hidden
%
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
)
==
0
);
EP_HOST_ASSERT
(
!
(
use_ue8m0
&&
use_int8
));
EP_HOST_ASSERT
(
hidden
%
(
QUANTIZATION_GROUPSIZE
*
4
)
==
0
);
// 计算scale_col的大小
int
scales_col_size
=
1
;
// 默认为per-channel
if
(
quant_group_size
>
0
)
{
if
(
quant_type
==
3
)
{
// FP8_UE8M0比较特殊
scales_col_size
=
hidden
/
(
QUANTIZATION_GROUPSIZE
*
4
);
}
else
{
scales_col_size
=
hidden
/
QUANTIZATION_GROUPSIZE
;
}
}
if
(
use_ue8m0
)
{
EP_HOST_ASSERT
(
round_scale
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
),
num_ranks
*
num_max_dispatch_tokens_per_rank
},
// 设置packed_recv_x_scales
if
(
quant_type
==
3
)
{
// FP8_UE8M0比较特殊,需要单独处理
EP_HOST_ASSERT
(
fp8_round_scale
&&
quant_group_size
==
128
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
scales_col_size
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
}
else
if
(
use_int8
)
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
1
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
else
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
scales_col_size
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
...
...
@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_int8
,
quant_type
,
quant_group_size
,
fp8_round_scale
,
workspace
,
num_device_sms
,
launch_stream
,
phases
);
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
...
csrc/deep_ep.hpp
View file @
44ec8bed
...
...
@@ -177,7 +177,7 @@ public:
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
quant_group_size
,
bool
fp8_round_scale
,
bool
async
,
bool
return_recv_hook
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
...
...
csrc/kernels/api.cuh
View file @
44ec8bed
...
...
@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
group_size
,
bool
fp8_round_scale
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
...
...
csrc/kernels/configs.cuh
View file @
44ec8bed
...
...
@@ -23,7 +23,7 @@
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
128
#define QUANTIZATION_
GROUPSIZE
128
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
...
...
csrc/kernels/internode_ll.cu
View file @
44ec8bed
This diff is collapsed.
Click to expand it.
csrc/kernels/utils.cuh
View file @
44ec8bed
...
...
@@ -341,10 +341,14 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
}
constexpr
float
kFP8Margin
=
1e-4
;
// 设置不同的量化方式的最大值与相反数
constexpr
float
kFP8Margin
=
0.0
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
constexpr
float
kInt8Amax
=
127.0
f
;
constexpr
float
kFinfoAmaxE5M2
=
57344.0
f
;
constexpr
float
kFinfoAmaxInvE5M2
=
1.0
f
/
kFinfoAmaxE5M2
;
constexpr
float
kFinfoAmaxInt8
=
127.0
f
;
constexpr
float
kFinfoAmaxInvInt8
=
1.0
f
/
127.0
f
;
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
...
...
@@ -359,22 +363,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
)
{
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
scale
=
kFinfoAmaxE4M3
/
amax
;
template
<
int
kQuantType
>
__forceinline__
__device__
void
calculate_quant8bit_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
=
0
)
{
amax
=
fmaxf
(
amax
,
1e-6
f
);
if
constexpr
(
kQuantType
==
1
)
{
// 使用 INT8 对称量化
scale_inv
=
kFinfoAmaxInvInt8
*
amax
;
scale
=
kFinfoAmaxInt8
/
amax
;
}
else
if
constexpr
(
kQuantType
==
2
||
kQuantType
==
3
)
{
// 使用 FP8_E4M3 或 FP8_UE8M0 非对称量化
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
scale
=
kFinfoAmaxE4M3
/
amax
;
}
}
else
if
constexpr
(
kQuantType
==
4
)
{
// 使用 FP8_E5M2 对称量化
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE5M2
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE5M2
;
scale
=
kFinfoAmaxE5M2
/
amax
;
}
}
}
__forceinline__
__device__
void
calculate_int8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
)
{
scale
=
kInt8Amax
/
amax
;
scale_inv
=
amax
/
kInt8Amax
;
}
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
kIsUE8M0
)
{
...
...
deep_ep/buffer.py
View file @
44ec8bed
...
...
@@ -841,7 +841,7 @@ class Buffer:
# noinspection PyTypeChecker
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
use_fp8
:
bool
=
True
,
round_scale
:
bool
=
False
,
use_ue8m0
:
bool
=
False
,
use_int8
:
bool
=
False
,
quant_type
:
int
=
1
,
quant_group_size
:
int
=
0
,
fp8_round_scale
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
...
...
@@ -858,10 +858,20 @@ class Buffer:
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
use_int8: whether to enable INT8 casting.
量化配置
quant_type: int 量化类型枚举
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
quant_group_size: int 量化分组大小
0 -> 逐token量化 (per-channel)
128-> 每 128 元素一组 (per-group) 量化
fp8_round_scale: bool 是否将 FP8 缩放因子取整为 2 的幂
true -> 缩放因子 = 2^k,硬件零开销
false -> 缩放因子 = 任意浮点,精度更高
异步配置
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
...
...
@@ -869,15 +879,25 @@ class Buffer:
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
- packed_recv_x:
存储接收到的 Token 数据,形状为
`[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`。
数据类型取决于 quant_type:
quant_type == 1 -> torch.int8
quant_type == 2 -> torch.float8_e4m3fnuz
quant_type == 3 -> torch.float8_e4m3fnuz (UE8M0 使用 E4M3 格式存储)
quant_type == 4 -> torch.float8_e5m2fnuz
其他 (非量化) -> torch.bfloat16
- packed_recv_x_scales (可选):
仅在 quant_type > 0 时存在,存储量化的 Scale 值。
形状为 `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, scales_col_size]`。
- 当 quant_type == 3 (UE8M0) 时:
scales_col_size = hidden // 512
数据类型为 torch.int (内部打包存储 4-bit scale)。
*注意:此模式强制要求 fp8_round_scale=True 且 group_size=128。
- 当 quant_type == 1, 2, 4 时:
scales_col_size = hidden // 128 (若使用 group_size) 或 1 (per-channel)。
数据类型为 torch.float32。
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
...
...
@@ -889,14 +909,15 @@ class Buffer:
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_int8
,
quant_type
,
quant_group_size
,
fp8_round_scale
,
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
)
return
(
packed_recv_x
,
packed_recv_x_scales
)
if
use_fp8
else
packed_recv_x
,
packed_recv_count
,
handle
,
\
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
recv_x
=
(
packed_recv_x
,
packed_recv_x_scales
)
if
(
quant_type
>
0
)
else
packed_recv_x
return
recv_x
,
packed_recv_count
,
handle
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
# noinspection PyTypeChecker
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
...
...
tests/test_low_latency_new.py
View file @
44ec8bed
...
...
@@ -58,7 +58,8 @@ def test_main(num_tokens: int,
x_list
=
[
x
]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
# x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
# x_list = [x_rand]
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
...
...
@@ -80,14 +81,19 @@ def test_main(num_tokens: int,
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
True
):
for
dispatch_use_fp8
in
(
False
,
True
):
for
round_scale
in
(
False
,
True
)
if
dispatch_use_fp8
else
(
False
,):
for
use_ue8m0
in
(
False
,
True
)
if
round_scale
else
(
False
,):
for
quant_type
in
(
0
,
2
,
3
,
):
# 0: 不量化, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True)
dispatch_use_fp8
=
quant_type
>
0
for
fp8_round_scale
in
(
False
,
True
)
if
dispatch_use_fp8
else
(
False
,
):
for
quant_group_size
in
(
128
,
):
# 跳过不支持的情况
if
quant_type
==
3
and
fp8_round_scale
==
False
:
continue
num_times
+=
1
for
_
in
range
((
num_times
%
2
)
+
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_
fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
quant_type
=
quant_type
,
fp8
_
round_scale
=
fp8_
round_scale
,
quant_group_size
=
quant_group_size
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
...
...
@@ -115,13 +121,13 @@ def test_main(num_tokens: int,
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
if
round_scale
:
if
fp8_
round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
for
j
in
range
(
num_ranks
):
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
if
not
round_scale
:
if
not
fp8_
round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
if
dispatch_use_fp8
:
...
...
@@ -147,7 +153,7 @@ def test_main(num_tokens: int,
if
do_check
:
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
# if not round_scale:
# if not
fp8_
round_scale:
assert
diff
<
(
9e-4
if
dispatch_use_fp8
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_fp8=
{
dispatch_use_fp8
}
, zero_copy=
{
zero_copy
}
'
hash_value
^=
hash_tensor
(
combined_x
)
...
...
@@ -162,7 +168,8 @@ def test_main(num_tokens: int,
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
True
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
quant_type
=
2
,
quant_group_size
=
128
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
...
...
tests/test_low_latency_new_int8.py
View file @
44ec8bed
...
...
@@ -54,16 +54,16 @@ def test_main(num_tokens: int,
do_check
=
True
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
):
for
dispatch_use_fp8
in
(
True
,
):
for
round_scale
in
(
False
,
):
for
use_ue8m0
in
(
False
,
):
for
return_recv_hook
in
(
False
,
True
):
for
quant_type
in
(
1
,
):
for
fp8_round_scale
in
(
False
,
):
for
quant_group_size
in
(
0
,
):
dispatch_use_fp8
=
quant_type
>
0
num_times
+=
1
use_int8
=
True
for
_
in
range
(
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
use_int8
=
use_int8
,
quant_type
=
quant_type
,
quant_group_size
=
quant_group_size
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
...
...
@@ -97,9 +97,7 @@ def test_main(num_tokens: int,
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
elif
use_int8
:
if
quant_type
==
1
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.01
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
...
...
@@ -131,7 +129,7 @@ def test_main(num_tokens: int,
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
True
,
round_scale
=
False
,
use_ue8m0
=
False
,
use_int8
=
True
,
quant_type
=
1
,
quant_group_size
=
0
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment