Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
33bafa16
Commit
33bafa16
authored
Mar 06, 2026
by
lishen
Browse files
lowlatency combine实现3级流水
parent
61bc0aff
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
215 additions
and
181 deletions
+215
-181
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+189
-142
csrc/kernels/internode_ll_logfmt.cuh
csrc/kernels/internode_ll_logfmt.cuh
+20
-39
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+6
-0
No files found.
csrc/kernels/internode_ll.cu
View file @
33bafa16
...
@@ -636,6 +636,7 @@ combine(void* combined_x,
...
@@ -636,6 +636,7 @@ combine(void* combined_x,
const
auto
num_local_experts
=
num_experts
/
num_ranks
;
const
auto
num_local_experts
=
num_experts
/
num_ranks
;
const
auto
warp_group_id
=
warp_id
/
num_warps_per_group
;
const
auto
warp_group_id
=
warp_id
/
num_warps_per_group
;
const
auto
sub_warp_id
=
warp_id
%
num_warps_per_group
;
const
auto
sub_warp_id
=
warp_id
%
num_warps_per_group
;
const
auto
num_warps
=
num_threads
/
kWarpSize
;
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
const
auto
responsible_expert_idx
=
sm_id
*
num_warp_groups
+
warp_group_id
;
// Data type staffs
// Data type staffs
...
@@ -656,7 +657,7 @@ combine(void* combined_x,
...
@@ -656,7 +657,7 @@ combine(void* combined_x,
constexpr
int
kNumDivisions
=
kHidden
/
QUANTIZATION_GROUPSIZE
;
constexpr
int
kNumDivisions
=
kHidden
/
QUANTIZATION_GROUPSIZE
;
constexpr
int
kNumMetaBytes
=
kNumDivisions
*
sizeof
(
__hip_bfloat162
);
// 用于记录数据的最大最小值
constexpr
int
kNumMetaBytes
=
kNumDivisions
*
sizeof
(
__hip_bfloat162
);
// 用于记录数据的最大最小值
constexpr
int
kNumSendLogFMTBytes
=
kNumMsgInt4ElemPerWarp
*
sizeof
(
int4
);
constexpr
int
kNumSendLogFMTBytes
=
kNumMsgInt4ElemPerWarp
*
sizeof
(
int4
);
constexpr
int
kNumStages
=
1
;
// 使用kNumStages>1,则需要的LDS大于64KB
constexpr
int
kNumStages
=
3
;
// 使用kNumStages>1,则需要的LDS大于64KB
constexpr
int
kLogFMTShmemSize
=
kMaxNumWarps
*
(
kNumStages
*
kNumSendLogFMTBytes
+
kNumMetaBytes
);
constexpr
int
kLogFMTShmemSize
=
kMaxNumWarps
*
(
kNumStages
*
kNumSendLogFMTBytes
+
kNumMetaBytes
);
__shared__
uint8_t
smem_buffer
[
kLogFMTShmemSize
];
__shared__
uint8_t
smem_buffer
[
kLogFMTShmemSize
];
/////////////////////////////////////////////
/////////////////////////////////////////////
...
@@ -707,6 +708,17 @@ combine(void* combined_x,
...
@@ -707,6 +708,17 @@ combine(void* combined_x,
auto
logfmt_buffers
=
PatternVisitor
([
=
](
const
int
&
i
)
{
return
reinterpret_cast
<
int4
*>
(
smem_ptr
+
i
*
kNumSendLogFMTBytes
);
});
auto
logfmt_buffers
=
PatternVisitor
([
=
](
const
int
&
i
)
{
return
reinterpret_cast
<
int4
*>
(
smem_ptr
+
i
*
kNumSendLogFMTBytes
);
});
// 存储logfmt的最大最小值
// 存储logfmt的最大最小值
auto
meta_buffers
=
bSupportLogFMT
?
reinterpret_cast
<
__hip_bfloat162
*>
(
smem_ptr
+
kNumStages
*
kNumSendLogFMTBytes
)
:
nullptr
;
auto
meta_buffers
=
bSupportLogFMT
?
reinterpret_cast
<
__hip_bfloat162
*>
(
smem_ptr
+
kNumStages
*
kNumSendLogFMTBytes
)
:
nullptr
;
// 用于多buffer时临时存储
auto
get_num_logfmt_bytes
=
[
&
](
const
int
&
offset_int4
)
{
return
min
(
kNumSendLogFMTBytes
,
static_cast
<
int
>
((
hidden_bf16_int4
-
offset_int4
)
*
sizeof
(
int4
)));
};
// 简化从global到LDS的存储写法
auto
logfmt_load_global2lds
=
[
&
](
const
int
&
stage_idx
,
const
int4
*
gmem_ptr
,
const
int
&
num_bytes
)
{
UNROLLED_WARP_COPY_LL
(
1
,
lane_id
,
num_bytes
/
sizeof
(
int4
),
reinterpret_cast
<
int4
*>
(
logfmt_buffers
[
stage_idx
]),
reinterpret_cast
<
const
int4
*>
(
gmem_ptr
),
ld_direct_global
,
st_na_global
);
};
// Unpack layout
// Unpack layout
int
offset
,
num_tokens_to_send
;
int
offset
,
num_tokens_to_send
;
...
@@ -722,72 +734,109 @@ combine(void* combined_x,
...
@@ -722,72 +734,109 @@ combine(void* combined_x,
const
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
src_idx
=
__ldg
(
local_src_info
+
token_idx
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
buf_ptr
=
reinterpret_cast
<
int64_t
>
(
rdma_send_x_vec_row
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
;
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_recv_x
)
+
(
global_expert_idx
*
num_max_dispatch_tokens_per_rank
+
src_idx
)
*
num_bytes_per_slot
;
// 采用logfmt或者直接拷贝
// 采用logfmt或者直接拷贝
uint64_t
dst_p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
uint64_t
dst_p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
int
num_send_bytes
=
hidden
*
sizeof
(
hip_bfloat16
);
int
num_send_bytes
=
hidden
*
sizeof
(
hip_bfloat16
);
if
(
not
zero_copy
or
dst_p2p_ptr
!=
0
)
{
if
(
not
zero_copy
or
dst_p2p_ptr
!=
0
)
{
// Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
const
auto
cpy_src_int4_ptr
=
zero_copy
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
x_int4
;
const
auto
cpy_src_int4_ptr
=
zero_copy
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
x_int4
;
const
auto
cpy_dst_int4_ptr
=
dst_p2p_ptr
==
0
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
reinterpret_cast
<
int4
*>
(
dst_p2p_ptr
);
const
auto
cpy_dst_int4_ptr
=
dst_p2p_ptr
==
0
?
reinterpret_cast
<
int4
*>
(
buf_ptr
)
:
reinterpret_cast
<
int4
*>
(
dst_p2p_ptr
);
// 设置数据的真实偏移量
int
logfmt_offset_bytes
=
kNumMetaBytes
;
// 进入循环,逐步拷贝数据
constexpr
int
encode_num_warps
=
hidden_bf16_int4
/
kNumMsgInt4ElemPerWarp
;
for
(
int
iter_idx
=
0
;
iter_idx
<
encode_num_warps
;
++
iter_idx
)
{
int
num_logfmt_bytes
=
kNumMsgInt4ElemPerWarp
*
sizeof
(
int4
);
// 原始数据的warp级编译
int
warp_offset
=
iter_idx
*
kNumMsgInt4ElemPerWarp
;
if
constexpr
(
bSupportLogFMT
)
{
// 采用 寄存器->lds->global 的流水线方式, 量化后拷贝到buf_ptr中
const
int
&
stage_idx
=
iter_idx
%
kNumStages
;
// thread偏移
int
thread_offset
=
warp_offset
+
lane_id
*
kNumSendUnrolls
;
constexpr
int
kNumInt4PerDivision
=
128
/
kNumElemsPerInt4
;
// = 128/(sizeof(int4) / sizeof(hip_bfloat16)) = 128/(16/2)=16
num_logfmt_bytes
=
logfmt_encode
<
kNumSendUnrolls
>
(
cpy_src_int4_ptr
+
warp_offset
,
// 等同于 x_int4
logfmt_buffers
[
stage_idx
],
// NOTES: only the leader lane will write the result
(
thread_offset
%
kNumInt4PerDivision
==
0
)
?
meta_buffers
+
thread_offset
/
kNumInt4PerDivision
:
nullptr
,
lane_id
);
// 将量化后的数据写入
constexpr
int
kNumIters
=
hidden_bf16_int4
/
kNumMsgInt4ElemPerWarp
;
using
vec_type
=
uint32_t
;
EP_STATIC_ASSERT
(
kNumIters
>=
1
,
"hidden length too small"
);
UNROLLED_WARP_COPY_LL
(
2
,
lane_id
,
num_logfmt_bytes
/
sizeof
(
vec_type
),
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
cpy_dst_int4_ptr
)
+
logfmt_offset_bytes
),
reinterpret_cast
<
vec_type
*>
(
logfmt_buffers
[
stage_idx
]),
ld_nc_global
,
st_na_global
);
// 起始地址偏移
if
constexpr
(
bSupportLogFMT
)
{
logfmt_offset_bytes
+=
num_logfmt_bytes
;
// ===== LogFMT 路径:使用 LDS + encode + 多级流水 =====
}
else
{
int
logfmt_offset_bytes
=
kNumMetaBytes
;
// 非量化数据的传输
// meta_buffers 存储的thread间隔
UNROLLED_WARP_COPY_LL
(
2
,
lane_id
,
kNumMsgInt4ElemPerWarp
,
constexpr
int
kNumInt4PerDivision
=
128
/
kNumElemsPerInt4
;
reinterpret_cast
<
int4
*>
(
cpy_dst_int4_ptr
+
warp_offset
),
// 记录S1~S3的编码字节数
reinterpret_cast
<
const
int4
*>
(
cpy_src_int4_ptr
+
warp_offset
),
int
encoded_bytes
[
kNumStages
];
ld_nc_global
,
st_na_global
);
// Prefetch: iter0执行S1
logfmt_load_global2lds
(
0
,
cpy_src_int4_ptr
,
get_num_logfmt_bytes
(
0
));
syncwarp
();
// Prefetch: iter0执行S2, iter1执行S1
if
(
kNumStages
>
2
&&
kNumIters
>
1
)
{
int
warp_offset
=
/*1 * */
kNumMsgInt4ElemPerWarp
;
logfmt_load_global2lds
(
1
,
cpy_src_int4_ptr
+
warp_offset
,
get_num_logfmt_bytes
(
warp_offset
));
int
thread_offset
=
/*0 + */
lane_id
*
kNumSendUnrolls
;
int
num_bytes
=
logfmt_encode
<
kNumSendUnrolls
>
(
logfmt_buffers
[
0
],
(
thread_offset
%
kNumInt4PerDivision
==
0
)
?
meta_buffers
+
thread_offset
/
kNumInt4PerDivision
:
nullptr
,
lane_id
);
encoded_bytes
[
0
]
=
num_bytes
;
}
}
syncwarp
();
syncwarp
();
}
// Store metadata (min/max values) for LogFMT
// 采用3级流水
if
constexpr
(
bSupportLogFMT
)
{
for
(
int
iter_idx
=
0
;
iter_idx
<
kNumIters
;
++
iter_idx
)
{
// 最终设置节点间传输的字节数
// 流水线S1: 加载第 (kNumStages-1) 轮之后的数据
const
int
stage_last_iter
=
iter_idx
+
kNumStages
-
1
;
// 当前iter所在stage中的最后一个,初始为S3的读取数据
if
(
stage_last_iter
<
kNumIters
)
{
int
stage_idx
=
stage_last_iter
%
kNumStages
;
int
warp_offset
=
stage_last_iter
*
kNumMsgInt4ElemPerWarp
;
logfmt_load_global2lds
(
stage_idx
,
cpy_src_int4_ptr
+
warp_offset
,
get_num_logfmt_bytes
(
warp_offset
));
}
// 流水线S2: 处理下一轮的数据量化
const
int
stage_next_iter
=
iter_idx
+
1
;
if
(
stage_next_iter
<
kNumIters
)
{
int
stage_idx
=
stage_next_iter
%
kNumStages
;
int
warp_offset
=
stage_next_iter
*
kNumMsgInt4ElemPerWarp
;
int
thread_offset
=
warp_offset
+
lane_id
*
kNumSendUnrolls
;
int
num_bytes
=
logfmt_encode
<
kNumSendUnrolls
>
(
logfmt_buffers
[
stage_idx
],
(
thread_offset
%
kNumInt4PerDivision
==
0
)
?
meta_buffers
+
thread_offset
/
kNumInt4PerDivision
:
nullptr
,
lane_id
);
encoded_bytes
[
stage_idx
]
=
num_bytes
;
}
// 流水线S3:当前轮进行数据拷贝到通信显存
if
(
iter_idx
<
kNumIters
)
{
int
stage_idx
=
iter_idx
%
kNumStages
;
using
vec_type
=
uint64_t
;
int
nvecs
=
encoded_bytes
[
stage_idx
]
/
sizeof
(
vec_type
);
if
(
nvecs
>
0
)
{
UNROLLED_WARP_COPY_LL
(
1
,
lane_id
,
nvecs
,
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
cpy_dst_int4_ptr
)
+
logfmt_offset_bytes
),
reinterpret_cast
<
vec_type
*>
(
logfmt_buffers
[
stage_idx
]),
ld_direct_global
,
st_na_global
);
}
logfmt_offset_bytes
+=
encoded_bytes
[
stage_idx
];
}
syncwarp
();
}
num_send_bytes
=
logfmt_offset_bytes
;
num_send_bytes
=
logfmt_offset_bytes
;
using
vec_type
=
uint32_t
;
// Store metadata
auto
meta_buffers_ptr
=
reinterpret_cast
<
vec_type
*>
(
meta_buffers
);
using
meta_vec_type
=
uint32_t
;
auto
cpy_dst_uint32_ptr
=
reinterpret_cast
<
vec_type
*>
(
cpy_dst_int4_ptr
);
UNROLLED_WARP_COPY_LL
(
1
,
lane_id
,
kNumMetaBytes
/
sizeof
(
meta_vec_type
),
reinterpret_cast
<
meta_vec_type
*>
(
cpy_dst_int4_ptr
),
reinterpret_cast
<
meta_vec_type
*>
(
meta_buffers
),
ld_direct_global
,
st_na_global
);
for
(
int
j
=
lane_id
;
j
<
kNumMetaBytes
/
sizeof
(
vec_type
);
j
+=
kWarpSize
)
{
}
else
{
*
(
cpy_dst_uint32_ptr
+
j
)
=
meta_buffers_ptr
[
j
];
// ===== 非 LogFMT 路径:直接 global -> global,不经过 LDS =====
for
(
int
iter_idx
=
0
;
iter_idx
<
kNumIters
;
++
iter_idx
)
{
int
warp_offset
=
iter_idx
*
kNumMsgInt4ElemPerWarp
;
UNROLLED_WARP_COPY_LL
(
kNumSendUnrolls
,
lane_id
,
kNumMsgInt4ElemPerWarp
,
cpy_dst_int4_ptr
+
warp_offset
,
cpy_src_int4_ptr
+
warp_offset
,
ld_direct_global
,
st_na_global
);
syncwarp
();
}
}
// 非 LogFMT 时,发送字节数为原始大小
num_send_bytes
=
hidden_bf16_int4
*
sizeof
(
int4
);
// 或根据实际计算
}
}
syncwarp
();
syncwarp
();
}
}
...
@@ -858,10 +907,6 @@ LOW_LATENCY_COMBINE_RECV:
...
@@ -858,10 +907,6 @@ LOW_LATENCY_COMBINE_RECV:
// 计算需要多少个warp
// 计算需要多少个warp
constexpr
int
num_decode_warps
=
hidden_bf16_int4
/
(
kNumRecvUnrolls
*
kWarpSize
);
constexpr
int
num_decode_warps
=
hidden_bf16_int4
/
(
kNumRecvUnrolls
*
kWarpSize
);
// 限制thread_id
if
(
warp_id
>=
num_decode_warps
)
{
return
;
}
// 每128个数据记录一个max/min值,即该数为总的max/min值数量
// 每128个数据记录一个max/min值,即该数为总的max/min值数量
constexpr
int
kNumDivisionBytes
=
kNumDivisions
*
sizeof
(
float
);
constexpr
int
kNumDivisionBytes
=
kNumDivisions
*
sizeof
(
float
);
...
@@ -889,102 +934,104 @@ LOW_LATENCY_COMBINE_RECV:
...
@@ -889,102 +934,104 @@ LOW_LATENCY_COMBINE_RECV:
topk_weights_by_lane
=
__ldg
(
topk_weights
+
token_idx
*
num_topk
+
lane_id
);
topk_weights_by_lane
=
__ldg
(
topk_weights
+
token_idx
*
num_topk
+
lane_id
);
}
}
float
combined_values
[
kNumElemsPerInt4
*
kNumRecvUnrolls
]
=
{
0.0
f
};
for
(
int
w_i
=
warp_id
;
w_i
<
num_decode_warps
;
w_i
+=
num_warps
)
{
#pragma unroll
float
combined_values
[
kNumElemsPerInt4
*
kNumRecvUnrolls
]
=
{
0.0
f
};
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
{
#pragma unroll
int
topk_idx_reg
=
shfl_sync
(
topk_idx_by_lane
,
i
);
for
(
int
i
=
0
;
i
<
num_topk
;
++
i
)
{
if
(
topk_idx_reg
<
0
)
int
topk_idx_reg
=
shfl_sync
(
topk_idx_by_lane
,
i
);
continue
;
if
(
topk_idx_reg
<
0
)
const
auto
&
topk_weight_reg
=
shfl_sync
(
topk_weights_by_lane
,
i
);
continue
;
const
auto
&
topk_weight_reg
=
shfl_sync
(
topk_weights_by_lane
,
i
);
// Read from sources
auto
rdma_buffer_type
=
reinterpret_cast
<
const
uint8_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
// Read from sources
(
topk_idx_reg
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
auto
rdma_buffer_type
=
reinterpret_cast
<
const
uint8_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
rdma_recv_x
)
+
(
topk_idx_reg
*
num_max_dispatch_tokens_per_rank
+
token_idx
)
*
num_bytes_per_slot
);
if
constexpr
(
bSupportLogFMT
)
{
// 接收到的数据位置
if
constexpr
(
bSupportLogFMT
)
{
const
uint8_t
*
data_buffer
=
rdma_buffer_type
+
kNumMetaBytes
;
// 接收到的数据位置
const
uint8_t
*
data_buffer
=
rdma_buffer_type
+
kNumMetaBytes
;
// 读取max/min数据
if
(
warp_id
==
0
)
{
// 读取max/min数据
// 因为每个warp能处理数据量为 kWarpSize*sizeof(int4)/sizeof(bfloat16) * kNumSendUnrolls
if
(
w_i
==
0
)
{
// 即不考虑kNumSendUnrolls,一共 kWarpSize*sizeof(int4)/sizeof(bfloat16)/128 组, 代入参数 = kWarpSize / 16 个warp,nv上为2,dcu上为4
// 因为每个warp能处理数据量为 kWarpSize*sizeof(int4)/sizeof(bfloat16) * kNumSendUnrolls
logfmt_check_amaxmin
<
kNumDivisions
/
(
kWarpSize
/
16
),
kNumSendUnrolls
,
kNumRecvUnrolls
>
(
// 即不考虑kNumSendUnrolls,一共 kWarpSize*sizeof(int4)/sizeof(bfloat16)/128 组, 代入参数 = kWarpSize / 16 个warp,nv上为2,dcu上为4
/*meta_buffer*/
rdma_buffer_type
,
logfmt_check_amaxmin
<
kNumDivisions
/
(
kWarpSize
/
16
),
kNumSendUnrolls
,
kNumRecvUnrolls
>
(
reinterpret_cast
<
int4
*>
(
log_amax_buffers
[
stage_idx
]),
/*meta_buffer*/
rdma_buffer_type
,
reinterpret_cast
<
int4
*>
(
log_amin_buffers
[
stage_idx
]),
reinterpret_cast
<
int4
*>
(
log_amax_buffers
[
stage_idx
]),
cast_info_buffers
[
stage_idx
],
reinterpret_cast
<
int4
*>
(
log_amin_buffers
[
stage_idx
]),
lane_id
);
cast_info_buffers
[
stage_idx
],
}
lane_id
);
}
__syncthreads
();
// 获取cast_info_buffers
__syncthreads
();
const
auto
&
info
=
cast_info_buffers
[
stage_idx
][
warp_id
];
bool
enable_cast
=
info
&
1
;
int
num_casted_prefix
=
info
>>
1
;
// 可用的
// 计算偏移(与TMA版本逻辑一致)
int
warp_offset
=
kNumLogFMTPerWarpBytes
*
num_casted_prefix
+
kNumBF16PerWarpBytes
*
(
warp_id
-
num_casted_prefix
);
int
lane_offset
=
(
enable_cast
?
kNumLogFMTPerWarpBytes
:
kNumBF16PerWarpBytes
)
/
kWarpSize
*
lane_id
;
// 使用临时缓冲区进行归约
const
uint8_t
*
thread_data_ptr
=
data_buffer
+
warp_offset
+
lane_offset
;
/**
一共有kNumDivisions个max/min数据对,读取时每warp默认处理256bit的max/min,所以logfmt_check_amaxmin的kNumLanes设置为 kNumDivisions/2
保存数据时每个log_amax_buffers为float2数据类型,保存总的warpkNumDivisions / 2
实际保存数据时,每个warp保存的实际数据个数为 kWarpSize*kNumRecvUnrolls*sizeof(int4)/sizeof(hip_bfloat16)
实际每个warp读取的max/min的 warp_idx=kWarpSize*kNumRecvUnrolls*sizeof(int4)/sizeof(hip_bfloat16) / 128 = kNumRecvUnrolls * 2
具体的lane_id处理的数据量为 warp_idx / kWarpSize
*/
int
log_amaxmin_per_warp
=
kNumRecvUnrolls
*
kWarpSize
*
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
)
/
QUANTIZATION_GROUPSIZE
;
int
division_idx
=
warp_id
*
log_amaxmin_per_warp
+
lane_id
*
log_amaxmin_per_warp
/
kWarpSize
;
// 反量化
decode_and_accumulate
<
kNumRecvUnrolls
>
(
reinterpret_cast
<
const
uint32_t
*>
(
thread_data_ptr
),
// 直接使用全局内存地址
combined_values
,
log_amax_buffers
[
stage_idx
][
division_idx
],
log_amin_buffers
[
stage_idx
][
division_idx
],
enable_cast
,
topk_weight_reg
);
}
else
{
// 接收到的数据位置
const
uint8_t
*
data_buffer
=
rdma_buffer_type
;
// 计算偏移
// 获取cast_info_buffers
int
warp_offset
=
kNumBF16PerWarpBytes
*
warp_id
;
const
auto
&
info
=
cast_info_buffers
[
stage_idx
][
w_i
];
int
lane_offset
=
kNumBF16PerWarpBytes
/
kWarpSize
*
lane_id
;
bool
enable_cast
=
info
&
1
;
// 使用临时缓冲区进行归约
int
num_casted_prefix
=
info
>>
1
;
// 可用的
const
uint8_t
*
thread_data_ptr
=
data_buffer
+
warp_offset
+
lane_offset
;
// 计算偏移(与TMA版本逻辑一致)
int
warp_offset
=
kNumLogFMTPerWarpBytes
*
num_casted_prefix
+
kNumBF16PerWarpBytes
*
(
w_i
-
num_casted_prefix
);
int
lane_offset
=
(
enable_cast
?
kNumLogFMTPerWarpBytes
:
kNumBF16PerWarpBytes
)
/
kWarpSize
*
lane_id
;
// 使用临时缓冲区进行归约
const
uint8_t
*
thread_data_ptr
=
data_buffer
+
warp_offset
+
lane_offset
;
/**
一共有kNumDivisions个max/min数据对,读取时每warp默认处理256bit的max/min,所以logfmt_check_amaxmin的kNumLanes设置为 kNumDivisions/2
保存数据时每个log_amax_buffers为float2数据类型,保存总的warpkNumDivisions / 2
实际保存数据时,每个warp保存的实际数据个数为 kWarpSize*kNumRecvUnrolls*sizeof(int4)/sizeof(hip_bfloat16)
实际每个warp读取的max/min的 warp_idx=kWarpSize*kNumRecvUnrolls*sizeof(int4)/sizeof(hip_bfloat16) / 128 = kNumRecvUnrolls * 2
具体的lane_id处理的数据量为 warp_idx / kWarpSize
*/
int
log_amaxmin_per_warp
=
kNumRecvUnrolls
*
kWarpSize
*
sizeof
(
int4
)
/
sizeof
(
hip_bfloat16
)
/
QUANTIZATION_GROUPSIZE
;
int
division_idx
=
w_i
*
log_amaxmin_per_warp
+
lane_id
*
log_amaxmin_per_warp
/
kWarpSize
;
// 反量化
decode_and_accumulate
<
kNumRecvUnrolls
>
(
reinterpret_cast
<
const
uint32_t
*>
(
thread_data_ptr
),
// 直接使用全局内存地址
combined_values
,
log_amax_buffers
[
stage_idx
][
division_idx
],
log_amin_buffers
[
stage_idx
][
division_idx
],
enable_cast
,
topk_weight_reg
);
}
else
{
// 接收到的数据位置
const
uint8_t
*
data_buffer
=
rdma_buffer_type
;
#pragma unroll
// 计算偏移
for
(
int
j
=
0
;
j
<
kNumRecvUnrolls
;
++
j
)
{
int
warp_offset
=
kNumBF16PerWarpBytes
*
w_i
;
auto
tmp_rdma_value
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
thread_data_ptr
)
+
j
);
int
lane_offset
=
kNumBF16PerWarpBytes
/
kWarpSize
*
lane_id
;
const
auto
x_bf16
=
reinterpret_cast
<
const
hip_bfloat16
*>
(
&
tmp_rdma_value
);
// 使用临时缓冲区进行归约
const
uint8_t
*
thread_data_ptr
=
data_buffer
+
warp_offset
+
lane_offset
;
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
kNumElemsPerInt4
;
++
k
)
{
for
(
int
j
=
0
;
j
<
kNumRecvUnrolls
;
++
j
)
{
int
combined_idx
=
j
*
kNumElemsPerInt4
+
k
;
auto
tmp_rdma_value
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
thread_data_ptr
)
+
j
);
combined_values
[
combined_idx
]
+=
static_cast
<
float
>
(
x_bf16
[
k
])
*
topk_weight_reg
;
const
auto
x_bf16
=
reinterpret_cast
<
const
hip_bfloat16
*>
(
&
tmp_rdma_value
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kNumElemsPerInt4
;
++
k
)
{
int
combined_idx
=
j
*
kNumElemsPerInt4
+
k
;
combined_values
[
combined_idx
]
+=
static_cast
<
float
>
(
x_bf16
[
k
])
*
topk_weight_reg
;
}
}
}
}
}
}
}
}
// Write results,kNumRecvUnrolls==2时则写256bit的数
// Write results,kNumRecvUnrolls==2时则写256bit的数
int4
combined_int4
[
kNumRecvUnrolls
];
int4
combined_int4
[
kNumRecvUnrolls
];
auto
combined_bf16
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
combined_int4
[
0
]);
auto
combined_bf16
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
combined_int4
[
0
]);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerInt4
*
kNumRecvUnrolls
;
++
j
)
{
for
(
int
j
=
0
;
j
<
kNumElemsPerInt4
*
kNumRecvUnrolls
;
++
j
)
{
combined_bf16
[
j
]
=
static_cast
<
hip_bfloat16
>
(
combined_values
[
j
]);
combined_bf16
[
j
]
=
static_cast
<
hip_bfloat16
>
(
combined_values
[
j
]);
}
}
for
(
int
j
=
0
;
j
<
kNumRecvUnrolls
;
++
j
)
{
for
(
int
j
=
0
;
j
<
kNumRecvUnrolls
;
++
j
)
{
(
reinterpret_cast
<
int4
*>
(
combined_x
)
+
token_idx
*
hidden_bf16_int4
+
(
reinterpret_cast
<
int4
*>
(
combined_x
)
+
token_idx
*
hidden_bf16_int4
+
warp_id
*
kWarpSize
*
kNumRecvUnrolls
)[
lane_id
*
kNumRecvUnrolls
+
j
]
=
combined_int4
[
j
];
w_i
*
kWarpSize
*
kNumRecvUnrolls
)[
lane_id
*
kNumRecvUnrolls
+
j
]
=
combined_int4
[
j
];
}
}
}
}
}
}
}
...
@@ -1001,7 +1048,7 @@ void combine(void* combined_x,
...
@@ -1001,7 +1048,7 @@ void combine(void* combined_x,
bool
use_logfmt
,
bool
use_logfmt
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
void
*
workspace
,
int
num_device_sms
,
hipStream_t
stream
,
int
phases
,
bool
zero_copy
)
{
int
phases
,
bool
zero_copy
)
{
constexpr
int
kMaxNumWarps
=
16
;
constexpr
int
kMaxNumWarps
=
8
;
constexpr
int
kNumMaxTopk
=
11
;
constexpr
int
kNumMaxTopk
=
11
;
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warp_groups
=
ceil_div
(
num_experts
,
num_device_sms
);
const
int
num_warps_per_group
=
kMaxNumWarps
/
num_warp_groups
;
// num_warps_per_group>1, "Requires more than one warp per group"
const
int
num_warps_per_group
=
kMaxNumWarps
/
num_warp_groups
;
// num_warps_per_group>1, "Requires more than one warp per group"
...
...
csrc/kernels/internode_ll_logfmt.cuh
View file @
33bafa16
...
@@ -17,7 +17,7 @@ namespace internode_ll {
...
@@ -17,7 +17,7 @@ namespace internode_ll {
template
<
int
kNumSendUnrolls
>
template
<
int
kNumSendUnrolls
>
__forceinline__
__device__
int
logfmt_encode
(
const
int4
*
cpy_src_int4_ptr
,
int4
*
ds
t
_buffer
,
__hip_bfloat162
*
shared_amaxmin
,
const
int
&
lane_id
)
{
__forceinline__
__device__
int
logfmt_encode
(
int4
*
l
ds_buffer
,
__hip_bfloat162
*
shared_amaxmin
,
const
int
&
lane_id
)
{
EP_STATIC_ASSERT
(
kNumSendUnrolls
==
2
,
"kNumSendUnrolls == 2 only"
);
EP_STATIC_ASSERT
(
kNumSendUnrolls
==
2
,
"kNumSendUnrolls == 2 only"
);
constexpr
int
kNumElemsPerInt4
=
sizeof
(
int4
)
/
sizeof
(
__hip_bfloat16
);
// 8
constexpr
int
kNumElemsPerInt4
=
sizeof
(
int4
)
/
sizeof
(
__hip_bfloat16
);
// 8
...
@@ -33,7 +33,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
...
@@ -33,7 +33,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
const
auto
&
bf162_values
=
reinterpret_cast
<
__hip_bfloat162
*>
(
int4_values
);
const
auto
&
bf162_values
=
reinterpret_cast
<
__hip_bfloat162
*>
(
int4_values
);
// Calculate lane offset
// Calculate lane offset
const
auto
&
ld_buffer
=
cpy_src_int4_ptr
+
lane_id
*
kNumSendUnrolls
;
const
auto
&
ld_buffer
=
reinterpret_cast
<
int4
*>
(
reinterpret_cast
<
uint8_t
*>
(
lds_buffer
)
+
lane_id
*
kSendValueBytes
);
const
auto
&
st_buffer
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
lds_buffer
)
+
lane_id
*
kSendValueBytes
*
10
/
16
);
// Local log amax
// Local log amax
auto
bf162_amax
=
__hip_bfloat162
(
HIPRT_ZERO_BF16
,
HIPRT_ZERO_BF16
);
auto
bf162_amax
=
__hip_bfloat162
(
HIPRT_ZERO_BF16
,
HIPRT_ZERO_BF16
);
...
@@ -68,6 +69,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
...
@@ -68,6 +69,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
// Reduce per 128 channels
// Reduce per 128 channels
// TODO: figure out how hardware do 2-byte min/max
// TODO: figure out how hardware do 2-byte min/max
const
auto
&
fp162_max
=
__bfloat1622float2
(
bf162_amax
);
auto
amax
=
__builtin_fmaxf
(
static_cast
<
float
>
(
bf162_amax
.
x
),
static_cast
<
float
>
(
bf162_amax
.
y
));
auto
amax
=
__builtin_fmaxf
(
static_cast
<
float
>
(
bf162_amax
.
x
),
static_cast
<
float
>
(
bf162_amax
.
y
));
auto
amin
=
__builtin_fminf
(
static_cast
<
float
>
(
bf162_amin
.
x
),
static_cast
<
float
>
(
bf162_amin
.
y
));
auto
amin
=
__builtin_fminf
(
static_cast
<
float
>
(
bf162_amin
.
x
),
static_cast
<
float
>
(
bf162_amin
.
y
));
...
@@ -80,26 +83,22 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
...
@@ -80,26 +83,22 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
if
(
shared_amaxmin
!=
nullptr
)
{
if
(
shared_amaxmin
!=
nullptr
)
{
*
shared_amaxmin
=
__hip_bfloat162
(
amax
,
amin
);
*
shared_amaxmin
=
__hip_bfloat162
(
amax
,
amin
);
}
}
syncwarp
();
// Calculate log amin/amax float
// Calculate log amin/amax float
const
auto
&
log_amax
=
__builtin_log2f
(
amax
);
const
auto
&
log_amax
=
__builtin_amdgcn_logf
(
amax
);
const
auto
&
log_amin
=
__builtin_fmaxf
(
__builtin_log2f
(
amin
),
log_amax
-
kMinClip
);
const
auto
&
log_amin
=
__builtin_fmaxf
(
__builtin_amdgcn_logf
(
amin
),
log_amax
-
kMinClip
);
// 在组内广播enable_cast结果
// 在组内广播enable_cast结果
const
bool
&
enable_cast
=
warp_reduce_and
<
kNumLanesToReduce
,
true
>
(
log_amax
<
kLogThreshold
and
log_amin
<
log_amax
);
const
bool
&
enable_cast
=
warp_reduce_and
<
kNumLanesToReduce
,
true
>
(
log_amax
<
kLogThreshold
and
log_amin
<
log_amax
);
// Case into LogFMT-10 if satisfied
// Case into LogFMT-10 if satisfied
if
(
enable_cast
)
{
if
(
enable_cast
)
{
constexpr
int
dst_buffer_step
=
kSendValueBytes
*
10
/
16
;
const
auto
&
st_buffer
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
dst_buffer
)
+
lane_id
*
dst_buffer_step
);
uint32_t
st_u32_values
[
dst_buffer_step
/
sizeof
(
uint32_t
)];
// = 5
// 计算10bit数据的两个相邻数值的差值
// 计算10bit数据的两个相邻数值的差值
const
auto
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
const
auto
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
const
auto
step_inv
=
1.0
f
/
step
;
const
auto
step_inv
=
1.0
f
/
step
;
// 计算舍入值
// 计算舍入值
const
auto
rounding
=
2.0
f
-
__builtin_log
2
f
((
1.0
f
+
__builtin_exp2f
(
step
))
*
0.5
f
)
*
step_inv
;
const
auto
rounding
=
2.0
f
-
__builtin_
amdgcn_
logf
((
1.0
f
+
__builtin_
amdgcn_
exp2f
(
step
))
*
0.5
f
)
*
step_inv
;
const
auto
fused_rounding
=
rounding
-
log_amin
*
step_inv
;
const
auto
fused_rounding
=
rounding
-
log_amin
*
step_inv
;
// 用于存储编码后的值
// 用于存储编码后的值
...
@@ -111,7 +110,7 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
...
@@ -111,7 +110,7 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
kNumElemsPerInt4
;
++
k
)
{
// 8
for
(
int
k
=
0
;
k
<
kNumElemsPerInt4
;
++
k
)
{
// 8
// 将 bfloat162 转换为 float2
// 将 bfloat162 转换为 float2
const
auto
&
fp
16
2_fvalue
=
__bfloat1622float2
(
bf162_values
[
k
]);
const
auto
&
fp
32
2_fvalue
=
__bfloat1622float2
(
bf162_values
[
k
]);
/*
/*
实际进行压缩的公式为:
实际进行压缩的公式为:
...
@@ -124,37 +123,19 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
...
@@ -124,37 +123,19 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
K: 压缩后的整数的最大值(即,K为2的幂)
K: 压缩后的整数的最大值(即,K为2的幂)
*/
*/
// 对 float 值进行编码
// 对 float 值进行编码
encoded
[
k
*
2
+
0
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_log
2
f
(
fp
16
2_fvalue
.
x
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
0
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_
amdgcn_
logf
(
fp
32
2_fvalue
.
x
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
1
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_log
2
f
(
fp
16
2_fvalue
.
y
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
1
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_
amdgcn_
logf
(
fp
32
2_fvalue
.
y
)
*
step_inv
+
fused_rounding
,
0
));
}
}
// 批量打包编码后的值到 st_buffer
// 批量打包编码后的值到 st_buffer
st_u32_values
[
0
]
=
(
encoded
[
0
]
>>
0
)
|
(
encoded
[
1
]
<<
9
)
|
(
encoded
[
2
]
<<
18
)
|
(
encoded
[
3
]
<<
27
);
st_buffer
[
0
]
=
(
encoded
[
0
]
>>
0
)
|
(
encoded
[
1
]
<<
9
)
|
(
encoded
[
2
]
<<
18
)
|
(
encoded
[
3
]
<<
27
);
st_u32_values
[
1
]
=
(
encoded
[
3
]
>>
5
)
|
(
encoded
[
4
]
<<
4
)
|
(
encoded
[
5
]
<<
13
)
|
(
encoded
[
6
]
<<
22
)
|
(
encoded
[
7
]
<<
31
);
st_buffer
[
1
]
=
(
encoded
[
3
]
>>
5
)
|
(
encoded
[
4
]
<<
4
)
|
(
encoded
[
5
]
<<
13
)
|
(
encoded
[
6
]
<<
22
)
|
(
encoded
[
7
]
<<
31
);
st_u32_values
[
2
]
=
(
encoded
[
7
]
>>
1
)
|
(
encoded
[
8
]
<<
8
)
|
(
encoded
[
9
]
<<
17
)
|
(
encoded
[
10
]
<<
26
);
st_buffer
[
2
]
=
(
encoded
[
7
]
>>
1
)
|
(
encoded
[
8
]
<<
8
)
|
(
encoded
[
9
]
<<
17
)
|
(
encoded
[
10
]
<<
26
);
st_u32_values
[
3
]
=
(
encoded
[
10
]
>>
6
)
|
(
encoded
[
11
]
<<
3
)
|
(
encoded
[
12
]
<<
12
)
|
(
encoded
[
13
]
<<
21
)
|
(
encoded
[
14
]
<<
30
);
st_buffer
[
3
]
=
(
encoded
[
10
]
>>
6
)
|
(
encoded
[
11
]
<<
3
)
|
(
encoded
[
12
]
<<
12
)
|
(
encoded
[
13
]
<<
21
)
|
(
encoded
[
14
]
<<
30
);
st_u32_values
[
4
]
=
(
encoded
[
14
]
>>
2
)
|
(
encoded
[
15
]
<<
7
)
|
(
local_signs
<<
16
);
st_buffer
[
4
]
=
(
encoded
[
14
]
>>
2
)
|
(
encoded
[
15
]
<<
7
)
|
(
local_signs
<<
16
);
}
// 保存160bit的数据到st_buffer
st_buffer
[
0
]
=
st_u32_values
[
0
];
*
(
reinterpret_cast
<
int4
*>
(
st_buffer
+
1
))
=
*
(
reinterpret_cast
<
int4
*>
(
st_u32_values
+
1
));
}
else
{
// 准备收发数据
using
vec_type
=
int4
;
const
auto
&
ld_buffer_vec
=
reinterpret_cast
<
const
vec_type
*>
(
ld_buffer
);
auto
st_buffer_vec
=
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
dst_buffer
)
+
lane_id
*
kSendValueBytes
);
constexpr
int
kLoopIter
=
kSendValueBytes
/
sizeof
(
vec_type
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kLoopIter
;
++
k
)
{
st_buffer_vec
[
k
]
=
ld_nc_global
(
ld_buffer_vec
+
k
);
}
}
}
}
// 确保 warp 内的所有线程都完成打包操作
syncwarp
();
// 计算量化成功和失败时的数据量
// 计算量化成功和失败时的数据量
constexpr
int
unable_cast_num_bytes
=
kWarpSize
*
kSendValueBytes
;
// = 64*2*16 = 2048
constexpr
int
unable_cast_num_bytes
=
kWarpSize
*
kSendValueBytes
;
// = 64*2*16 = 2048
constexpr
int
enable_cast_num_bytes
=
unable_cast_num_bytes
*
10
/
16
;
// = 2048/16*10=1280
constexpr
int
enable_cast_num_bytes
=
unable_cast_num_bytes
*
10
/
16
;
// = 2048/16*10=1280
...
@@ -191,8 +172,8 @@ __forceinline__ __device__ void logfmt_check_amaxmin(
...
@@ -191,8 +172,8 @@ __forceinline__ __device__ void logfmt_check_amaxmin(
for
(
int
i
=
0
;
i
<
kNumQuantGroupsPerWarp
;
++
i
)
{
// sizeof(uint64_t) / sizeof(__hip_bfloat162) = 2
for
(
int
i
=
0
;
i
<
kNumQuantGroupsPerWarp
;
++
i
)
{
// sizeof(uint64_t) / sizeof(__hip_bfloat162) = 2
auto
amax
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
x
);
auto
amax
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
x
);
auto
amin
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
y
);
auto
amin
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
y
);
log_amax
[
i
]
=
__builtin_log
2
f
(
amax
);
log_amax
[
i
]
=
__builtin_
amdgcn_
logf
(
amax
);
log_amin
[
i
]
=
amin
==
0
?
log_amax
[
i
]
-
kMinClip
:
__builtin_fmaxf
(
__builtin_log
2
f
(
amin
),
log_amax
[
i
]
-
kMinClip
);
log_amin
[
i
]
=
amin
==
0
?
log_amax
[
i
]
-
kMinClip
:
__builtin_fmaxf
(
__builtin_
amdgcn_
logf
(
amin
),
log_amax
[
i
]
-
kMinClip
);
enable_cast
=
enable_cast
and
log_amax
[
i
]
<
kLogThreshold
and
log_amin
[
i
]
<
log_amax
[
i
];
enable_cast
=
enable_cast
and
log_amax
[
i
]
<
kLogThreshold
and
log_amin
[
i
]
<
log_amax
[
i
];
}
}
...
@@ -229,7 +210,7 @@ __forceinline__ __device__ void decode_and_accumulate(
...
@@ -229,7 +210,7 @@ __forceinline__ __device__ void decode_and_accumulate(
const
auto
&
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
const
auto
&
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
auto
decode
=
[
=
](
const
uint32_t
&
encoded
,
const
uint32_t
&
sign
)
{
auto
decode
=
[
=
](
const
uint32_t
&
encoded
,
const
uint32_t
&
sign
)
{
const
auto
decoded
=
encoded
==
0
?
.0
f
:
__builtin_exp2f
((
encoded
-
1
)
*
step
+
log_amin
);
const
auto
decoded
=
encoded
==
0
?
.0
f
:
__builtin_
amdgcn_
exp2f
((
encoded
-
1
)
*
step
+
log_amin
);
return
sign
?
-
decoded
:
decoded
;
return
sign
?
-
decoded
:
decoded
;
};
};
...
...
csrc/kernels/utils.cuh
View file @
33bafa16
...
@@ -240,6 +240,12 @@ template <typename dtype_t> __device__ __forceinline__ dtype_t ld_nc_global(cons
...
@@ -240,6 +240,12 @@ template <typename dtype_t> __device__ __forceinline__ dtype_t ld_nc_global(cons
return
*
reinterpret_cast
<
dtype_t
*>
(
&
ret
);
return
*
reinterpret_cast
<
dtype_t
*>
(
&
ret
);
}
}
template
<
typename
dtype_t
>
__device__
__forceinline__
dtype_t
ld_direct_global
(
const
dtype_t
*
ptr
)
{
using
T
=
typename
VecInt
<
sizeof
(
dtype_t
)
>::
vec_t
;
auto
ret
=
*
(
reinterpret_cast
<
const
T
*>
(
ptr
));
return
*
reinterpret_cast
<
dtype_t
*>
(
&
ret
);
}
////////////////// used in ibgda
////////////////// used in ibgda
__device__
__forceinline__
void
st_na_relaxed
(
const
uint8_t
*
ptr
,
uint8_t
val
)
{
__device__
__forceinline__
void
st_na_relaxed
(
const
uint8_t
*
ptr
,
uint8_t
val
)
{
uint8_t
*
non_const_ptr
=
const_cast
<
uint8_t
*>
(
ptr
);
uint8_t
*
non_const_ptr
=
const_cast
<
uint8_t
*>
(
ptr
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment