Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
d0fcf024
Commit
d0fcf024
authored
Feb 04, 2026
by
lishen
Browse files
Merge branch 'quant_master' into 'main'
And quant. See merge request dcutoolkit/deeplearing/DeepEP!19
parents
81e56124
ace6e18e
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
326 additions
and
407 deletions
+326
-407
csrc/config.hpp
csrc/config.hpp
+1
-1
csrc/deep_ep.cu
csrc/deep_ep.cu
+28
-14
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-1
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-1
csrc/kernels/configs.cuh
csrc/kernels/configs.cuh
+1
-1
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+163
-95
csrc/kernels/launch.cuh
csrc/kernels/launch.cuh
+7
-5
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+30
-16
deep_ep/buffer.py
deep_ep/buffer.py
+38
-17
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+45
-23
tests/test_low_latency_new_int8.py
tests/test_low_latency_new_int8.py
+0
-213
tests/utils.py
tests/utils.py
+11
-20
No files found.
csrc/config.hpp
View file @
d0fcf024
...
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
)
{
const
int
num_scales
=
hidden
/
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
;
const
int
num_scales
=
hidden
/
QUANTIZATION_
GROUPSIZE
;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
...
...
csrc/deep_ep.cu
View file @
d0fcf024
...
...
@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
Buffer
::
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
quant_group_size
,
bool
fp8_round_scale
,
bool
async
,
bool
return_recv_hook
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
...
...
@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait
(
launch_stream
,
compute_stream
);
// Allocate packed tensors
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
use_int8
?
torch
::
kInt8
:
use_fp8
?
torch
::
kFloat8_e4m3fnuz
:
torch
::
kBFloat16
));
auto
packed_recv_x_dtype
=
torch
::
kBFloat16
;
switch
(
quant_type
)
{
case
1
:
packed_recv_x_dtype
=
torch
::
kInt8
;
break
;
case
2
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fn
;
break
;
case
3
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fn
;
break
;
case
4
:
packed_recv_x_dtype
=
torch
::
kFloat8_e5m2
;
break
;
}
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
packed_recv_x_dtype
));
auto
packed_recv_src_info
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
auto
packed_recv_layout_range
=
torch
::
empty
({
num_local_experts
,
num_ranks
},
torch
::
dtype
(
torch
::
kInt64
).
device
(
torch
::
kCUDA
));
auto
packed_recv_count
=
torch
::
empty
({
num_local_experts
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
));
...
...
@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales
auto
packed_recv_x_scales
=
std
::
optional
<
torch
::
Tensor
>
();
void
*
packed_recv_x_scales_ptr
=
nullptr
;
if
(
use_fp8
)
{
if
(
quant_type
>
0
)
{
EP_HOST_ASSERT
((
num_ranks
*
num_max_dispatch_tokens_per_rank
)
%
4
==
0
and
"TMA requires the number of tokens to be multiple of 4"
);
// TODO: support unaligned cases
EP_HOST_ASSERT
(
hidden
%
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
)
==
0
);
EP_HOST_ASSERT
(
!
(
use_ue8m0
&&
use_int8
));
EP_HOST_ASSERT
(
hidden
%
(
QUANTIZATION_GROUPSIZE
*
4
)
==
0
);
// 计算scale_col的大小
int
scales_col_size
=
1
;
// 默认为per-channel
if
(
quant_group_size
>
0
)
{
if
(
quant_type
==
3
)
{
// FP8_UE8M0比较特殊
scales_col_size
=
hidden
/
(
QUANTIZATION_GROUPSIZE
*
4
);
}
else
{
scales_col_size
=
hidden
/
QUANTIZATION_GROUPSIZE
;
}
}
if
(
use_ue8m0
)
{
EP_HOST_ASSERT
(
round_scale
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
(
FP8_QUANTIZATION_NUM_PER_CHANNEL
*
4
),
num_ranks
*
num_max_dispatch_tokens_per_rank
},
// 设置packed_recv_x_scales
if
(
quant_type
==
3
)
{
// FP8_UE8M0比较特殊,需要单独处理
EP_HOST_ASSERT
(
fp8_round_scale
&&
quant_group_size
==
128
);
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
scales_col_size
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kInt
).
device
(
torch
::
kCUDA
));
}
else
if
(
use_int8
)
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
1
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
else
{
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
hidden
/
FP8_QUANTIZATION_NUM_PER_CHANNEL
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
packed_recv_x_scales
=
torch
::
empty
({
num_local_experts
,
scales_col_size
,
num_ranks
*
num_max_dispatch_tokens_per_rank
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
));
}
packed_recv_x_scales
=
torch
::
transpose
(
packed_recv_x_scales
.
value
(),
1
,
2
);
...
...
@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_int8
,
quant_type
,
quant_group_size
,
fp8_round_scale
,
workspace
,
num_device_sms
,
launch_stream
,
phases
);
};
launcher
(
return_recv_hook
?
LOW_LATENCY_SEND_PHASE
:
(
LOW_LATENCY_SEND_PHASE
|
LOW_LATENCY_RECV_PHASE
));
...
...
csrc/deep_ep.hpp
View file @
d0fcf024
...
...
@@ -177,7 +177,7 @@ public:
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
low_latency_dispatch
(
const
torch
::
Tensor
&
x
,
const
torch
::
Tensor
&
topk_idx
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
quant_group_size
,
bool
fp8_round_scale
,
bool
async
,
bool
return_recv_hook
);
std
::
tuple
<
torch
::
Tensor
,
std
::
optional
<
EventHandle
>
,
std
::
optional
<
std
::
function
<
void
()
>>>
...
...
csrc/kernels/api.cuh
View file @
d0fcf024
...
...
@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
group_size
,
bool
fp8_round_scale
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
);
void
combine
(
void
*
combined_x
,
...
...
csrc/kernels/configs.cuh
View file @
d0fcf024
...
...
@@ -23,7 +23,7 @@
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
128
#define QUANTIZATION_
GROUPSIZE
128
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
...
...
csrc/kernels/internode_ll.cu
View file @
d0fcf024
...
...
@@ -116,20 +116,74 @@ internode_ll_long_atomic_add(long* dest, const long &value,
#endif // defined(FORCE_DUSHMEM_API)
}
template
<
bool
kUseFP8
,
bool
kUseUE8M0
,
bool
kUseInt8
,
int
kHidden
>
/**
* @brief 将 K 个浮点数(BF16/FP32)量化并打包成 INT2(64位)存储
*
* @tparam kQuantType 量化类型 (1: Int8, 2/3: FP8_E4M3/UE8M0, 4: FP8_E5M2)
* @tparam kNumElemsPerRead 每次读取的元素数量 (通常为 2, 4, 8)
* @tparam SrcT 源数据类型 (float 或 __hip_bfloat16)
* @tparam DstT 目标数据类型 (int2 或 int4)
* @param src_values 源数据数组 (长度 >= kNumElemsPerRead)
* @param scale 缩放因子 (将 FP32 值映射到量化范围)
* @param[out] dst_vec 输出的 64 位向量 (int2 或 int4)
*/
template
<
int
kQuantType
,
int
kNumElemsPerRead
,
typename
SrcT
,
typename
DstT
>
__forceinline__
__device__
void
pack_quantized_values
(
const
SrcT
*
src_values
,
float
scale
,
DstT
&
dst_vec
)
{
if
constexpr
(
kQuantType
==
1
)
{
// INT8 量化
auto
int8_ptr
=
reinterpret_cast
<
int8_t
*>
(
&
dst_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
// 如果源是 bfloat16,先提升为 float
float
fp32_value_scaled
=
static_cast
<
float
>
(
src_values
[
j
])
*
scale
;
// 使用 nearbyintf 进行四舍五入
int8_ptr
[
j
]
=
static_cast
<
int8_t
>
(
nearbyintf
(
fp32_value_scaled
));
}
}
else
{
// FP8 量化 (E4M3, UE8M0, E5M2)
// 假设 dst_vec 能容纳 kNumElemsPerRead/2 个 fp8x2 元素
auto
fp8x2_ptr
=
reinterpret_cast
<
__hip_fp8x2_storage_t
*>
(
&
dst_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
j
+=
2
)
{
// 处理两个元素
float2
fp32x2
=
{
static_cast
<
float
>
(
src_values
[
j
])
*
scale
,
static_cast
<
float
>
(
src_values
[
j
+
1
])
*
scale
};
if
constexpr
(
kQuantType
==
4
)
{
// FP8 E5M2
fp8x2_ptr
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E5M2
);
}
else
{
// FP8 E4M3 或 UE8M0
fp8x2_ptr
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E4M3
);
}
}
}
}
template
<
int
kHidden
,
int
kQuantType
=
0
,
int
kQuantGroupSize
=
0
,
int
kMaxNumWarps
=
16
>
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
bool
round_scale
,
int
phases
)
{
dispatch
(
void
*
packed_recv_x
,
void
*
packed_recv_x_scales
,
int
*
packed_recv_src_info
,
int64_t
*
packed_recv_layout_range
,
int
*
packed_recv_count
,
int
*
global_atomic_counter
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_count
,
void
*
rdma_x
,
const
void
*
x
,
const
int64_t
*
topk_idx
,
int
*
atomic_counter_per_expert
,
int
*
atomic_finish_counter_per_expert
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_warp_groups
,
int
num_warps_per_group
,
bool
fp8_round_scale
,
int
phases
)
{
// 定义量化类型的枚举
enum
class
QuantType
{
None
=
0
,
// 不进行量化
Int8
=
1
,
// 采用 Int8 量化
FP8_E4M3
=
2
,
// 采用 FP8 量化 __HIP_E4M3
FP8_UE8M0
=
3
,
// 采用 FP8 量化 DeepseekV3.1的 UE8M0
FP8_E5M2
=
4
// 采用 FP8 量化 __HIP_E5M2
};
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
thread_id
=
static_cast
<
int
>
(
threadIdx
.
x
);
const
auto
warp_id
=
thread_id
/
kWarpSize
,
lane_id
=
get_lane_id
();
...
...
@@ -141,20 +195,22 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
// May extract UE8M0 from the scales
constexpr
bool
kUseQuant8Bit
=
kQuantType
>
0
;
constexpr
bool
kUseUE8M0
=
kQuantType
==
3
;
// QuantType::FP8_UE8M0
using
scale_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint8_t
,
float
>
;
using
packed_t
=
std
::
conditional_t
<
kUseUE8M0
,
uint32_t
,
float
>
;
EP_STATIC_ASSERT
(
sizeof
(
packed_t
)
%
sizeof
(
scale_t
)
==
0
,
"Invalid vector length"
);
// FP8 staffs
constexpr
int
kNumPerChannels
=
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
;
constexpr
int
kNumPerChannels
=
QUANTIZATION_
GROUPSIZE
;
constexpr
int
kNumScales
=
kHidden
/
kNumPerChannels
;
const
size_t
hidden_bytes
=
kHidden
*
(
kUse
FP8
?
sizeof
(
__hip_fp8_storage_t
)
:
sizeof
(
hip_bfloat16
));
const
size_t
hidden_bytes
=
kHidden
*
(
kUse
Quant8Bit
?
sizeof
(
__hip_fp8_storage_t
)
:
sizeof
(
hip_bfloat16
));
const
size_t
hidden_int4
=
hidden_bytes
/
sizeof
(
int4
);
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
using
vec_t
=
typename
std
::
conditional
<
kUse
FP8
,
int2
,
int4
>::
type
;
constexpr
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUse
FP8
?
(
kHidden
+
kNumScales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
using
vec_t
=
typename
std
::
conditional
<
kUse
Quant8Bit
,
int2
,
int4
>::
type
;
constexpr
size_t
num_bytes_per_msg
=
sizeof
(
int4
)
+
(
kUse
Quant8Bit
?
(
kHidden
+
kNumScales
*
sizeof
(
float
))
:
(
kHidden
*
sizeof
(
hip_bfloat16
)));
EP_STATIC_ASSERT
(
num_bytes_per_msg
%
sizeof
(
int4
)
==
0
,
"Invalid message size"
);
constexpr
size_t
num_int4_per_msg
=
num_bytes_per_msg
/
sizeof
(
int4
);
...
...
@@ -171,6 +227,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// 2. The last warp for reading `topk_idx` and count for per-expert information
if
(
warp_id
<
num_warps
)
{
constexpr
int
kNumElemsPerRead
=
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
);
constexpr
int
kNumThreadPerGroup
=
QUANTIZATION_GROUPSIZE
/
kNumElemsPerRead
;
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
%
kNumPerChannels
==
0
,
"Invalid vectorization"
);
const
auto
num_threads
=
(
num_warps
-
1
)
*
kWarpSize
;
...
...
@@ -186,10 +243,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto
dst_expert_idx
=
warp_id
<
num_topk
?
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
warp_id
))
:
-
1
;
thread_id
==
0
?
(
*
rdma_x_src_idx
=
token_idx
)
:
0
;
__shared__
float
int8_amaxf
[
kNumScales
];
if
constexpr
(
kUseInt8
)
{
// 用于记录per-channel量化的amax
__shared__
float
channel_amaxf
[
kNumScales
];
if
constexpr
(
kUseQuant8Bit
&&
kQuantGroupSize
==
0
)
{
if
(
thread_id
<
kNumScales
)
{
int8
_amaxf
[
thread_id
]
=
kFP8Margin
;
channel
_amaxf
[
thread_id
]
=
0.0
;
}
__syncthreads
();
}
...
...
@@ -200,11 +258,11 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Read
auto
int4_value
=
__ldg
(
x_int4
+
i
);
if
constexpr
(
kUse
FP8
)
{
if
constexpr
(
kUse
Quant8Bit
)
{
// Calculate local amax
auto
bf16_values
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
int4_value
);
float
fp32_values
[
kNumElemsPerRead
];
float
amax
=
kFP8Margin
,
scale
,
scale_inv
;
float
amax
=
0.0
,
scale
,
scale_inv
;
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
fp32_values
[
j
]
=
static_cast
<
float
>
(
bf16_values
[
j
]);
...
...
@@ -212,25 +270,20 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
// Reduce amax and scale
EP_STATIC_ASSERT
(
kNumElemsPerRead
*
kWarpSize
/
kNumPerChannels
==
4
,
"Invalid vectorization"
);
amax
=
warp_reduce_max
<
16
>
(
amax
);
const
int
scale_offset
=
i
*
kNumElemsPerRead
/
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
;
amax
=
warp_reduce_max
<
kNumThreadPerGroup
>
(
amax
);
const
int
scale_offset
=
i
*
kNumElemsPerRead
/
QUANTIZATION_
GROUPSIZE
;
if
constexpr
(
k
UseInt8
)
{
if
constexpr
(
k
QuantGroupSize
==
0
)
{
// 记录每128个数的最大值
int8
_amaxf
[
scale_offset
]
=
fmaxf
(
amax
,
int8
_amaxf
[
scale_offset
]);
channel
_amaxf
[
scale_offset
]
=
fmaxf
(
amax
,
channel
_amaxf
[
scale_offset
]);
}
else
{
calculate_
fp8_scales
(
amax
,
scale
,
scale_inv
,
round_scale
);
if
(
lane_id
%
16
==
0
)
calculate_
quant8bit_scales
<
kQuantType
>
(
amax
,
scale
,
scale_inv
,
fp8_
round_scale
);
if
(
lane_id
%
kNumThreadPerGroup
==
0
)
rdma_x_scales
[
scale_offset
]
=
scale_inv
;
// Cast into send buffer
vec_t
int2_value
;
auto
fp8x2_values
=
reinterpret_cast
<
__hip_fp8x2_storage_t
*>
(
&
int2_value
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
j
+=
2
)
{
float2
fp32x2
=
{
fp32_values
[
j
]
*
scale
,
fp32_values
[
j
+
1
]
*
scale
};
fp8x2_values
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E4M3_FNUZ
);
}
pack_quantized_values
<
kQuantType
,
kNumElemsPerRead
>
(
fp32_values
,
scale
,
int2_value
);
rdma_x_vec
[
i
]
=
int2_value
;
}
}
else
{
...
...
@@ -240,24 +293,24 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
}
__syncthreads
();
if
constexpr
(
kUse
Int8
)
{
float
amax_per_token
=
kFP8Margin
;
if
constexpr
(
kUse
Quant8Bit
&&
kQuantGroupSize
==
0
)
{
float
amax_per_token
=
0.0
;
// 并行规约,计算每个token的amax
for
(
int
s
=
0
;
s
<
kNumScales
;
s
+=
kWarpSize
)
{
int
src_idx
=
s
+
lane_id
;
float
tmp_amaxf
=
0
;
if
(
src_idx
<
kNumScales
)
{
tmp_amaxf
=
int8
_amaxf
[
src_idx
];
tmp_amaxf
=
channel
_amaxf
[
src_idx
];
}
tmp_amaxf
=
warp_reduce_max
<
kWarpSize
>
(
tmp_amaxf
);
int8
_amaxf
[
0
]
=
fmaxf
(
tmp_amaxf
,
int8
_amaxf
[
0
]);
channel
_amaxf
[
0
]
=
fmaxf
(
tmp_amaxf
,
channel
_amaxf
[
0
]);
__syncthreads
();
}
amax_per_token
=
int8
_amaxf
[
0
];
amax_per_token
=
channel
_amaxf
[
0
];
// 根据最大值计算scale
float
scale
,
scale_inv
;
calculate_
int8_scales
(
amax_per_token
,
scale
,
scale_inv
);
calculate_
quant8bit_scales
<
kQuantType
>
(
amax_per_token
,
scale
,
scale_inv
,
fp8_round_scale
);
if
(
thread_id
==
0
)
{
rdma_x_scales
[
0
]
=
scale_inv
;
}
...
...
@@ -269,13 +322,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Cast into send buffer
vec_t
int2_value
;
auto
int8_values
=
reinterpret_cast
<
int8_t
*>
(
&
int2_value
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
auto
fp32_value
=
static_cast
<
float
>
(
bf16_values
[
j
]);
auto
fp32_value_scaled
=
fp32_value
*
scale
;
int8_values
[
j
]
=
static_cast
<
int8_t
>
(
nearbyintf
(
fp32_value_scaled
));
}
pack_quantized_values
<
kQuantType
,
kNumElemsPerRead
>
(
bf16_values
,
scale
,
int2_value
);
rdma_x_vec
[
i
]
=
int2_value
;
}
__syncthreads
();
...
...
@@ -297,8 +344,8 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
uint64_t
p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
if
(
p2p_ptr
==
0
)
{
// RDMA
internode_ll_putmem_nbi
((
void
*
)
dst_ptr
,
(
void
*
)
src_ptr
,
num_ranks
,
dst_rank
,
dst_expert_local_idx
,
num_bytes_per_msg
);
num_ranks
,
dst_rank
,
dst_expert_local_idx
,
num_bytes_per_msg
);
}
else
{
// 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
...
...
@@ -392,11 +439,10 @@ LOW_LATENCY_DISPATCH_RECV:
}
// 16 is the max possible number of warps in AMD GPUs
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
constexpr
int
num_sync_large_iteration
=
kMaxNumWarps
;
__shared__
volatile
int
sync_large_warp_counters
[
num_sync_large_iteration
];
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_id
;
i
<
num_sync_large_iteration
;
i
+=
blockDim
.
x
)
{
sync_large_warp_counters
[
i
]
=
0
;
}
...
...
@@ -416,7 +462,7 @@ LOW_LATENCY_DISPATCH_RECV:
const
auto
num_aligned_scales
=
ALIGN
<
int
>
(
kNumScales
,
sizeof
(
float
)
/
sizeof
(
scale_t
));
const
auto
recv_x_scales
=
static_cast
<
scale_t
*>
(
packed_recv_x_scales
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
(
k
UseInt8
?
1
:
num_aligned_scales
);
(
k
QuantGroupSize
==
0
?
1
:
num_aligned_scales
);
// Shared between sub-warps in warp groups
__shared__
int
shared_num_recv_tokens
[
kNumMaxWarpGroups
],
shared_recv_token_begin_idx
[
kNumMaxWarpGroups
];
...
...
@@ -461,14 +507,14 @@ LOW_LATENCY_DISPATCH_RECV:
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_int4
,
dst_data
,
src_data
,
ld_nc_global
,
st_na_global
);
// Copy scales
if
constexpr
(
kUse
FP8
)
{
if
constexpr
(
kUse
Quant8Bit
)
{
const
auto
src_scales
=
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
uint8_t
*>
(
src_data
)
+
hidden_bytes
);
const
auto
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
packed_t
)
/
sizeof
(
scale_t
));
const
auto
token_idx
=
recv_token_begin_idx
+
i
;
const
auto
token_stride
=
num_elems_per_pack
;
const
auto
pack_stride
=
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_elems_per_pack
;
if
constexpr
(
k
UseInt8
)
{
if
constexpr
(
k
QuantGroupSize
==
0
)
{
if
(
lane_id
==
0
)
{
recv_x_scales
[
token_idx
]
=
ld_nc_global
(
src_scales
);
}
...
...
@@ -500,12 +546,13 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_fp8
,
bool
round_scale
,
bool
use_ue8m0
,
bool
use_int8
,
int
quant_type
,
int
quant_group_size
,
bool
fp8_round_scale
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
)
{
constexpr
int
kMaxNumWarps
=
16
;
constexpr
int
kNumMaxTopK
=
11
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
const
int
num_warps_per_group
=
kMaxNumWarps
/
num_warp_groups
;
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
);
EP_HOST_ASSERT
(
kNumMaxTopK
+
1
<=
num_warp_groups
*
num_warps_per_group
);
...
...
@@ -518,33 +565,54 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto
atomic_finish_counter_per_expert
=
atomic_counter_per_expert
+
num_experts
;
EP_HOST_ASSERT
(
num_experts
*
sizeof
(
int
)
*
2
<=
NUM_WORKSPACE_BYTES
);
#define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = dispatch<false, false, false, hidden>; \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, false, hidden>; \
if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, false, hidden>; \
if (use_int8) \
dispatch_func = dispatch<true, false, true, hidden>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
packed_recv_count, \
global_atomic_counter, \
rdma_recv_x, rdma_recv_count, rdma_x, \
x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, round_scale, phases); } break
// 限制groupsize的大小
EP_HOST_ASSERT
(
quant_group_size
==
0
||
quant_group_size
==
128
);
/*量化类型枚举
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2)
*/
#define DISPATCH_LAUNCH_CASE(hidden) \
{ \
auto dispatch_func = dispatch<hidden, 0, 0, kMaxNumWarps>; \
if (quant_group_size == 0) { \
switch (quant_type) { \
case 1: dispatch_func = dispatch<hidden, 1, 0, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch<hidden, 2, 0, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch<hidden, 3, 0, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch<hidden, 4, 0, kMaxNumWarps>; break; \
} \
} else { \
switch (quant_type) { \
case 1: dispatch_func = dispatch<hidden, 1, 128, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch<hidden, 2, 128, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch<hidden, 3, 128, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch<hidden, 4, 128, kMaxNumWarps>; break; \
} \
} \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, packed_recv_count, \
global_atomic_counter, \
rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, fp8_round_scale, phases); \
} \
break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SWITCH_HIDDEN
(
DISPATCH_LAUNCH_CASE
);
#undef DISPATCH_LAUNCH_CASE
}
template
<
int
kHidden
,
int
kNumMaxTopk
>
template
<
int
kHidden
,
int
kNumMaxTopk
,
int
kMaxNumWarps
=
16
>
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
...
...
@@ -574,12 +642,11 @@ combine(void* combined_x,
const
size_t
hidden_bf16_int4
=
kHidden
/
kNumElemsPerInt4
;
// Message package
EP_STATIC_ASSERT
(
kHidden
%
FP8_
QUANTIZATION_
NUM_PER_CHANNEL
==
0
,
"Invalid hidden"
);
EP_STATIC_ASSERT
(
kHidden
%
QUANTIZATION_
GROUPSIZE
==
0
,
"Invalid hidden"
);
constexpr
size_t
num_bytes_per_slot
=
kHidden
*
sizeof
(
hip_bfloat16
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
// 16 is the max possible number of warps in AMD GPUs
constexpr
int
kMaxNumWarps
=
1024
/
kWarpSize
;
// 初始化用于细粒度warp间同步的计数器数组
__shared__
volatile
int
sync_large_warp_counters
[
kMaxNumWarps
];
if
(
threadIdx
.
x
==
0
){
#pragma unroll
...
...
@@ -755,9 +822,10 @@ void combine(void* combined_x,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
constexpr
int
kMaxNumWarps
=
16
;
constexpr
int
kNumMaxTopk
=
11
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
16
/
num_warp_groups
;
// num_warps_per_group>1, "Requires more than one warp per group"
const
int
num_warps_per_group
=
kMaxNumWarps
/
num_warp_groups
;
// num_warps_per_group>1, "Requires more than one warp per group"
const
int
num_recv_per_sm
=
ceil_div
(
num_combined_tokens
,
num_device_sms
);
EP_HOST_ASSERT
(
num_warp_groups
>
0
and
num_warps_per_group
>
0
and
num_recv_per_sm
>=
0
);
...
...
@@ -770,20 +838,20 @@ void combine(void* combined_x,
EP_HOST_ASSERT
(
sizeof
(
int
)
<=
NUM_WORKSPACE_BYTES
);
EP_HOST_ASSERT
(
num_topk
<=
kNumMaxTopk
);
#define COMBINE_LAUNCH_CASE(hidden)
{
\
auto combine_func = combine<hidden, kNumMaxTopk>;
\
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func,
\
combined_x,
\
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
global_atomic_counter, \
combine_wait_recv_cost_stats,
\
next
_clean, num_
next_clean_int,
\
atomic_clean_flag,
\
num_combined_tokens, hidden, num_topk,
\
num_max_dispatch_tokens_per_rank,
\
num_experts, rank, num_ranks,
\
num_warp_groups, num_warps_per_group, phases, zero_copy); }
break
#define COMBINE_LAUNCH_CASE(hidden)
\
{
\
auto combine_func = combine<hidden, kNumMaxTopk, kMaxNumWarps>;
\
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func,
\
combined_x,
rdma_recv_x, rdma_recv_flag, rdma_send_x,
\
x, topk_idx, topk_weights, src_info, layout_range,
\
global_atomic_counter,
combine_wait_recv_cost_stats,
\
next_clean, num_next_clean_int,
\
atomic
_clean
_flag
, num_
combined_tokens, hidden,
\
num_topk, num_max_dispatch_tokens_per_rank,
\
num_experts, rank, num_ranks,
\
num_warp_groups, num_warps_per_group, phases, zero_copy);
\
}
\
break
SETUP_LAUNCH_CONFIG
(
num_sms
,
num_warps
*
kWarpSize
,
stream
);
SWITCH_HIDDEN
(
COMBINE_LAUNCH_CASE
);
...
...
csrc/kernels/launch.cuh
View file @
d0fcf024
...
...
@@ -62,7 +62,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 8: \
case_macro(8); \
default: \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
EP_HOST_ASSERT(false and "Unsupported ranks");
\
} \
while (false)
...
...
@@ -83,7 +83,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 20: \
case_macro(20); \
default: \
EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
EP_HOST_ASSERT(false and "Unsupported RDMA ranks");
\
} \
while (false)
...
...
@@ -96,7 +96,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 8: \
case_macro(dtype, 8); \
default: \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
EP_HOST_ASSERT(false and "Unsupported ranks");
\
} \
while (false)
...
...
@@ -107,7 +107,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case HIP_R_32F: \
case_macro(float); \
default: \
EP_HOST_ASSERT(false and "Unsupported type"); \
EP_HOST_ASSERT(false and "Unsupported type");
\
} \
while (false)
...
...
@@ -121,7 +121,9 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case_macro(4096); \
case 7168: \
case_macro(7168); \
case 8192: \
case_macro(8192); \
default: \
EP_HOST_ASSERT(false and "Unsupported hidden"); \
EP_HOST_ASSERT(false and "Unsupported hidden");
\
} \
while (false)
csrc/kernels/utils.cuh
View file @
d0fcf024
...
...
@@ -341,10 +341,13 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return
*
reinterpret_cast
<
dtype_t
*>
(
recv_int_values
);
}
constexpr
float
kFP8Margin
=
1e-4
;
constexpr
float
kFinfoAmaxE4M3
=
240
.0
f
;
// 设置不同的量化方式的最大值与相反数
constexpr
float
kFinfoAmaxE4M3
=
448
.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
constexpr
float
kInt8Amax
=
127.0
f
;
constexpr
float
kFinfoAmaxE5M2
=
57344.0
f
;
constexpr
float
kFinfoAmaxInvE5M2
=
1.0
f
/
kFinfoAmaxE5M2
;
constexpr
float
kFinfoAmaxInt8
=
127.0
f
;
constexpr
float
kFinfoAmaxInvInt8
=
1.0
f
/
127.0
f
;
__forceinline__
__device__
float
fast_pow2
(
int
x
)
{
// We can ensure `-126 <= x and x <= 127`
...
...
@@ -359,22 +362,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return
exp_x
-
127
+
(
man_bits
!=
0
);
}
__forceinline__
__device__
void
calculate_fp8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
)
{
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
scale
=
kFinfoAmaxE4M3
/
amax
;
template
<
int
kQuantType
>
__forceinline__
__device__
void
calculate_quant8bit_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
,
bool
round_scale
=
0
)
{
amax
=
fmaxf
(
amax
,
1e-6
f
);
if
constexpr
(
kQuantType
==
1
)
{
// 使用 INT8 对称量化
scale_inv
=
kFinfoAmaxInvInt8
*
amax
;
scale
=
kFinfoAmaxInt8
/
amax
;
}
else
if
constexpr
(
kQuantType
==
2
||
kQuantType
==
3
)
{
// 使用 FP8_E4M3 或 FP8_UE8M0 非对称量化
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE4M3
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE4M3
;
scale
=
kFinfoAmaxE4M3
/
amax
;
}
}
else
if
constexpr
(
kQuantType
==
4
)
{
// 使用 FP8_E5M2 对称量化
if
(
round_scale
)
{
auto
exp_scale_inv
=
fast_log2_ceil
(
amax
*
kFinfoAmaxInvE5M2
);
scale
=
fast_pow2
(
-
exp_scale_inv
);
scale_inv
=
fast_pow2
(
exp_scale_inv
);
}
else
{
scale_inv
=
amax
*
kFinfoAmaxInvE5M2
;
scale
=
kFinfoAmaxE5M2
/
amax
;
}
}
}
__forceinline__
__device__
void
calculate_int8_scales
(
float
amax
,
float
&
scale
,
float
&
scale_inv
)
{
scale
=
kInt8Amax
/
amax
;
scale_inv
=
amax
/
kInt8Amax
;
}
template
<
bool
kIsUE8M0
,
typename
out_dtype_t
=
std
::
conditional_t
<
kIsUE8M0
,
uint8_t
,
float
>
>
__forceinline__
__device__
out_dtype_t
extract_required_scale_format
(
float
value
)
{
if
constexpr
(
kIsUE8M0
)
{
...
...
deep_ep/buffer.py
View file @
d0fcf024
...
...
@@ -841,7 +841,7 @@ class Buffer:
# noinspection PyTypeChecker
def
low_latency_dispatch
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
num_max_dispatch_tokens_per_rank
:
int
,
num_experts
:
int
,
use_fp8
:
bool
=
True
,
round_scale
:
bool
=
False
,
use_ue8m0
:
bool
=
False
,
use_int8
:
bool
=
False
,
quant_type
:
int
=
1
,
quant_group_size
:
int
=
0
,
fp8_round_scale
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
)
->
\
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
,
EventOverlap
,
Callable
]:
"""
...
...
@@ -858,10 +858,20 @@ class Buffer:
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
use_int8: whether to enable INT8 casting.
量化配置
quant_type: int 量化类型枚举
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
quant_group_size: int 量化分组大小
0 -> 逐token量化 (per-channel)
128-> 每 128 元素一组 (per-group) 量化
fp8_round_scale: bool 是否将 FP8 缩放因子取整为 2 的幂
true -> 缩放因子 = 2^k,硬件零开销
false -> 缩放因子 = 任意浮点,精度更高
异步配置
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
...
...
@@ -869,15 +879,25 @@ class Buffer:
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
- packed_recv_x:
存储接收到的 Token 数据,形状为
`[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`。
数据类型取决于 quant_type:
quant_type == 1 -> torch.int8
quant_type == 2 -> torch.float8_e4m3fnuz
quant_type == 3 -> torch.float8_e4m3fnuz (UE8M0 使用 E4M3 格式存储)
quant_type == 4 -> torch.float8_e5m2fnuz
其他 (非量化) -> torch.bfloat16
- packed_recv_x_scales (可选):
仅在 quant_type > 0 时存在,存储量化的 Scale 值。
形状为 `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, scales_col_size]`。
- 当 quant_type == 3 (UE8M0) 时:
scales_col_size = hidden // 512
数据类型为 torch.int (内部打包存储 4-bit scale)。
*注意:此模式强制要求 fp8_round_scale=True 且 group_size=128。
- 当 quant_type == 1, 2, 4 时:
scales_col_size = hidden // 128 (若使用 group_size) 或 1 (per-channel)。
数据类型为 torch.float32。
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
...
...
@@ -889,14 +909,15 @@ class Buffer:
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
,
event
,
hook
=
\
self
.
runtime
.
low_latency_dispatch
(
x
,
topk_idx
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
use_fp8
,
round_scale
,
use_ue8m0
,
use_int8
,
quant_type
,
quant_group_size
,
fp8_round_scale
,
async_finish
,
return_recv_hook
)
handle
=
(
packed_recv_src_info
,
packed_recv_layout_range
,
num_max_dispatch_tokens_per_rank
,
x
.
size
(
1
),
num_experts
)
tensors_to_record
=
(
x
,
topk_idx
,
packed_recv_x
,
packed_recv_x_scales
,
packed_recv_count
,
packed_recv_src_info
,
packed_recv_layout_range
)
return
(
packed_recv_x
,
packed_recv_x_scales
)
if
use_fp8
else
packed_recv_x
,
packed_recv_count
,
handle
,
\
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
recv_x
=
(
packed_recv_x
,
packed_recv_x_scales
)
if
(
quant_type
>
0
)
else
packed_recv_x
return
recv_x
,
packed_recv_count
,
handle
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
# noinspection PyTypeChecker
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
...
...
tests/test_low_latency_new.py
View file @
d0fcf024
...
...
@@ -6,7 +6,7 @@ from functools import partial
from
typing
import
Literal
,
Set
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_back
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_
pg_back
,
per_token_cast_pc_
back
def
simulate_failure_and_skip
(
rank
:
int
,
api
:
Literal
[
"dispatch"
,
"combine"
,
"clean"
],
expected_masked_ranks
:
Set
[
int
]):
...
...
@@ -58,7 +58,8 @@ def test_main(num_tokens: int,
x_list
=
[
x
]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
# x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
# x_list = [x_rand]
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
...
...
@@ -80,22 +81,35 @@ def test_main(num_tokens: int,
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
True
):
for
dispatch_use_fp8
in
(
False
,
True
):
for
round_scale
in
(
False
,
True
)
if
dispatch_use_fp8
else
(
False
,):
for
use_ue8m0
in
(
False
,
True
)
if
round_scale
else
(
False
,):
for
quant_type
in
(
0
,
1
,
2
,
3
,
):
# 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant
=
quant_type
>
0
for
fp8_round_scale
in
(
False
,
True
)
if
quant_type
!=
3
else
(
True
,
):
for
quant_group_size
in
(
0
,
128
,)
if
quant_type
>=
2
else
(
0
,
):
if
quant_type
==
3
and
(
fp8_round_scale
==
False
or
quant_group_size
==
0
):
continue
num_times
+=
1
for
_
in
range
((
num_times
%
2
)
+
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_
fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
quant_type
=
quant_type
,
fp8
_
round_scale
=
fp8_
round_scale
,
quant_group_size
=
quant_group_size
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
\
if
dispatch_use_fp8
else
packed_recv_x
.
clone
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_quant
else
packed_recv_x
if
not
dispatch_use_quant
:
simulated_gemm_x
=
packed_recv_x
.
clone
()
elif
quant_group_size
==
0
:
simulated_gemm_x
=
per_token_cast_pc_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
reshape
(
-
1
)).
view
(
packed_recv_x
[
0
].
shape
)
elif
quant_group_size
==
128
:
simulated_gemm_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
per_token_cast_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
if
not
dispatch_use_quant
:
recv_x
=
packed_recv_x
[
i
]
elif
quant_group_size
==
0
:
recv_x
=
per_token_cast_pc_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
elif
quant_group_size
==
128
:
recv_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
...
...
@@ -113,18 +127,25 @@ def test_main(num_tokens: int,
if
current_x
is
x
:
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
if
round_scale
:
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
dispatch_use_quant
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
for
j
in
range
(
num_ranks
):
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
if
not
round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
if
dispatch_use_fp8
:
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
if
quant_group_size
!=
0
:
if
fp8_round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
for
j
in
range
(
num_ranks
):
begin_idx
,
count
=
(
recv_layout_range
[
j
]
>>
32
).
item
(),
(
recv_layout_range
[
j
]
&
int_mask
).
item
()
if
not
fp8_round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
if
dispatch_use_quant
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
...
...
@@ -147,8 +168,8 @@ def test_main(num_tokens: int,
if
do_check
:
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
# if not round_scale:
assert
diff
<
(
9e-4
if
dispatch_use_
fp8
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_
fp8
=
{
dispatch_use_
fp8
}
, zero_copy=
{
zero_copy
}
'
# if not
fp8_
round_scale:
assert
diff
<
(
9e-4
if
dispatch_use_
quant
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_
quant
=
{
dispatch_use_
quant
}
, zero_copy=
{
zero_copy
}
'
hash_value
^=
hash_tensor
(
combined_x
)
# noinspection PyShadowingNames
...
...
@@ -162,7 +183,8 @@ def test_main(num_tokens: int,
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
True
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
quant_type
=
2
,
quant_group_size
=
0
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
...
...
tests/test_low_latency_new_int8.py
deleted
100644 → 0
View file @
81e56124
import
argparse
import
random
import
os
import
torch
import
torch.distributed
as
dist
from
functools
import
partial
from
typing
import
Literal
,
Set
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_back_int8
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
rank
:
int
,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
assert
num_experts
%
num_ranks
==
0
num_local_experts
=
num_experts
//
num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset
=
128
assert
num_ranks
-
rank_offset
<
257
,
'Too many ranks (exceeding test precision limit)'
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
(
rank
-
rank_offset
)
x
[:,
-
128
:]
=
torch
.
arange
(
num_tokens
,
device
=
'cuda'
).
to
(
torch
.
bfloat16
).
view
(
-
1
,
1
)
x_list
=
[
x
]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
topk_weights
=
torch
.
randn
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
# Randomly mask some positions
for
_
in
range
(
10
):
topk_idx
[
random
.
randint
(
0
,
num_tokens
-
1
),
random
.
randint
(
0
,
num_topk
-
1
)]
=
-
1
all_topk_idx
=
torch
.
empty
((
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
'cuda'
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
# For failure simulation and shrink testing
mask_status
=
torch
.
zeros
((
num_ranks
,),
dtype
=
torch
.
int
,
device
=
'cuda'
)
# Check dispatch correctness
do_check
=
True
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
):
for
dispatch_use_fp8
in
(
True
,
):
for
round_scale
in
(
False
,
):
for
use_ue8m0
in
(
False
,
):
num_times
+=
1
use_int8
=
True
for
_
in
range
(
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
dispatch_use_fp8
,
round_scale
=
round_scale
,
use_ue8m0
=
use_ue8m0
,
use_int8
=
use_int8
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_back_int8
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
1
)).
view
(
packed_recv_x
[
0
].
shape
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
per_token_cast_back_int8
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
int_mask
=
(
2
**
32
)
-
1
num_valid_tokens
=
recv_count
.
item
()
assert
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
(
),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
()
}
'
if
num_valid_tokens
==
0
:
continue
# Check received data
if
current_x
is
x
:
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
elif
use_int8
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.01
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
# for j in range(num_ranks):
# if (not round_scale):
# check_tmp1 = (recv_x_amin == j - rank_offset).sum().item()
# check_tmp2 = (all_topk_idx[j] == expert_id).sum().item()
# print(f'rank: {rank}, j: {j}, check_tmp1: {check_tmp1}, check_tmp2: {check_tmp2}, diff: {abs(check_tmp1 - check_tmp2)}')
# assert abs(check_tmp1 - check_tmp2) < 3
# assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if
dispatch_use_fp8
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
print
(
"dispatch int 8 pass"
)
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
mat_0
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_1
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_0
@
mat_1
hook
()
# noinspection PyShadowingNames
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
use_fp8
=
True
,
round_scale
=
False
,
use_ue8m0
=
False
,
use_int8
=
True
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
# Calculate bandwidth
scale_size
=
1
# hidden / 128
num_fp8_bytes
,
num_bf16_bytes
=
(
hidden
+
scale_size
*
4
+
16
),
hidden
*
2
num_dispatch_comm_bytes
,
num_combine_comm_bytes
=
0
,
0
for
i
in
range
(
num_tokens
):
num_selections
=
(
topk_idx
[
i
]
!=
-
1
).
sum
().
item
()
num_dispatch_comm_bytes
+=
num_fp8_bytes
*
num_selections
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
# Separate profiling
for
return_recv_hook
in
(
True
,
False
):
group
.
barrier
()
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
return_recv_hook
=
return_recv_hook
),
kernel_names
=
(
'dispatch'
,
'combine'
),
barrier_comm_profiling
=
True
,
suppress_kineto_output
=
True
,
num_kernels_per_period
=
2
if
return_recv_hook
else
1
)
if
not
return_recv_hook
:
print
(
f
'[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us | '
f
'Combine bandwidth:
{
num_combine_comm_bytes
/
1e9
/
combine_t
:.
2
f
}
GB/s, avg_t=
{
combine_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
else
:
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
[
0
]
*
1e6
:.
2
f
}
+
{
dispatch_t
[
1
]
*
1e6
:.
2
f
}
us | '
f
'Combine send/recv time:
{
combine_t
[
0
]
*
1e6
:.
2
f
}
+
{
combine_t
[
1
]
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
return
hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
:
argparse
.
Namespace
):
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
num_topk
,
num_experts
=
args
.
num_topk
,
args
.
num_experts
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
if
local_rank
==
0
:
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
buffer
=
deep_ep
.
Buffer
(
group
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
num_qps_per_rank
=
num_experts
//
num_ranks
,
allow_nvlink_for_low_latency_mode
=
not
args
.
disable_nvlink
,
explicitly_destroy
=
True
,
allow_mnnvl
=
args
.
allow_mnnvl
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
# Destroy the buffer runtime and communication group
buffer
.
destroy
()
dist
.
barrier
()
dist
.
destroy_process_group
()
if
__name__
==
'__main__'
:
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser
=
argparse
.
ArgumentParser
(
description
=
'Test low-latency EP kernels'
)
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
help
=
'Number of processes to spawn (default: 8)'
)
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
128
,
help
=
'Number of tokens (default: 128)'
)
parser
.
add_argument
(
'--hidden'
,
type
=
int
,
default
=
7168
,
help
=
'Hidden dimension size (default: 7168)'
)
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
help
=
'Number of top-k experts (default: 8)'
)
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
256
,
help
=
'Number of experts (default: 288)'
)
parser
.
add_argument
(
'--allow-mnnvl'
,
action
=
"store_true"
,
help
=
'Allow MNNVL for communication'
)
parser
.
add_argument
(
'--disable-nvlink'
,
action
=
'store_true'
,
help
=
'Whether to disable NVLink for testing'
)
parser
.
add_argument
(
"--pressure-test"
,
action
=
'store_true'
,
help
=
'Whether to do pressure test'
)
parser
.
add_argument
(
"--shrink-test"
,
action
=
'store_true'
,
help
=
'Whether to simulate failure and test shrink mode'
)
parser
.
add_argument
(
'--use-logfmt'
,
action
=
'store_true'
,
help
=
'Whether to test LogFMT combine'
)
args
=
parser
.
parse_args
()
num_processes
=
args
.
num_processes
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
tests/utils.py
View file @
d0fcf024
...
...
@@ -57,28 +57,22 @@ def per_token_cast_to_fp8(x: torch.Tensor):
return
(
x_padded_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
aligned_n
)[:,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_token_cast_back
(
x
_fp8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
if
x
_fp8
.
numel
()
==
0
:
return
x
_fp8
.
to
(
torch
.
bfloat16
)
def
per_token_cast_
pg_
back
(
x
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
if
x
.
numel
()
==
0
:
return
x
.
to
(
torch
.
bfloat16
)
assert
x
_fp8
.
dim
()
==
2
m
,
n
=
x
_fp8
.
shape
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
aligned_n
=
align_up
(
n
,
128
)
x_
fp8_
padded
=
torch
.
nn
.
functional
.
pad
(
x
_fp8
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_padded
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
if
x_scales
.
dtype
==
torch
.
int
:
x_scales
=
x_scales
.
view
(
dtype
=
torch
.
uint8
).
to
(
torch
.
int
)
<<
23
x_scales
=
x_scales
.
view
(
dtype
=
torch
.
float
)
x_fp32_padded
=
x_fp8_padded
.
to
(
torch
.
float32
).
view
(
x_fp8
.
size
(
0
),
-
1
,
128
)
x_scales
=
x_scales
.
view
(
x_fp8
.
size
(
0
),
-
1
,
1
)
return
(
x_fp32_padded
*
x_scales
).
view
(
x_fp8_padded
.
shape
).
to
(
torch
.
bfloat16
)[:,:
n
].
contiguous
()
x_fp32_padded
=
x_padded
.
to
(
torch
.
float32
).
view
(
x
.
size
(
0
),
-
1
,
128
)
x_scales
=
x_scales
.
view
(
x
.
size
(
0
),
-
1
,
1
)
return
(
x_fp32_padded
*
x_scales
).
view
(
x_padded
.
shape
).
to
(
torch
.
bfloat16
)[:,:
n
].
contiguous
()
def
per_token_cast_back_int8
(
x_int8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
"""
x_int8: [m, n] int8 tensor
x_scales: [m, n] 或 [m, 1] 或 [m, n/128] 量化 scale float
return: [m, n] bf16 tensor
"""
def
per_token_cast_pc_back
(
x_int8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
if
x_int8
.
numel
()
==
0
:
return
x_int8
.
to
(
torch
.
bfloat16
)
...
...
@@ -86,12 +80,9 @@ def per_token_cast_back_int8(x_int8: torch.Tensor, x_scales: torch.Tensor):
m
,
n
=
x_int8
.
shape
aligned_n
=
align_up
(
n
,
128
)
x_int8_padded
=
torch
.
nn
.
functional
.
pad
(
x_int8
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_int8_padded
=
torch
.
nn
.
functional
.
pad
(
x_int8
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_fp32_padded
=
x_int8_padded
.
to
(
torch
.
float32
).
view
(
m
,
-
1
,
1
)
x_scales
=
x_scales
.
view
(
m
,
-
1
,
1
).
to
(
torch
.
float32
)
# print(f'x_int8.shape: {x_int8.shape}, x_fp32_padded: {x_fp32_padded.shape}, x_scales: {x_scales.shape}')
x_deq
=
(
x_fp32_padded
*
x_scales
).
view
(
m
,
aligned_n
)
return
x_deq
[:,
:
n
].
to
(
torch
.
bfloat16
).
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment