Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
d0fcf024
Commit
d0fcf024
authored
Feb 04, 2026
by
lishen
Browse files
Merge branch 'quant_master' into 'main'
And quant. See merge request dcutoolkit/deeplearing/DeepEP!19
parents
81e56124
ace6e18e
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
326 additions
and
407 deletions
+326
-407
csrc/config.hpp
csrc/config.hpp
+1
-1
csrc/deep_ep.cu
csrc/deep_ep.cu
+28
-14
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-1
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-1
csrc/kernels/configs.cuh
csrc/kernels/configs.cuh
+1
-1
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+163
-95
csrc/kernels/launch.cuh
csrc/kernels/launch.cuh
+7
-5
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+30
-16
deep_ep/buffer.py
deep_ep/buffer.py
+38
-17
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+45
-23
tests/test_low_latency_new_int8.py
tests/test_low_latency_new_int8.py
+0
-213
tests/utils.py
tests/utils.py
+11
-20
No files found.
csrc/config.hpp
View file @
d0fcf024
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
)
{
int
num_ranks
,
int
num_experts
)
{
const
int
num_scales
=
hidden
/
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
;
const
int
num_scales
=
hidden
/
QUANTIZATION_
GROUPSIZE
;
// Dispatch and combine layout:
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even send buffer
...
...
csrc/deep_ep.cu
View file @
d0fcf024
...
@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
...
@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
quant_group_size
,
bool
fp8_round_scale
,
bool
async
,
bool
return_recv_hook
)
{
bool
async
,
bool
return_recv_hook
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
...
@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait
(
launch_stream
,
compute_stream
);
stream_wait
(
launch_stream
,
compute_stream
);
// Allocate packed tensors
// Allocate packed tensors
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
auto
packed_recv_x_dtype
=
torch
::
kBFloat16
;
x
.
options
().
dtype
(
use_int8
?
torch
::
kInt8
:
use_fp8
?
torch
::
kFloat8_e4m3fnuz
:
torch
::
kBFloat16
));
switch
(
quant_type
)
{
case
1
:
packed_recv_x_dtype
=
torch
::
kInt8
;
break
;
case
2
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fn
;
break
;
case
3
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fn
;
break
;
case
4
:
packed_recv_x_dtype
=
torch
::
kFloat8_e5m2
;
break
;
}
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
packed_recv_x_dtype
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
...
@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales
// Allocate column-majored scales
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
void
*
packed_recv_x_scales_ptr
=
nullptr
;
void
*
packed_recv_x_scales_ptr
=
nullptr
;
if
(
use_fp8
)
{
if
(
quant_type
>
0
)
{
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
// TODO: support unaligned cases
// TODO: support unaligned cases
EP_HOST_ASSERT
(
hidden
%
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
)
==
0
);
EP_HOST_ASSERT
(
hidden
%
(
QUANTIZATION_GROUPSIZE
*
4
)
==
0
);
EP_HOST_ASSERT
(
!
(
use_ue8m0
&&
use_int8
));
// 计算scale_col的大小
int
scales_col_size
=
1
;
// 默认为per-channel
if
(
quant_group_size
>
0
)
{
if
(
quant_type
==
3
)
{
// FP8_UE8M0比较特殊
scales_col_size
=
hidden
/
(
QUANTIZATION_GROUPSIZE
*
4
);
}
else
{
scales_col_size
=
hidden
/
QUANTIZATION_GROUPSIZE
;
}
}
if
(
use_ue8m0
)
{
// 设置packed_recv_x_scales
EP_HOST_ASSERT
(
round_scale
);
if
(
quant_type
==
3
)
{
// FP8_UE8M0比较特殊,需要单独处理
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
),
num_ranks
*
num_max_dispatch_tokens_per_rank
},
EP_HOST_ASSERT
(
fp8_round_scale
&&
quant_group_size
==
128
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
scales_col_size
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
}
else
if
(
use_int8
)
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
1
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
else
{
}
else
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
scales_col_size
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
}
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
...
@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta
.
first
,
next_clean_meta
.
second
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_int8
,
quant_type
,
quant_group_size
,
fp8_round_scale
,
workspace
,
num_device_sms
,
launch_stream
,
phases
);
workspace
,
num_device_sms
,
launch_stream
,
phases
);
};
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
...
csrc/deep_ep.hpp
View file @
d0fcf024
...
@@ -177,7 +177,7 @@ public:
...
@@ -177,7 +177,7 @@ public:
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
quant_group_size
,
bool
fp8_round_scale
,
bool
async
,
bool
return_recv_hook
);
bool
async
,
bool
return_recv_hook
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
...
...
csrc/kernels/api.cuh
View file @
d0fcf024
...
@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
group_size
,
bool
fp8_round_scale
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
void
combine
(
void
*
combined_x
,
...
...
csrc/kernels/configs.cuh
View file @
d0fcf024
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
#define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
128
#define QUANTIZATION_
GROUPSIZE
128
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
...
...
csrc/kernels/internode_ll.cu
View file @
d0fcf024
...
@@ -116,20 +116,74 @@ internode_ll_long_atomic_add(long* dest, const long &value,
...
@@ -116,20 +116,74 @@ internode_ll_long_atomic_add(long* dest, const long &value,
#endif // defined(FORCE_DUSHMEM_API)
#endif // defined(FORCE_DUSHMEM_API)
}
}
template
<
bool
kUseFP8
,
bool
kUseUE8M0
,
bool
kUseInt8
,
int
kHidden
>
/**
* @brief 将 K 个浮点数(BF16/FP32)量化并打包成 INT2(64位)存储
*
* @tparam kQuantType 量化类型 (1: Int8, 2/3: FP8_E4M3/UE8M0, 4: FP8_E5M2)
* @tparam kNumElemsPerRead 每次读取的元素数量 (通常为 2, 4, 8)
* @tparam SrcT 源数据类型 (float 或 __hip_bfloat16)
* @tparam DstT 目标数据类型 (int2 或 int4)
* @param src_values 源数据数组 (长度 >= kNumElemsPerRead)
* @param scale 缩放因子 (将 FP32 值映射到量化范围)
* @param[out] dst_vec 输出的 64 位向量 (int2 或 int4)
*/
template
<
int
kQuantType
,
int
kNumElemsPerRead
,
typename
SrcT
,
typename
DstT
>
__forceinline__
__device__
void
pack_quantized_values
(
const
SrcT
*
src_values
,
float
scale
,
DstT
&
dst_vec
)
{
if
constexpr
(
kQuantType
==
1
)
{
// INT8 量化
auto
int8_ptr
=
reinterpret_cast
<
int8_t
*>
(
&
dst_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
// 如果源是 bfloat16,先提升为 float
float
fp32_value_scaled
=
static_cast
<
float
>
(
src_values
[
j
])
*
scale
;
// 使用 nearbyintf 进行四舍五入
int8_ptr
[
j
]
=
static_cast
<
int8_t
>
(
nearbyintf
(
fp32_value_scaled
));
}
}
else
{
// FP8 量化 (E4M3, UE8M0, E5M2)
// 假设 dst_vec 能容纳 kNumElemsPerRead/2 个 fp8x2 元素
auto
fp8x2_ptr
=
reinterpret_cast
<
__hip_fp8x2_storage_t
*>
(
&
dst_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
j
+=
2
)
{
// 处理两个元素
float2
fp32x2
=
{
static_cast
<
float
>
(
src_values
[
j
])
*
scale
,
static_cast
<
float
>
(
src_values
[
j
+
1
])
*
scale
};
if
constexpr
(
kQuantType
==
4
)
{
// FP8 E5M2
fp8x2_ptr
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E5M2
);
}
else
{
// FP8 E4M3 或 UE8M0
fp8x2_ptr
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E4M3
);
}
}
}
}
template
<
int
kHidden
,
int
kQuantType
=
0
,
int
kQuantGroupSize
=
0
,
int
kMaxNumWarps
=
16
>
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
int
*
global_atomic_counter
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
int
num_warp_groups
,
int
num_warps_per_group
,
bool
round_scale
,
int
phases
)
{
bool
fp8_round_scale
,
int
phases
)
{
// 定义量化类型的枚举
enum
class
QuantType
{
None
=
0
,
// 不进行量化
Int8
=
1
,
// 采用 Int8 量化
FP8_E4M3
=
2
,
// 采用 FP8 量化 __HIP_E4M3
FP8_UE8M0
=
3
,
// 采用 FP8 量化 DeepseekV3.1的 UE8M0
FP8_E5M2
=
4
// 采用 FP8 量化 __HIP_E5M2
};
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
const
auto
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
...
@@ -141,20 +195,22 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -141,20 +195,22 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
// May extract UE8M0 from the scales
// May extract UE8M0 from the scales
constexpr
bool
kUseQuant8Bit
=
kQuantType
>
0
;
constexpr
bool
kUseUE8M0
=
kQuantType
==
3
;
// QuantType::FP8_UE8M0
using
scale_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint8_t
,
float
>
;
using
scale_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint8_t
,
float
>
;
using
packed_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint32_t
,
float
>
;
using
packed_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint32_t
,
float
>
;
EP_STATIC_ASSERT
(
sizeof
(
packed_t
)
%
sizeof
(
scale_t
)
==
0
,
"Invalid vector length"
);
EP_STATIC_ASSERT
(
sizeof
(
packed_t
)
%
sizeof
(
scale_t
)
==
0
,
"Invalid vector length"
);
// FP8 staffs
// FP8 staffs
constexpr
int
kNumPerChannels
=
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
;
constexpr
int
kNumPerChannels
=
QUANTIZATION_
GROUPSIZE
;
constexpr
int
kNumScales
=
kHidden
/
kNumPerChannels
;
constexpr
int
kNumScales
=
kHidden
/
kNumPerChannels
;
const
size_t
hidden_bytes
=
kHidden
*
(
kUse
FP8
?
sizeof
(
__hip_fp8_storage_t
)
:
sizeof
(
hip_bfloat16
));
const
size_t
hidden_bytes
=
kHidden
*
(
kUse
Quant8Bit
?
sizeof
(
__hip_fp8_storage_t
)
:
sizeof
(
hip_bfloat16
));
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
// Message package: hidden data, FP8 scales, index at source
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
// NOTES: currently we have 3 reserved int fields for future use
using
vec_t
=
typename
std
::
conditional
<
kUse
FP8
,
int2
,
int4
>::
type
;
using
vec_t
=
typename
std
::
conditional
<
kUse
Quant8Bit
,
int2
,
int4
>::
type
;
constexpr
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUse
FP8
?
(
kHidden
+
kNumScales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
constexpr
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUse
Quant8Bit
?
(
kHidden
+
kNumScales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
EP_STATIC_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
,
"Invalid message size"
);
EP_STATIC_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
,
"Invalid message size"
);
constexpr
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
constexpr
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
...
@@ -171,6 +227,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -171,6 +227,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// 2. The last warp for reading `topk_idx` and count for per-expert information
// 2. The last warp for reading `topk_idx` and count for per-expert information
if
(
warp_id
<
num_warps
)
{
if
(
warp_id
<
num_warps
)
{
constexpr
int
kNumElemsPerRead
=
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
);
constexpr
int
kNumElemsPerRead
=
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
);
constexpr
int
kNumThreadPerGroup
=
QUANTIZATION_GROUPSIZE
/
kNumElemsPerRead
;
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
const
auto
num_threads
=
(
num_warps
-
1
)
*
kWarpSize
;
const
auto
num_threads
=
(
num_warps
-
1
)
*
kWarpSize
;
...
@@ -186,10 +243,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -186,10 +243,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
thread_id
==
0
?
(
*
rdma_x_src_idx
=
token_idx
)
:
0
;
thread_id
==
0
?
(
*
rdma_x_src_idx
=
token_idx
)
:
0
;
__shared__
float
int8_amaxf
[
kNumScales
];
// 用于记录per-channel量化的amax
if
constexpr
(
kUseInt8
)
{
__shared__
float
channel_amaxf
[
kNumScales
];
if
constexpr
(
kUseQuant8Bit
&&
kQuantGroupSize
==
0
)
{
if
(
thread_id
<
kNumScales
)
{
if
(
thread_id
<
kNumScales
)
{
int8
_amaxf
[
thread_id
]
=
kFP8Margin
;
channel
_amaxf
[
thread_id
]
=
0.0
;
}
}
__syncthreads
();
__syncthreads
();
}
}
...
@@ -200,11 +258,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -200,11 +258,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Read
// Read
auto
int4_value
=
__ldg
(
x_int4
+
i
);
auto
int4_value
=
__ldg
(
x_int4
+
i
);
if
constexpr
(
kUse
FP8
)
{
if
constexpr
(
kUse
Quant8Bit
)
{
// Calculate local amax
// Calculate local amax
auto
bf16_values
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
int4_value
);
auto
bf16_values
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
int4_value
);
float
fp32_values
[
kNumElemsPerRead
];
float
fp32_values
[
kNumElemsPerRead
];
float
amax
=
kFP8Margin
,
scale
,
scale_inv
;
float
amax
=
0.0
,
scale
,
scale_inv
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
fp32_values
[
j
]
=
static_cast
<
float
>
(
bf16_values
[
j
]);
fp32_values
[
j
]
=
static_cast
<
float
>
(
bf16_values
[
j
]);
...
@@ -212,25 +270,20 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -212,25 +270,20 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
}
// Reduce amax and scale
// Reduce amax and scale
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
/
kNumPerChannels
==
4
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
/
kNumPerChannels
==
4
,
"Invalid vectorization"
);
amax
=
warp_reduce_max
<
16
>
(
amax
);
amax
=
warp_reduce_max
<
kNumThreadPerGroup
>
(
amax
);
const
int
scale_offset
=
i
*
kNumElemsPerRead
/
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
;
const
int
scale_offset
=
i
*
kNumElemsPerRead
/
QUANTIZATION_
GROUPSIZE
;
if
constexpr
(
k
UseInt8
)
{
if
constexpr
(
k
QuantGroupSize
==
0
)
{
// 记录每128个数的最大值
// 记录每128个数的最大值
int8
_amaxf
[
scale_offset
]
=
fmaxf
(
amax
,
int8
_amaxf
[
scale_offset
]);
channel
_amaxf
[
scale_offset
]
=
fmaxf
(
amax
,
channel
_amaxf
[
scale_offset
]);
}
else
{
}
else
{
calculate_
fp8_scales
(
amax
,
scale
,
scale_inv
,
round_scale
);
calculate_
quant8bit_scales
<
kQuantType
>
(
amax
,
scale
,
scale_inv
,
fp8_
round_scale
);
if
(
lane_id
%
16
==
0
)
if
(
lane_id
%
kNumThreadPerGroup
==
0
)
rdma_x_scales
[
scale_offset
]
=
scale_inv
;
rdma_x_scales
[
scale_offset
]
=
scale_inv
;
// Cast into send buffer
// Cast into send buffer
vec_t
int2_value
;
vec_t
int2_value
;
auto
fp8x2_values
=
reinterpret_cast
<
__hip_fp8x2_storage_t
*>
(
&
int2_value
);
pack_quantized_values
<
kQuantType
,
kNumElemsPerRead
>
(
fp32_values
,
scale
,
int2_value
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
j
+=
2
)
{
float2
fp32x2
=
{
fp32_values
[
j
]
*
scale
,
fp32_values
[
j
+
1
]
*
scale
};
fp8x2_values
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E4M3_FNUZ
);
}
rdma_x_vec
[
i
]
=
int2_value
;
rdma_x_vec
[
i
]
=
int2_value
;
}
}
}
else
{
}
else
{
...
@@ -240,24 +293,24 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -240,24 +293,24 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
}
__syncthreads
();
__syncthreads
();
if
constexpr
(
kUse
Int8
)
{
if
constexpr
(
kUse
Quant8Bit
&&
kQuantGroupSize
==
0
)
{
float
amax_per_token
=
kFP8Margin
;
float
amax_per_token
=
0.0
;
// 并行规约,计算每个token的amax
// 并行规约,计算每个token的amax
for
(
int
s
=
0
;
s
<
kNumScales
;
s
+=
kWarpSize
)
{
for
(
int
s
=
0
;
s
<
kNumScales
;
s
+=
kWarpSize
)
{
int
src_idx
=
s
+
lane_id
;
int
src_idx
=
s
+
lane_id
;
float
tmp_amaxf
=
0
;
float
tmp_amaxf
=
0
;
if
(
src_idx
<
kNumScales
)
{
if
(
src_idx
<
kNumScales
)
{
tmp_amaxf
=
int8
_amaxf
[
src_idx
];
tmp_amaxf
=
channel
_amaxf
[
src_idx
];
}
}
tmp_amaxf
=
warp_reduce_max
<
kWarpSize
>
(
tmp_amaxf
);
tmp_amaxf
=
warp_reduce_max
<
kWarpSize
>
(
tmp_amaxf
);
int8
_amaxf
[
0
]
=
fmaxf
(
tmp_amaxf
,
int8
_amaxf
[
0
]);
channel
_amaxf
[
0
]
=
fmaxf
(
tmp_amaxf
,
channel
_amaxf
[
0
]);
__syncthreads
();
__syncthreads
();
}
}
amax_per_token
=
int8
_amaxf
[
0
];
amax_per_token
=
channel
_amaxf
[
0
];
// 根据最大值计算scale
// 根据最大值计算scale
float
scale
,
scale_inv
;
float
scale
,
scale_inv
;
calculate_
int8_scales
(
amax_per_token
,
scale
,
scale_inv
);
calculate_
quant8bit_scales
<
kQuantType
>
(
amax_per_token
,
scale
,
scale_inv
,
fp8_round_scale
);
if
(
thread_id
==
0
)
{
if
(
thread_id
==
0
)
{
rdma_x_scales
[
0
]
=
scale_inv
;
rdma_x_scales
[
0
]
=
scale_inv
;
}
}
...
@@ -269,13 +322,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -269,13 +322,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Cast into send buffer
// Cast into send buffer
vec_t
int2_value
;
vec_t
int2_value
;
auto
int8_values
=
reinterpret_cast
<
int8_t
*>
(
&
int2_value
);
pack_quantized_values
<
kQuantType
,
kNumElemsPerRead
>
(
bf16_values
,
scale
,
int2_value
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
auto
fp32_value
=
static_cast
<
float
>
(
bf16_values
[
j
]);
auto
fp32_value_scaled
=
fp32_value
*
scale
;
int8_values
[
j
]
=
static_cast
<
int8_t
>
(
nearbyintf
(
fp32_value_scaled
));
}
rdma_x_vec
[
i
]
=
int2_value
;
rdma_x_vec
[
i
]
=
int2_value
;
}
}
__syncthreads
();
__syncthreads
();
...
@@ -297,8 +344,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -297,8 +344,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
uint64_t
p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
uint64_t
p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
if
(
p2p_ptr
==
0
)
{
// RDMA
if
(
p2p_ptr
==
0
)
{
// RDMA
internode_ll_putmem_nbi
((
void
*
)
dst_ptr
,
(
void
*
)
src_ptr
,
internode_ll_putmem_nbi
((
void
*
)
dst_ptr
,
(
void
*
)
src_ptr
,
num_ranks
,
dst_rank
,
dst_expert_local_idx
,
num_ranks
,
dst_rank
,
dst_expert_local_idx
,
num_bytes_per_msg
);
num_bytes_per_msg
);
}
else
{
// 本地 GPU 和 同一计算节点的 其他 GPU 地址
}
else
{
// 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
...
@@ -392,11 +439,10 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -392,11 +439,10 @@ LOW_LATENCY_DISPATCH_RECV:
}
}
// 16 is the max possible number of warps in AMD GPUs
// 16 is the max possible number of warps in AMD GPUs
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
constexpr
int
num_sync_large_iteration
=
kMaxNumWarps
;
constexpr
int
num_sync_large_iteration
=
kMaxNumWarps
;
__shared__
volatile
int
sync_large_warp_counters
[
num_sync_large_iteration
];
__shared__
volatile
int
sync_large_warp_counters
[
num_sync_large_iteration
];
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_sync_large_iteration
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
thread_id
;
i
<
num_sync_large_iteration
;
i
+=
blockDim
.
x
)
{
sync_large_warp_counters
[
i
]
=
0
;
sync_large_warp_counters
[
i
]
=
0
;
}
}
...
@@ -416,7 +462,7 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -416,7 +462,7 @@ LOW_LATENCY_DISPATCH_RECV:
const
auto
num_aligned_scales
=
ALIGN
<
int
>
(
kNumScales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
const
auto
num_aligned_scales
=
ALIGN
<
int
>
(
kNumScales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
const
auto
recv_x_scales
=
static_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
const
auto
recv_x_scales
=
static_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
(
k
UseInt8
?
1
:
num_aligned_scales
);
(
k
QuantGroupSize
==
0
?
1
:
num_aligned_scales
);
// Shared between sub-warps in warp groups
// Shared between sub-warps in warp groups
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
...
@@ -461,14 +507,14 @@ LOW_LATENCY_DISPATCH_RECV:
...
@@ -461,14 +507,14 @@ LOW_LATENCY_DISPATCH_RECV:
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_int4
,
dst_data
,
src_data
,
ld_nc_global
,
st_na_global
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_int4
,
dst_data
,
src_data
,
ld_nc_global
,
st_na_global
);
// Copy scales
// Copy scales
if
constexpr
(
kUse
FP8
)
{
if
constexpr
(
kUse
Quant8Bit
)
{
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
packed_t
)
/
sizeof
(
scale_t
));
const
auto
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
packed_t
)
/
sizeof
(
scale_t
));
const
auto
token_idx
=
recv_token_begin_idx
+
i
;
const
auto
token_idx
=
recv_token_begin_idx
+
i
;
const
auto
token_stride
=
num_elems_per_pack
;
const
auto
token_stride
=
num_elems_per_pack
;
const
auto
pack_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_elems_per_pack
;
const
auto
pack_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_elems_per_pack
;
if
constexpr
(
k
UseInt8
)
{
if
constexpr
(
k
QuantGroupSize
==
0
)
{
if
(
lane_id
==
0
)
{
if
(
lane_id
==
0
)
{
recv_x_scales
[
token_idx
]
=
ld_nc_global
(
src_scales
);
recv_x_scales
[
token_idx
]
=
ld_nc_global
(
src_scales
);
}
}
...
@@ -500,12 +546,13 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -500,12 +546,13 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
quant_group_size
,
bool
fp8_round_scale
,
void
*
workspace
,
int
num_device_sms
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
)
{
hipStream_t
stream
,
int
phases
)
{
constexpr
int
kMaxNumWarps
=
16
;
constexpr
int
kNumMaxTopK
=
11
;
constexpr
int
kNumMaxTopK
=
11
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
const
int
num_warps_per_group
=
kMaxNumWarps
/
num_warp_groups
;
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
);
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
);
EP_HOST_ASSERT
(
kNumMaxTopK
+
1
<=
num_warp_groups
*
num_warps_per_group
);
EP_HOST_ASSERT
(
kNumMaxTopK
+
1
<=
num_warp_groups
*
num_warps_per_group
);
...
@@ -518,33 +565,54 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -518,33 +565,54 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
#define DISPATCH_LAUNCH_CASE(hidden) { \
// 限制groupsize的大小
auto dispatch_func = dispatch<false, false, false, hidden>; \
EP_HOST_ASSERT
(
quant_group_size
==
0
||
quant_group_size
==
128
);
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, false, hidden>; \
/*量化类型枚举
if (use_fp8 and use_ue8m0) \
0 -> None 不量化,保持原始精度
dispatch_func = dispatch<true, true, false, hidden>; \
1 -> Int8 使用 INT8 对称量化
if (use_int8) \
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3)
dispatch_func = dispatch<true, false, true, hidden>; \
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2)
packed_recv_x, packed_recv_x_scales, \
*/
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
#define DISPATCH_LAUNCH_CASE(hidden) \
global_atomic_counter, \
{ \
rdma_recv_x, rdma_recv_count, rdma_x, \
auto dispatch_func = dispatch<hidden, 0, 0, kMaxNumWarps>; \
x, topk_idx, \
if (quant_group_size == 0) { \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
switch (quant_type) { \
next_clean, num_next_clean_int, \
case 1: dispatch_func = dispatch<hidden, 1, 0, kMaxNumWarps>; break; \
num_tokens, num_max_dispatch_tokens_per_rank, \
case 2: dispatch_func = dispatch<hidden, 2, 0, kMaxNumWarps>; break; \
num_topk, num_experts, rank, num_ranks, \
case 3: dispatch_func = dispatch<hidden, 3, 0, kMaxNumWarps>; break; \
num_warp_groups, num_warps_per_group, round_scale, phases); } break
case 4: dispatch_func = dispatch<hidden, 4, 0, kMaxNumWarps>; break; \
} \
} else { \
switch (quant_type) { \
case 1: dispatch_func = dispatch<hidden, 1, 128, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch<hidden, 2, 128, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch<hidden, 3, 128, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch<hidden, 4, 128, kMaxNumWarps>; break; \
} \
} \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, packed_recv_count, \
global_atomic_counter, \
rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, fp8_round_scale, phases); \
} \
break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
#undef DISPATCH_LAUNCH_CASE
#undef DISPATCH_LAUNCH_CASE
}
}
template
<
int
kHidden
,
int
kNumMaxTopk
>
template
<
int
kHidden
,
int
kNumMaxTopk
,
int
kMaxNumWarps
=
16
>
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
combine
(
void
*
combined_x
,
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
...
@@ -574,12 +642,11 @@ combine(void* combined_x,
...
@@ -574,12 +642,11 @@ combine(void* combined_x,
const
size_t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerInt4
;
const
size_t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerInt4
;
// Message package
// Message package
EP_STATIC_ASSERT
(
kHidden
%
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
==
0
,
"Invalid hidden"
);
EP_STATIC_ASSERT
(
kHidden
%
QUANTIZATION_
GROUPSIZE
==
0
,
"Invalid hidden"
);
constexpr
size_t
num_bytes_per_slot
=
kHidden
*
sizeof
(
hip_bfloat16
);
constexpr
size_t
num_bytes_per_slot
=
kHidden
*
sizeof
(
hip_bfloat16
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
// 16 is the max possible number of warps in AMD GPUs
// 初始化用于细粒度warp间同步的计数器数组
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
__shared__
volatile
int
sync_large_warp_counters
[
kMaxNumWarps
];
__shared__
volatile
int
sync_large_warp_counters
[
kMaxNumWarps
];
if
(
threadIdx
.
x
==
0
){
if
(
threadIdx
.
x
==
0
){
#pragma unroll
#pragma unroll
...
@@ -755,9 +822,10 @@ void combine(void* combined_x,
...
@@ -755,9 +822,10 @@ void combine(void* combined_x,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
int
phases
,
bool
zero_copy
)
{
constexpr
int
kMaxNumWarps
=
16
;
constexpr
int
kNumMaxTopk
=
11
;
constexpr
int
kNumMaxTopk
=
11
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
// num_warps_per_group>1, "Requires more than one warp per group"
const
int
num_warps_per_group
=
kMaxNumWarps
/
num_warp_groups
;
// num_warps_per_group>1, "Requires more than one warp per group"
const
int
num_recv_per_sm
=
ceil_div
(
num_combined_tokens
,
num_device_sms
);
const
int
num_recv_per_sm
=
ceil_div
(
num_combined_tokens
,
num_device_sms
);
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
and
num_recv_per_sm
>=
0
);
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
and
num_recv_per_sm
>=
0
);
...
@@ -770,20 +838,20 @@ void combine(void* combined_x,
...
@@ -770,20 +838,20 @@ void combine(void* combined_x,
EP_HOST_ASSERT
(
sizeof
(
int
)
<=
NUM_WORKSPACE_BYTES
);
EP_HOST_ASSERT
(
sizeof
(
int
)
<=
NUM_WORKSPACE_BYTES
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopk
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopk
);
#define COMBINE_LAUNCH_CASE(hidden)
{
\
#define COMBINE_LAUNCH_CASE(hidden)
\
auto combine_func = combine<hidden, kNumMaxTopk>;
\
{
\
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func,
\
auto combine_func = combine<hidden, kNumMaxTopk, kMaxNumWarps>;
\
combined_x,
\
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func,
\
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
combined_x,
rdma_recv_x, rdma_recv_flag, rdma_send_x,
\
x, topk_idx, topk_weights, src_info, layout_range, \
x, topk_idx, topk_weights, src_info, layout_range,
\
global_atomic_counter, \
global_atomic_counter,
combine_wait_recv_cost_stats,
\
combine_wait_recv_cost_stats,
\
next_clean, num_next_clean_int,
\
next
_clean, num_
next_clean_int,
\
atomic
_clean
_flag
, num_
combined_tokens, hidden,
\
atomic_clean_flag,
\
num_topk, num_max_dispatch_tokens_per_rank,
\
num_combined_tokens, hidden, num_topk,
\
num_experts, rank, num_ranks,
\
num_max_dispatch_tokens_per_rank,
\
num_warp_groups, num_warps_per_group, phases, zero_copy);
\
num_experts, rank, num_ranks,
\
}
\
num_warp_groups, num_warps_per_group, phases, zero_copy); }
break
break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SWITCH_HIDDEN
(
COMBINE_LAUNCH_CASE
);
SWITCH_HIDDEN
(
COMBINE_LAUNCH_CASE
);
...
...
csrc/kernels/launch.cuh
View file @
d0fcf024
...
@@ -62,7 +62,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
...
@@ -62,7 +62,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 8: \
case 8: \
case_macro(8); \
case_macro(8); \
default: \
default: \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
EP_HOST_ASSERT(false and "Unsupported ranks");
\
} \
} \
while (false)
while (false)
...
@@ -83,7 +83,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
...
@@ -83,7 +83,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 20: \
case 20: \
case_macro(20); \
case_macro(20); \
default: \
default: \
EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
EP_HOST_ASSERT(false and "Unsupported RDMA ranks");
\
} \
} \
while (false)
while (false)
...
@@ -96,7 +96,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
...
@@ -96,7 +96,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 8: \
case 8: \
case_macro(dtype, 8); \
case_macro(dtype, 8); \
default: \
default: \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
EP_HOST_ASSERT(false and "Unsupported ranks");
\
} \
} \
while (false)
while (false)
...
@@ -107,7 +107,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
...
@@ -107,7 +107,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case HIP_R_32F: \
case HIP_R_32F: \
case_macro(float); \
case_macro(float); \
default: \
default: \
EP_HOST_ASSERT(false and "Unsupported type"); \
EP_HOST_ASSERT(false and "Unsupported type");
\
} \
} \
while (false)
while (false)
...
@@ -121,7 +121,9 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
...
@@ -121,7 +121,9 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case_macro(4096); \
case_macro(4096); \
case 7168: \
case 7168: \
case_macro(7168); \
case_macro(7168); \
case 8192: \
case_macro(8192); \
default: \
default: \
EP_HOST_ASSERT(false and "Unsupported hidden"); \
EP_HOST_ASSERT(false and "Unsupported hidden");
\
} \
} \
while (false)
while (false)
csrc/kernels/utils.cuh
View file @
d0fcf024
...
@@ -341,10 +341,13 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
...
@@ -341,10 +341,13 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
}
}
constexpr
float
kFP8Margin
=
1e-4
;
// 设置不同的量化方式的最大值与相反数
constexpr
float
kFinfoAmaxE4M3
=
240
.0
f
;
constexpr
float
kFinfoAmaxE4M3
=
448
.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
constexpr
float
kInt8Amax
=
127.0
f
;
constexpr
float
kFinfoAmaxE5M2
=
57344.0
f
;
constexpr
float
kFinfoAmaxInvE5M2
=
1.0
f
/
kFinfoAmaxE5M2
;
constexpr
float
kFinfoAmaxInt8
=
127.0
f
;
constexpr
float
kFinfoAmaxInvInt8
=
1.0
f
/
127.0
f
;
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
// We can ensure `-126 <= x and x <= 127`
...
@@ -359,22 +362,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
...
@@ -359,22 +362,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return
exp_x
-
127
+
(
man_bits
!=
0
);
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
}
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
)
{
template
<
int
kQuantType
>
if
(
round_scale
)
{
__forceinline__
__device__
void
calculate_quant8bit_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
=
0
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
amax
=
fmaxf
(
amax
,
1e-6
f
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
if
constexpr
(
kQuantType
==
1
)
{
// 使用 INT8 对称量化
scale_inv
=
fast_pow2
(
exp_scale_inv
);
scale_inv
=
kFinfoAmaxInvInt8
*
amax
;
}
else
{
scale
=
kFinfoAmaxInt8
/
amax
;
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
}
else
if
constexpr
(
kQuantType
==
2
||
kQuantType
==
3
)
{
// 使用 FP8_E4M3 或 FP8_UE8M0 非对称量化
scale
=
kFinfoAmaxE4M3
/
amax
;
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
scale
=
kFinfoAmaxE4M3
/
amax
;
}
}
else
if
constexpr
(
kQuantType
==
4
)
{
// 使用 FP8_E5M2 对称量化
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE5M2
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE5M2
;
scale
=
kFinfoAmaxE5M2
/
amax
;
}
}
}
}
}
__forceinline__
__device__
void
calculate_int8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
)
{
scale
=
kInt8Amax
/
amax
;
scale_inv
=
amax
/
kInt8Amax
;
}
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
kIsUE8M0
)
{
if
constexpr
(
kIsUE8M0
)
{
...
...
deep_ep/buffer.py
View file @
d0fcf024
...
@@ -841,7 +841,7 @@ class Buffer:
...
@@ -841,7 +841,7 @@ class Buffer:
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
use_fp8
:
bool
=
True
,
round_scale
:
bool
=
False
,
use_ue8m0
:
bool
=
False
,
use_int8
:
bool
=
False
,
quant_type
:
int
=
1
,
quant_group_size
:
int
=
0
,
fp8_round_scale
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
"""
...
@@ -858,10 +858,20 @@ class Buffer:
...
@@ -858,10 +858,20 @@ class Buffer:
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
量化配置
round_scale: whether round the scaling factors into power of 2.
quant_type: int 量化类型枚举
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
0 -> None 不量化,保持原始精度
use_int8: whether to enable INT8 casting.
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
quant_group_size: int 量化分组大小
0 -> 逐token量化 (per-channel)
128-> 每 128 元素一组 (per-group) 量化
fp8_round_scale: bool 是否将 FP8 缩放因子取整为 2 的幂
true -> 缩放因子 = 2^k,硬件零开销
false -> 缩放因子 = 任意浮点,精度更高
异步配置
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
...
@@ -869,15 +879,25 @@ class Buffer:
...
@@ -869,15 +879,25 @@ class Buffer:
Returns:
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
- packed_recv_x:
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
存储接收到的 Token 数据,形状为
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`。
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
数据类型取决于 quant_type:
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
quant_type == 1 -> torch.int8
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
quant_type == 2 -> torch.float8_e4m3fnuz
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
quant_type == 3 -> torch.float8_e4m3fnuz (UE8M0 使用 E4M3 格式存储)
With `use_fp8=False`, the result would be a tensor shaped as
quant_type == 4 -> torch.float8_e5m2fnuz
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
其他 (非量化) -> torch.bfloat16
- packed_recv_x_scales (可选):
仅在 quant_type > 0 时存在,存储量化的 Scale 值。
形状为 `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, scales_col_size]`。
- 当 quant_type == 3 (UE8M0) 时:
scales_col_size = hidden // 512
数据类型为 torch.int (内部打包存储 4-bit scale)。
*注意:此模式强制要求 fp8_round_scale=True 且 group_size=128。
- 当 quant_type == 1, 2, 4 时:
scales_col_size = hidden // 128 (若使用 group_size) 或 1 (per-channel)。
数据类型为 torch.float32。
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
...
@@ -889,14 +909,15 @@ class Buffer:
...
@@ -889,14 +909,15 @@ class Buffer:
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_int8
,
quant_type
,
quant_group_size
,
fp8_round_scale
,
async_finish
,
return_recv_hook
)
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
)
packed_recv_src_info
,
packed_recv_layout_range
)
return
(
packed_recv_x
,
packed_recv_x_scales
)
if
use_fp8
else
packed_recv_x
,
packed_recv_count
,
handle
,
\
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
recv_x
=
(
packed_recv_x
,
packed_recv_x_scales
)
if
(
quant_type
>
0
)
else
packed_recv_x
return
recv_x
,
packed_recv_count
,
handle
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
...
...
tests/test_low_latency_new.py
View file @
d0fcf024
...
@@ -6,7 +6,7 @@ from functools import partial
...
@@ -6,7 +6,7 @@ from functools import partial
from
typing
import
Literal
,
Set
from
typing
import
Literal
,
Set
import
deep_ep
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_back
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_
pg_back
,
per_token_cast_pc_
back
def
simulate_failure_and_skip
(
rank
:
int
,
api
:
Literal
[
"dispatch"
,
"combine"
,
"clean"
],
expected_masked_ranks
:
Set
[
int
]):
def
simulate_failure_and_skip
(
rank
:
int
,
api
:
Literal
[
"dispatch"
,
"combine"
,
"clean"
],
expected_masked_ranks
:
Set
[
int
]):
...
@@ -58,7 +58,8 @@ def test_main(num_tokens: int,
...
@@ -58,7 +58,8 @@ def test_main(num_tokens: int,
x_list
=
[
x
]
x_list
=
[
x
]
# # NOTES: the last one is for performance testing
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
# x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
# x_list = [x_rand]
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
...
@@ -80,22 +81,35 @@ def test_main(num_tokens: int,
...
@@ -80,22 +81,35 @@ def test_main(num_tokens: int,
hash_value
,
num_times
=
0
,
0
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
True
):
for
return_recv_hook
in
(
False
,
True
):
for
dispatch_use_fp8
in
(
False
,
True
):
for
quant_type
in
(
0
,
1
,
2
,
3
,
):
# 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
for
round_scale
in
(
False
,
True
)
if
dispatch_use_fp8
else
(
False
,):
dispatch_use_quant
=
quant_type
>
0
for
use_ue8m0
in
(
False
,
True
)
if
round_scale
else
(
False
,):
for
fp8_round_scale
in
(
False
,
True
)
if
quant_type
!=
3
else
(
True
,
):
for
quant_group_size
in
(
0
,
128
,)
if
quant_type
>=
2
else
(
0
,
):
if
quant_type
==
3
and
(
fp8_round_scale
==
False
or
quant_group_size
==
0
):
continue
num_times
+=
1
num_times
+=
1
for
_
in
range
((
num_times
%
2
)
+
1
):
for
_
in
range
((
num_times
%
2
)
+
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_
fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
quant_type
=
quant_type
,
fp8
_
round_scale
=
fp8_
round_scale
,
quant_group_size
=
quant_group_size
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_quant
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
\
if
not
dispatch_use_quant
:
if
dispatch_use_fp8
else
packed_recv_x
.
clone
()
simulated_gemm_x
=
packed_recv_x
.
clone
()
elif
quant_group_size
==
0
:
simulated_gemm_x
=
per_token_cast_pc_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
reshape
(
-
1
)).
view
(
packed_recv_x
[
0
].
shape
)
elif
quant_group_size
==
128
:
simulated_gemm_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
per_token_cast_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
if
not
dispatch_use_quant
:
recv_x
=
packed_recv_x
[
i
]
elif
quant_group_size
==
0
:
recv_x
=
per_token_cast_pc_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
elif
quant_group_size
==
128
:
recv_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
# Check expert indices
...
@@ -113,18 +127,25 @@ def test_main(num_tokens: int,
...
@@ -113,18 +127,25 @@ def test_main(num_tokens: int,
if
current_x
is
x
:
if
current_x
is
x
:
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
round_scale
:
if
dispatch_use_quant
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
for
j
in
range
(
num_ranks
):
if
quant_group_size
!=
0
:
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
if
fp8_round_scale
:
if
not
round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
else
:
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
if
dispatch_use_fp8
:
for
j
in
range
(
num_ranks
):
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
if
not
fp8_round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
if
dispatch_use_quant
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
else
:
...
@@ -147,8 +168,8 @@ def test_main(num_tokens: int,
...
@@ -147,8 +168,8 @@ def test_main(num_tokens: int,
if
do_check
:
if
do_check
:
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
# if not round_scale:
# if not
fp8_
round_scale:
assert
diff
<
(
9e-4
if
dispatch_use_
fp8
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_
fp8
=
{
dispatch_use_
fp8
}
, zero_copy=
{
zero_copy
}
'
assert
diff
<
(
9e-4
if
dispatch_use_
quant
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_
quant
=
{
dispatch_use_
quant
}
, zero_copy=
{
zero_copy
}
'
hash_value
^=
hash_tensor
(
combined_x
)
hash_value
^=
hash_tensor
(
combined_x
)
# noinspection PyShadowingNames
# noinspection PyShadowingNames
...
@@ -162,7 +183,8 @@ def test_main(num_tokens: int,
...
@@ -162,7 +183,8 @@ def test_main(num_tokens: int,
def
test_func
(
return_recv_hook
:
bool
):
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
True
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
quant_type
=
2
,
quant_group_size
=
0
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_idx
,
...
...
tests/test_low_latency_new_int8.py
deleted
100644 → 0
View file @
81e56124
import
argparse
import
random
import
os
import
torch
import
torch.distributed
as
dist
from
functools
import
partial
from
typing
import
Literal
,
Set
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_back_int8
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
rank
:
int
,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
assert
num_experts
%
num_ranks
==
0
num_local_experts
=
num_experts
//
num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset
=
128
assert
num_ranks
-
rank_offset
<
257
,
'Too many ranks (exceeding test precision limit)'
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
(
rank
-
rank_offset
)
x
[:,
-
128
:]
=
torch
.
arange
(
num_tokens
,
device
=
'cuda'
).
to
(
torch
.
bfloat16
).
view
(
-
1
,
1
)
x_list
=
[
x
]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
topk_weights
=
torch
.
randn
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
# Randomly mask some positions
for
_
in
range
(
10
):
topk_idx
[
random
.
randint
(
0
,
num_tokens
-
1
),
random
.
randint
(
0
,
num_topk
-
1
)]
=
-
1
all_topk_idx
=
torch
.
empty
((
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
'cuda'
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
# For failure simulation and shrink testing
mask_status
=
torch
.
zeros
((
num_ranks
,),
dtype
=
torch
.
int
,
device
=
'cuda'
)
# Check dispatch correctness
do_check
=
True
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
):
for
dispatch_use_fp8
in
(
True
,
):
for
round_scale
in
(
False
,
):
for
use_ue8m0
in
(
False
,
):
num_times
+=
1
use_int8
=
True
for
_
in
range
(
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
use_int8
=
use_int8
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_back_int8
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
1
)).
view
(
packed_recv_x
[
0
].
shape
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
per_token_cast_back_int8
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
int_mask
=
(
2
**
32
)
-
1
num_valid_tokens
=
recv_count
.
item
()
assert
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
(
),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
()
}
'
if
num_valid_tokens
==
0
:
continue
# Check received data
if
current_x
is
x
:
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
elif
use_int8
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.01
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
# for j in range(num_ranks):
# if (not round_scale):
# check_tmp1 = (recv_x_amin == j - rank_offset).sum().item()
# check_tmp2 = (all_topk_idx[j] == expert_id).sum().item()
# print(f'rank: {rank}, j: {j}, check_tmp1: {check_tmp1}, check_tmp2: {check_tmp2}, diff: {abs(check_tmp1 - check_tmp2)}')
# assert abs(check_tmp1 - check_tmp2) < 3
# assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if
dispatch_use_fp8
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
print
(
"dispatch int 8 pass"
)
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
mat_0
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_1
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_0
@
mat_1
hook
()
# noinspection PyShadowingNames
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
True
,
round_scale
=
False
,
use_ue8m0
=
False
,
use_int8
=
True
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
# Calculate bandwidth
scale_size
=
1
# hidden / 128
num_fp8_bytes
,
num_bf16_bytes
=
(
hidden
+
scale_size
*
4
+
16
),
hidden
*
2
num_dispatch_comm_bytes
,
num_combine_comm_bytes
=
0
,
0
for
i
in
range
(
num_tokens
):
num_selections
=
(
topk_idx
[
i
]
!=
-
1
).
sum
().
item
()
num_dispatch_comm_bytes
+=
num_fp8_bytes
*
num_selections
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
# Separate profiling
for
return_recv_hook
in
(
True
,
False
):
group
.
barrier
()
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
return_recv_hook
=
return_recv_hook
),
kernel_names
=
(
'dispatch'
,
'combine'
),
barrier_comm_profiling
=
True
,
suppress_kineto_output
=
True
,
num_kernels_per_period
=
2
if
return_recv_hook
else
1
)
if
not
return_recv_hook
:
print
(
f
'[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us | '
f
'Combine bandwidth:
{
num_combine_comm_bytes
/
1e9
/
combine_t
:.
2
f
}
GB/s, avg_t=
{
combine_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
else
:
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
[
0
]
*
1e6
:.
2
f
}
+
{
dispatch_t
[
1
]
*
1e6
:.
2
f
}
us | '
f
'Combine send/recv time:
{
combine_t
[
0
]
*
1e6
:.
2
f
}
+
{
combine_t
[
1
]
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
return
hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
:
argparse
.
Namespace
):
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
num_topk
,
num_experts
=
args
.
num_topk
,
args
.
num_experts
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
if
local_rank
==
0
:
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
buffer
=
deep_ep
.
Buffer
(
group
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
num_qps_per_rank
=
num_experts
//
num_ranks
,
allow_nvlink_for_low_latency_mode
=
not
args
.
disable_nvlink
,
explicitly_destroy
=
True
,
allow_mnnvl
=
args
.
allow_mnnvl
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
# Destroy the buffer runtime and communication group
buffer
.
destroy
()
dist
.
barrier
()
dist
.
destroy_process_group
()
if
__name__
==
'__main__'
:
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser
=
argparse
.
ArgumentParser
(
description
=
'Test low-latency EP kernels'
)
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
help
=
'Number of processes to spawn (default: 8)'
)
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
128
,
help
=
'Number of tokens (default: 128)'
)
parser
.
add_argument
(
'--hidden'
,
type
=
int
,
default
=
7168
,
help
=
'Hidden dimension size (default: 7168)'
)
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
help
=
'Number of top-k experts (default: 8)'
)
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
256
,
help
=
'Number of experts (default: 288)'
)
parser
.
add_argument
(
'--allow-mnnvl'
,
action
=
"store_true"
,
help
=
'Allow MNNVL for communication'
)
parser
.
add_argument
(
'--disable-nvlink'
,
action
=
'store_true'
,
help
=
'Whether to disable NVLink for testing'
)
parser
.
add_argument
(
"--pressure-test"
,
action
=
'store_true'
,
help
=
'Whether to do pressure test'
)
parser
.
add_argument
(
"--shrink-test"
,
action
=
'store_true'
,
help
=
'Whether to simulate failure and test shrink mode'
)
parser
.
add_argument
(
'--use-logfmt'
,
action
=
'store_true'
,
help
=
'Whether to test LogFMT combine'
)
args
=
parser
.
parse_args
()
num_processes
=
args
.
num_processes
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
tests/utils.py
View file @
d0fcf024
...
@@ -57,28 +57,22 @@ def per_token_cast_to_fp8(x: torch.Tensor):
...
@@ -57,28 +57,22 @@ def per_token_cast_to_fp8(x: torch.Tensor):
return
(
x_padded_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
aligned_n
)[:,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
return
(
x_padded_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
aligned_n
)[:,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_token_cast_back
(
x
_fp8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
def
per_token_cast_
pg_
back
(
x
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
if
x
_fp8
.
numel
()
==
0
:
if
x
.
numel
()
==
0
:
return
x
_fp8
.
to
(
torch
.
bfloat16
)
return
x
.
to
(
torch
.
bfloat16
)
assert
x
_fp8
.
dim
()
==
2
assert
x
.
dim
()
==
2
m
,
n
=
x
_fp8
.
shape
m
,
n
=
x
.
shape
aligned_n
=
align_up
(
n
,
128
)
aligned_n
=
align_up
(
n
,
128
)
x_
fp8_
padded
=
torch
.
nn
.
functional
.
pad
(
x
_fp8
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_padded
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
if
x_scales
.
dtype
==
torch
.
int
:
if
x_scales
.
dtype
==
torch
.
int
:
x_scales
=
x_scales
.
view
(
dtype
=
torch
.
uint8
).
to
(
torch
.
int
)
<<
23
x_scales
=
x_scales
.
view
(
dtype
=
torch
.
uint8
).
to
(
torch
.
int
)
<<
23
x_scales
=
x_scales
.
view
(
dtype
=
torch
.
float
)
x_scales
=
x_scales
.
view
(
dtype
=
torch
.
float
)
x_fp32_padded
=
x_fp8_padded
.
to
(
torch
.
float32
).
view
(
x_fp8
.
size
(
0
),
-
1
,
128
)
x_fp32_padded
=
x_padded
.
to
(
torch
.
float32
).
view
(
x
.
size
(
0
),
-
1
,
128
)
x_scales
=
x_scales
.
view
(
x_fp8
.
size
(
0
),
-
1
,
1
)
x_scales
=
x_scales
.
view
(
x
.
size
(
0
),
-
1
,
1
)
return
(
x_fp32_padded
*
x_scales
).
view
(
x_fp8_padded
.
shape
).
to
(
torch
.
bfloat16
)[:,:
n
].
contiguous
()
return
(
x_fp32_padded
*
x_scales
).
view
(
x_padded
.
shape
).
to
(
torch
.
bfloat16
)[:,:
n
].
contiguous
()
def
per_token_cast_back_int8
(
x_int8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
def
per_token_cast_pc_back
(
x_int8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
"""
x_int8: [m, n] int8 tensor
x_scales: [m, n] 或 [m, 1] 或 [m, n/128] 量化 scale float
return: [m, n] bf16 tensor
"""
if
x_int8
.
numel
()
==
0
:
if
x_int8
.
numel
()
==
0
:
return
x_int8
.
to
(
torch
.
bfloat16
)
return
x_int8
.
to
(
torch
.
bfloat16
)
...
@@ -86,12 +80,9 @@ def per_token_cast_back_int8(x_int8: torch.Tensor, x_scales: torch.Tensor):
...
@@ -86,12 +80,9 @@ def per_token_cast_back_int8(x_int8: torch.Tensor, x_scales: torch.Tensor):
m
,
n
=
x_int8
.
shape
m
,
n
=
x_int8
.
shape
aligned_n
=
align_up
(
n
,
128
)
aligned_n
=
align_up
(
n
,
128
)
x_int8_padded
=
torch
.
nn
.
functional
.
pad
(
x_int8_padded
=
torch
.
nn
.
functional
.
pad
(
x_int8
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_int8
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_fp32_padded
=
x_int8_padded
.
to
(
torch
.
float32
).
view
(
m
,
-
1
,
1
)
x_fp32_padded
=
x_int8_padded
.
to
(
torch
.
float32
).
view
(
m
,
-
1
,
1
)
x_scales
=
x_scales
.
view
(
m
,
-
1
,
1
).
to
(
torch
.
float32
)
x_scales
=
x_scales
.
view
(
m
,
-
1
,
1
).
to
(
torch
.
float32
)
# print(f'x_int8.shape: {x_int8.shape}, x_fp32_padded: {x_fp32_padded.shape}, x_scales: {x_scales.shape}')
x_deq
=
(
x_fp32_padded
*
x_scales
).
view
(
m
,
aligned_n
)
x_deq
=
(
x_fp32_padded
*
x_scales
).
view
(
m
,
aligned_n
)
return
x_deq
[:,
:
n
].
to
(
torch
.
bfloat16
).
contiguous
()
return
x_deq
[:,
:
n
].
to
(
torch
.
bfloat16
).
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment