Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
33bafa16
"csrc/vscode:/vscode.git/clone" did not exist on "2ff5a7733b213de5024c95f8a0da31bed82e28ca"
Commit
33bafa16
authored
Mar 06, 2026
by
lishen
Browse files
lowlatency combine实现3级流水
parent
61bc0aff
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
215 additions
and
181 deletions
+215
-181
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+189
-142
csrc/kernels/internode_ll_logfmt.cuh
csrc/kernels/internode_ll_logfmt.cuh
+20
-39
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+6
-0
No files found.
csrc/kernels/internode_ll.cu
View file @
33bafa16
...
...
@@ -636,6 +636,7 @@ combine(void* combined_x,
const
auto
num_local_experts
=
num_experts
/
num_ranks
;
const
auto
warp_group_id
=
warp_id
/
num_warps_per_group
;
const
auto
sub_warp_id
=
warp_id
%
num_warps_per_group
;
const
auto
num_warps
=
num_threads
/
kWarpSize
;
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
// Data type staffs
...
...
@@ -656,7 +657,7 @@ combine(void* combined_x,
constexpr
int
kNumDivisions
=
kHidden
/
QUANTIZATION_GROUPSIZE
;
constexpr
int
kNumMetaBytes
=
kNumDivisions
*
sizeof
(
__hip_bfloat162
);
// 用于记录数据的最大最小值
constexpr
int
kNumSendLogFMTBytes
=
kNumMsgInt4ElemPerWarp
*
sizeof
(
int4
);
constexpr
int
kNumStages
=
1
;
// 使用kNumStages>1,则需要的LDS大于64KB
constexpr
int
kNumStages
=
3
;
// 使用kNumStages>1,则需要的LDS大于64KB
constexpr
int
kLogFMTShmemSize
=
kMaxNumWarps
*
(
kNumStages
*
kNumSendLogFMTBytes
+
kNumMetaBytes
);
__shared__
uint8_t
smem_buffer
[
kLogFMTShmemSize
];
/////////////////////////////////////////////
...
...
@@ -707,6 +708,17 @@ combine(void* combined_x,
auto
logfmt_buffers
=
PatternVisitor
([
=
](
const
int
&
i
)
{
return
reinterpret_cast
<
int4
*>
(
smem_ptr
+
i
*
kNumSendLogFMTBytes
);
});
// 存储logfmt的最大最小值
auto
meta_buffers
=
bSupportLogFMT
?
reinterpret_cast
<
__hip_bfloat162
*>
(
smem_ptr
+
kNumStages
*
kNumSendLogFMTBytes
)
:
nullptr
;
// 用于多buffer时临时存储
auto
get_num_logfmt_bytes
=
[
&
](
const
int
&
offset_int4
)
{
return
min
(
kNumSendLogFMTBytes
,
static_cast
<
int
>
((
hidden_bf16_int4
-
offset_int4
)
*
sizeof
(
int4
)));
};
// 简化从global到LDS的存储写法
auto
logfmt_load_global2lds
=
[
&
](
const
int
&
stage_idx
,
const
int4
*
gmem_ptr
,
const
int
&
num_bytes
)
{
UNROLLED_WARP_COPY_LL
(
1
,
lane_id
,
num_bytes
/
sizeof
(
int4
),
reinterpret_cast
<
int4
*>
(
logfmt_buffers
[
stage_idx
]),
reinterpret_cast
<
const
int4
*>
(
gmem_ptr
),
ld_direct_global
,
st_na_global
);
};
// Unpack layout
int
offset
,
num_tokens_to_send
;
...
...
@@ -728,66 +740,103 @@ combine(void* combined_x,
int
num_send_bytes
=
hidden
*
sizeof
(
hip_bfloat16
);
if
(
not
zero_copy
or
dst_p2p_ptr
!=
0
)
{
// Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
const
auto
cpy_src_int4_ptr
=
zero_copy
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
x_int4
;
const
auto
cpy_dst_int4_ptr
=
dst_p2p_ptr
==
0
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
reinterpret_cast
<
int4
*>
(
dst_p2p_ptr
);
const
auto
cpy_dst_int4_ptr
=
dst_p2p_ptr
==
0
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
reinterpret_cast
<
int4
*>
(
dst_p2p_ptr
);
// 设置数据的真实偏移量
constexpr
int
kNumIters
=
hidden_bf16_int4
/
kNumMsgInt4ElemPerWarp
;
EP_STATIC_ASSERT
(
kNumIters
>=
1
,
"hidden length too small"
);
if
constexpr
(
bSupportLogFMT
)
{
// ===== LogFMT 路径:使用 LDS + encode + 多级流水 =====
int
logfmt_offset_bytes
=
kNumMetaBytes
;
// 进入循环,逐步拷贝数据
constexpr
int
encode_num_warps
=
hidden_bf16_int4
/
kNumMsgInt4
ElemPer
Warp
;
for
(
int
iter_idx
=
0
;
iter_idx
<
encode_num_warps
;
++
iter_idx
)
{
int
num_logfmt
_bytes
=
kNum
MsgInt4ElemPerWarp
*
sizeof
(
int4
)
;
// meta_buffers 存储的thread间隔
constexpr
int
kNumInt4PerDivision
=
128
/
kNum
Elem
s
Per
Int4
;
// 记录S1~S3的编码字节数
int
encoded
_bytes
[
kNum
Stages
]
;
// 原始数据的warp级编译
int
warp_offset
=
iter_idx
*
kNumMsgInt4ElemPerWarp
;
// Prefetch: iter0执行S1
logfmt_load_global2lds
(
0
,
cpy_src_int4_ptr
,
get_num_logfmt_bytes
(
0
));
syncwarp
();
if
constexpr
(
bSupportLogFMT
)
{
// 采用 寄存器->lds->global 的流水线方式, 量化后拷贝到buf_ptr中
const
int
&
stage_idx
=
iter_idx
%
kNumStages
;
// Prefetch: iter0执行S2, iter1执行S1
if
(
kNumStages
>
2
&&
kNumIters
>
1
)
{
int
warp_offset
=
/*1 * */
kNumMsgInt4ElemPerWarp
;
logfmt_load_global2lds
(
1
,
cpy_src_int4_ptr
+
warp_offset
,
get_num_logfmt_bytes
(
warp_offset
));
// thread偏移
int
thread_offset
=
/*0 + */
lane_id
*
kNumSendUnrolls
;
int
num_bytes
=
logfmt_encode
<
kNumSendUnrolls
>
(
logfmt_buffers
[
0
],
(
thread_offset
%
kNumInt4PerDivision
==
0
)
?
meta_buffers
+
thread_offset
/
kNumInt4PerDivision
:
nullptr
,
lane_id
);
encoded_bytes
[
0
]
=
num_bytes
;
}
syncwarp
();
// 采用3级流水
for
(
int
iter_idx
=
0
;
iter_idx
<
kNumIters
;
++
iter_idx
)
{
// 流水线S1: 加载第 (kNumStages-1) 轮之后的数据
const
int
stage_last_iter
=
iter_idx
+
kNumStages
-
1
;
// 当前iter所在stage中的最后一个,初始为S3的读取数据
if
(
stage_last_iter
<
kNumIters
)
{
int
stage_idx
=
stage_last_iter
%
kNumStages
;
int
warp_offset
=
stage_last_iter
*
kNumMsgInt4ElemPerWarp
;
logfmt_load_global2lds
(
stage_idx
,
cpy_src_int4_ptr
+
warp_offset
,
get_num_logfmt_bytes
(
warp_offset
));
}
// 流水线S2: 处理下一轮的数据量化
const
int
stage_next_iter
=
iter_idx
+
1
;
if
(
stage_next_iter
<
kNumIters
)
{
int
stage_idx
=
stage_next_iter
%
kNumStages
;
int
warp_offset
=
stage_next_iter
*
kNumMsgInt4ElemPerWarp
;
int
thread_offset
=
warp_offset
+
lane_id
*
kNumSendUnrolls
;
constexpr
int
kNumInt4PerDivision
=
128
/
kNumElemsPerInt4
;
// = 128/(sizeof(int4) / sizeof(hip_bfloat16)) = 128/(16/2)=16
num_logfmt_bytes
=
logfmt_encode
<
kNumSendUnrolls
>
(
cpy_src_int4_ptr
+
warp_offset
,
// 等同于 x_int4
int
num_bytes
=
logfmt_encode
<
kNumSendUnrolls
>
(
logfmt_buffers
[
stage_idx
],
// NOTES: only the leader lane will write the result
(
thread_offset
%
kNumInt4PerDivision
==
0
)
?
meta_buffers
+
thread_offset
/
kNumInt4PerDivision
:
nullptr
,
lane_id
);
// 将量化后的数据写入
using
vec_type
=
uint32_t
;
UNROLLED_WARP_COPY_LL
(
2
,
lane_id
,
num_logfmt_bytes
/
sizeof
(
vec_type
),
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
cpy_dst_int4_ptr
)
+
logfmt_offset_bytes
),
reinterpret_cast
<
vec_type
*>
(
logfmt_buffers
[
stage_idx
]),
ld_nc_global
,
st_na_global
);
lane_id
);
encoded_bytes
[
stage_idx
]
=
num_bytes
;
}
// 起始地址偏移
logfmt_offset_bytes
+=
num_logfmt_bytes
;
}
else
{
// 非量化数据的传输
UNROLLED_WARP_COPY_LL
(
2
,
lane_id
,
kNumMsgInt4ElemPerWarp
,
reinterpret_cast
<
int4
*>
(
cpy_dst_int4_ptr
+
warp_offset
),
reinterpret_cast
<
const
int4
*>
(
cpy_src_int4_ptr
+
warp_offset
),
ld_nc_global
,
st_na_global
);
// 流水线S3:当前轮进行数据拷贝到通信显存
if
(
iter_idx
<
kNumIters
)
{
int
stage_idx
=
iter_idx
%
kNumStages
;
using
vec_type
=
uint64_t
;
int
nvecs
=
encoded_bytes
[
stage_idx
]
/
sizeof
(
vec_type
);
if
(
nvecs
>
0
)
{
UNROLLED_WARP_COPY_LL
(
1
,
lane_id
,
nvecs
,
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
cpy_dst_int4_ptr
)
+
logfmt_offset_bytes
),
reinterpret_cast
<
vec_type
*>
(
logfmt_buffers
[
stage_idx
]),
ld_direct_global
,
st_na_global
);
}
logfmt_offset_bytes
+=
encoded_bytes
[
stage_idx
];
}
syncwarp
();
}
// Store metadata (min/max values) for LogFMT
if
constexpr
(
bSupportLogFMT
)
{
// 最终设置节点间传输的字节数
num_send_bytes
=
logfmt_offset_bytes
;
using
vec_type
=
uint32_t
;
auto
meta_buffers_ptr
=
reinterpret_cast
<
vec_type
*>
(
meta_buffers
);
auto
cpy_dst_uint32_ptr
=
reinterpret_cast
<
vec_type
*>
(
cpy_dst_int4_ptr
);
// Store metadata
using
meta_vec_type
=
uint32_t
;
UNROLLED_WARP_COPY_LL
(
1
,
lane_id
,
kNumMetaBytes
/
sizeof
(
meta_vec_type
),
reinterpret_cast
<
meta_vec_type
*>
(
cpy_dst_int4_ptr
),
reinterpret_cast
<
meta_vec_type
*>
(
meta_buffers
),
ld_direct_global
,
st_na_global
);
for
(
int
j
=
lane_id
;
j
<
kNumMetaBytes
/
sizeof
(
vec_type
);
j
+=
kWarpSize
)
{
*
(
cpy_dst_uint32_ptr
+
j
)
=
meta_buffers_ptr
[
j
];
}
else
{
// ===== 非 LogFMT 路径:直接 global -> global,不经过 LDS =====
for
(
int
iter_idx
=
0
;
iter_idx
<
kNumIters
;
++
iter_idx
)
{
int
warp_offset
=
iter_idx
*
kNumMsgInt4ElemPerWarp
;
UNROLLED_WARP_COPY_LL
(
kNumSendUnrolls
,
lane_id
,
kNumMsgInt4ElemPerWarp
,
cpy_dst_int4_ptr
+
warp_offset
,
cpy_src_int4_ptr
+
warp_offset
,
ld_direct_global
,
st_na_global
);
syncwarp
();
}
// 非 LogFMT 时,发送字节数为原始大小
num_send_bytes
=
hidden_bf16_int4
*
sizeof
(
int4
);
// 或根据实际计算
}
syncwarp
();
}
...
...
@@ -858,10 +907,6 @@ LOW_LATENCY_COMBINE_RECV:
// 计算需要多少个warp
constexpr
int
num_decode_warps
=
hidden_bf16_int4
/
(
kNumRecvUnrolls
*
kWarpSize
);
// 限制thread_id
if
(
warp_id
>=
num_decode_warps
)
{
return
;
}
// 每128个数据记录一个max/min值,即该数为总的max/min值数量
constexpr
int
kNumDivisionBytes
=
kNumDivisions
*
sizeof
(
float
);
...
...
@@ -889,6 +934,7 @@ LOW_LATENCY_COMBINE_RECV:
topk_weights_by_lane
=
__ldg
(
topk_weights
+
token_idx
*
num_topk
+
lane_id
);
}
for
(
int
w_i
=
warp_id
;
w_i
<
num_decode_warps
;
w_i
+=
num_warps
)
{
float
combined_values
[
kNumElemsPerInt4
*
kNumRecvUnrolls
]
=
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
{
...
...
@@ -906,7 +952,7 @@ LOW_LATENCY_COMBINE_RECV:
const
uint8_t
*
data_buffer
=
rdma_buffer_type
+
kNumMetaBytes
;
// 读取max/min数据
if
(
w
arp
_i
d
==
0
)
{
if
(
w_i
==
0
)
{
// 因为每个warp能处理数据量为 kWarpSize*sizeof(int4)/sizeof(bfloat16) * kNumSendUnrolls
// 即不考虑kNumSendUnrolls,一共 kWarpSize*sizeof(int4)/sizeof(bfloat16)/128 组, 代入参数 = kWarpSize / 16 个warp,nv上为2,dcu上为4
logfmt_check_amaxmin
<
kNumDivisions
/
(
kWarpSize
/
16
),
kNumSendUnrolls
,
kNumRecvUnrolls
>
(
...
...
@@ -920,13 +966,13 @@ LOW_LATENCY_COMBINE_RECV:
__syncthreads
();
// 获取cast_info_buffers
const
auto
&
info
=
cast_info_buffers
[
stage_idx
][
w
arp
_i
d
];
const
auto
&
info
=
cast_info_buffers
[
stage_idx
][
w_i
];
bool
enable_cast
=
info
&
1
;
int
num_casted_prefix
=
info
>>
1
;
// 可用的
// 计算偏移(与TMA版本逻辑一致)
int
warp_offset
=
kNumLogFMTPerWarpBytes
*
num_casted_prefix
+
kNumBF16PerWarpBytes
*
(
w
arp
_i
d
-
num_casted_prefix
);
kNumBF16PerWarpBytes
*
(
w_i
-
num_casted_prefix
);
int
lane_offset
=
(
enable_cast
?
kNumLogFMTPerWarpBytes
:
kNumBF16PerWarpBytes
)
/
kWarpSize
*
lane_id
;
// 使用临时缓冲区进行归约
...
...
@@ -940,7 +986,7 @@ LOW_LATENCY_COMBINE_RECV:
具体的lane_id处理的数据量为 warp_idx / kWarpSize
*/
int
log_amaxmin_per_warp
=
kNumRecvUnrolls
*
kWarpSize
*
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
)
/
QUANTIZATION_GROUPSIZE
;
int
division_idx
=
w
arp
_i
d
*
log_amaxmin_per_warp
+
lane_id
*
log_amaxmin_per_warp
/
kWarpSize
;
int
division_idx
=
w_i
*
log_amaxmin_per_warp
+
lane_id
*
log_amaxmin_per_warp
/
kWarpSize
;
// 反量化
decode_and_accumulate
<
kNumRecvUnrolls
>
(
...
...
@@ -955,7 +1001,7 @@ LOW_LATENCY_COMBINE_RECV:
const
uint8_t
*
data_buffer
=
rdma_buffer_type
;
// 计算偏移
int
warp_offset
=
kNumBF16PerWarpBytes
*
w
arp
_i
d
;
int
warp_offset
=
kNumBF16PerWarpBytes
*
w_i
;
int
lane_offset
=
kNumBF16PerWarpBytes
/
kWarpSize
*
lane_id
;
// 使用临时缓冲区进行归约
const
uint8_t
*
thread_data_ptr
=
data_buffer
+
warp_offset
+
lane_offset
;
...
...
@@ -984,7 +1030,8 @@ LOW_LATENCY_COMBINE_RECV:
for
(
int
j
=
0
;
j
<
kNumRecvUnrolls
;
++
j
)
{
(
reinterpret_cast
<
int4
*>
(
combined_x
)
+
token_idx
*
hidden_bf16_int4
+
warp_id
*
kWarpSize
*
kNumRecvUnrolls
)[
lane_id
*
kNumRecvUnrolls
+
j
]
=
combined_int4
[
j
];
w_i
*
kWarpSize
*
kNumRecvUnrolls
)[
lane_id
*
kNumRecvUnrolls
+
j
]
=
combined_int4
[
j
];
}
}
}
}
...
...
@@ -1001,7 +1048,7 @@ void combine(void* combined_x,
bool
use_logfmt
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
constexpr
int
kMaxNumWarps
=
16
;
constexpr
int
kMaxNumWarps
=
8
;
constexpr
int
kNumMaxTopk
=
11
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
kMaxNumWarps
/
num_warp_groups
;
// num_warps_per_group>1, "Requires more than one warp per group"
...
...
csrc/kernels/internode_ll_logfmt.cuh
View file @
33bafa16
...
...
@@ -17,7 +17,7 @@ namespace internode_ll {
template
<
int
kNumSendUnrolls
>
__forceinline__
__device__
int
logfmt_encode
(
const
int4
*
cpy_src_int4_ptr
,
int4
*
ds
t
_buffer
,
__hip_bfloat162
*
shared_amaxmin
,
const
int
&
lane_id
)
{
__forceinline__
__device__
int
logfmt_encode
(
int4
*
l
ds_buffer
,
__hip_bfloat162
*
shared_amaxmin
,
const
int
&
lane_id
)
{
EP_STATIC_ASSERT
(
kNumSendUnrolls
==
2
,
"kNumSendUnrolls == 2 only"
);
constexpr
int
kNumElemsPerInt4
=
sizeof
(
int4
)
/
sizeof
(
__hip_bfloat16
);
// 8
...
...
@@ -33,7 +33,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
const
auto
&
bf162_values
=
reinterpret_cast
<
__hip_bfloat162
*>
(
int4_values
);
// Calculate lane offset
const
auto
&
ld_buffer
=
cpy_src_int4_ptr
+
lane_id
*
kNumSendUnrolls
;
const
auto
&
ld_buffer
=
reinterpret_cast
<
int4
*>
(
reinterpret_cast
<
uint8_t
*>
(
lds_buffer
)
+
lane_id
*
kSendValueBytes
);
const
auto
&
st_buffer
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
lds_buffer
)
+
lane_id
*
kSendValueBytes
*
10
/
16
);
// Local log amax
auto
bf162_amax
=
__hip_bfloat162
(
HIPRT_ZERO_BF16
,
HIPRT_ZERO_BF16
);
...
...
@@ -68,6 +69,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
// Reduce per 128 channels
// TODO: figure out how hardware do 2-byte min/max
const
auto
&
fp162_max
=
__bfloat1622float2
(
bf162_amax
);
auto
amax
=
__builtin_fmaxf
(
static_cast
<
float
>
(
bf162_amax
.
x
),
static_cast
<
float
>
(
bf162_amax
.
y
));
auto
amin
=
__builtin_fminf
(
static_cast
<
float
>
(
bf162_amin
.
x
),
static_cast
<
float
>
(
bf162_amin
.
y
));
...
...
@@ -80,26 +83,22 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
if
(
shared_amaxmin
!=
nullptr
)
{
*
shared_amaxmin
=
__hip_bfloat162
(
amax
,
amin
);
}
syncwarp
();
// Calculate log amin/amax float
const
auto
&
log_amax
=
__builtin_log2f
(
amax
);
const
auto
&
log_amin
=
__builtin_fmaxf
(
__builtin_log2f
(
amin
),
log_amax
-
kMinClip
);
const
auto
&
log_amax
=
__builtin_amdgcn_logf
(
amax
);
const
auto
&
log_amin
=
__builtin_fmaxf
(
__builtin_amdgcn_logf
(
amin
),
log_amax
-
kMinClip
);
// 在组内广播enable_cast结果
const
bool
&
enable_cast
=
warp_reduce_and
<
kNumLanesToReduce
,
true
>
(
log_amax
<
kLogThreshold
and
log_amin
<
log_amax
);
// Case into LogFMT-10 if satisfied
if
(
enable_cast
)
{
constexpr
int
dst_buffer_step
=
kSendValueBytes
*
10
/
16
;
const
auto
&
st_buffer
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
dst_buffer
)
+
lane_id
*
dst_buffer_step
);
uint32_t
st_u32_values
[
dst_buffer_step
/
sizeof
(
uint32_t
)];
// = 5
// 计算10bit数据的两个相邻数值的差值
const
auto
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
const
auto
step_inv
=
1.0
f
/
step
;
// 计算舍入值
const
auto
rounding
=
2.0
f
-
__builtin_log
2
f
((
1.0
f
+
__builtin_exp2f
(
step
))
*
0.5
f
)
*
step_inv
;
const
auto
rounding
=
2.0
f
-
__builtin_
amdgcn_
logf
((
1.0
f
+
__builtin_
amdgcn_
exp2f
(
step
))
*
0.5
f
)
*
step_inv
;
const
auto
fused_rounding
=
rounding
-
log_amin
*
step_inv
;
// 用于存储编码后的值
...
...
@@ -111,7 +110,7 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
#pragma unroll
for
(
int
k
=
0
;
k
<
kNumElemsPerInt4
;
++
k
)
{
// 8
// 将 bfloat162 转换为 float2
const
auto
&
fp
16
2_fvalue
=
__bfloat1622float2
(
bf162_values
[
k
]);
const
auto
&
fp
32
2_fvalue
=
__bfloat1622float2
(
bf162_values
[
k
]);
/*
实际进行压缩的公式为:
...
...
@@ -124,37 +123,19 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
K: 压缩后的整数的最大值(即,K为2的幂)
*/
// 对 float 值进行编码
encoded
[
k
*
2
+
0
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_log
2
f
(
fp
16
2_fvalue
.
x
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
1
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_log
2
f
(
fp
16
2_fvalue
.
y
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
0
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_
amdgcn_
logf
(
fp
32
2_fvalue
.
x
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
1
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_
amdgcn_
logf
(
fp
32
2_fvalue
.
y
)
*
step_inv
+
fused_rounding
,
0
));
}
// 批量打包编码后的值到 st_buffer
st_u32_values
[
0
]
=
(
encoded
[
0
]
>>
0
)
|
(
encoded
[
1
]
<<
9
)
|
(
encoded
[
2
]
<<
18
)
|
(
encoded
[
3
]
<<
27
);
st_u32_values
[
1
]
=
(
encoded
[
3
]
>>
5
)
|
(
encoded
[
4
]
<<
4
)
|
(
encoded
[
5
]
<<
13
)
|
(
encoded
[
6
]
<<
22
)
|
(
encoded
[
7
]
<<
31
);
st_u32_values
[
2
]
=
(
encoded
[
7
]
>>
1
)
|
(
encoded
[
8
]
<<
8
)
|
(
encoded
[
9
]
<<
17
)
|
(
encoded
[
10
]
<<
26
);
st_u32_values
[
3
]
=
(
encoded
[
10
]
>>
6
)
|
(
encoded
[
11
]
<<
3
)
|
(
encoded
[
12
]
<<
12
)
|
(
encoded
[
13
]
<<
21
)
|
(
encoded
[
14
]
<<
30
);
st_u32_values
[
4
]
=
(
encoded
[
14
]
>>
2
)
|
(
encoded
[
15
]
<<
7
)
|
(
local_signs
<<
16
);
}
// 保存160bit的数据到st_buffer
st_buffer
[
0
]
=
st_u32_values
[
0
];
*
(
reinterpret_cast
<
int4
*>
(
st_buffer
+
1
))
=
*
(
reinterpret_cast
<
int4
*>
(
st_u32_values
+
1
));
}
else
{
// 准备收发数据
using
vec_type
=
int4
;
const
auto
&
ld_buffer_vec
=
reinterpret_cast
<
const
vec_type
*>
(
ld_buffer
);
auto
st_buffer_vec
=
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
dst_buffer
)
+
lane_id
*
kSendValueBytes
);
constexpr
int
kLoopIter
=
kSendValueBytes
/
sizeof
(
vec_type
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kLoopIter
;
++
k
)
{
st_buffer_vec
[
k
]
=
ld_nc_global
(
ld_buffer_vec
+
k
);
st_buffer
[
0
]
=
(
encoded
[
0
]
>>
0
)
|
(
encoded
[
1
]
<<
9
)
|
(
encoded
[
2
]
<<
18
)
|
(
encoded
[
3
]
<<
27
);
st_buffer
[
1
]
=
(
encoded
[
3
]
>>
5
)
|
(
encoded
[
4
]
<<
4
)
|
(
encoded
[
5
]
<<
13
)
|
(
encoded
[
6
]
<<
22
)
|
(
encoded
[
7
]
<<
31
);
st_buffer
[
2
]
=
(
encoded
[
7
]
>>
1
)
|
(
encoded
[
8
]
<<
8
)
|
(
encoded
[
9
]
<<
17
)
|
(
encoded
[
10
]
<<
26
);
st_buffer
[
3
]
=
(
encoded
[
10
]
>>
6
)
|
(
encoded
[
11
]
<<
3
)
|
(
encoded
[
12
]
<<
12
)
|
(
encoded
[
13
]
<<
21
)
|
(
encoded
[
14
]
<<
30
);
st_buffer
[
4
]
=
(
encoded
[
14
]
>>
2
)
|
(
encoded
[
15
]
<<
7
)
|
(
local_signs
<<
16
);
}
}
// 确保 warp 内的所有线程都完成打包操作
syncwarp
();
// 计算量化成功和失败时的数据量
constexpr
int
unable_cast_num_bytes
=
kWarpSize
*
kSendValueBytes
;
// = 64*2*16 = 2048
constexpr
int
enable_cast_num_bytes
=
unable_cast_num_bytes
*
10
/
16
;
// = 2048/16*10=1280
...
...
@@ -191,8 +172,8 @@ __forceinline__ __device__ void logfmt_check_amaxmin(
for
(
int
i
=
0
;
i
<
kNumQuantGroupsPerWarp
;
++
i
)
{
// sizeof(uint64_t) / sizeof(__hip_bfloat162) = 2
auto
amax
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
x
);
auto
amin
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
y
);
log_amax
[
i
]
=
__builtin_log
2
f
(
amax
);
log_amin
[
i
]
=
amin
==
0
?
log_amax
[
i
]
-
kMinClip
:
__builtin_fmaxf
(
__builtin_log
2
f
(
amin
),
log_amax
[
i
]
-
kMinClip
);
log_amax
[
i
]
=
__builtin_
amdgcn_
logf
(
amax
);
log_amin
[
i
]
=
amin
==
0
?
log_amax
[
i
]
-
kMinClip
:
__builtin_fmaxf
(
__builtin_
amdgcn_
logf
(
amin
),
log_amax
[
i
]
-
kMinClip
);
enable_cast
=
enable_cast
and
log_amax
[
i
]
<
kLogThreshold
and
log_amin
[
i
]
<
log_amax
[
i
];
}
...
...
@@ -229,7 +210,7 @@ __forceinline__ __device__ void decode_and_accumulate(
const
auto
&
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
auto
decode
=
[
=
](
const
uint32_t
&
encoded
,
const
uint32_t
&
sign
)
{
const
auto
decoded
=
encoded
==
0
?
.0
f
:
__builtin_exp2f
((
encoded
-
1
)
*
step
+
log_amin
);
const
auto
decoded
=
encoded
==
0
?
.0
f
:
__builtin_
amdgcn_
exp2f
((
encoded
-
1
)
*
step
+
log_amin
);
return
sign
?
-
decoded
:
decoded
;
};
...
...
csrc/kernels/utils.cuh
View file @
33bafa16
...
...
@@ -240,6 +240,12 @@ template <typename dtype_t> __device__ __forceinline__ dtype_t ld_nc_global(cons
return
*
reinterpret_cast
<
dtype_t
*>
(
&
ret
);
}
template
<
typename
dtype_t
>
__device__
__forceinline__
dtype_t
ld_direct_global
(
const
dtype_t
*
ptr
)
{
using
T
=
typename
VecInt
<
sizeof
(
dtype_t
)
>::
vec_t
;
auto
ret
=
*
(
reinterpret_cast
<
const
T
*>
(
ptr
));
return
*
reinterpret_cast
<
dtype_t
*>
(
&
ret
);
}
////////////////// used in ibgda
__device__
__forceinline__
void
st_na_relaxed
(
const
uint8_t
*
ptr
,
uint8_t
val
)
{
uint8_t
*
non_const_ptr
=
const_cast
<
uint8_t
*>
(
ptr
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment