Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
33bafa16
Commit
33bafa16
authored
Mar 06, 2026
by
lishen
Browse files
lowlatency combine实现3级流水
parent
61bc0aff
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
215 additions
and
181 deletions
+215
-181
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+189
-142
csrc/kernels/internode_ll_logfmt.cuh
csrc/kernels/internode_ll_logfmt.cuh
+20
-39
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+6
-0
No files found.
csrc/kernels/internode_ll.cu
View file @
33bafa16
...
...
@@ -636,6 +636,7 @@ combine(void* combined_x,
const
auto
num_local_experts
=
num_experts
/
num_ranks
;
const
auto
warp_group_id
=
warp_id
/
num_warps_per_group
;
const
auto
sub_warp_id
=
warp_id
%
num_warps_per_group
;
const
auto
num_warps
=
num_threads
/
kWarpSize
;
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
// Data type staffs
...
...
@@ -656,7 +657,7 @@ combine(void* combined_x,
constexpr
int
kNumDivisions
=
kHidden
/
QUANTIZATION_GROUPSIZE
;
constexpr
int
kNumMetaBytes
=
kNumDivisions
*
sizeof
(
__hip_bfloat162
);
// 用于记录数据的最大最小值
constexpr
int
kNumSendLogFMTBytes
=
kNumMsgInt4ElemPerWarp
*
sizeof
(
int4
);
constexpr
int
kNumStages
=
1
;
// 使用kNumStages>1,则需要的LDS大于64KB
constexpr
int
kNumStages
=
3
;
// 使用kNumStages>1,则需要的LDS大于64KB
constexpr
int
kLogFMTShmemSize
=
kMaxNumWarps
*
(
kNumStages
*
kNumSendLogFMTBytes
+
kNumMetaBytes
);
__shared__
uint8_t
smem_buffer
[
kLogFMTShmemSize
];
/////////////////////////////////////////////
...
...
@@ -707,6 +708,17 @@ combine(void* combined_x,
auto
logfmt_buffers
=
PatternVisitor
([
=
](
const
int
&
i
)
{
return
reinterpret_cast
<
int4
*>
(
smem_ptr
+
i
*
kNumSendLogFMTBytes
);
});
// 存储logfmt的最大最小值
auto
meta_buffers
=
bSupportLogFMT
?
reinterpret_cast
<
__hip_bfloat162
*>
(
smem_ptr
+
kNumStages
*
kNumSendLogFMTBytes
)
:
nullptr
;
// 用于多buffer时临时存储
auto
get_num_logfmt_bytes
=
[
&
](
const
int
&
offset_int4
)
{
return
min
(
kNumSendLogFMTBytes
,
static_cast
<
int
>
((
hidden_bf16_int4
-
offset_int4
)
*
sizeof
(
int4
)));
};
// 简化从global到LDS的存储写法
auto
logfmt_load_global2lds
=
[
&
](
const
int
&
stage_idx
,
const
int4
*
gmem_ptr
,
const
int
&
num_bytes
)
{
UNROLLED_WARP_COPY_LL
(
1
,
lane_id
,
num_bytes
/
sizeof
(
int4
),
reinterpret_cast
<
int4
*>
(
logfmt_buffers
[
stage_idx
]),
reinterpret_cast
<
const
int4
*>
(
gmem_ptr
),
ld_direct_global
,
st_na_global
);
};
// Unpack layout
int
offset
,
num_tokens_to_send
;
...
...
@@ -728,66 +740,103 @@ combine(void* combined_x,
int
num_send_bytes
=
hidden
*
sizeof
(
hip_bfloat16
);
if
(
not
zero_copy
or
dst_p2p_ptr
!=
0
)
{
// Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
const
auto
cpy_src_int4_ptr
=
zero_copy
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
x_int4
;
const
auto
cpy_dst_int4_ptr
=
dst_p2p_ptr
==
0
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
reinterpret_cast
<
int4
*>
(
dst_p2p_ptr
);
const
auto
cpy_dst_int4_ptr
=
dst_p2p_ptr
==
0
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
reinterpret_cast
<
int4
*>
(
dst_p2p_ptr
);
// 设置数据的真实偏移量
constexpr
int
kNumIters
=
hidden_bf16_int4
/
kNumMsgInt4ElemPerWarp
;
EP_STATIC_ASSERT
(
kNumIters
>=
1
,
"hidden length too small"
);
if
constexpr
(
bSupportLogFMT
)
{
// ===== LogFMT 路径:使用 LDS + encode + 多级流水 =====
int
logfmt_offset_bytes
=
kNumMetaBytes
;
// 进入循环,逐步拷贝数据
constexpr
int
encode_num_warps
=
hidden_bf16_int4
/
kNumMsgInt4
ElemPer
Warp
;
for
(
int
iter_idx
=
0
;
iter_idx
<
encode_num_warps
;
++
iter_idx
)
{
int
num_logfmt
_bytes
=
kNum
MsgInt4ElemPerWarp
*
sizeof
(
int4
)
;
// meta_buffers 存储的thread间隔
constexpr
int
kNumInt4PerDivision
=
128
/
kNum
Elem
s
Per
Int4
;
// 记录S1~S3的编码字节数
int
encoded
_bytes
[
kNum
Stages
]
;
// 原始数据的warp级编译
int
warp_offset
=
iter_idx
*
kNumMsgInt4ElemPerWarp
;
// Prefetch: iter0执行S1
logfmt_load_global2lds
(
0
,
cpy_src_int4_ptr
,
get_num_logfmt_bytes
(
0
));
syncwarp
();
if
constexpr
(
bSupportLogFMT
)
{
// 采用 寄存器->lds->global 的流水线方式, 量化后拷贝到buf_ptr中
const
int
&
stage_idx
=
iter_idx
%
kNumStages
;
// Prefetch: iter0执行S2, iter1执行S1
if
(
kNumStages
>
2
&&
kNumIters
>
1
)
{
int
warp_offset
=
/*1 * */
kNumMsgInt4ElemPerWarp
;
logfmt_load_global2lds
(
1
,
cpy_src_int4_ptr
+
warp_offset
,
get_num_logfmt_bytes
(
warp_offset
));
// thread偏移
int
thread_offset
=
/*0 + */
lane_id
*
kNumSendUnrolls
;
int
num_bytes
=
logfmt_encode
<
kNumSendUnrolls
>
(
logfmt_buffers
[
0
],
(
thread_offset
%
kNumInt4PerDivision
==
0
)
?
meta_buffers
+
thread_offset
/
kNumInt4PerDivision
:
nullptr
,
lane_id
);
encoded_bytes
[
0
]
=
num_bytes
;
}
syncwarp
();
// 采用3级流水
for
(
int
iter_idx
=
0
;
iter_idx
<
kNumIters
;
++
iter_idx
)
{
// 流水线S1: 加载第 (kNumStages-1) 轮之后的数据
const
int
stage_last_iter
=
iter_idx
+
kNumStages
-
1
;
// 当前iter所在stage中的最后一个,初始为S3的读取数据
if
(
stage_last_iter
<
kNumIters
)
{
int
stage_idx
=
stage_last_iter
%
kNumStages
;
int
warp_offset
=
stage_last_iter
*
kNumMsgInt4ElemPerWarp
;
logfmt_load_global2lds
(
stage_idx
,
cpy_src_int4_ptr
+
warp_offset
,
get_num_logfmt_bytes
(
warp_offset
));
}
// 流水线S2: 处理下一轮的数据量化
const
int
stage_next_iter
=
iter_idx
+
1
;
if
(
stage_next_iter
<
kNumIters
)
{
int
stage_idx
=
stage_next_iter
%
kNumStages
;
int
warp_offset
=
stage_next_iter
*
kNumMsgInt4ElemPerWarp
;
int
thread_offset
=
warp_offset
+
lane_id
*
kNumSendUnrolls
;
constexpr
int
kNumInt4PerDivision
=
128
/
kNumElemsPerInt4
;
// = 128/(sizeof(int4) / sizeof(hip_bfloat16)) = 128/(16/2)=16
num_logfmt_bytes
=
logfmt_encode
<
kNumSendUnrolls
>
(
cpy_src_int4_ptr
+
warp_offset
,
// 等同于 x_int4
int
num_bytes
=
logfmt_encode
<
kNumSendUnrolls
>
(
logfmt_buffers
[
stage_idx
],
// NOTES: only the leader lane will write the result
(
thread_offset
%
kNumInt4PerDivision
==
0
)
?
meta_buffers
+
thread_offset
/
kNumInt4PerDivision
:
nullptr
,
lane_id
);
// 将量化后的数据写入
using
vec_type
=
uint32_t
;
UNROLLED_WARP_COPY_LL
(
2
,
lane_id
,
num_logfmt_bytes
/
sizeof
(
vec_type
),
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
cpy_dst_int4_ptr
)
+
logfmt_offset_bytes
),
reinterpret_cast
<
vec_type
*>
(
logfmt_buffers
[
stage_idx
]),
ld_nc_global
,
st_na_global
);
lane_id
);
encoded_bytes
[
stage_idx
]
=
num_bytes
;
}
// 起始地址偏移
logfmt_offset_bytes
+=
num_logfmt_bytes
;
}
else
{
// 非量化数据的传输
UNROLLED_WARP_COPY_LL
(
2
,
lane_id
,
kNumMsgInt4ElemPerWarp
,
reinterpret_cast
<
int4
*>
(
cpy_dst_int4_ptr
+
warp_offset
),
reinterpret_cast
<
const
int4
*>
(
cpy_src_int4_ptr
+
warp_offset
),
ld_nc_global
,
st_na_global
);
// 流水线S3:当前轮进行数据拷贝到通信显存
if
(
iter_idx
<
kNumIters
)
{
int
stage_idx
=
iter_idx
%
kNumStages
;
using
vec_type
=
uint64_t
;
int
nvecs
=
encoded_bytes
[
stage_idx
]
/
sizeof
(
vec_type
);
if
(
nvecs
>
0
)
{
UNROLLED_WARP_COPY_LL
(
1
,
lane_id
,
nvecs
,
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
cpy_dst_int4_ptr
)
+
logfmt_offset_bytes
),
reinterpret_cast
<
vec_type
*>
(
logfmt_buffers
[
stage_idx
]),
ld_direct_global
,
st_na_global
);
}
logfmt_offset_bytes
+=
encoded_bytes
[
stage_idx
];
}
syncwarp
();
}
// Store metadata (min/max values) for LogFMT
if
constexpr
(
bSupportLogFMT
)
{
// 最终设置节点间传输的字节数
num_send_bytes
=
logfmt_offset_bytes
;
using
vec_type
=
uint32_t
;
auto
meta_buffers_ptr
=
reinterpret_cast
<
vec_type
*>
(
meta_buffers
);
auto
cpy_dst_uint32_ptr
=
reinterpret_cast
<
vec_type
*>
(
cpy_dst_int4_ptr
);
// Store metadata
using
meta_vec_type
=
uint32_t
;
UNROLLED_WARP_COPY_LL
(
1
,
lane_id
,
kNumMetaBytes
/
sizeof
(
meta_vec_type
),
reinterpret_cast
<
meta_vec_type
*>
(
cpy_dst_int4_ptr
),
reinterpret_cast
<
meta_vec_type
*>
(
meta_buffers
),
ld_direct_global
,
st_na_global
);
for
(
int
j
=
lane_id
;
j
<
kNumMetaBytes
/
sizeof
(
vec_type
);
j
+=
kWarpSize
)
{
*
(
cpy_dst_uint32_ptr
+
j
)
=
meta_buffers_ptr
[
j
];
}
else
{
// ===== 非 LogFMT 路径:直接 global -> global,不经过 LDS =====
for
(
int
iter_idx
=
0
;
iter_idx
<
kNumIters
;
++
iter_idx
)
{
int
warp_offset
=
iter_idx
*
kNumMsgInt4ElemPerWarp
;
UNROLLED_WARP_COPY_LL
(
kNumSendUnrolls
,
lane_id
,
kNumMsgInt4ElemPerWarp
,
cpy_dst_int4_ptr
+
warp_offset
,
cpy_src_int4_ptr
+
warp_offset
,
ld_direct_global
,
st_na_global
);
syncwarp
();
}
// 非 LogFMT 时,发送字节数为原始大小
num_send_bytes
=
hidden_bf16_int4
*
sizeof
(
int4
);
// 或根据实际计算
}
syncwarp
();
}
...
...
@@ -858,10 +907,6 @@ LOW_LATENCY_COMBINE_RECV:
// 计算需要多少个warp
constexpr
int
num_decode_warps
=
hidden_bf16_int4
/
(
kNumRecvUnrolls
*
kWarpSize
);
// 限制thread_id
if
(
warp_id
>=
num_decode_warps
)
{
return
;
}
// 每128个数据记录一个max/min值,即该数为总的max/min值数量
constexpr
int
kNumDivisionBytes
=
kNumDivisions
*
sizeof
(
float
);
...
...
@@ -889,6 +934,7 @@ LOW_LATENCY_COMBINE_RECV:
topk_weights_by_lane
=
__ldg
(
topk_weights
+
token_idx
*
num_topk
+
lane_id
);
}
for
(
int
w_i
=
warp_id
;
w_i
<
num_decode_warps
;
w_i
+=
num_warps
)
{
float
combined_values
[
kNumElemsPerInt4
*
kNumRecvUnrolls
]
=
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
{
...
...
@@ -906,7 +952,7 @@ LOW_LATENCY_COMBINE_RECV:
const
uint8_t
*
data_buffer
=
rdma_buffer_type
+
kNumMetaBytes
;
// 读取max/min数据
if
(
w
arp
_i
d
==
0
)
{
if
(
w_i
==
0
)
{
// 因为每个warp能处理数据量为 kWarpSize*sizeof(int4)/sizeof(bfloat16) * kNumSendUnrolls
// 即不考虑kNumSendUnrolls,一共 kWarpSize*sizeof(int4)/sizeof(bfloat16)/128 组, 代入参数 = kWarpSize / 16 个warp,nv上为2,dcu上为4
logfmt_check_amaxmin
<
kNumDivisions
/
(
kWarpSize
/
16
),
kNumSendUnrolls
,
kNumRecvUnrolls
>
(
...
...
@@ -920,13 +966,13 @@ LOW_LATENCY_COMBINE_RECV:
__syncthreads
();
// 获取cast_info_buffers
const
auto
&
info
=
cast_info_buffers
[
stage_idx
][
w
arp
_i
d
];
const
auto
&
info
=
cast_info_buffers
[
stage_idx
][
w_i
];
bool
enable_cast
=
info
&
1
;
int
num_casted_prefix
=
info
>>
1
;
// 可用的
// 计算偏移(与TMA版本逻辑一致)
int
warp_offset
=
kNumLogFMTPerWarpBytes
*
num_casted_prefix
+
kNumBF16PerWarpBytes
*
(
w
arp
_i
d
-
num_casted_prefix
);
kNumBF16PerWarpBytes
*
(
w_i
-
num_casted_prefix
);
int
lane_offset
=
(
enable_cast
?
kNumLogFMTPerWarpBytes
:
kNumBF16PerWarpBytes
)
/
kWarpSize
*
lane_id
;
// 使用临时缓冲区进行归约
...
...
@@ -940,7 +986,7 @@ LOW_LATENCY_COMBINE_RECV:
具体的lane_id处理的数据量为 warp_idx / kWarpSize
*/
int
log_amaxmin_per_warp
=
kNumRecvUnrolls
*
kWarpSize
*
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
)
/
QUANTIZATION_GROUPSIZE
;
int
division_idx
=
w
arp
_i
d
*
log_amaxmin_per_warp
+
lane_id
*
log_amaxmin_per_warp
/
kWarpSize
;
int
division_idx
=
w_i
*
log_amaxmin_per_warp
+
lane_id
*
log_amaxmin_per_warp
/
kWarpSize
;
// 反量化
decode_and_accumulate
<
kNumRecvUnrolls
>
(
...
...
@@ -955,7 +1001,7 @@ LOW_LATENCY_COMBINE_RECV:
const
uint8_t
*
data_buffer
=
rdma_buffer_type
;
// 计算偏移
int
warp_offset
=
kNumBF16PerWarpBytes
*
w
arp
_i
d
;
int
warp_offset
=
kNumBF16PerWarpBytes
*
w_i
;
int
lane_offset
=
kNumBF16PerWarpBytes
/
kWarpSize
*
lane_id
;
// 使用临时缓冲区进行归约
const
uint8_t
*
thread_data_ptr
=
data_buffer
+
warp_offset
+
lane_offset
;
...
...
@@ -984,7 +1030,8 @@ LOW_LATENCY_COMBINE_RECV:
for
(
int
j
=
0
;
j
<
kNumRecvUnrolls
;
++
j
)
{
(
reinterpret_cast
<
int4
*>
(
combined_x
)
+
token_idx
*
hidden_bf16_int4
+
warp_id
*
kWarpSize
*
kNumRecvUnrolls
)[
lane_id
*
kNumRecvUnrolls
+
j
]
=
combined_int4
[
j
];
w_i
*
kWarpSize
*
kNumRecvUnrolls
)[
lane_id
*
kNumRecvUnrolls
+
j
]
=
combined_int4
[
j
];
}
}
}
}
...
...
@@ -1001,7 +1048,7 @@ void combine(void* combined_x,
bool
use_logfmt
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
constexpr
int
kMaxNumWarps
=
16
;
constexpr
int
kMaxNumWarps
=
8
;
constexpr
int
kNumMaxTopk
=
11
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
kMaxNumWarps
/
num_warp_groups
;
// num_warps_per_group>1, "Requires more than one warp per group"
...
...
csrc/kernels/internode_ll_logfmt.cuh
View file @
33bafa16
...
...
@@ -17,7 +17,7 @@ namespace internode_ll {
template
<
int
kNumSendUnrolls
>
__forceinline__
__device__
int
logfmt_encode
(
const
int4
*
cpy_src_int4_ptr
,
int4
*
ds
t
_buffer
,
__hip_bfloat162
*
shared_amaxmin
,
const
int
&
lane_id
)
{
__forceinline__
__device__
int
logfmt_encode
(
int4
*
l
ds_buffer
,
__hip_bfloat162
*
shared_amaxmin
,
const
int
&
lane_id
)
{
EP_STATIC_ASSERT
(
kNumSendUnrolls
==
2
,
"kNumSendUnrolls == 2 only"
);
constexpr
int
kNumElemsPerInt4
=
sizeof
(
int4
)
/
sizeof
(
__hip_bfloat16
);
// 8
...
...
@@ -33,7 +33,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
const
auto
&
bf162_values
=
reinterpret_cast
<
__hip_bfloat162
*>
(
int4_values
);
// Calculate lane offset
const
auto
&
ld_buffer
=
cpy_src_int4_ptr
+
lane_id
*
kNumSendUnrolls
;
const
auto
&
ld_buffer
=
reinterpret_cast
<
int4
*>
(
reinterpret_cast
<
uint8_t
*>
(
lds_buffer
)
+
lane_id
*
kSendValueBytes
);
const
auto
&
st_buffer
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
lds_buffer
)
+
lane_id
*
kSendValueBytes
*
10
/
16
);
// Local log amax
auto
bf162_amax
=
__hip_bfloat162
(
HIPRT_ZERO_BF16
,
HIPRT_ZERO_BF16
);
...
...
@@ -68,6 +69,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
// Reduce per 128 channels
// TODO: figure out how hardware do 2-byte min/max
const
auto
&
fp162_max
=
__bfloat1622float2
(
bf162_amax
);
auto
amax
=
__builtin_fmaxf
(
static_cast
<
float
>
(
bf162_amax
.
x
),
static_cast
<
float
>
(
bf162_amax
.
y
));
auto
amin
=
__builtin_fminf
(
static_cast
<
float
>
(
bf162_amin
.
x
),
static_cast
<
float
>
(
bf162_amin
.
y
));
...
...
@@ -80,26 +83,22 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
if
(
shared_amaxmin
!=
nullptr
)
{
*
shared_amaxmin
=
__hip_bfloat162
(
amax
,
amin
);
}
syncwarp
();
// Calculate log amin/amax float
const
auto
&
log_amax
=
__builtin_log2f
(
amax
);
const
auto
&
log_amin
=
__builtin_fmaxf
(
__builtin_log2f
(
amin
),
log_amax
-
kMinClip
);
const
auto
&
log_amax
=
__builtin_amdgcn_logf
(
amax
);
const
auto
&
log_amin
=
__builtin_fmaxf
(
__builtin_amdgcn_logf
(
amin
),
log_amax
-
kMinClip
);
// 在组内广播enable_cast结果
const
bool
&
enable_cast
=
warp_reduce_and
<
kNumLanesToReduce
,
true
>
(
log_amax
<
kLogThreshold
and
log_amin
<
log_amax
);
// Case into LogFMT-10 if satisfied
if
(
enable_cast
)
{
constexpr
int
dst_buffer_step
=
kSendValueBytes
*
10
/
16
;
const
auto
&
st_buffer
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
dst_buffer
)
+
lane_id
*
dst_buffer_step
);
uint32_t
st_u32_values
[
dst_buffer_step
/
sizeof
(
uint32_t
)];
// = 5
// 计算10bit数据的两个相邻数值的差值
const
auto
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
const
auto
step_inv
=
1.0
f
/
step
;
// 计算舍入值
const
auto
rounding
=
2.0
f
-
__builtin_log
2
f
((
1.0
f
+
__builtin_exp2f
(
step
))
*
0.5
f
)
*
step_inv
;
const
auto
rounding
=
2.0
f
-
__builtin_
amdgcn_
logf
((
1.0
f
+
__builtin_
amdgcn_
exp2f
(
step
))
*
0.5
f
)
*
step_inv
;
const
auto
fused_rounding
=
rounding
-
log_amin
*
step_inv
;
// 用于存储编码后的值
...
...
@@ -111,7 +110,7 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
#pragma unroll
for
(
int
k
=
0
;
k
<
kNumElemsPerInt4
;
++
k
)
{
// 8
// 将 bfloat162 转换为 float2
const
auto
&
fp
16
2_fvalue
=
__bfloat1622float2
(
bf162_values
[
k
]);
const
auto
&
fp
32
2_fvalue
=
__bfloat1622float2
(
bf162_values
[
k
]);
/*
实际进行压缩的公式为:
...
...
@@ -124,37 +123,19 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
K: 压缩后的整数的最大值(即,K为2的幂)
*/
// 对 float 值进行编码
encoded
[
k
*
2
+
0
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_log
2
f
(
fp
16
2_fvalue
.
x
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
1
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_log
2
f
(
fp
16
2_fvalue
.
y
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
0
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_
amdgcn_
logf
(
fp
32
2_fvalue
.
x
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
1
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_
amdgcn_
logf
(
fp
32
2_fvalue
.
y
)
*
step_inv
+
fused_rounding
,
0
));
}
// 批量打包编码后的值到 st_buffer
st_u32_values
[
0
]
=
(
encoded
[
0
]
>>
0
)
|
(
encoded
[
1
]
<<
9
)
|
(
encoded
[
2
]
<<
18
)
|
(
encoded
[
3
]
<<
27
);
st_u32_values
[
1
]
=
(
encoded
[
3
]
>>
5
)
|
(
encoded
[
4
]
<<
4
)
|
(
encoded
[
5
]
<<
13
)
|
(
encoded
[
6
]
<<
22
)
|
(
encoded
[
7
]
<<
31
);
st_u32_values
[
2
]
=
(
encoded
[
7
]
>>
1
)
|
(
encoded
[
8
]
<<
8
)
|
(
encoded
[
9
]
<<
17
)
|
(
encoded
[
10
]
<<
26
);
st_u32_values
[
3
]
=
(
encoded
[
10
]
>>
6
)
|
(
encoded
[
11
]
<<
3
)
|
(
encoded
[
12
]
<<
12
)
|
(
encoded
[
13
]
<<
21
)
|
(
encoded
[
14
]
<<
30
);
st_u32_values
[
4
]
=
(
encoded
[
14
]
>>
2
)
|
(
encoded
[
15
]
<<
7
)
|
(
local_signs
<<
16
);
}
// 保存160bit的数据到st_buffer
st_buffer
[
0
]
=
st_u32_values
[
0
];
*
(
reinterpret_cast
<
int4
*>
(
st_buffer
+
1
))
=
*
(
reinterpret_cast
<
int4
*>
(
st_u32_values
+
1
));
}
else
{
// 准备收发数据
using
vec_type
=
int4
;
const
auto
&
ld_buffer_vec
=
reinterpret_cast
<
const
vec_type
*>
(
ld_buffer
);
auto
st_buffer_vec
=
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
dst_buffer
)
+
lane_id
*
kSendValueBytes
);
constexpr
int
kLoopIter
=
kSendValueBytes
/
sizeof
(
vec_type
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kLoopIter
;
++
k
)
{
st_buffer_vec
[
k
]
=
ld_nc_global
(
ld_buffer_vec
+
k
);
st_buffer
[
0
]
=
(
encoded
[
0
]
>>
0
)
|
(
encoded
[
1
]
<<
9
)
|
(
encoded
[
2
]
<<
18
)
|
(
encoded
[
3
]
<<
27
);
st_buffer
[
1
]
=
(
encoded
[
3
]
>>
5
)
|
(
encoded
[
4
]
<<
4
)
|
(
encoded
[
5
]
<<
13
)
|
(
encoded
[
6
]
<<
22
)
|
(
encoded
[
7
]
<<
31
);
st_buffer
[
2
]
=
(
encoded
[
7
]
>>
1
)
|
(
encoded
[
8
]
<<
8
)
|
(
encoded
[
9
]
<<
17
)
|
(
encoded
[
10
]
<<
26
);
st_buffer
[
3
]
=
(
encoded
[
10
]
>>
6
)
|
(
encoded
[
11
]
<<
3
)
|
(
encoded
[
12
]
<<
12
)
|
(
encoded
[
13
]
<<
21
)
|
(
encoded
[
14
]
<<
30
);
st_buffer
[
4
]
=
(
encoded
[
14
]
>>
2
)
|
(
encoded
[
15
]
<<
7
)
|
(
local_signs
<<
16
);
}
}
// 确保 warp 内的所有线程都完成打包操作
syncwarp
();
// 计算量化成功和失败时的数据量
constexpr
int
unable_cast_num_bytes
=
kWarpSize
*
kSendValueBytes
;
// = 64*2*16 = 2048
constexpr
int
enable_cast_num_bytes
=
unable_cast_num_bytes
*
10
/
16
;
// = 2048/16*10=1280
...
...
@@ -191,8 +172,8 @@ __forceinline__ __device__ void logfmt_check_amaxmin(
for
(
int
i
=
0
;
i
<
kNumQuantGroupsPerWarp
;
++
i
)
{
// sizeof(uint64_t) / sizeof(__hip_bfloat162) = 2
auto
amax
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
x
);
auto
amin
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
y
);
log_amax
[
i
]
=
__builtin_log
2
f
(
amax
);
log_amin
[
i
]
=
amin
==
0
?
log_amax
[
i
]
-
kMinClip
:
__builtin_fmaxf
(
__builtin_log
2
f
(
amin
),
log_amax
[
i
]
-
kMinClip
);
log_amax
[
i
]
=
__builtin_
amdgcn_
logf
(
amax
);
log_amin
[
i
]
=
amin
==
0
?
log_amax
[
i
]
-
kMinClip
:
__builtin_fmaxf
(
__builtin_
amdgcn_
logf
(
amin
),
log_amax
[
i
]
-
kMinClip
);
enable_cast
=
enable_cast
and
log_amax
[
i
]
<
kLogThreshold
and
log_amin
[
i
]
<
log_amax
[
i
];
}
...
...
@@ -229,7 +210,7 @@ __forceinline__ __device__ void decode_and_accumulate(
const
auto
&
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
auto
decode
=
[
=
](
const
uint32_t
&
encoded
,
const
uint32_t
&
sign
)
{
const
auto
decoded
=
encoded
==
0
?
.0
f
:
__builtin_exp2f
((
encoded
-
1
)
*
step
+
log_amin
);
const
auto
decoded
=
encoded
==
0
?
.0
f
:
__builtin_
amdgcn_
exp2f
((
encoded
-
1
)
*
step
+
log_amin
);
return
sign
?
-
decoded
:
decoded
;
};
...
...
csrc/kernels/utils.cuh
View file @
33bafa16
...
...
@@ -240,6 +240,12 @@ template <typename dtype_t> __device__ __forceinline__ dtype_t ld_nc_global(cons
return
*
reinterpret_cast
<
dtype_t
*>
(
&
ret
);
}
template
<
typename
dtype_t
>
__device__
__forceinline__
dtype_t
ld_direct_global
(
const
dtype_t
*
ptr
)
{
using
T
=
typename
VecInt
<
sizeof
(
dtype_t
)
>::
vec_t
;
auto
ret
=
*
(
reinterpret_cast
<
const
T
*>
(
ptr
));
return
*
reinterpret_cast
<
dtype_t
*>
(
&
ret
);
}
////////////////// used in ibgda
__device__
__forceinline__
void
st_na_relaxed
(
const
uint8_t
*
ptr
,
uint8_t
val
)
{
uint8_t
*
non_const_ptr
=
const_cast
<
uint8_t
*>
(
ptr
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment