Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
75b00cfb
Commit
75b00cfb
authored
Feb 09, 2026
by
lijian6
Browse files
Merge branch 'logfmt_master' into 'main'
低延迟combine支持10bit量化代码 See merge request dcutoolkit/deeplearing/DeepEP!21
parents
dbf9fd61
17d9c844
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
532 additions
and
60 deletions
+532
-60
csrc/config.hpp
csrc/config.hpp
+4
-4
csrc/deep_ep.cu
csrc/deep_ep.cu
+2
-0
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-0
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-0
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+226
-45
csrc/kernels/internode_ll_logfmt.cuh
csrc/kernels/internode_ll_logfmt.cuh
+264
-0
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+12
-3
deep_ep/buffer.py
deep_ep/buffer.py
+4
-2
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+18
-6
No files found.
csrc/config.hpp
View file @
75b00cfb
...
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
,
int
quant_group_size
=
0
)
{
const
int
num_scales
=
quant_group_size
==
0
?
4
:
hidden
/
QUANTIZATION_GROUPSIZE
;
// 应该是1,但是代码中为了满足int4对齐
const
int
num_scales
=
hidden
/
QUANTIZATION_GROUPSIZE
;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
...
...
@@ -148,11 +148,11 @@ struct LowLatencyLayout {
// transformation
EP_HOST_ASSERT
(
num_scales
*
sizeof
(
float
)
<=
static_cast
<
size_t
>
(
hidden
));
size_t
num_bytes_per_dispatch_msg
=
sizeof
(
int4
)
+
std
::
max
(
hidden
*
sizeof
(
hip_bfloat16
),
hidden
+
num_scales
*
sizeof
(
float
));
sizeof
(
int4
)
+
std
::
max
(
hidden
*
sizeof
(
hip_bfloat16
),
hidden
+
(
quant_group_size
==
0
?
4
:
num_scales
)
*
sizeof
(
float
));
// 应该是1,但是代码中为了满足int4对齐
// 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t
num_bytes_per_combine_msg
=
hidden
*
sizeof
(
hip_bfloat16
);
size_t
num_bytes_per_combine_msg
=
hidden
*
sizeof
(
hip_bfloat16
)
+
num_scales
*
sizeof
(
__hip_bfloat162
)
;
// Send buffer
size_t
dispatch_send_buffer_bytes
=
...
...
csrc/deep_ep.cu
View file @
75b00cfb
...
...
@@ -1413,6 +1413,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
std
::
optional
<
torch
::
Tensor
>&
combine_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_logfmt
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>&
out
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
...
...
@@ -1482,6 +1483,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_logfmt
,
workspace
,
num_device_sms
,
launch_stream
,
phases
,
zero_copy
);
};
...
...
csrc/deep_ep.hpp
View file @
75b00cfb
...
...
@@ -185,6 +185,7 @@ public:
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
std
::
optional
<
torch
::
Tensor
>&
combine_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_logfmt
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>&
out
=
std
::
nullopt
);
...
...
csrc/kernels/api.cuh
View file @
75b00cfb
...
...
@@ -159,6 +159,7 @@ void combine(void* combined_x,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_logfmt
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
);
...
...
csrc/kernels/internode_ll.cu
View file @
75b00cfb
...
...
@@ -9,6 +9,7 @@
#include "hip/hip_runtime.h"
#include "shmem_wrapper.cuh"
#include "internode_ll_logfmt.cuh"
namespace
deep_ep
{
...
...
@@ -612,7 +613,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#undef DISPATCH_LAUNCH_CASE
}
template
<
int
kHidden
,
int
kNumMaxTopk
,
int
kMaxNumWarps
=
16
>
template
<
bool
kUseLogFMT
,
int
kHidden
,
int
kNumMaxTopk
,
int
kMaxNumWarps
=
16
>
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
...
...
@@ -643,7 +644,24 @@ combine(void* combined_x,
// Message package
EP_STATIC_ASSERT
(
kHidden
%
QUANTIZATION_GROUPSIZE
==
0
,
"Invalid hidden"
);
constexpr
size_t
num_bytes_per_slot
=
kHidden
*
sizeof
(
hip_bfloat16
);
/////////////// LogFMT使用 ///////////////
constexpr
int
bSupportLogFMT
=
kUseLogFMT
&&
hidden_bf16_int4
%
(
kWarpSize
*
2
)
==
0
;
constexpr
int
kNumSendUnrolls
=
bSupportLogFMT
?
2
:
1
;
constexpr
int
kNumRecvUnrolls
=
bSupportLogFMT
?
2
:
1
;
constexpr
int
kNumMsgInt4ElemPerWarp
=
kWarpSize
*
kNumSendUnrolls
;
// 每个warp发送的int4元素数据量,即每个warp发送 kNumMsgInt4ElemPerWarp*sizeof(int4)/sizeof(bfloat16)
EP_STATIC_ASSERT
(
hidden_bf16_int4
%
(
kNumSendUnrolls
*
kWarpSize
)
==
0
,
"Invalid hidden"
);
EP_STATIC_ASSERT
(
kNumSendUnrolls
>=
kNumRecvUnrolls
,
"Invalid unroll factors"
);
constexpr
int
kNumDivisions
=
kHidden
/
QUANTIZATION_GROUPSIZE
;
constexpr
int
kNumMetaBytes
=
kNumDivisions
*
sizeof
(
__hip_bfloat162
);
// 用于记录数据的最大最小值
constexpr
int
kNumSendLogFMTBytes
=
kNumMsgInt4ElemPerWarp
*
sizeof
(
int4
);
constexpr
int
kNumStages
=
1
;
// 使用kNumStages>1,则需要的LDS大于64KB
constexpr
int
kLogFMTShmemSize
=
kMaxNumWarps
*
(
kNumStages
*
kNumSendLogFMTBytes
+
kNumMetaBytes
);
__shared__
uint8_t
smem_buffer
[
kLogFMTShmemSize
];
/////////////////////////////////////////////
constexpr
size_t
num_bytes_per_slot
=
kHidden
*
sizeof
(
hip_bfloat16
)
+
kNumMetaBytes
;
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
// 初始化用于细粒度warp间同步的计数器数组
...
...
@@ -683,6 +701,12 @@ combine(void* combined_x,
const
auto
local_src_info
=
src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
rdma_send_x_vec
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_slot
;
// 用于logfmt的LDS
auto
smem_ptr
=
smem_buffer
+
warp_id
*
(
kNumStages
*
kNumSendLogFMTBytes
+
kNumMetaBytes
);
// 存储logfmt的起始地址,并根据stage_idx进行索引块
auto
logfmt_buffers
=
PatternVisitor
([
=
](
const
int
&
i
)
{
return
reinterpret_cast
<
int4
*>
(
smem_ptr
+
i
*
kNumSendLogFMTBytes
);
});
// 存储logfmt的最大最小值
auto
meta_buffers
=
bSupportLogFMT
?
reinterpret_cast
<
__hip_bfloat162
*>
(
smem_ptr
+
kNumStages
*
kNumSendLogFMTBytes
)
:
nullptr
;
// Unpack layout
int
offset
,
num_tokens_to_send
;
...
...
@@ -699,20 +723,78 @@ combine(void* combined_x,
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
;
uint64_t
p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
if
(
p2p_ptr
==
0
)
{
// RDMA
const
auto
buf_int4_ptr
=
reinterpret_cast
<
int4
*>
(
buf_ptr
);
if
(
not
zero_copy
)
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
// 采用logfmt或者直接拷贝
uint64_t
dst_p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
int
num_send_bytes
=
hidden
*
sizeof
(
hip_bfloat16
);
if
(
not
zero_copy
or
dst_p2p_ptr
!=
0
)
{
// Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
const
auto
cpy_src_int4_ptr
=
zero_copy
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
x_int4
;
const
auto
cpy_dst_int4_ptr
=
dst_p2p_ptr
==
0
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
reinterpret_cast
<
int4
*>
(
dst_p2p_ptr
);
// 设置数据的真实偏移量
int
logfmt_offset_bytes
=
kNumMetaBytes
;
// 进入循环,逐步拷贝数据
constexpr
int
encode_num_warps
=
hidden_bf16_int4
/
kNumMsgInt4ElemPerWarp
;
for
(
int
iter_idx
=
0
;
iter_idx
<
encode_num_warps
;
++
iter_idx
)
{
int
num_logfmt_bytes
=
kNumMsgInt4ElemPerWarp
*
sizeof
(
int4
);
// 原始数据的warp级编译
int
warp_offset
=
iter_idx
*
kNumMsgInt4ElemPerWarp
;
if
constexpr
(
bSupportLogFMT
)
{
// 采用 寄存器->lds->global 的流水线方式, 量化后拷贝到buf_ptr中
const
int
&
stage_idx
=
iter_idx
%
kNumStages
;
// thread偏移
int
thread_offset
=
warp_offset
+
lane_id
*
kNumSendUnrolls
;
constexpr
int
kNumInt4PerDivision
=
128
/
kNumElemsPerInt4
;
// = 128/(sizeof(int4) / sizeof(hip_bfloat16)) = 128/(16/2)=16
num_logfmt_bytes
=
logfmt_encode
<
kNumSendUnrolls
>
(
cpy_src_int4_ptr
+
warp_offset
,
// 等同于 x_int4
logfmt_buffers
[
stage_idx
],
// NOTES: only the leader lane will write the result
(
thread_offset
%
kNumInt4PerDivision
==
0
)
?
meta_buffers
+
thread_offset
/
kNumInt4PerDivision
:
nullptr
,
lane_id
);
// 将量化后的数据写入
using
vec_type
=
uint32_t
;
UNROLLED_WARP_COPY_LL
(
2
,
lane_id
,
num_logfmt_bytes
/
sizeof
(
vec_type
),
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
cpy_dst_int4_ptr
)
+
logfmt_offset_bytes
),
reinterpret_cast
<
vec_type
*>
(
logfmt_buffers
[
stage_idx
]),
ld_nc_global
,
st_na_global
);
// 起始地址偏移
logfmt_offset_bytes
+=
num_logfmt_bytes
;
}
else
{
// 非量化数据的传输
UNROLLED_WARP_COPY_LL
(
2
,
lane_id
,
kNumMsgInt4ElemPerWarp
,
reinterpret_cast
<
int4
*>
(
cpy_dst_int4_ptr
+
warp_offset
),
reinterpret_cast
<
const
int4
*>
(
cpy_src_int4_ptr
+
warp_offset
),
ld_nc_global
,
st_na_global
);
}
syncwarp
();
}
// Store metadata (min/max values) for LogFMT
if
constexpr
(
bSupportLogFMT
)
{
// 最终设置节点间传输的字节数
num_send_bytes
=
logfmt_offset_bytes
;
using
vec_type
=
uint32_t
;
auto
meta_buffers_ptr
=
reinterpret_cast
<
vec_type
*>
(
meta_buffers
);
auto
cpy_dst_uint32_ptr
=
reinterpret_cast
<
vec_type
*>
(
cpy_dst_int4_ptr
);
for
(
int
j
=
lane_id
;
j
<
kNumMetaBytes
/
sizeof
(
vec_type
);
j
+=
kWarpSize
)
{
*
(
cpy_dst_uint32_ptr
+
j
)
=
meta_buffers_ptr
[
j
];
}
}
syncwarp
();
}
if
(
dst_p2p_ptr
==
0
)
{
internode_ll_putmem_nbi
((
void
*
)
dst_ptr
,
(
void
*
)
buf_ptr
,
num_ranks
,
dst_rank
,
local_expert_idx
,
hidden
*
sizeof
(
hip_bfloat16
));
}
else
{
// 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
x_int4
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
p2p_ptr
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
num_ranks
,
dst_rank
,
local_expert_idx
,
num_send_bytes
);
}
}
...
...
@@ -773,40 +855,136 @@ LOW_LATENCY_COMBINE_RECV:
// Reduce tokens with FP8 cast
// EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
EP_STATIC_ASSERT
(
kHidden
%
(
kWarpSize
*
kNumElemsPerInt4
)
==
0
,
"Invalid vectorization"
);
if
(
thread_id
<
hidden_bf16_int4
)
{
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_combined_tokens
;
token_idx
+=
num_sms
)
{
// Read top-k indices and weights
int
reg_topk_idx
[
kNumMaxTopk
];
float
reg_topk_weights
[
kNumMaxTopk
];
#pragma unroll
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
{
reg_topk_idx
[
i
]
=
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
i
));
reg_topk_weights
[
i
]
=
__ldg
(
topk_weights
+
token_idx
*
num_topk
+
i
);
}
float
combined_values
[
kNumElemsPerInt4
]
=
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
if
(
reg_topk_idx
[
i
]
>=
0
)
{
// Read from sources
auto
rdma_buffer_type
=
reinterpret_cast
<
const
int
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
(
reg_topk_idx
[
i
]
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
);
// Reduce
auto
x_vec
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_buffer_row
)
+
thread_id
);
const
auto
x_bf16
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
x_vec
);
// 计算需要多少个warp
constexpr
int
num_decode_warps
=
hidden_bf16_int4
/
(
kNumRecvUnrolls
*
kWarpSize
);
// 限制thread_id
if
(
warp_id
>=
num_decode_warps
)
{
return
;
}
// 每128个数据记录一个max/min值,即该数为总的max/min值数量
constexpr
int
kNumDivisionBytes
=
kNumDivisions
*
sizeof
(
float
);
// 每个warp内总的BF16值的数量
constexpr
int
kNumBF16PerWarpBytes
=
kWarpSize
*
kNumRecvUnrolls
*
sizeof
(
int4
);
constexpr
int
kNumLogFMTPerWarpBytes
=
kNumBF16PerWarpBytes
*
10
/
16
;
// 用于记录 max/min 值的 log 值
auto
log_amax_buffers
=
PatternVisitor
([
=
](
const
int
&
i
)
{
return
reinterpret_cast
<
float
*>
(
smem_buffer
+
i
*
kNumDivisionBytes
);
});
auto
log_amin_buffers
=
PatternVisitor
([
=
](
const
int
&
i
)
{
return
reinterpret_cast
<
float
*>
(
smem_buffer
+
kNumStages
*
kNumDivisionBytes
+
i
*
kNumDivisionBytes
);
});
auto
cast_info_buffers
=
PatternVisitor
([
=
](
const
int
&
i
)
{
return
reinterpret_cast
<
int
*>
(
smem_buffer
+
kNumStages
*
kNumDivisionBytes
*
2
+
i
*
kNumDivisionBytes
);
});
// 初始化 topk_idx 和 topk_weights
int
topk_idx_by_lane
=
-
1
;
float
topk_weights_by_lane
=
-
1
;
int
stage_idx
=
0
;
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_combined_tokens
;
token_idx
+=
num_sms
)
{
if
(
lane_id
<
num_topk
)
{
topk_idx_by_lane
=
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
lane_id
));
topk_weights_by_lane
=
__ldg
(
topk_weights
+
token_idx
*
num_topk
+
lane_id
);
}
float
combined_values
[
kNumElemsPerInt4
*
kNumRecvUnrolls
]
=
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
{
int
topk_idx_reg
=
shfl_sync
(
topk_idx_by_lane
,
i
);
if
(
topk_idx_reg
<
0
)
continue
;
const
auto
&
topk_weight_reg
=
shfl_sync
(
topk_weights_by_lane
,
i
);
// Read from sources
auto
rdma_buffer_type
=
reinterpret_cast
<
const
uint8_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
(
topk_idx_reg
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
if
constexpr
(
bSupportLogFMT
)
{
// 接收到的数据位置
const
uint8_t
*
data_buffer
=
rdma_buffer_type
+
kNumMetaBytes
;
// 读取max/min数据
if
(
warp_id
==
0
)
{
// 因为每个warp能处理数据量为 kWarpSize*sizeof(int4)/sizeof(bfloat16) * kNumSendUnrolls
// 即不考虑kNumSendUnrolls,一共 kWarpSize*sizeof(int4)/sizeof(bfloat16)/128 组, 代入参数 = kWarpSize / 16 个warp,nv上为2,dcu上为4
logfmt_check_amaxmin
<
kNumDivisions
/
(
kWarpSize
/
16
),
kNumSendUnrolls
,
kNumRecvUnrolls
>
(
/*meta_buffer*/
rdma_buffer_type
,
reinterpret_cast
<
int4
*>
(
log_amax_buffers
[
stage_idx
]),
reinterpret_cast
<
int4
*>
(
log_amin_buffers
[
stage_idx
]),
cast_info_buffers
[
stage_idx
],
lane_id
);
}
__syncthreads
();
// 获取cast_info_buffers
const
auto
&
info
=
cast_info_buffers
[
stage_idx
][
warp_id
];
bool
enable_cast
=
info
&
1
;
int
num_casted_prefix
=
info
>>
1
;
// 可用的
// 计算偏移(与TMA版本逻辑一致)
int
warp_offset
=
kNumLogFMTPerWarpBytes
*
num_casted_prefix
+
kNumBF16PerWarpBytes
*
(
warp_id
-
num_casted_prefix
);
int
lane_offset
=
(
enable_cast
?
kNumLogFMTPerWarpBytes
:
kNumBF16PerWarpBytes
)
/
kWarpSize
*
lane_id
;
// 使用临时缓冲区进行归约
const
uint8_t
*
thread_data_ptr
=
data_buffer
+
warp_offset
+
lane_offset
;
/**
一共有kNumDivisions个max/min数据对,读取时每warp默认处理256bit的max/min,所以logfmt_check_amaxmin的kNumLanes设置为 kNumDivisions/2
保存数据时每个log_amax_buffers为float2数据类型,保存总的warpkNumDivisions / 2
实际保存数据时,每个warp保存的实际数据个数为 kWarpSize*kNumRecvUnrolls*sizeof(int4)/sizeof(hip_bfloat16)
实际每个warp读取的max/min的 warp_idx=kWarpSize*kNumRecvUnrolls*sizeof(int4)/sizeof(hip_bfloat16) / 128 = kNumRecvUnrolls * 2
具体的lane_id处理的数据量为 warp_idx / kWarpSize
*/
int
log_amaxmin_per_warp
=
kNumRecvUnrolls
*
kWarpSize
*
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
)
/
QUANTIZATION_GROUPSIZE
;
int
division_idx
=
warp_id
*
log_amaxmin_per_warp
+
lane_id
*
log_amaxmin_per_warp
/
kWarpSize
;
// 反量化
decode_and_accumulate
<
kNumRecvUnrolls
>
(
reinterpret_cast
<
const
uint32_t
*>
(
thread_data_ptr
),
// 直接使用全局内存地址
combined_values
,
log_amax_buffers
[
stage_idx
][
division_idx
],
log_amin_buffers
[
stage_idx
][
division_idx
],
enable_cast
,
topk_weight_reg
);
}
else
{
// 接收到的数据位置
const
uint8_t
*
data_buffer
=
rdma_buffer_type
;
// 计算偏移
int
warp_offset
=
kNumBF16PerWarpBytes
*
warp_id
;
int
lane_offset
=
kNumBF16PerWarpBytes
/
kWarpSize
*
lane_id
;
// 使用临时缓冲区进行归约
const
uint8_t
*
thread_data_ptr
=
data_buffer
+
warp_offset
+
lane_offset
;
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerInt4
;
++
j
)
combined_values
[
j
]
+=
static_cast
<
float
>
(
x_bf16
[
j
])
*
reg_topk_weights
[
i
];
for
(
int
j
=
0
;
j
<
kNumRecvUnrolls
;
++
j
)
{
auto
tmp_rdma_value
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
thread_data_ptr
)
+
j
);
const
auto
x_bf16
=
reinterpret_cast
<
const
hip_bfloat16
*>
(
&
tmp_rdma_value
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kNumElemsPerInt4
;
++
k
)
{
int
combined_idx
=
j
*
kNumElemsPerInt4
+
k
;
combined_values
[
combined_idx
]
+=
static_cast
<
float
>
(
x_bf16
[
k
])
*
topk_weight_reg
;
}
}
}
}
// Write results
int4
&
combined_int4
=
*
reinterpret_cast
<
int4
*>
(
combined_values
);
auto
combined_bf16
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
combined_values
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerInt4
;
++
j
)
combined_bf16
[
j
]
=
static_cast
<
hip_bfloat16
>
(
combined_values
[
j
]);
(
reinterpret_cast
<
int4
*>
(
combined_x
)
+
token_idx
*
hidden_bf16_int4
)[
thread_id
]
=
combined_int4
;
// Write results,kNumRecvUnrolls==2时则写256bit的数
int4
combined_int4
[
kNumRecvUnrolls
];
auto
combined_bf16
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
combined_int4
[
0
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerInt4
*
kNumRecvUnrolls
;
++
j
)
{
combined_bf16
[
j
]
=
static_cast
<
hip_bfloat16
>
(
combined_values
[
j
]);
}
for
(
int
j
=
0
;
j
<
kNumRecvUnrolls
;
++
j
)
{
(
reinterpret_cast
<
int4
*>
(
combined_x
)
+
token_idx
*
hidden_bf16_int4
+
warp_id
*
kWarpSize
*
kNumRecvUnrolls
)[
lane_id
*
kNumRecvUnrolls
+
j
]
=
combined_int4
[
j
];
}
}
}
...
...
@@ -820,6 +998,7 @@ void combine(void* combined_x,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_logfmt
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
constexpr
int
kMaxNumWarps
=
16
;
...
...
@@ -840,7 +1019,9 @@ void combine(void* combined_x,
#define COMBINE_LAUNCH_CASE(hidden) \
{ \
auto combine_func = combine<hidden, kNumMaxTopk, kMaxNumWarps>; \
auto combine_func = use_logfmt ? \
combine<true, hidden, kNumMaxTopk, kMaxNumWarps> : \
combine<false, hidden, kNumMaxTopk, kMaxNumWarps>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
...
...
csrc/kernels/internode_ll_logfmt.cuh
0 → 100644
View file @
75b00cfb
#pragma once
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "buffer.cuh"
#include "utils.cuh"
#include <iostream>
#include "hip/hip_runtime.h"
#include "shmem_wrapper.cuh"
namespace
deep_ep
{
namespace
internode_ll
{
template
<
int
kNumSendUnrolls
>
__forceinline__
__device__
int
logfmt_encode
(
const
int4
*
cpy_src_int4_ptr
,
int4
*
dst_buffer
,
__hip_bfloat162
*
shared_amaxmin
,
const
int
&
lane_id
)
{
EP_STATIC_ASSERT
(
kNumSendUnrolls
==
2
,
"kNumSendUnrolls == 2 only"
);
constexpr
int
kNumElemsPerInt4
=
sizeof
(
int4
)
/
sizeof
(
__hip_bfloat16
);
// 8
constexpr
float
kLogThreshold
=
0
;
constexpr
float
kMinClip
=
32
;
// `== log_2(2 ^ (2 ^ 5))`
constexpr
int
kNumBits
=
10
;
constexpr
int
kNumValues
=
1
<<
(
kNumBits
-
1
);
// = 512
constexpr
int
kSendValueBytes
=
kNumSendUnrolls
*
sizeof
(
int4
);
//=2*16=32
constexpr
int
kNumElementPerInt4
=
sizeof
(
int4
)
/
sizeof
(
uint32_t
);
int4
int4_values
[
kNumSendUnrolls
];
const
auto
&
uint32_values
=
reinterpret_cast
<
uint32_t
*>
(
int4_values
);
const
auto
&
bf162_values
=
reinterpret_cast
<
__hip_bfloat162
*>
(
int4_values
);
// Calculate lane offset
const
auto
&
ld_buffer
=
cpy_src_int4_ptr
+
lane_id
*
kNumSendUnrolls
;
// Local log amax
auto
bf162_amax
=
__hip_bfloat162
(
HIPRT_ZERO_BF16
,
HIPRT_ZERO_BF16
);
auto
bf162_amin
=
__hip_bfloat162
(
HIPRT_INF_BF16
,
HIPRT_INF_BF16
);
uint32_t
local_signs
=
0
;
#pragma unroll
for
(
int
v
=
0
;
v
<
kNumSendUnrolls
;
++
v
)
{
int4
ld_int4_value
=
ld_nc_global
(
ld_buffer
+
v
);
// 向量化读取
uint32_t
*
ld_u32_ptr
=
reinterpret_cast
<
uint32_t
*>
(
&
ld_int4_value
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kNumElementPerInt4
;
++
k
)
{
// 也是kNumSendUnrolls * kNumElemsPerInt4 / 2
// TODO: eliminate bank conflicts
uint32_t
ld_u32_value
=
ld_u32_ptr
[
k
];
int
k_offset
=
v
*
kNumElementPerInt4
+
k
;
// 提取符号位: 每个bfloat16的最高位是符号位
local_signs
|=
((
ld_u32_value
>>
15
)
&
1
)
<<
(
k_offset
*
2
);
local_signs
|=
((
ld_u32_value
>>
31
)
&
1
)
<<
(
k_offset
*
2
+
1
);
// 清除符号位,保留幅值
ld_u32_value
&=
0x7fff7fff
;
auto
ld_bf16_value
=
*
reinterpret_cast
<
__hip_bfloat162
*>
(
&
ld_u32_value
);
bf162_amax
=
__hmax2
(
bf162_amax
,
ld_bf16_value
);
bf162_amin
=
__hmin2
(
bf162_amin
,
ld_bf16_value
);
uint32_values
[
k_offset
]
=
ld_u32_value
;
}
}
// Reduce per 128 channels
// TODO: figure out how hardware do 2-byte min/max
auto
amax
=
__builtin_fmaxf
(
static_cast
<
float
>
(
bf162_amax
.
x
),
static_cast
<
float
>
(
bf162_amax
.
y
));
auto
amin
=
__builtin_fminf
(
static_cast
<
float
>
(
bf162_amin
.
x
),
static_cast
<
float
>
(
bf162_amin
.
y
));
// 即每128个值进行一次reduce
constexpr
static
int
kNumLanesToReduce
=
128
*
sizeof
(
__hip_bfloat16
)
/
kSendValueBytes
;
// =128*2 / (kNumSendUnrolls * sizeof(int4)) = 8
amax
=
warp_reduce_max
<
kNumLanesToReduce
>
(
amax
);
amin
=
warp_reduce_min
<
kNumLanesToReduce
>
(
amin
);
// Write min/max into the shared memory
if
(
shared_amaxmin
!=
nullptr
)
{
*
shared_amaxmin
=
__hip_bfloat162
(
amax
,
amin
);
}
syncwarp
();
// Calculate log amin/amax float
const
auto
&
log_amax
=
__builtin_log2f
(
amax
);
const
auto
&
log_amin
=
__builtin_fmaxf
(
__builtin_log2f
(
amin
),
log_amax
-
kMinClip
);
// 在组内广播enable_cast结果
const
bool
&
enable_cast
=
warp_reduce_and
<
kNumLanesToReduce
,
true
>
(
log_amax
<
kLogThreshold
and
log_amin
<
log_amax
);
// Case into LogFMT-10 if satisfied
if
(
enable_cast
)
{
constexpr
int
dst_buffer_step
=
kSendValueBytes
*
10
/
16
;
const
auto
&
st_buffer
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
dst_buffer
)
+
lane_id
*
dst_buffer_step
);
uint32_t
st_u32_values
[
dst_buffer_step
/
sizeof
(
uint32_t
)];
// = 5
// 计算10bit数据的两个相邻数值的差值
const
auto
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
const
auto
step_inv
=
1.0
f
/
step
;
// 计算舍入值
const
auto
rounding
=
2.0
f
-
__builtin_log2f
((
1.0
f
+
__builtin_exp2f
(
step
))
*
0.5
f
)
*
step_inv
;
const
auto
fused_rounding
=
rounding
-
log_amin
*
step_inv
;
// 用于存储编码后的值
uint32_t
encoded
[
kNumElemsPerInt4
*
2
];
// 展开循环,处理数据打包
{
// 将int4值(128bit)转换为 bfloat162
#pragma unroll
for
(
int
k
=
0
;
k
<
kNumElemsPerInt4
;
++
k
)
{
// 8
// 将 bfloat162 转换为 float2
const
auto
&
fp162_fvalue
=
__bfloat1622float2
(
bf162_values
[
k
]);
/*
实际进行压缩的公式为:
q = clamp( round( (log2(abs(x)) - log_min) / (log_max - log_min) * (K - 2) + 0.5 ), 0, K - 1)
其中:
x: 输入的浮点数
q: 输出的整数,表示压缩后的值
log_min: 输入中最小值的log2值
log_max: 输入中最大值的log2值
K: 压缩后的整数的最大值(即,K为2的幂)
*/
// 对 float 值进行编码
encoded
[
k
*
2
+
0
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_log2f
(
fp162_fvalue
.
x
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
1
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_log2f
(
fp162_fvalue
.
y
)
*
step_inv
+
fused_rounding
,
0
));
}
// 批量打包编码后的值到 st_buffer
st_u32_values
[
0
]
=
(
encoded
[
0
]
>>
0
)
|
(
encoded
[
1
]
<<
9
)
|
(
encoded
[
2
]
<<
18
)
|
(
encoded
[
3
]
<<
27
);
st_u32_values
[
1
]
=
(
encoded
[
3
]
>>
5
)
|
(
encoded
[
4
]
<<
4
)
|
(
encoded
[
5
]
<<
13
)
|
(
encoded
[
6
]
<<
22
)
|
(
encoded
[
7
]
<<
31
);
st_u32_values
[
2
]
=
(
encoded
[
7
]
>>
1
)
|
(
encoded
[
8
]
<<
8
)
|
(
encoded
[
9
]
<<
17
)
|
(
encoded
[
10
]
<<
26
);
st_u32_values
[
3
]
=
(
encoded
[
10
]
>>
6
)
|
(
encoded
[
11
]
<<
3
)
|
(
encoded
[
12
]
<<
12
)
|
(
encoded
[
13
]
<<
21
)
|
(
encoded
[
14
]
<<
30
);
st_u32_values
[
4
]
=
(
encoded
[
14
]
>>
2
)
|
(
encoded
[
15
]
<<
7
)
|
(
local_signs
<<
16
);
}
// 保存160bit的数据到st_buffer
st_buffer
[
0
]
=
st_u32_values
[
0
];
*
(
reinterpret_cast
<
int4
*>
(
st_buffer
+
1
))
=
*
(
reinterpret_cast
<
int4
*>
(
st_u32_values
+
1
));
}
else
{
// 准备收发数据
using
vec_type
=
int4
;
const
auto
&
ld_buffer_vec
=
reinterpret_cast
<
const
vec_type
*>
(
ld_buffer
);
auto
st_buffer_vec
=
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
dst_buffer
)
+
lane_id
*
kSendValueBytes
);
constexpr
int
kLoopIter
=
kSendValueBytes
/
sizeof
(
vec_type
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kLoopIter
;
++
k
)
{
st_buffer_vec
[
k
]
=
ld_nc_global
(
ld_buffer_vec
+
k
);
}
}
// 确保 warp 内的所有线程都完成打包操作
syncwarp
();
// 计算量化成功和失败时的数据量
constexpr
int
unable_cast_num_bytes
=
kWarpSize
*
kSendValueBytes
;
// = 64*2*16 = 2048
constexpr
int
enable_cast_num_bytes
=
unable_cast_num_bytes
*
10
/
16
;
// = 2048/16*10=1280
// Return TMA copy bytes
return
enable_cast
?
enable_cast_num_bytes
:
unable_cast_num_bytes
;
}
template
<
int
kNumLanes
,
int
kNumSendUnrolls
,
int
kNumRecvUnrolls
>
__forceinline__
__device__
void
logfmt_check_amaxmin
(
const
uint8_t
*
meta_buffer
,
int4
*
shared_log_amax
,
int4
*
shared_log_amin
,
int
*
shared_cast_info
,
const
int
lane_id
)
{
// 定义log阈值和最小剪切值
constexpr
float
kLogThreshold
=
0
;
constexpr
float
kMinClip
=
32
;
// `== log_2(2 ^ (2 ^ 5))`
constexpr
int
kNumQuantGroupsPerWarp
=
kWarpSize
/
16
;
using
log_vec_type
=
int4
;
EP_STATIC_ASSERT
(
sizeof
(
log_vec_type
)
/
sizeof
(
__hip_bfloat162
)
==
kNumQuantGroupsPerWarp
,
"kNumQuantGroupsPerWarp == sizeof(log_vec_type) only"
);
// 初始化类型转换启用标志
bool
enable_cast
=
true
;
// 如果 lane_id 小于 kNumLanes,则进行计算
if
(
lane_id
<
kNumLanes
)
{
// 从 meta_buffer 中读取 amaxmin2 值
auto
amaxmin4
=
reinterpret_cast
<
const
log_vec_type
*>
(
meta_buffer
)[
lane_id
];
const
auto
&
bf162_amaxmin
=
reinterpret_cast
<
__hip_bfloat162
*>
(
&
amaxmin4
);
// 定义 log_amax 和 log_amin 数组
float
log_amax
[
kNumQuantGroupsPerWarp
],
log_amin
[
kNumQuantGroupsPerWarp
];
// 展开循环,计算 log_amax 和 log_amin
#pragma unroll
for
(
int
i
=
0
;
i
<
kNumQuantGroupsPerWarp
;
++
i
)
{
// sizeof(uint64_t) / sizeof(__hip_bfloat162) = 2
auto
amax
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
x
);
auto
amin
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
y
);
log_amax
[
i
]
=
__builtin_log2f
(
amax
);
log_amin
[
i
]
=
amin
==
0
?
log_amax
[
i
]
-
kMinClip
:
__builtin_fmaxf
(
__builtin_log2f
(
amin
),
log_amax
[
i
]
-
kMinClip
);
enable_cast
=
enable_cast
and
log_amax
[
i
]
<
kLogThreshold
and
log_amin
[
i
]
<
log_amax
[
i
];
}
// 将计算结果存储到 shared_log_amax 和 shared_log_amin 中
int4
log_amax_int4
=
*
reinterpret_cast
<
int4
*>
(
log_amax
);
int4
log_amin_int4
=
*
reinterpret_cast
<
int4
*>
(
log_amin
);
shared_log_amax
[
lane_id
]
=
log_amax_int4
;
shared_log_amin
[
lane_id
]
=
log_amin_int4
;
}
// 计算 casted 值。根据当前线程是否启用了类型转换,计算它所属的组的索引
const
auto
&
casted
=
warp_reduce_and
<
kNumSendUnrolls
>
(
enable_cast
)
?
1u
<<
(
lane_id
/
kNumRecvUnrolls
)
:
0u
;
// 计算 num_casted_prefix 值。计算当前线程之前有多少个线程启用了类型转换。
const
auto
&
num_casted_prefix
=
__popc
(
warp_reduce_or
<
kNumRecvUnrolls
,
true
>
(
casted
)
&
((
1u
<<
(
lane_id
/
kNumRecvUnrolls
))
-
1
));
// 如果 lane_id 小于 kNumLanes 且 lane_id 是 kNumRecvUnrolls 的倍数,则更新 shared_cast_info
if
(
lane_id
<
kNumLanes
and
lane_id
%
kNumRecvUnrolls
==
0
)
{
// 最低1位保存casted结果,最高31位保存num_casted_prefix值
shared_cast_info
[
lane_id
/
kNumRecvUnrolls
]
=
(
num_casted_prefix
<<
1
)
|
(
casted
?
1u
:
0u
);
}
}
template
<
int
kNumRecvUnrolls
>
__forceinline__
__device__
void
decode_and_accumulate
(
const
uint32_t
*
ld_buffer
,
float
*
accum
,
const
float
&
log_amax
,
const
float
&
log_amin
,
const
bool
&
enable_cast
,
const
float
&
weight
)
{
EP_STATIC_ASSERT
(
kNumRecvUnrolls
==
2
,
"kNumRecvUnrolls == 2 only"
);
if
(
enable_cast
)
{
constexpr
int
kNumBits
=
10
;
constexpr
int
kNumValues
=
1
<<
(
kNumBits
-
1
);
const
auto
&
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
auto
decode
=
[
=
](
const
uint32_t
&
encoded
,
const
uint32_t
&
sign
)
{
const
auto
decoded
=
encoded
==
0
?
.0
f
:
__builtin_exp2f
((
encoded
-
1
)
*
step
+
log_amin
);
return
sign
?
-
decoded
:
decoded
;
};
uint32_t
concat
[
6
];
concat
[
0
]
=
ld_buffer
[
0
];
#pragma unroll
for
(
int
k
=
1
;
k
<
5
;
++
k
)
concat
[
k
]
=
(
ld_buffer
[
k
-
1
]
>>
(
32
-
k
*
5
))
|
(
ld_buffer
[
k
]
<<
(
k
*
5
));
concat
[
5
]
=
ld_buffer
[
4
]
>>
7
;
const
uint32_t
&
local_signs
=
ld_buffer
[
4
]
>>
16
;
#pragma unroll
for
(
int
k
=
0
;
k
<
5
;
++
k
)
{
accum
[
k
*
3
+
0
]
+=
decode
((
concat
[
k
]
>>
0
)
&
0x1ff
,
(
local_signs
>>
(
k
*
3
+
0
))
&
1
)
*
weight
;
accum
[
k
*
3
+
1
]
+=
decode
((
concat
[
k
]
>>
9
)
&
0x1ff
,
(
local_signs
>>
(
k
*
3
+
1
))
&
1
)
*
weight
;
accum
[
k
*
3
+
2
]
+=
decode
((
concat
[
k
]
>>
18
)
&
0x1ff
,
(
local_signs
>>
(
k
*
3
+
2
))
&
1
)
*
weight
;
}
accum
[
15
]
+=
decode
(
concat
[
5
]
&
0x1ff
,
(
local_signs
>>
15
)
&
1
)
*
weight
;
}
else
{
constexpr
int
kLoopIter
=
kNumRecvUnrolls
*
sizeof
(
int4
)
/
sizeof
(
uint32_t
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kLoopIter
;
++
k
)
{
auto
bf16_pack
=
*
reinterpret_cast
<
const
__hip_bfloat162
*>
(
ld_buffer
+
k
);
accum
[
k
*
2
+
0
]
+=
static_cast
<
float
>
(
bf16_pack
.
x
)
*
weight
;
accum
[
k
*
2
+
1
]
+=
static_cast
<
float
>
(
bf16_pack
.
y
)
*
weight
;
}
}
}
}
// namespace internode_ll
}
// namespace deep_ep
csrc/kernels/utils.cuh
View file @
75b00cfb
...
...
@@ -72,9 +72,9 @@ __device__ __forceinline__ T shfl_xor(const T val, int laneMask, int width = kWa
return
__shfl_xor
(
val
,
laneMask
,
width
);
}
__device__
__forceinline__
int
shfl_sync
(
const
int
val
,
int
srcLane
=
0
,
int
width
=
kWarpSize
,
uint64_t
shfl_sync_mask
=
kFullWarpMask
)
{
// Let compiler deduce type
template
<
typename
T
>
__device__
__forceinline__
T
shfl_sync
(
const
T
val
,
int
srcLane
=
0
,
int
width
=
kWarpSize
,
uint64_t
shfl_sync_mask
=
kFullWarpMask
)
{
// Let compiler deduce type
return
__shfl
(
val
,
srcLane
,
width
);
}
...
...
@@ -115,6 +115,15 @@ template <> struct VecInt<16> {
using
vec_t
=
native_int4
;
};
template
<
typename
FuncT
>
struct
PatternVisitor
{
FuncT
func
;
__device__
__host__
explicit
PatternVisitor
(
FuncT
&&
func
)
:
func
(
std
::
forward
<
FuncT
>
(
func
))
{}
__device__
__host__
auto
operator
[](
const
uint32_t
&
i
)
{
return
func
(
i
);
}
};
__device__
__forceinline__
void
trap
()
{
abort
();
}
...
...
deep_ep/buffer.py
View file @
75b00cfb
...
...
@@ -923,7 +923,8 @@ class Buffer:
# noinspection PyTypeChecker
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
handle
:
tuple
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
handle
:
tuple
,
use_logfmt
:
bool
=
False
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
combine_wait_recv_cost_stats
:
Optional
[
torch
.
Tensor
]
=
None
)
->
\
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
...
...
@@ -944,6 +945,7 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
...
...
@@ -964,7 +966,7 @@ class Buffer:
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combine_wait_recv_cost_stats
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
zero_copy
,
async_finish
,
return_recv_hook
,
out
)
use_logfmt
,
zero_copy
,
async_finish
,
return_recv_hook
,
out
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
return
combined_x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
...
...
tests/test_low_latency_new.py
View file @
75b00cfb
...
...
@@ -42,6 +42,7 @@ def test_main(num_tokens: int,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
use_logfmt
:
bool
=
False
,
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
...
...
@@ -56,10 +57,12 @@ def test_main(num_tokens: int,
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
(
rank
-
rank_offset
)
x
[:,
-
128
:]
=
torch
.
arange
(
num_tokens
,
device
=
'cuda'
).
to
(
torch
.
bfloat16
).
view
(
-
1
,
1
)
x_list
=
[
x
]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
# x_list = [x_rand]
for
_
in
range
(
4
if
use_logfmt
else
0
):
# NOTES: make more LogFMT casts and also with some BF16
x_list
.
append
(
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
0.5
*
random
.
random
())
# NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list
.
append
(
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
0.1
)
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
...
...
@@ -79,7 +82,7 @@ def test_main(num_tokens: int,
# Check dispatch correctness
do_check
=
True
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
x_i
,
current_x
in
enumerate
(
x_list
)
:
for
return_recv_hook
in
(
False
,
True
):
for
quant_type
in
(
0
,
1
,
2
,
3
,
):
# 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant
=
quant_type
>
0
...
...
@@ -152,7 +155,7 @@ def test_main(num_tokens: int,
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
# Check combine correctness
for
zero_copy
in
(
False
,
True
):
for
zero_copy
in
(
False
,
)
if
use_logfmt
else
(
False
,
True
,
):
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
...
...
@@ -160,6 +163,7 @@ def test_main(num_tokens: int,
topk_idx
,
topk_weights
,
handle
,
use_logfmt
=
use_logfmt
,
async_finish
=
not
return_recv_hook
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
...
...
@@ -172,6 +176,10 @@ def test_main(num_tokens: int,
assert
diff
<
(
9e-4
if
dispatch_use_quant
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_quant=
{
dispatch_use_quant
}
, zero_copy=
{
zero_copy
}
'
hash_value
^=
hash_tensor
(
combined_x
)
if
rank
==
0
:
print
(
f
"data:
{
x_i
}
, return_recv_hook:
{
return_recv_hook
}
, quant_type:
{
quant_type
}
, "
,
f
"fp8_round_scale:
{
fp8_round_scale
}
, quant_group_size:
{
quant_group_size
}
pass"
)
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
mat_0
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
...
...
@@ -190,6 +198,7 @@ def test_main(num_tokens: int,
topk_idx
,
topk_weights
,
handle
,
use_logfmt
=
use_logfmt
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
...
...
@@ -251,6 +260,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_ranks
,
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
1
)
do_pressure_test
=
args
.
pressure_test
...
...
@@ -265,6 +275,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_ranks
,
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
seed
)
for
_
in
range
(
20
):
assert
test_main
(
num_tokens
,
...
...
@@ -275,6 +286,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_ranks
,
group
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
# Destroy the buffer runtime and communication group
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment