Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
44ec8bed
Commit
44ec8bed
authored
Feb 03, 2026
by
lishen
Browse files
支持更复杂的量化,包括fp8/int8/ue8m0,且支持per-group/per-channel
parent
81e56124
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
283 additions
and
160 deletions
+283
-160
csrc/config.hpp
csrc/config.hpp
+1
-1
csrc/deep_ep.cu
csrc/deep_ep.cu
+28
-14
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-1
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-1
csrc/kernels/configs.cuh
csrc/kernels/configs.cuh
+1
-1
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+159
-91
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+30
-15
deep_ep/buffer.py
deep_ep/buffer.py
+38
-17
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+16
-9
tests/test_low_latency_new_int8.py
tests/test_low_latency_new_int8.py
+8
-10
No files found.
csrc/config.hpp
View file @
44ec8bed
...
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
)
{
const
int
num_scales
=
hidden
/
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
;
const
int
num_scales
=
hidden
/
QUANTIZATION_
GROUPSIZE
;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
...
...
csrc/deep_ep.cu
View file @
44ec8bed
...
...
@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
quant_group_size
,
bool
fp8_round_scale
,
bool
async
,
bool
return_recv_hook
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
...
...
@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait
(
launch_stream
,
compute_stream
);
// Allocate packed tensors
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
use_int8
?
torch
::
kInt8
:
use_fp8
?
torch
::
kFloat8_e4m3fnuz
:
torch
::
kBFloat16
));
auto
packed_recv_x_dtype
=
torch
::
kBFloat16
;
switch
(
quant_type
)
{
case
1
:
packed_recv_x_dtype
=
torch
::
kInt8
;
break
;
case
2
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fnuz
;
break
;
case
3
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fnuz
;
break
;
case
4
:
packed_recv_x_dtype
=
torch
::
kFloat8_e5m2fnuz
;
break
;
}
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
packed_recv_x_dtype
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
...
...
@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
void
*
packed_recv_x_scales_ptr
=
nullptr
;
if
(
use_fp8
)
{
if
(
quant_type
>
0
)
{
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
// TODO: support unaligned cases
EP_HOST_ASSERT
(
hidden
%
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
)
==
0
);
EP_HOST_ASSERT
(
!
(
use_ue8m0
&&
use_int8
));
EP_HOST_ASSERT
(
hidden
%
(
QUANTIZATION_GROUPSIZE
*
4
)
==
0
);
// 计算scale_col的大小
int
scales_col_size
=
1
;
// 默认为per-channel
if
(
quant_group_size
>
0
)
{
if
(
quant_type
==
3
)
{
// FP8_UE8M0比较特殊
scales_col_size
=
hidden
/
(
QUANTIZATION_GROUPSIZE
*
4
);
}
else
{
scales_col_size
=
hidden
/
QUANTIZATION_GROUPSIZE
;
}
}
if
(
use_ue8m0
)
{
EP_HOST_ASSERT
(
round_scale
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
),
num_ranks
*
num_max_dispatch_tokens_per_rank
},
// 设置packed_recv_x_scales
if
(
quant_type
==
3
)
{
// FP8_UE8M0比较特殊,需要单独处理
EP_HOST_ASSERT
(
fp8_round_scale
&&
quant_group_size
==
128
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
scales_col_size
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
}
else
if
(
use_int8
)
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
1
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
else
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
scales_col_size
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
...
...
@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_int8
,
quant_type
,
quant_group_size
,
fp8_round_scale
,
workspace
,
num_device_sms
,
launch_stream
,
phases
);
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
...
csrc/deep_ep.hpp
View file @
44ec8bed
...
...
@@ -177,7 +177,7 @@ public:
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
quant_group_size
,
bool
fp8_round_scale
,
bool
async
,
bool
return_recv_hook
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
...
...
csrc/kernels/api.cuh
View file @
44ec8bed
...
...
@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
group_size
,
bool
fp8_round_scale
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
...
...
csrc/kernels/configs.cuh
View file @
44ec8bed
...
...
@@ -23,7 +23,7 @@
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
128
#define QUANTIZATION_
GROUPSIZE
128
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
...
...
csrc/kernels/internode_ll.cu
View file @
44ec8bed
...
...
@@ -116,20 +116,74 @@ internode_ll_long_atomic_add(long* dest, const long &value,
#endif // defined(FORCE_DUSHMEM_API)
}
template
<
bool
kUseFP8
,
bool
kUseUE8M0
,
bool
kUseInt8
,
int
kHidden
>
/**
* @brief 将 K 个浮点数(BF16/FP32)量化并打包成 INT2(64位)存储
*
* @tparam kQuantType 量化类型 (1: Int8, 2/3: FP8_E4M3/UE8M0, 4: FP8_E5M2)
* @tparam kNumElemsPerRead 每次读取的元素数量 (通常为 2, 4, 8)
* @tparam SrcT 源数据类型 (float 或 __hip_bfloat16)
* @tparam DstT 目标数据类型 (int2 或 int4)
* @param src_values 源数据数组 (长度 >= kNumElemsPerRead)
* @param scale 缩放因子 (将 FP32 值映射到量化范围)
* @param[out] dst_vec 输出的 64 位向量 (int2 或 int4)
*/
template
<
int
kQuantType
,
int
kNumElemsPerRead
,
typename
SrcT
,
typename
DstT
>
__forceinline__
__device__
void
pack_quantized_values
(
const
SrcT
*
src_values
,
float
scale
,
DstT
&
dst_vec
)
{
if
constexpr
(
kQuantType
==
1
)
{
// INT8 量化
auto
int8_ptr
=
reinterpret_cast
<
int8_t
*>
(
&
dst_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
// 如果源是 bfloat16,先提升为 float
float
fp32_value_scaled
=
static_cast
<
float
>
(
src_values
[
j
])
*
scale
;
// 使用 nearbyintf 进行四舍五入
int8_ptr
[
j
]
=
static_cast
<
int8_t
>
(
nearbyintf
(
fp32_value_scaled
));
}
}
else
{
// FP8 量化 (E4M3, UE8M0, E5M2)
// 假设 dst_vec 能容纳 kNumElemsPerRead/2 个 fp8x2 元素
auto
fp8x2_ptr
=
reinterpret_cast
<
__hip_fp8x2_storage_t
*>
(
&
dst_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
j
+=
2
)
{
// 处理两个元素
float2
fp32x2
=
{
static_cast
<
float
>
(
src_values
[
j
])
*
scale
,
static_cast
<
float
>
(
src_values
[
j
+
1
])
*
scale
};
if
constexpr
(
kQuantType
==
4
)
{
// FP8 E5M2
fp8x2_ptr
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E5M2_FNUZ
);
}
else
{
// FP8 E4M3 或 UE8M0
fp8x2_ptr
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E4M3_FNUZ
);
}
}
}
}
template
<
int
kHidden
,
int
kQuantType
=
0
,
int
kQuantGroupSize
=
0
,
int
kMaxNumWarps
=
16
>
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
bool
round_scale
,
int
phases
)
{
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
bool
fp8_round_scale
,
int
phases
)
{
// 定义量化类型的枚举
enum
class
QuantType
{
None
=
0
,
// 不进行量化
Int8
=
1
,
// 采用 Int8 量化
FP8_E4M3
=
2
,
// 采用 FP8 量化 __HIP_E4M3_FNUZ
FP8_UE8M0
=
3
,
// 采用 FP8 量化 DeepseekV3.1的 UE8M0
FP8_E5M2
=
4
// 采用 FP8 量化 __HIP_E5M2_FNUZ
};
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
...
...
@@ -141,20 +195,22 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
// May extract UE8M0 from the scales
constexpr
bool
kUseQuant8Bit
=
kQuantType
>
0
;
constexpr
bool
kUseUE8M0
=
kQuantType
==
3
;
// QuantType::FP8_UE8M0
using
scale_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint8_t
,
float
>
;
using
packed_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint32_t
,
float
>
;
EP_STATIC_ASSERT
(
sizeof
(
packed_t
)
%
sizeof
(
scale_t
)
==
0
,
"Invalid vector length"
);
// FP8 staffs
constexpr
int
kNumPerChannels
=
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
;
constexpr
int
kNumPerChannels
=
QUANTIZATION_
GROUPSIZE
;
constexpr
int
kNumScales
=
kHidden
/
kNumPerChannels
;
const
size_t
hidden_bytes
=
kHidden
*
(
kUse
FP8
?
sizeof
(
__hip_fp8_storage_t
)
:
sizeof
(
hip_bfloat16
));
const
size_t
hidden_bytes
=
kHidden
*
(
kUse
Quant8Bit
?
sizeof
(
__hip_fp8_storage_t
)
:
sizeof
(
hip_bfloat16
));
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
using
vec_t
=
typename
std
::
conditional
<
kUse
FP8
,
int2
,
int4
>::
type
;
constexpr
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUse
FP8
?
(
kHidden
+
kNumScales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
using
vec_t
=
typename
std
::
conditional
<
kUse
Quant8Bit
,
int2
,
int4
>::
type
;
constexpr
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUse
Quant8Bit
?
(
kHidden
+
kNumScales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
EP_STATIC_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
,
"Invalid message size"
);
constexpr
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
...
...
@@ -171,6 +227,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// 2. The last warp for reading `topk_idx` and count for per-expert information
if
(
warp_id
<
num_warps
)
{
constexpr
int
kNumElemsPerRead
=
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
);
constexpr
int
kNumThreadPerGroup
=
QUANTIZATION_GROUPSIZE
/
kNumElemsPerRead
;
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
const
auto
num_threads
=
(
num_warps
-
1
)
*
kWarpSize
;
...
...
@@ -186,10 +243,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
thread_id
==
0
?
(
*
rdma_x_src_idx
=
token_idx
)
:
0
;
__shared__
float
int8_amaxf
[
kNumScales
];
if
constexpr
(
kUseInt8
)
{
// 用于记录per-channel量化的amax
__shared__
float
channel_amaxf
[
kNumScales
];
if
constexpr
(
kUseQuant8Bit
&&
kQuantGroupSize
==
0
)
{
if
(
thread_id
<
kNumScales
)
{
int8
_amaxf
[
thread_id
]
=
kFP8Margin
;
channel
_amaxf
[
thread_id
]
=
kFP8Margin
;
}
__syncthreads
();
}
...
...
@@ -200,7 +258,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Read
auto
int4_value
=
__ldg
(
x_int4
+
i
);
if
constexpr
(
kUse
FP8
)
{
if
constexpr
(
kUse
Quant8Bit
)
{
// Calculate local amax
auto
bf16_values
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
int4_value
);
float
fp32_values
[
kNumElemsPerRead
];
...
...
@@ -212,25 +270,20 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
// Reduce amax and scale
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
/
kNumPerChannels
==
4
,
"Invalid vectorization"
);
amax
=
warp_reduce_max
<
16
>
(
amax
);
const
int
scale_offset
=
i
*
kNumElemsPerRead
/
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
;
amax
=
warp_reduce_max
<
kNumThreadPerGroup
>
(
amax
);
const
int
scale_offset
=
i
*
kNumElemsPerRead
/
QUANTIZATION_
GROUPSIZE
;
if
constexpr
(
k
UseInt8
)
{
if
constexpr
(
k
QuantGroupSize
==
0
)
{
// 记录每128个数的最大值
int8
_amaxf
[
scale_offset
]
=
fmaxf
(
amax
,
int8
_amaxf
[
scale_offset
]);
channel
_amaxf
[
scale_offset
]
=
fmaxf
(
amax
,
channel
_amaxf
[
scale_offset
]);
}
else
{
calculate_
fp8_scales
(
amax
,
scale
,
scale_inv
,
round_scale
);
if
(
lane_id
%
16
==
0
)
calculate_
quant8bit_scales
<
kQuantType
>
(
amax
,
scale
,
scale_inv
,
fp8_
round_scale
);
if
(
lane_id
%
kNumThreadPerGroup
==
0
)
rdma_x_scales
[
scale_offset
]
=
scale_inv
;
// Cast into send buffer
vec_t
int2_value
;
auto
fp8x2_values
=
reinterpret_cast
<
__hip_fp8x2_storage_t
*>
(
&
int2_value
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
j
+=
2
)
{
float2
fp32x2
=
{
fp32_values
[
j
]
*
scale
,
fp32_values
[
j
+
1
]
*
scale
};
fp8x2_values
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E4M3_FNUZ
);
}
pack_quantized_values
<
kQuantType
,
kNumElemsPerRead
>
(
fp32_values
,
scale
,
int2_value
);
rdma_x_vec
[
i
]
=
int2_value
;
}
}
else
{
...
...
@@ -240,24 +293,24 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
__syncthreads
();
if
constexpr
(
kUse
Int8
)
{
if
constexpr
(
kUse
Quant8Bit
&&
kQuantGroupSize
==
0
)
{
float
amax_per_token
=
kFP8Margin
;
// 并行规约,计算每个token的amax
for
(
int
s
=
0
;
s
<
kNumScales
;
s
+=
kWarpSize
)
{
int
src_idx
=
s
+
lane_id
;
float
tmp_amaxf
=
0
;
if
(
src_idx
<
kNumScales
)
{
tmp_amaxf
=
int8
_amaxf
[
src_idx
];
tmp_amaxf
=
channel
_amaxf
[
src_idx
];
}
tmp_amaxf
=
warp_reduce_max
<
kWarpSize
>
(
tmp_amaxf
);
int8
_amaxf
[
0
]
=
fmaxf
(
tmp_amaxf
,
int8
_amaxf
[
0
]);
channel
_amaxf
[
0
]
=
fmaxf
(
tmp_amaxf
,
channel
_amaxf
[
0
]);
__syncthreads
();
}
amax_per_token
=
int8
_amaxf
[
0
];
amax_per_token
=
channel
_amaxf
[
0
];
// 根据最大值计算scale
float
scale
,
scale_inv
;
calculate_
int8_scales
(
amax_per_token
,
scale
,
scale_inv
);
calculate_
quant8bit_scales
<
kQuantType
>
(
amax_per_token
,
scale
,
scale_inv
);
if
(
thread_id
==
0
)
{
rdma_x_scales
[
0
]
=
scale_inv
;
}
...
...
@@ -269,13 +322,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Cast into send buffer
vec_t
int2_value
;
auto
int8_values
=
reinterpret_cast
<
int8_t
*>
(
&
int2_value
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
auto
fp32_value
=
static_cast
<
float
>
(
bf16_values
[
j
]);
auto
fp32_value_scaled
=
fp32_value
*
scale
;
int8_values
[
j
]
=
static_cast
<
int8_t
>
(
nearbyintf
(
fp32_value_scaled
));
}
pack_quantized_values
<
kQuantType
,
kNumElemsPerRead
>
(
bf16_values
,
scale
,
int2_value
);
rdma_x_vec
[
i
]
=
int2_value
;
}
__syncthreads
();
...
...
@@ -392,11 +439,10 @@ LOW_LATENCY_DISPATCH_RECV:
}
// 16 is the max possible number of warps in AMD GPUs
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
constexpr
int
num_sync_large_iteration
=
kMaxNumWarps
;
__shared__
volatile
int
sync_large_warp_counters
[
num_sync_large_iteration
];
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_sync_large_iteration
;
i
+=
blockDim
.
x
)
{
sync_large_warp_counters
[
i
]
=
0
;
}
...
...
@@ -416,7 +462,7 @@ LOW_LATENCY_DISPATCH_RECV:
const
auto
num_aligned_scales
=
ALIGN
<
int
>
(
kNumScales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
const
auto
recv_x_scales
=
static_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
(
k
UseInt8
?
1
:
num_aligned_scales
);
(
k
QuantType
==
1
?
1
:
num_aligned_scales
);
// Shared between sub-warps in warp groups
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
...
...
@@ -461,14 +507,14 @@ LOW_LATENCY_DISPATCH_RECV:
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_int4
,
dst_data
,
src_data
,
ld_nc_global
,
st_na_global
);
// Copy scales
if
constexpr
(
kUse
FP8
)
{
if
constexpr
(
kUse
Quant8Bit
)
{
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
packed_t
)
/
sizeof
(
scale_t
));
const
auto
token_idx
=
recv_token_begin_idx
+
i
;
const
auto
token_stride
=
num_elems_per_pack
;
const
auto
pack_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_elems_per_pack
;
if
constexpr
(
k
UseInt8
)
{
if
constexpr
(
k
QuantType
==
1
)
{
if
(
lane_id
==
0
)
{
recv_x_scales
[
token_idx
]
=
ld_nc_global
(
src_scales
);
}
...
...
@@ -500,12 +546,13 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
quant_group_size
,
bool
fp8_round_scale
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
)
{
constexpr
int
kMaxNumWarps
=
16
;
constexpr
int
kNumMaxTopK
=
11
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
const
int
num_warps_per_group
=
kMaxNumWarps
/
num_warp_groups
;
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
);
EP_HOST_ASSERT
(
kNumMaxTopK
+
1
<=
num_warp_groups
*
num_warps_per_group
);
...
...
@@ -518,33 +565,54 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
#define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = dispatch<false, false, false, hidden>; \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, false, hidden>; \
if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, false, hidden>; \
if (use_int8) \
dispatch_func = dispatch<true, false, true, hidden>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
global_atomic_counter, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, round_scale, phases); } break
// 限制groupsize的大小
EP_HOST_ASSERT
(
quant_group_size
==
0
||
quant_group_size
==
128
);
/*量化类型枚举
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
*/
#define DISPATCH_LAUNCH_CASE(hidden) \
{ \
auto dispatch_func = dispatch<hidden, 0, 0, kMaxNumWarps>; \
if (quant_group_size == 0) { \
switch (quant_type) { \
case 1: dispatch_func = dispatch<hidden, 1, 0, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch<hidden, 2, 0, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch<hidden, 3, 0, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch<hidden, 4, 0, kMaxNumWarps>; break; \
} \
} else { \
switch (quant_type) { \
case 1: dispatch_func = dispatch<hidden, 1, 128, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch<hidden, 2, 128, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch<hidden, 3, 128, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch<hidden, 4, 128, kMaxNumWarps>; break; \
} \
} \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, packed_recv_count, \
global_atomic_counter, \
rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, fp8_round_scale, phases); \
} \
break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
#undef DISPATCH_LAUNCH_CASE
}
template
<
int
kHidden
,
int
kNumMaxTopk
>
template
<
int
kHidden
,
int
kNumMaxTopk
,
int
kMaxNumWarps
=
16
>
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
...
...
@@ -574,12 +642,11 @@ combine(void* combined_x,
const
size_t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerInt4
;
// Message package
EP_STATIC_ASSERT
(
kHidden
%
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
==
0
,
"Invalid hidden"
);
EP_STATIC_ASSERT
(
kHidden
%
QUANTIZATION_
GROUPSIZE
==
0
,
"Invalid hidden"
);
constexpr
size_t
num_bytes_per_slot
=
kHidden
*
sizeof
(
hip_bfloat16
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
// 16 is the max possible number of warps in AMD GPUs
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
// 初始化用于细粒度warp间同步的计数器数组
__shared__
volatile
int
sync_large_warp_counters
[
kMaxNumWarps
];
if
(
threadIdx
.
x
==
0
){
#pragma unroll
...
...
@@ -755,9 +822,10 @@ void combine(void* combined_x,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
constexpr
int
kMaxNumWarps
=
16
;
constexpr
int
kNumMaxTopk
=
11
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
// num_warps_per_group>1, "Requires more than one warp per group"
const
int
num_warps_per_group
=
kMaxNumWarps
/
num_warp_groups
;
// num_warps_per_group>1, "Requires more than one warp per group"
const
int
num_recv_per_sm
=
ceil_div
(
num_combined_tokens
,
num_device_sms
);
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
and
num_recv_per_sm
>=
0
);
...
...
@@ -770,20 +838,20 @@ void combine(void* combined_x,
EP_HOST_ASSERT
(
sizeof
(
int
)
<=
NUM_WORKSPACE_BYTES
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopk
);
#define COMBINE_LAUNCH_CASE(hidden)
{
\
auto combine_func = combine<hidden, kNumMaxTopk>;
\
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func,
\
combined_x,
\
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
global_atomic_counter, \
combine_wait_recv_cost_stats,
\
next
_clean, num_
next_clean_int,
\
atomic_clean_flag,
\
num_combined_tokens, hidden, num_topk,
\
num_max_dispatch_tokens_per_rank,
\
num_experts, rank, num_ranks,
\
num_warp_groups, num_warps_per_group, phases, zero_copy); }
break
#define COMBINE_LAUNCH_CASE(hidden)
\
{
\
auto combine_func = combine<hidden, kNumMaxTopk, kMaxNumWarps>;
\
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func,
\
combined_x,
rdma_recv_x, rdma_recv_flag, rdma_send_x,
\
x, topk_idx, topk_weights, src_info, layout_range,
\
global_atomic_counter,
combine_wait_recv_cost_stats,
\
next_clean, num_next_clean_int,
\
atomic
_clean
_flag
, num_
combined_tokens, hidden,
\
num_topk, num_max_dispatch_tokens_per_rank,
\
num_experts, rank, num_ranks,
\
num_warp_groups, num_warps_per_group, phases, zero_copy);
\
}
\
break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SWITCH_HIDDEN
(
COMBINE_LAUNCH_CASE
);
...
...
csrc/kernels/utils.cuh
View file @
44ec8bed
...
...
@@ -341,10 +341,14 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
}
constexpr
float
kFP8Margin
=
1e-4
;
// 设置不同的量化方式的最大值与相反数
constexpr
float
kFP8Margin
=
0.0
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
constexpr
float
kInt8Amax
=
127.0
f
;
constexpr
float
kFinfoAmaxE5M2
=
57344.0
f
;
constexpr
float
kFinfoAmaxInvE5M2
=
1.0
f
/
kFinfoAmaxE5M2
;
constexpr
float
kFinfoAmaxInt8
=
127.0
f
;
constexpr
float
kFinfoAmaxInvInt8
=
1.0
f
/
127.0
f
;
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
...
...
@@ -359,22 +363,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
)
{
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
scale
=
kFinfoAmaxE4M3
/
amax
;
template
<
int
kQuantType
>
__forceinline__
__device__
void
calculate_quant8bit_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
=
0
)
{
amax
=
fmaxf
(
amax
,
1e-6
f
);
if
constexpr
(
kQuantType
==
1
)
{
// 使用 INT8 对称量化
scale_inv
=
kFinfoAmaxInvInt8
*
amax
;
scale
=
kFinfoAmaxInt8
/
amax
;
}
else
if
constexpr
(
kQuantType
==
2
||
kQuantType
==
3
)
{
// 使用 FP8_E4M3 或 FP8_UE8M0 非对称量化
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
scale
=
kFinfoAmaxE4M3
/
amax
;
}
}
else
if
constexpr
(
kQuantType
==
4
)
{
// 使用 FP8_E5M2 对称量化
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE5M2
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE5M2
;
scale
=
kFinfoAmaxE5M2
/
amax
;
}
}
}
__forceinline__
__device__
void
calculate_int8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
)
{
scale
=
kInt8Amax
/
amax
;
scale_inv
=
amax
/
kInt8Amax
;
}
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
kIsUE8M0
)
{
...
...
deep_ep/buffer.py
View file @
44ec8bed
...
...
@@ -841,7 +841,7 @@ class Buffer:
# noinspection PyTypeChecker
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
use_fp8
:
bool
=
True
,
round_scale
:
bool
=
False
,
use_ue8m0
:
bool
=
False
,
use_int8
:
bool
=
False
,
quant_type
:
int
=
1
,
quant_group_size
:
int
=
0
,
fp8_round_scale
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
...
...
@@ -858,10 +858,20 @@ class Buffer:
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
use_int8: whether to enable INT8 casting.
量化配置
quant_type: int 量化类型枚举
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
quant_group_size: int 量化分组大小
0 -> 逐token量化 (per-channel)
128-> 每 128 元素一组 (per-group) 量化
fp8_round_scale: bool 是否将 FP8 缩放因子取整为 2 的幂
true -> 缩放因子 = 2^k,硬件零开销
false -> 缩放因子 = 任意浮点,精度更高
异步配置
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
...
...
@@ -869,15 +879,25 @@ class Buffer:
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
- packed_recv_x:
存储接收到的 Token 数据,形状为
`[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`。
数据类型取决于 quant_type:
quant_type == 1 -> torch.int8
quant_type == 2 -> torch.float8_e4m3fnuz
quant_type == 3 -> torch.float8_e4m3fnuz (UE8M0 使用 E4M3 格式存储)
quant_type == 4 -> torch.float8_e5m2fnuz
其他 (非量化) -> torch.bfloat16
- packed_recv_x_scales (可选):
仅在 quant_type > 0 时存在,存储量化的 Scale 值。
形状为 `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, scales_col_size]`。
- 当 quant_type == 3 (UE8M0) 时:
scales_col_size = hidden // 512
数据类型为 torch.int (内部打包存储 4-bit scale)。
*注意:此模式强制要求 fp8_round_scale=True 且 group_size=128。
- 当 quant_type == 1, 2, 4 时:
scales_col_size = hidden // 128 (若使用 group_size) 或 1 (per-channel)。
数据类型为 torch.float32。
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
...
...
@@ -889,14 +909,15 @@ class Buffer:
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_int8
,
quant_type
,
quant_group_size
,
fp8_round_scale
,
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
)
return
(
packed_recv_x
,
packed_recv_x_scales
)
if
use_fp8
else
packed_recv_x
,
packed_recv_count
,
handle
,
\
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
recv_x
=
(
packed_recv_x
,
packed_recv_x_scales
)
if
(
quant_type
>
0
)
else
packed_recv_x
return
recv_x
,
packed_recv_count
,
handle
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
# noinspection PyTypeChecker
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
...
...
tests/test_low_latency_new.py
View file @
44ec8bed
...
...
@@ -58,7 +58,8 @@ def test_main(num_tokens: int,
x_list
=
[
x
]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
# x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
# x_list = [x_rand]
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
...
...
@@ -80,14 +81,19 @@ def test_main(num_tokens: int,
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
True
):
for
dispatch_use_fp8
in
(
False
,
True
):
for
round_scale
in
(
False
,
True
)
if
dispatch_use_fp8
else
(
False
,):
for
use_ue8m0
in
(
False
,
True
)
if
round_scale
else
(
False
,):
for
quant_type
in
(
0
,
2
,
3
,
):
# 0: 不量化, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True)
dispatch_use_fp8
=
quant_type
>
0
for
fp8_round_scale
in
(
False
,
True
)
if
dispatch_use_fp8
else
(
False
,
):
for
quant_group_size
in
(
128
,
):
# 跳过不支持的情况
if
quant_type
==
3
and
fp8_round_scale
==
False
:
continue
num_times
+=
1
for
_
in
range
((
num_times
%
2
)
+
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_
fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
quant_type
=
quant_type
,
fp8
_
round_scale
=
fp8_
round_scale
,
quant_group_size
=
quant_group_size
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
...
...
@@ -115,13 +121,13 @@ def test_main(num_tokens: int,
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
if
round_scale
:
if
fp8_
round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
for
j
in
range
(
num_ranks
):
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
if
not
round_scale
:
if
not
fp8_
round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
if
dispatch_use_fp8
:
...
...
@@ -147,7 +153,7 @@ def test_main(num_tokens: int,
if
do_check
:
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
# if not round_scale:
# if not
fp8_
round_scale:
assert
diff
<
(
9e-4
if
dispatch_use_fp8
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_fp8=
{
dispatch_use_fp8
}
, zero_copy=
{
zero_copy
}
'
hash_value
^=
hash_tensor
(
combined_x
)
...
...
@@ -162,7 +168,8 @@ def test_main(num_tokens: int,
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
True
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
quant_type
=
2
,
quant_group_size
=
128
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
...
...
tests/test_low_latency_new_int8.py
View file @
44ec8bed
...
...
@@ -54,16 +54,16 @@ def test_main(num_tokens: int,
do_check
=
True
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
):
for
dispatch_use_fp8
in
(
True
,
):
for
round_scale
in
(
False
,
):
for
use_ue8m0
in
(
False
,
):
for
return_recv_hook
in
(
False
,
True
):
for
quant_type
in
(
1
,
):
for
fp8_round_scale
in
(
False
,
):
for
quant_group_size
in
(
0
,
):
dispatch_use_fp8
=
quant_type
>
0
num_times
+=
1
use_int8
=
True
for
_
in
range
(
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
use_int8
=
use_int8
,
quant_type
=
quant_type
,
quant_group_size
=
quant_group_size
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
...
...
@@ -97,9 +97,7 @@ def test_main(num_tokens: int,
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
elif
use_int8
:
if
quant_type
==
1
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.01
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
...
...
@@ -131,7 +129,7 @@ def test_main(num_tokens: int,
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
True
,
round_scale
=
False
,
use_ue8m0
=
False
,
use_int8
=
True
,
quant_type
=
1
,
quant_group_size
=
0
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment