Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
75b00cfb
Commit
75b00cfb
authored
Feb 09, 2026
by
lijian6
Browse files
Merge branch 'logfmt_master' into 'main'
低延迟combine支持10bit量化代码 See merge request dcutoolkit/deeplearing/DeepEP!21
parents
dbf9fd61
17d9c844
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
532 additions
and
60 deletions
+532
-60
csrc/config.hpp
csrc/config.hpp
+4
-4
csrc/deep_ep.cu
csrc/deep_ep.cu
+2
-0
csrc/deep_ep.hpp
csrc/deep_ep.hpp
+1
-0
csrc/kernels/api.cuh
csrc/kernels/api.cuh
+1
-0
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+226
-45
csrc/kernels/internode_ll_logfmt.cuh
csrc/kernels/internode_ll_logfmt.cuh
+264
-0
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+12
-3
deep_ep/buffer.py
deep_ep/buffer.py
+4
-2
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+18
-6
No files found.
csrc/config.hpp
View file @
75b00cfb
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
...
@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
LowLatencyLayout
(
void
*
rdma_buffer
,
int
num_max_dispatch_tokens_per_rank
,
int
hidden
,
int
num_ranks
,
int
num_experts
,
int
quant_group_size
=
0
)
{
int
num_ranks
,
int
num_experts
,
int
quant_group_size
=
0
)
{
const
int
num_scales
=
quant_group_size
==
0
?
4
:
hidden
/
QUANTIZATION_GROUPSIZE
;
// 应该是1,但是代码中为了满足int4对齐
const
int
num_scales
=
hidden
/
QUANTIZATION_GROUPSIZE
;
// Dispatch and combine layout:
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even send buffer
...
@@ -148,11 +148,11 @@ struct LowLatencyLayout {
...
@@ -148,11 +148,11 @@ struct LowLatencyLayout {
// transformation
// transformation
EP_HOST_ASSERT
(
num_scales
*
sizeof
(
float
)
<=
static_cast
<
size_t
>
(
hidden
));
EP_HOST_ASSERT
(
num_scales
*
sizeof
(
float
)
<=
static_cast
<
size_t
>
(
hidden
));
size_t
num_bytes_per_dispatch_msg
=
size_t
num_bytes_per_dispatch_msg
=
sizeof
(
int4
)
+
sizeof
(
int4
)
+
std
::
max
(
hidden
*
sizeof
(
hip_bfloat16
),
hidden
+
std
::
max
(
hidden
*
sizeof
(
hip_bfloat16
),
hidden
+
num_scales
*
sizeof
(
float
));
(
quant_group_size
==
0
?
4
:
num_scales
)
*
sizeof
(
float
));
// 应该是1,但是代码中为了满足int4对齐
// 与internode_ll::combine 中的 num_bytes_per_slot 相等
// 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t
num_bytes_per_combine_msg
=
hidden
*
sizeof
(
hip_bfloat16
);
size_t
num_bytes_per_combine_msg
=
hidden
*
sizeof
(
hip_bfloat16
)
+
num_scales
*
sizeof
(
__hip_bfloat162
)
;
// Send buffer
// Send buffer
size_t
dispatch_send_buffer_bytes
=
size_t
dispatch_send_buffer_bytes
=
...
...
csrc/deep_ep.cu
View file @
75b00cfb
...
@@ -1413,6 +1413,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1413,6 +1413,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
std
::
optional
<
torch
::
Tensor
>&
combine_wait_recv_cost_stats
,
const
std
::
optional
<
torch
::
Tensor
>&
combine_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_logfmt
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>&
out
)
{
const
std
::
optional
<
torch
::
Tensor
>&
out
)
{
EP_HOST_ASSERT
(
low_latency_mode
);
EP_HOST_ASSERT
(
low_latency_mode
);
...
@@ -1482,6 +1483,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
...
@@ -1482,6 +1483,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta
.
first
,
next_clean_meta
.
second
,
next_clean_meta
.
first
,
next_clean_meta
.
second
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_combined_tokens
,
hidden
,
num_max_dispatch_tokens_per_rank
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
num_topk
,
num_experts
,
rank
,
num_ranks
,
use_logfmt
,
workspace
,
num_device_sms
,
launch_stream
,
workspace
,
num_device_sms
,
launch_stream
,
phases
,
zero_copy
);
phases
,
zero_copy
);
};
};
...
...
csrc/deep_ep.hpp
View file @
75b00cfb
...
@@ -185,6 +185,7 @@ public:
...
@@ -185,6 +185,7 @@ public:
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
torch
::
Tensor
&
src_info
,
const
torch
::
Tensor
&
layout_range
,
const
std
::
optional
<
torch
::
Tensor
>&
combine_wait_recv_cost_stats
,
const
std
::
optional
<
torch
::
Tensor
>&
combine_wait_recv_cost_stats
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
int
num_max_dispatch_tokens_per_rank
,
int
num_experts
,
bool
use_logfmt
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
bool
zero_copy
,
bool
async
,
bool
return_recv_hook
,
const
std
::
optional
<
torch
::
Tensor
>&
out
=
std
::
nullopt
);
const
std
::
optional
<
torch
::
Tensor
>&
out
=
std
::
nullopt
);
...
...
csrc/kernels/api.cuh
View file @
75b00cfb
...
@@ -159,6 +159,7 @@ void combine(void* combined_x,
...
@@ -159,6 +159,7 @@ void combine(void* combined_x,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_logfmt
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
);
int
phases
,
bool
zero_copy
);
...
...
csrc/kernels/internode_ll.cu
View file @
75b00cfb
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "shmem_wrapper.cuh"
#include "shmem_wrapper.cuh"
#include "internode_ll_logfmt.cuh"
namespace
deep_ep
{
namespace
deep_ep
{
...
@@ -612,7 +613,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -612,7 +613,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#undef DISPATCH_LAUNCH_CASE
#undef DISPATCH_LAUNCH_CASE
}
}
template
<
int
kHidden
,
int
kNumMaxTopk
,
int
kMaxNumWarps
=
16
>
template
<
bool
kUseLogFMT
,
int
kHidden
,
int
kNumMaxTopk
,
int
kMaxNumWarps
=
16
>
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
__global__
__launch_bounds__
(
16
*
kWarpSize
,
1
)
void
combine
(
void
*
combined_x
,
combine
(
void
*
combined_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
void
*
rdma_recv_x
,
int64_t
*
rdma_recv_flag
,
void
*
rdma_send_x
,
...
@@ -643,7 +644,24 @@ combine(void* combined_x,
...
@@ -643,7 +644,24 @@ combine(void* combined_x,
// Message package
// Message package
EP_STATIC_ASSERT
(
kHidden
%
QUANTIZATION_GROUPSIZE
==
0
,
"Invalid hidden"
);
EP_STATIC_ASSERT
(
kHidden
%
QUANTIZATION_GROUPSIZE
==
0
,
"Invalid hidden"
);
constexpr
size_t
num_bytes_per_slot
=
kHidden
*
sizeof
(
hip_bfloat16
);
/////////////// LogFMT使用 ///////////////
constexpr
int
bSupportLogFMT
=
kUseLogFMT
&&
hidden_bf16_int4
%
(
kWarpSize
*
2
)
==
0
;
constexpr
int
kNumSendUnrolls
=
bSupportLogFMT
?
2
:
1
;
constexpr
int
kNumRecvUnrolls
=
bSupportLogFMT
?
2
:
1
;
constexpr
int
kNumMsgInt4ElemPerWarp
=
kWarpSize
*
kNumSendUnrolls
;
// 每个warp发送的int4元素数据量,即每个warp发送 kNumMsgInt4ElemPerWarp*sizeof(int4)/sizeof(bfloat16)
EP_STATIC_ASSERT
(
hidden_bf16_int4
%
(
kNumSendUnrolls
*
kWarpSize
)
==
0
,
"Invalid hidden"
);
EP_STATIC_ASSERT
(
kNumSendUnrolls
>=
kNumRecvUnrolls
,
"Invalid unroll factors"
);
constexpr
int
kNumDivisions
=
kHidden
/
QUANTIZATION_GROUPSIZE
;
constexpr
int
kNumMetaBytes
=
kNumDivisions
*
sizeof
(
__hip_bfloat162
);
// 用于记录数据的最大最小值
constexpr
int
kNumSendLogFMTBytes
=
kNumMsgInt4ElemPerWarp
*
sizeof
(
int4
);
constexpr
int
kNumStages
=
1
;
// 使用kNumStages>1,则需要的LDS大于64KB
constexpr
int
kLogFMTShmemSize
=
kMaxNumWarps
*
(
kNumStages
*
kNumSendLogFMTBytes
+
kNumMetaBytes
);
__shared__
uint8_t
smem_buffer
[
kLogFMTShmemSize
];
/////////////////////////////////////////////
constexpr
size_t
num_bytes_per_slot
=
kHidden
*
sizeof
(
hip_bfloat16
)
+
kNumMetaBytes
;
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
num_bytes_per_slot
%
sizeof
(
int4
)
==
0
,
"Invalid vectorization"
);
// 初始化用于细粒度warp间同步的计数器数组
// 初始化用于细粒度warp间同步的计数器数组
...
@@ -683,6 +701,12 @@ combine(void* combined_x,
...
@@ -683,6 +701,12 @@ combine(void* combined_x,
const
auto
local_src_info
=
src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
local_src_info
=
src_info
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
;
const
auto
rdma_send_x_vec
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_x
)
+
const
auto
rdma_send_x_vec
=
reinterpret_cast
<
uint8_t
*>
(
rdma_send_x
)
+
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_slot
;
local_expert_idx
*
num_ranks
*
num_max_dispatch_tokens_per_rank
*
num_bytes_per_slot
;
// 用于logfmt的LDS
auto
smem_ptr
=
smem_buffer
+
warp_id
*
(
kNumStages
*
kNumSendLogFMTBytes
+
kNumMetaBytes
);
// 存储logfmt的起始地址,并根据stage_idx进行索引块
auto
logfmt_buffers
=
PatternVisitor
([
=
](
const
int
&
i
)
{
return
reinterpret_cast
<
int4
*>
(
smem_ptr
+
i
*
kNumSendLogFMTBytes
);
});
// 存储logfmt的最大最小值
auto
meta_buffers
=
bSupportLogFMT
?
reinterpret_cast
<
__hip_bfloat162
*>
(
smem_ptr
+
kNumStages
*
kNumSendLogFMTBytes
)
:
nullptr
;
// Unpack layout
// Unpack layout
int
offset
,
num_tokens_to_send
;
int
offset
,
num_tokens_to_send
;
...
@@ -699,20 +723,78 @@ combine(void* combined_x,
...
@@ -699,20 +723,78 @@ combine(void* combined_x,
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
;
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
;
uint64_t
p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
// 采用logfmt或者直接拷贝
if
(
p2p_ptr
==
0
)
{
// RDMA
uint64_t
dst_p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
const
auto
buf_int4_ptr
=
reinterpret_cast
<
int4
*>
(
buf_ptr
);
int
num_send_bytes
=
hidden
*
sizeof
(
hip_bfloat16
);
if
(
not
zero_copy
)
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
buf_int4_ptr
,
x_int4
,
ld_nc_global
,
st_na_global
);
if
(
not
zero_copy
or
dst_p2p_ptr
!=
0
)
{
// Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
const
auto
cpy_src_int4_ptr
=
zero_copy
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
x_int4
;
const
auto
cpy_dst_int4_ptr
=
dst_p2p_ptr
==
0
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
reinterpret_cast
<
int4
*>
(
dst_p2p_ptr
);
// 设置数据的真实偏移量
int
logfmt_offset_bytes
=
kNumMetaBytes
;
// 进入循环,逐步拷贝数据
constexpr
int
encode_num_warps
=
hidden_bf16_int4
/
kNumMsgInt4ElemPerWarp
;
for
(
int
iter_idx
=
0
;
iter_idx
<
encode_num_warps
;
++
iter_idx
)
{
int
num_logfmt_bytes
=
kNumMsgInt4ElemPerWarp
*
sizeof
(
int4
);
// 原始数据的warp级编译
int
warp_offset
=
iter_idx
*
kNumMsgInt4ElemPerWarp
;
if
constexpr
(
bSupportLogFMT
)
{
// 采用 寄存器->lds->global 的流水线方式, 量化后拷贝到buf_ptr中
const
int
&
stage_idx
=
iter_idx
%
kNumStages
;
// thread偏移
int
thread_offset
=
warp_offset
+
lane_id
*
kNumSendUnrolls
;
constexpr
int
kNumInt4PerDivision
=
128
/
kNumElemsPerInt4
;
// = 128/(sizeof(int4) / sizeof(hip_bfloat16)) = 128/(16/2)=16
num_logfmt_bytes
=
logfmt_encode
<
kNumSendUnrolls
>
(
cpy_src_int4_ptr
+
warp_offset
,
// 等同于 x_int4
logfmt_buffers
[
stage_idx
],
// NOTES: only the leader lane will write the result
(
thread_offset
%
kNumInt4PerDivision
==
0
)
?
meta_buffers
+
thread_offset
/
kNumInt4PerDivision
:
nullptr
,
lane_id
);
// 将量化后的数据写入
using
vec_type
=
uint32_t
;
UNROLLED_WARP_COPY_LL
(
2
,
lane_id
,
num_logfmt_bytes
/
sizeof
(
vec_type
),
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
cpy_dst_int4_ptr
)
+
logfmt_offset_bytes
),
reinterpret_cast
<
vec_type
*>
(
logfmt_buffers
[
stage_idx
]),
ld_nc_global
,
st_na_global
);
// 起始地址偏移
logfmt_offset_bytes
+=
num_logfmt_bytes
;
}
else
{
// 非量化数据的传输
UNROLLED_WARP_COPY_LL
(
2
,
lane_id
,
kNumMsgInt4ElemPerWarp
,
reinterpret_cast
<
int4
*>
(
cpy_dst_int4_ptr
+
warp_offset
),
reinterpret_cast
<
const
int4
*>
(
cpy_src_int4_ptr
+
warp_offset
),
ld_nc_global
,
st_na_global
);
}
syncwarp
();
}
// Store metadata (min/max values) for LogFMT
if
constexpr
(
bSupportLogFMT
)
{
// 最终设置节点间传输的字节数
num_send_bytes
=
logfmt_offset_bytes
;
using
vec_type
=
uint32_t
;
auto
meta_buffers_ptr
=
reinterpret_cast
<
vec_type
*>
(
meta_buffers
);
auto
cpy_dst_uint32_ptr
=
reinterpret_cast
<
vec_type
*>
(
cpy_dst_int4_ptr
);
for
(
int
j
=
lane_id
;
j
<
kNumMetaBytes
/
sizeof
(
vec_type
);
j
+=
kWarpSize
)
{
*
(
cpy_dst_uint32_ptr
+
j
)
=
meta_buffers_ptr
[
j
];
}
}
syncwarp
();
}
if
(
dst_p2p_ptr
==
0
)
{
internode_ll_putmem_nbi
((
void
*
)
dst_ptr
,
(
void
*
)
buf_ptr
,
internode_ll_putmem_nbi
((
void
*
)
dst_ptr
,
(
void
*
)
buf_ptr
,
num_ranks
,
dst_rank
,
local_expert_idx
,
num_ranks
,
dst_rank
,
local_expert_idx
,
hidden
*
sizeof
(
hip_bfloat16
));
num_send_bytes
);
}
else
{
// 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
x_int4
);
const
auto
*
dst_int4_ptr
=
reinterpret_cast
<
int4
*>
(
p2p_ptr
);
UNROLLED_WARP_COPY_LL
(
7
,
lane_id
,
hidden_bf16_int4
,
dst_int4_ptr
,
src_int4_ptr
,
ld_nc_global
,
st_na_global
);
}
}
}
}
...
@@ -773,40 +855,136 @@ LOW_LATENCY_COMBINE_RECV:
...
@@ -773,40 +855,136 @@ LOW_LATENCY_COMBINE_RECV:
// Reduce tokens with FP8 cast
// Reduce tokens with FP8 cast
// EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
// EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
EP_STATIC_ASSERT
(
kHidden
%
(
kWarpSize
*
kNumElemsPerInt4
)
==
0
,
"Invalid vectorization"
);
EP_STATIC_ASSERT
(
kHidden
%
(
kWarpSize
*
kNumElemsPerInt4
)
==
0
,
"Invalid vectorization"
);
if
(
thread_id
<
hidden_bf16_int4
)
{
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_combined_tokens
;
token_idx
+=
num_sms
)
{
// Read top-k indices and weights
int
reg_topk_idx
[
kNumMaxTopk
];
float
reg_topk_weights
[
kNumMaxTopk
];
#pragma unroll
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
{
reg_topk_idx
[
i
]
=
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
i
));
reg_topk_weights
[
i
]
=
__ldg
(
topk_weights
+
token_idx
*
num_topk
+
i
);
}
float
combined_values
[
kNumElemsPerInt4
]
=
{
0.0
f
};
// 计算需要多少个warp
#pragma unroll
constexpr
int
num_decode_warps
=
hidden_bf16_int4
/
(
kNumRecvUnrolls
*
kWarpSize
);
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
if
(
reg_topk_idx
[
i
]
>=
0
)
{
// 限制thread_id
// Read from sources
if
(
warp_id
>=
num_decode_warps
)
{
auto
rdma_buffer_type
=
reinterpret_cast
<
const
int
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
return
;
(
reg_topk_idx
[
i
]
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
}
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
);
// 每128个数据记录一个max/min值,即该数为总的max/min值数量
// Reduce
constexpr
int
kNumDivisionBytes
=
kNumDivisions
*
sizeof
(
float
);
auto
x_vec
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_buffer_row
)
+
thread_id
);
// 每个warp内总的BF16值的数量
const
auto
x_bf16
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
x_vec
);
constexpr
int
kNumBF16PerWarpBytes
=
kWarpSize
*
kNumRecvUnrolls
*
sizeof
(
int4
);
constexpr
int
kNumLogFMTPerWarpBytes
=
kNumBF16PerWarpBytes
*
10
/
16
;
// 用于记录 max/min 值的 log 值
auto
log_amax_buffers
=
PatternVisitor
([
=
](
const
int
&
i
)
{
return
reinterpret_cast
<
float
*>
(
smem_buffer
+
i
*
kNumDivisionBytes
);
});
auto
log_amin_buffers
=
PatternVisitor
([
=
](
const
int
&
i
)
{
return
reinterpret_cast
<
float
*>
(
smem_buffer
+
kNumStages
*
kNumDivisionBytes
+
i
*
kNumDivisionBytes
);
});
auto
cast_info_buffers
=
PatternVisitor
([
=
](
const
int
&
i
)
{
return
reinterpret_cast
<
int
*>
(
smem_buffer
+
kNumStages
*
kNumDivisionBytes
*
2
+
i
*
kNumDivisionBytes
);
});
// 初始化 topk_idx 和 topk_weights
int
topk_idx_by_lane
=
-
1
;
float
topk_weights_by_lane
=
-
1
;
int
stage_idx
=
0
;
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_combined_tokens
;
token_idx
+=
num_sms
)
{
if
(
lane_id
<
num_topk
)
{
topk_idx_by_lane
=
static_cast
<
int
>
(
__ldg
(
topk_idx
+
token_idx
*
num_topk
+
lane_id
));
topk_weights_by_lane
=
__ldg
(
topk_weights
+
token_idx
*
num_topk
+
lane_id
);
}
float
combined_values
[
kNumElemsPerInt4
*
kNumRecvUnrolls
]
=
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
{
int
topk_idx_reg
=
shfl_sync
(
topk_idx_by_lane
,
i
);
if
(
topk_idx_reg
<
0
)
continue
;
const
auto
&
topk_weight_reg
=
shfl_sync
(
topk_weights_by_lane
,
i
);
// Read from sources
auto
rdma_buffer_type
=
reinterpret_cast
<
const
uint8_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
(
topk_idx_reg
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
if
constexpr
(
bSupportLogFMT
)
{
// 接收到的数据位置
const
uint8_t
*
data_buffer
=
rdma_buffer_type
+
kNumMetaBytes
;
// 读取max/min数据
if
(
warp_id
==
0
)
{
// 因为每个warp能处理数据量为 kWarpSize*sizeof(int4)/sizeof(bfloat16) * kNumSendUnrolls
// 即不考虑kNumSendUnrolls,一共 kWarpSize*sizeof(int4)/sizeof(bfloat16)/128 组, 代入参数 = kWarpSize / 16 个warp,nv上为2,dcu上为4
logfmt_check_amaxmin
<
kNumDivisions
/
(
kWarpSize
/
16
),
kNumSendUnrolls
,
kNumRecvUnrolls
>
(
/*meta_buffer*/
rdma_buffer_type
,
reinterpret_cast
<
int4
*>
(
log_amax_buffers
[
stage_idx
]),
reinterpret_cast
<
int4
*>
(
log_amin_buffers
[
stage_idx
]),
cast_info_buffers
[
stage_idx
],
lane_id
);
}
__syncthreads
();
// 获取cast_info_buffers
const
auto
&
info
=
cast_info_buffers
[
stage_idx
][
warp_id
];
bool
enable_cast
=
info
&
1
;
int
num_casted_prefix
=
info
>>
1
;
// 可用的
// 计算偏移(与TMA版本逻辑一致)
int
warp_offset
=
kNumLogFMTPerWarpBytes
*
num_casted_prefix
+
kNumBF16PerWarpBytes
*
(
warp_id
-
num_casted_prefix
);
int
lane_offset
=
(
enable_cast
?
kNumLogFMTPerWarpBytes
:
kNumBF16PerWarpBytes
)
/
kWarpSize
*
lane_id
;
// 使用临时缓冲区进行归约
const
uint8_t
*
thread_data_ptr
=
data_buffer
+
warp_offset
+
lane_offset
;
/**
一共有kNumDivisions个max/min数据对,读取时每warp默认处理256bit的max/min,所以logfmt_check_amaxmin的kNumLanes设置为 kNumDivisions/2
保存数据时每个log_amax_buffers为float2数据类型,保存总的warpkNumDivisions / 2
实际保存数据时,每个warp保存的实际数据个数为 kWarpSize*kNumRecvUnrolls*sizeof(int4)/sizeof(hip_bfloat16)
实际每个warp读取的max/min的 warp_idx=kWarpSize*kNumRecvUnrolls*sizeof(int4)/sizeof(hip_bfloat16) / 128 = kNumRecvUnrolls * 2
具体的lane_id处理的数据量为 warp_idx / kWarpSize
*/
int
log_amaxmin_per_warp
=
kNumRecvUnrolls
*
kWarpSize
*
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
)
/
QUANTIZATION_GROUPSIZE
;
int
division_idx
=
warp_id
*
log_amaxmin_per_warp
+
lane_id
*
log_amaxmin_per_warp
/
kWarpSize
;
// 反量化
decode_and_accumulate
<
kNumRecvUnrolls
>
(
reinterpret_cast
<
const
uint32_t
*>
(
thread_data_ptr
),
// 直接使用全局内存地址
combined_values
,
log_amax_buffers
[
stage_idx
][
division_idx
],
log_amin_buffers
[
stage_idx
][
division_idx
],
enable_cast
,
topk_weight_reg
);
}
else
{
// 接收到的数据位置
const
uint8_t
*
data_buffer
=
rdma_buffer_type
;
// 计算偏移
int
warp_offset
=
kNumBF16PerWarpBytes
*
warp_id
;
int
lane_offset
=
kNumBF16PerWarpBytes
/
kWarpSize
*
lane_id
;
// 使用临时缓冲区进行归约
const
uint8_t
*
thread_data_ptr
=
data_buffer
+
warp_offset
+
lane_offset
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerInt4
;
++
j
)
for
(
int
j
=
0
;
j
<
kNumRecvUnrolls
;
++
j
)
{
combined_values
[
j
]
+=
static_cast
<
float
>
(
x_bf16
[
j
])
*
reg_topk_weights
[
i
];
auto
tmp_rdma_value
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
thread_data_ptr
)
+
j
);
const
auto
x_bf16
=
reinterpret_cast
<
const
hip_bfloat16
*>
(
&
tmp_rdma_value
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kNumElemsPerInt4
;
++
k
)
{
int
combined_idx
=
j
*
kNumElemsPerInt4
+
k
;
combined_values
[
combined_idx
]
+=
static_cast
<
float
>
(
x_bf16
[
k
])
*
topk_weight_reg
;
}
}
}
}
}
// Write results
// Write results,kNumRecvUnrolls==2时则写256bit的数
int4
&
combined_int4
=
*
reinterpret_cast
<
int4
*>
(
combined_values
);
int4
combined_int4
[
kNumRecvUnrolls
];
auto
combined_bf16
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
combined_values
);
auto
combined_bf16
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
combined_int4
[
0
]);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerInt4
;
++
j
)
for
(
int
j
=
0
;
j
<
kNumElemsPerInt4
*
kNumRecvUnrolls
;
++
j
)
{
combined_bf16
[
j
]
=
static_cast
<
hip_bfloat16
>
(
combined_values
[
j
]);
combined_bf16
[
j
]
=
static_cast
<
hip_bfloat16
>
(
combined_values
[
j
]);
(
reinterpret_cast
<
int4
*>
(
combined_x
)
+
token_idx
*
hidden_bf16_int4
)[
thread_id
]
=
combined_int4
;
}
for
(
int
j
=
0
;
j
<
kNumRecvUnrolls
;
++
j
)
{
(
reinterpret_cast
<
int4
*>
(
combined_x
)
+
token_idx
*
hidden_bf16_int4
+
warp_id
*
kWarpSize
*
kNumRecvUnrolls
)[
lane_id
*
kNumRecvUnrolls
+
j
]
=
combined_int4
[
j
];
}
}
}
}
}
}
...
@@ -820,6 +998,7 @@ void combine(void* combined_x,
...
@@ -820,6 +998,7 @@ void combine(void* combined_x,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int64_t
*
next_clean
,
int
num_next_clean_int
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_combined_tokens
,
int
hidden
,
int
num_max_dispatch_tokens_per_rank
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
int
num_topk
,
int
num_experts
,
int
rank
,
int
num_ranks
,
bool
use_logfmt
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
int
phases
,
bool
zero_copy
)
{
constexpr
int
kMaxNumWarps
=
16
;
constexpr
int
kMaxNumWarps
=
16
;
...
@@ -840,7 +1019,9 @@ void combine(void* combined_x,
...
@@ -840,7 +1019,9 @@ void combine(void* combined_x,
#define COMBINE_LAUNCH_CASE(hidden) \
#define COMBINE_LAUNCH_CASE(hidden) \
{ \
{ \
auto combine_func = combine<hidden, kNumMaxTopk, kMaxNumWarps>; \
auto combine_func = use_logfmt ? \
combine<true, hidden, kNumMaxTopk, kMaxNumWarps> : \
combine<false, hidden, kNumMaxTopk, kMaxNumWarps>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, \
combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
x, topk_idx, topk_weights, src_info, layout_range, \
...
...
csrc/kernels/internode_ll_logfmt.cuh
0 → 100644
View file @
75b00cfb
#pragma once
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "buffer.cuh"
#include "utils.cuh"
#include <iostream>
#include "hip/hip_runtime.h"
#include "shmem_wrapper.cuh"
namespace
deep_ep
{
namespace
internode_ll
{
template
<
int
kNumSendUnrolls
>
__forceinline__
__device__
int
logfmt_encode
(
const
int4
*
cpy_src_int4_ptr
,
int4
*
dst_buffer
,
__hip_bfloat162
*
shared_amaxmin
,
const
int
&
lane_id
)
{
EP_STATIC_ASSERT
(
kNumSendUnrolls
==
2
,
"kNumSendUnrolls == 2 only"
);
constexpr
int
kNumElemsPerInt4
=
sizeof
(
int4
)
/
sizeof
(
__hip_bfloat16
);
// 8
constexpr
float
kLogThreshold
=
0
;
constexpr
float
kMinClip
=
32
;
// `== log_2(2 ^ (2 ^ 5))`
constexpr
int
kNumBits
=
10
;
constexpr
int
kNumValues
=
1
<<
(
kNumBits
-
1
);
// = 512
constexpr
int
kSendValueBytes
=
kNumSendUnrolls
*
sizeof
(
int4
);
//=2*16=32
constexpr
int
kNumElementPerInt4
=
sizeof
(
int4
)
/
sizeof
(
uint32_t
);
int4
int4_values
[
kNumSendUnrolls
];
const
auto
&
uint32_values
=
reinterpret_cast
<
uint32_t
*>
(
int4_values
);
const
auto
&
bf162_values
=
reinterpret_cast
<
__hip_bfloat162
*>
(
int4_values
);
// Calculate lane offset
const
auto
&
ld_buffer
=
cpy_src_int4_ptr
+
lane_id
*
kNumSendUnrolls
;
// Local log amax
auto
bf162_amax
=
__hip_bfloat162
(
HIPRT_ZERO_BF16
,
HIPRT_ZERO_BF16
);
auto
bf162_amin
=
__hip_bfloat162
(
HIPRT_INF_BF16
,
HIPRT_INF_BF16
);
uint32_t
local_signs
=
0
;
#pragma unroll
for
(
int
v
=
0
;
v
<
kNumSendUnrolls
;
++
v
)
{
int4
ld_int4_value
=
ld_nc_global
(
ld_buffer
+
v
);
// 向量化读取
uint32_t
*
ld_u32_ptr
=
reinterpret_cast
<
uint32_t
*>
(
&
ld_int4_value
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kNumElementPerInt4
;
++
k
)
{
// 也是kNumSendUnrolls * kNumElemsPerInt4 / 2
// TODO: eliminate bank conflicts
uint32_t
ld_u32_value
=
ld_u32_ptr
[
k
];
int
k_offset
=
v
*
kNumElementPerInt4
+
k
;
// 提取符号位: 每个bfloat16的最高位是符号位
local_signs
|=
((
ld_u32_value
>>
15
)
&
1
)
<<
(
k_offset
*
2
);
local_signs
|=
((
ld_u32_value
>>
31
)
&
1
)
<<
(
k_offset
*
2
+
1
);
// 清除符号位,保留幅值
ld_u32_value
&=
0x7fff7fff
;
auto
ld_bf16_value
=
*
reinterpret_cast
<
__hip_bfloat162
*>
(
&
ld_u32_value
);
bf162_amax
=
__hmax2
(
bf162_amax
,
ld_bf16_value
);
bf162_amin
=
__hmin2
(
bf162_amin
,
ld_bf16_value
);
uint32_values
[
k_offset
]
=
ld_u32_value
;
}
}
// Reduce per 128 channels
// TODO: figure out how hardware do 2-byte min/max
auto
amax
=
__builtin_fmaxf
(
static_cast
<
float
>
(
bf162_amax
.
x
),
static_cast
<
float
>
(
bf162_amax
.
y
));
auto
amin
=
__builtin_fminf
(
static_cast
<
float
>
(
bf162_amin
.
x
),
static_cast
<
float
>
(
bf162_amin
.
y
));
// 即每128个值进行一次reduce
constexpr
static
int
kNumLanesToReduce
=
128
*
sizeof
(
__hip_bfloat16
)
/
kSendValueBytes
;
// =128*2 / (kNumSendUnrolls * sizeof(int4)) = 8
amax
=
warp_reduce_max
<
kNumLanesToReduce
>
(
amax
);
amin
=
warp_reduce_min
<
kNumLanesToReduce
>
(
amin
);
// Write min/max into the shared memory
if
(
shared_amaxmin
!=
nullptr
)
{
*
shared_amaxmin
=
__hip_bfloat162
(
amax
,
amin
);
}
syncwarp
();
// Calculate log amin/amax float
const
auto
&
log_amax
=
__builtin_log2f
(
amax
);
const
auto
&
log_amin
=
__builtin_fmaxf
(
__builtin_log2f
(
amin
),
log_amax
-
kMinClip
);
// 在组内广播enable_cast结果
const
bool
&
enable_cast
=
warp_reduce_and
<
kNumLanesToReduce
,
true
>
(
log_amax
<
kLogThreshold
and
log_amin
<
log_amax
);
// Case into LogFMT-10 if satisfied
if
(
enable_cast
)
{
constexpr
int
dst_buffer_step
=
kSendValueBytes
*
10
/
16
;
const
auto
&
st_buffer
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
dst_buffer
)
+
lane_id
*
dst_buffer_step
);
uint32_t
st_u32_values
[
dst_buffer_step
/
sizeof
(
uint32_t
)];
// = 5
// 计算10bit数据的两个相邻数值的差值
const
auto
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
const
auto
step_inv
=
1.0
f
/
step
;
// 计算舍入值
const
auto
rounding
=
2.0
f
-
__builtin_log2f
((
1.0
f
+
__builtin_exp2f
(
step
))
*
0.5
f
)
*
step_inv
;
const
auto
fused_rounding
=
rounding
-
log_amin
*
step_inv
;
// 用于存储编码后的值
uint32_t
encoded
[
kNumElemsPerInt4
*
2
];
// 展开循环,处理数据打包
{
// 将int4值(128bit)转换为 bfloat162
#pragma unroll
for
(
int
k
=
0
;
k
<
kNumElemsPerInt4
;
++
k
)
{
// 8
// 将 bfloat162 转换为 float2
const
auto
&
fp162_fvalue
=
__bfloat1622float2
(
bf162_values
[
k
]);
/*
实际进行压缩的公式为:
q = clamp( round( (log2(abs(x)) - log_min) / (log_max - log_min) * (K - 2) + 0.5 ), 0, K - 1)
其中:
x: 输入的浮点数
q: 输出的整数,表示压缩后的值
log_min: 输入中最小值的log2值
log_max: 输入中最大值的log2值
K: 压缩后的整数的最大值(即,K为2的幂)
*/
// 对 float 值进行编码
encoded
[
k
*
2
+
0
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_log2f
(
fp162_fvalue
.
x
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
1
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_log2f
(
fp162_fvalue
.
y
)
*
step_inv
+
fused_rounding
,
0
));
}
// 批量打包编码后的值到 st_buffer
st_u32_values
[
0
]
=
(
encoded
[
0
]
>>
0
)
|
(
encoded
[
1
]
<<
9
)
|
(
encoded
[
2
]
<<
18
)
|
(
encoded
[
3
]
<<
27
);
st_u32_values
[
1
]
=
(
encoded
[
3
]
>>
5
)
|
(
encoded
[
4
]
<<
4
)
|
(
encoded
[
5
]
<<
13
)
|
(
encoded
[
6
]
<<
22
)
|
(
encoded
[
7
]
<<
31
);
st_u32_values
[
2
]
=
(
encoded
[
7
]
>>
1
)
|
(
encoded
[
8
]
<<
8
)
|
(
encoded
[
9
]
<<
17
)
|
(
encoded
[
10
]
<<
26
);
st_u32_values
[
3
]
=
(
encoded
[
10
]
>>
6
)
|
(
encoded
[
11
]
<<
3
)
|
(
encoded
[
12
]
<<
12
)
|
(
encoded
[
13
]
<<
21
)
|
(
encoded
[
14
]
<<
30
);
st_u32_values
[
4
]
=
(
encoded
[
14
]
>>
2
)
|
(
encoded
[
15
]
<<
7
)
|
(
local_signs
<<
16
);
}
// 保存160bit的数据到st_buffer
st_buffer
[
0
]
=
st_u32_values
[
0
];
*
(
reinterpret_cast
<
int4
*>
(
st_buffer
+
1
))
=
*
(
reinterpret_cast
<
int4
*>
(
st_u32_values
+
1
));
}
else
{
// 准备收发数据
using
vec_type
=
int4
;
const
auto
&
ld_buffer_vec
=
reinterpret_cast
<
const
vec_type
*>
(
ld_buffer
);
auto
st_buffer_vec
=
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
dst_buffer
)
+
lane_id
*
kSendValueBytes
);
constexpr
int
kLoopIter
=
kSendValueBytes
/
sizeof
(
vec_type
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kLoopIter
;
++
k
)
{
st_buffer_vec
[
k
]
=
ld_nc_global
(
ld_buffer_vec
+
k
);
}
}
// 确保 warp 内的所有线程都完成打包操作
syncwarp
();
// 计算量化成功和失败时的数据量
constexpr
int
unable_cast_num_bytes
=
kWarpSize
*
kSendValueBytes
;
// = 64*2*16 = 2048
constexpr
int
enable_cast_num_bytes
=
unable_cast_num_bytes
*
10
/
16
;
// = 2048/16*10=1280
// Return TMA copy bytes
return
enable_cast
?
enable_cast_num_bytes
:
unable_cast_num_bytes
;
}
template
<
int
kNumLanes
,
int
kNumSendUnrolls
,
int
kNumRecvUnrolls
>
__forceinline__
__device__
void
logfmt_check_amaxmin
(
const
uint8_t
*
meta_buffer
,
int4
*
shared_log_amax
,
int4
*
shared_log_amin
,
int
*
shared_cast_info
,
const
int
lane_id
)
{
// 定义log阈值和最小剪切值
constexpr
float
kLogThreshold
=
0
;
constexpr
float
kMinClip
=
32
;
// `== log_2(2 ^ (2 ^ 5))`
constexpr
int
kNumQuantGroupsPerWarp
=
kWarpSize
/
16
;
using
log_vec_type
=
int4
;
EP_STATIC_ASSERT
(
sizeof
(
log_vec_type
)
/
sizeof
(
__hip_bfloat162
)
==
kNumQuantGroupsPerWarp
,
"kNumQuantGroupsPerWarp == sizeof(log_vec_type) only"
);
// 初始化类型转换启用标志
bool
enable_cast
=
true
;
// 如果 lane_id 小于 kNumLanes,则进行计算
if
(
lane_id
<
kNumLanes
)
{
// 从 meta_buffer 中读取 amaxmin2 值
auto
amaxmin4
=
reinterpret_cast
<
const
log_vec_type
*>
(
meta_buffer
)[
lane_id
];
const
auto
&
bf162_amaxmin
=
reinterpret_cast
<
__hip_bfloat162
*>
(
&
amaxmin4
);
// 定义 log_amax 和 log_amin 数组
float
log_amax
[
kNumQuantGroupsPerWarp
],
log_amin
[
kNumQuantGroupsPerWarp
];
// 展开循环,计算 log_amax 和 log_amin
#pragma unroll
for
(
int
i
=
0
;
i
<
kNumQuantGroupsPerWarp
;
++
i
)
{
// sizeof(uint64_t) / sizeof(__hip_bfloat162) = 2
auto
amax
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
x
);
auto
amin
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
y
);
log_amax
[
i
]
=
__builtin_log2f
(
amax
);
log_amin
[
i
]
=
amin
==
0
?
log_amax
[
i
]
-
kMinClip
:
__builtin_fmaxf
(
__builtin_log2f
(
amin
),
log_amax
[
i
]
-
kMinClip
);
enable_cast
=
enable_cast
and
log_amax
[
i
]
<
kLogThreshold
and
log_amin
[
i
]
<
log_amax
[
i
];
}
// 将计算结果存储到 shared_log_amax 和 shared_log_amin 中
int4
log_amax_int4
=
*
reinterpret_cast
<
int4
*>
(
log_amax
);
int4
log_amin_int4
=
*
reinterpret_cast
<
int4
*>
(
log_amin
);
shared_log_amax
[
lane_id
]
=
log_amax_int4
;
shared_log_amin
[
lane_id
]
=
log_amin_int4
;
}
// 计算 casted 值。根据当前线程是否启用了类型转换,计算它所属的组的索引
const
auto
&
casted
=
warp_reduce_and
<
kNumSendUnrolls
>
(
enable_cast
)
?
1u
<<
(
lane_id
/
kNumRecvUnrolls
)
:
0u
;
// 计算 num_casted_prefix 值。计算当前线程之前有多少个线程启用了类型转换。
const
auto
&
num_casted_prefix
=
__popc
(
warp_reduce_or
<
kNumRecvUnrolls
,
true
>
(
casted
)
&
((
1u
<<
(
lane_id
/
kNumRecvUnrolls
))
-
1
));
// 如果 lane_id 小于 kNumLanes 且 lane_id 是 kNumRecvUnrolls 的倍数,则更新 shared_cast_info
if
(
lane_id
<
kNumLanes
and
lane_id
%
kNumRecvUnrolls
==
0
)
{
// 最低1位保存casted结果,最高31位保存num_casted_prefix值
shared_cast_info
[
lane_id
/
kNumRecvUnrolls
]
=
(
num_casted_prefix
<<
1
)
|
(
casted
?
1u
:
0u
);
}
}
template
<
int
kNumRecvUnrolls
>
__forceinline__
__device__
void
decode_and_accumulate
(
const
uint32_t
*
ld_buffer
,
float
*
accum
,
const
float
&
log_amax
,
const
float
&
log_amin
,
const
bool
&
enable_cast
,
const
float
&
weight
)
{
EP_STATIC_ASSERT
(
kNumRecvUnrolls
==
2
,
"kNumRecvUnrolls == 2 only"
);
if
(
enable_cast
)
{
constexpr
int
kNumBits
=
10
;
constexpr
int
kNumValues
=
1
<<
(
kNumBits
-
1
);
const
auto
&
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
auto
decode
=
[
=
](
const
uint32_t
&
encoded
,
const
uint32_t
&
sign
)
{
const
auto
decoded
=
encoded
==
0
?
.0
f
:
__builtin_exp2f
((
encoded
-
1
)
*
step
+
log_amin
);
return
sign
?
-
decoded
:
decoded
;
};
uint32_t
concat
[
6
];
concat
[
0
]
=
ld_buffer
[
0
];
#pragma unroll
for
(
int
k
=
1
;
k
<
5
;
++
k
)
concat
[
k
]
=
(
ld_buffer
[
k
-
1
]
>>
(
32
-
k
*
5
))
|
(
ld_buffer
[
k
]
<<
(
k
*
5
));
concat
[
5
]
=
ld_buffer
[
4
]
>>
7
;
const
uint32_t
&
local_signs
=
ld_buffer
[
4
]
>>
16
;
#pragma unroll
for
(
int
k
=
0
;
k
<
5
;
++
k
)
{
accum
[
k
*
3
+
0
]
+=
decode
((
concat
[
k
]
>>
0
)
&
0x1ff
,
(
local_signs
>>
(
k
*
3
+
0
))
&
1
)
*
weight
;
accum
[
k
*
3
+
1
]
+=
decode
((
concat
[
k
]
>>
9
)
&
0x1ff
,
(
local_signs
>>
(
k
*
3
+
1
))
&
1
)
*
weight
;
accum
[
k
*
3
+
2
]
+=
decode
((
concat
[
k
]
>>
18
)
&
0x1ff
,
(
local_signs
>>
(
k
*
3
+
2
))
&
1
)
*
weight
;
}
accum
[
15
]
+=
decode
(
concat
[
5
]
&
0x1ff
,
(
local_signs
>>
15
)
&
1
)
*
weight
;
}
else
{
constexpr
int
kLoopIter
=
kNumRecvUnrolls
*
sizeof
(
int4
)
/
sizeof
(
uint32_t
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kLoopIter
;
++
k
)
{
auto
bf16_pack
=
*
reinterpret_cast
<
const
__hip_bfloat162
*>
(
ld_buffer
+
k
);
accum
[
k
*
2
+
0
]
+=
static_cast
<
float
>
(
bf16_pack
.
x
)
*
weight
;
accum
[
k
*
2
+
1
]
+=
static_cast
<
float
>
(
bf16_pack
.
y
)
*
weight
;
}
}
}
}
// namespace internode_ll
}
// namespace deep_ep
csrc/kernels/utils.cuh
View file @
75b00cfb
...
@@ -72,9 +72,9 @@ __device__ __forceinline__ T shfl_xor(const T val, int laneMask, int width = kWa
...
@@ -72,9 +72,9 @@ __device__ __forceinline__ T shfl_xor(const T val, int laneMask, int width = kWa
return
__shfl_xor
(
val
,
laneMask
,
width
);
return
__shfl_xor
(
val
,
laneMask
,
width
);
}
}
__device__
__forceinline__
int
template
<
typename
T
>
shfl_sync
(
const
int
val
,
int
srcLane
=
0
,
int
width
=
kWarpSize
,
__device__
__forceinline__
T
shfl_sync
(
const
T
val
,
int
srcLane
=
0
,
int
width
=
kWarpSize
,
uint64_t
shfl_sync_mask
=
kFullWarpMask
)
{
// Let compiler deduce type
uint64_t
shfl_sync_mask
=
kFullWarpMask
)
{
// Let compiler deduce type
return
__shfl
(
val
,
srcLane
,
width
);
return
__shfl
(
val
,
srcLane
,
width
);
}
}
...
@@ -115,6 +115,15 @@ template <> struct VecInt<16> {
...
@@ -115,6 +115,15 @@ template <> struct VecInt<16> {
using
vec_t
=
native_int4
;
using
vec_t
=
native_int4
;
};
};
template
<
typename
FuncT
>
struct
PatternVisitor
{
FuncT
func
;
__device__
__host__
explicit
PatternVisitor
(
FuncT
&&
func
)
:
func
(
std
::
forward
<
FuncT
>
(
func
))
{}
__device__
__host__
auto
operator
[](
const
uint32_t
&
i
)
{
return
func
(
i
);
}
};
__device__
__forceinline__
void
trap
()
{
__device__
__forceinline__
void
trap
()
{
abort
();
abort
();
}
}
...
...
deep_ep/buffer.py
View file @
75b00cfb
...
@@ -923,7 +923,8 @@ class Buffer:
...
@@ -923,7 +923,8 @@ class Buffer:
# noinspection PyTypeChecker
# noinspection PyTypeChecker
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
def
low_latency_combine
(
self
,
x
:
torch
.
Tensor
,
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
handle
:
tuple
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
handle
:
tuple
,
use_logfmt
:
bool
=
False
,
zero_copy
:
bool
=
False
,
async_finish
:
bool
=
False
,
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
return_recv_hook
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
combine_wait_recv_cost_stats
:
Optional
[
torch
.
Tensor
]
=
None
)
->
\
combine_wait_recv_cost_stats
:
Optional
[
torch
.
Tensor
]
=
None
)
->
\
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
Tuple
[
torch
.
Tensor
,
EventOverlap
,
Callable
]:
...
@@ -944,6 +945,7 @@ class Buffer:
...
@@ -944,6 +945,7 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor.
tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function.
handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`.
with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
...
@@ -964,7 +966,7 @@ class Buffer:
...
@@ -964,7 +966,7 @@ class Buffer:
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
,
event
,
hook
=
self
.
runtime
.
low_latency_combine
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combine_wait_recv_cost_stats
,
combine_wait_recv_cost_stats
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
num_max_dispatch_tokens_per_rank
,
num_experts
,
zero_copy
,
async_finish
,
return_recv_hook
,
out
)
use_logfmt
,
zero_copy
,
async_finish
,
return_recv_hook
,
out
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
tensors_to_record
=
(
x
,
topk_idx
,
topk_weights
,
src_info
,
layout_range
,
combined_x
)
return
combined_x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
return
combined_x
,
EventOverlap
(
event
,
tensors_to_record
if
async_finish
else
None
),
hook
...
...
tests/test_low_latency_new.py
View file @
75b00cfb
...
@@ -42,6 +42,7 @@ def test_main(num_tokens: int,
...
@@ -42,6 +42,7 @@ def test_main(num_tokens: int,
num_ranks
:
int
,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
buffer
:
deep_ep
.
Buffer
,
use_logfmt
:
bool
=
False
,
seed
:
int
=
0
):
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
...
@@ -56,10 +57,12 @@ def test_main(num_tokens: int,
...
@@ -56,10 +57,12 @@ def test_main(num_tokens: int,
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
(
rank
-
rank_offset
)
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
(
rank
-
rank_offset
)
x
[:,
-
128
:]
=
torch
.
arange
(
num_tokens
,
device
=
'cuda'
).
to
(
torch
.
bfloat16
).
view
(
-
1
,
1
)
x
[:,
-
128
:]
=
torch
.
arange
(
num_tokens
,
device
=
'cuda'
).
to
(
torch
.
bfloat16
).
view
(
-
1
,
1
)
x_list
=
[
x
]
x_list
=
[
x
]
# # NOTES: the last one is for performance testing
for
_
in
range
(
4
if
use_logfmt
else
0
):
# # Most of the values in the perf case is lower than the threshold, casting most channels
# NOTES: make more LogFMT casts and also with some BF16
# x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
x_list
.
append
(
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
0.5
*
random
.
random
())
# x_list = [x_rand]
# NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list
.
append
(
torch
.
randn
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
0.1
)
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
...
@@ -79,7 +82,7 @@ def test_main(num_tokens: int,
...
@@ -79,7 +82,7 @@ def test_main(num_tokens: int,
# Check dispatch correctness
# Check dispatch correctness
do_check
=
True
do_check
=
True
hash_value
,
num_times
=
0
,
0
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
x_i
,
current_x
in
enumerate
(
x_list
)
:
for
return_recv_hook
in
(
False
,
True
):
for
return_recv_hook
in
(
False
,
True
):
for
quant_type
in
(
0
,
1
,
2
,
3
,
):
# 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
for
quant_type
in
(
0
,
1
,
2
,
3
,
):
# 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant
=
quant_type
>
0
dispatch_use_quant
=
quant_type
>
0
...
@@ -152,7 +155,7 @@ def test_main(num_tokens: int,
...
@@ -152,7 +155,7 @@ def test_main(num_tokens: int,
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
# Check combine correctness
# Check combine correctness
for
zero_copy
in
(
False
,
True
):
for
zero_copy
in
(
False
,
)
if
use_logfmt
else
(
False
,
True
,
):
if
zero_copy
:
if
zero_copy
:
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
buffer
.
get_next_low_latency_combine_buffer
(
handle
)[:,
:,
:]
=
simulated_gemm_x
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
out
=
torch
.
empty
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
...
@@ -160,6 +163,7 @@ def test_main(num_tokens: int,
...
@@ -160,6 +163,7 @@ def test_main(num_tokens: int,
topk_idx
,
topk_idx
,
topk_weights
,
topk_weights
,
handle
,
handle
,
use_logfmt
=
use_logfmt
,
async_finish
=
not
return_recv_hook
,
async_finish
=
not
return_recv_hook
,
zero_copy
=
zero_copy
,
zero_copy
=
zero_copy
,
return_recv_hook
=
return_recv_hook
,
return_recv_hook
=
return_recv_hook
,
...
@@ -172,6 +176,10 @@ def test_main(num_tokens: int,
...
@@ -172,6 +176,10 @@ def test_main(num_tokens: int,
assert
diff
<
(
9e-4
if
dispatch_use_quant
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_quant=
{
dispatch_use_quant
}
, zero_copy=
{
zero_copy
}
'
assert
diff
<
(
9e-4
if
dispatch_use_quant
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_quant=
{
dispatch_use_quant
}
, zero_copy=
{
zero_copy
}
'
hash_value
^=
hash_tensor
(
combined_x
)
hash_value
^=
hash_tensor
(
combined_x
)
if
rank
==
0
:
print
(
f
"data:
{
x_i
}
, return_recv_hook:
{
return_recv_hook
}
, quant_type:
{
quant_type
}
, "
,
f
"fp8_round_scale:
{
fp8_round_scale
}
, quant_group_size:
{
quant_group_size
}
pass"
)
# noinspection PyShadowingNames
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
def
large_gemm_with_hook
(
hook
):
mat_0
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_0
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
...
@@ -190,6 +198,7 @@ def test_main(num_tokens: int,
...
@@ -190,6 +198,7 @@ def test_main(num_tokens: int,
topk_idx
,
topk_idx
,
topk_weights
,
topk_weights
,
handle
,
handle
,
use_logfmt
=
use_logfmt
,
return_recv_hook
=
return_recv_hook
)
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
...
@@ -251,6 +260,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -251,6 +260,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_ranks
,
num_ranks
,
group
,
group
,
buffer
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
1
)
seed
=
1
)
do_pressure_test
=
args
.
pressure_test
do_pressure_test
=
args
.
pressure_test
...
@@ -265,6 +275,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -265,6 +275,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_ranks
,
num_ranks
,
group
,
group
,
buffer
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
seed
)
seed
=
seed
)
for
_
in
range
(
20
):
for
_
in
range
(
20
):
assert
test_main
(
num_tokens
,
assert
test_main
(
num_tokens
,
...
@@ -275,6 +286,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...
@@ -275,6 +286,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_ranks
,
num_ranks
,
group
,
group
,
buffer
,
buffer
,
use_logfmt
=
args
.
use_logfmt
,
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
seed
=
seed
)
==
ref_hash
,
f
'Error: seed=
{
seed
}
'
# Destroy the buffer runtime and communication group
# Destroy the buffer runtime and communication group
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment