Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
33bafa16
Commit
33bafa16
authored
Mar 06, 2026
by
lishen
Browse files
lowlatency combine实现3级流水
parent
61bc0aff
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
215 additions
and
181 deletions
+215
-181
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+189
-142
csrc/kernels/internode_ll_logfmt.cuh
csrc/kernels/internode_ll_logfmt.cuh
+20
-39
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+6
-0
No files found.
csrc/kernels/internode_ll.cu
View file @
33bafa16
This diff is collapsed.
Click to expand it.
csrc/kernels/internode_ll_logfmt.cuh
View file @
33bafa16
...
@@ -17,7 +17,7 @@ namespace internode_ll {
...
@@ -17,7 +17,7 @@ namespace internode_ll {
template
<
int
kNumSendUnrolls
>
template
<
int
kNumSendUnrolls
>
__forceinline__
__device__
int
logfmt_encode
(
const
int4
*
cpy_src_int4_ptr
,
int4
*
ds
t
_buffer
,
__hip_bfloat162
*
shared_amaxmin
,
const
int
&
lane_id
)
{
__forceinline__
__device__
int
logfmt_encode
(
int4
*
l
ds_buffer
,
__hip_bfloat162
*
shared_amaxmin
,
const
int
&
lane_id
)
{
EP_STATIC_ASSERT
(
kNumSendUnrolls
==
2
,
"kNumSendUnrolls == 2 only"
);
EP_STATIC_ASSERT
(
kNumSendUnrolls
==
2
,
"kNumSendUnrolls == 2 only"
);
constexpr
int
kNumElemsPerInt4
=
sizeof
(
int4
)
/
sizeof
(
__hip_bfloat16
);
// 8
constexpr
int
kNumElemsPerInt4
=
sizeof
(
int4
)
/
sizeof
(
__hip_bfloat16
);
// 8
...
@@ -33,7 +33,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
...
@@ -33,7 +33,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
const
auto
&
bf162_values
=
reinterpret_cast
<
__hip_bfloat162
*>
(
int4_values
);
const
auto
&
bf162_values
=
reinterpret_cast
<
__hip_bfloat162
*>
(
int4_values
);
// Calculate lane offset
// Calculate lane offset
const
auto
&
ld_buffer
=
cpy_src_int4_ptr
+
lane_id
*
kNumSendUnrolls
;
const
auto
&
ld_buffer
=
reinterpret_cast
<
int4
*>
(
reinterpret_cast
<
uint8_t
*>
(
lds_buffer
)
+
lane_id
*
kSendValueBytes
);
const
auto
&
st_buffer
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
lds_buffer
)
+
lane_id
*
kSendValueBytes
*
10
/
16
);
// Local log amax
// Local log amax
auto
bf162_amax
=
__hip_bfloat162
(
HIPRT_ZERO_BF16
,
HIPRT_ZERO_BF16
);
auto
bf162_amax
=
__hip_bfloat162
(
HIPRT_ZERO_BF16
,
HIPRT_ZERO_BF16
);
...
@@ -68,6 +69,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
...
@@ -68,6 +69,8 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
// Reduce per 128 channels
// Reduce per 128 channels
// TODO: figure out how hardware do 2-byte min/max
// TODO: figure out how hardware do 2-byte min/max
const
auto
&
fp162_max
=
__bfloat1622float2
(
bf162_amax
);
auto
amax
=
__builtin_fmaxf
(
static_cast
<
float
>
(
bf162_amax
.
x
),
static_cast
<
float
>
(
bf162_amax
.
y
));
auto
amax
=
__builtin_fmaxf
(
static_cast
<
float
>
(
bf162_amax
.
x
),
static_cast
<
float
>
(
bf162_amax
.
y
));
auto
amin
=
__builtin_fminf
(
static_cast
<
float
>
(
bf162_amin
.
x
),
static_cast
<
float
>
(
bf162_amin
.
y
));
auto
amin
=
__builtin_fminf
(
static_cast
<
float
>
(
bf162_amin
.
x
),
static_cast
<
float
>
(
bf162_amin
.
y
));
...
@@ -80,26 +83,22 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
...
@@ -80,26 +83,22 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
if
(
shared_amaxmin
!=
nullptr
)
{
if
(
shared_amaxmin
!=
nullptr
)
{
*
shared_amaxmin
=
__hip_bfloat162
(
amax
,
amin
);
*
shared_amaxmin
=
__hip_bfloat162
(
amax
,
amin
);
}
}
syncwarp
();
// Calculate log amin/amax float
// Calculate log amin/amax float
const
auto
&
log_amax
=
__builtin_log2f
(
amax
);
const
auto
&
log_amax
=
__builtin_amdgcn_logf
(
amax
);
const
auto
&
log_amin
=
__builtin_fmaxf
(
__builtin_log2f
(
amin
),
log_amax
-
kMinClip
);
const
auto
&
log_amin
=
__builtin_fmaxf
(
__builtin_amdgcn_logf
(
amin
),
log_amax
-
kMinClip
);
// 在组内广播enable_cast结果
// 在组内广播enable_cast结果
const
bool
&
enable_cast
=
warp_reduce_and
<
kNumLanesToReduce
,
true
>
(
log_amax
<
kLogThreshold
and
log_amin
<
log_amax
);
const
bool
&
enable_cast
=
warp_reduce_and
<
kNumLanesToReduce
,
true
>
(
log_amax
<
kLogThreshold
and
log_amin
<
log_amax
);
// Case into LogFMT-10 if satisfied
// Case into LogFMT-10 if satisfied
if
(
enable_cast
)
{
if
(
enable_cast
)
{
constexpr
int
dst_buffer_step
=
kSendValueBytes
*
10
/
16
;
const
auto
&
st_buffer
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uint8_t
*>
(
dst_buffer
)
+
lane_id
*
dst_buffer_step
);
uint32_t
st_u32_values
[
dst_buffer_step
/
sizeof
(
uint32_t
)];
// = 5
// 计算10bit数据的两个相邻数值的差值
// 计算10bit数据的两个相邻数值的差值
const
auto
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
const
auto
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
const
auto
step_inv
=
1.0
f
/
step
;
const
auto
step_inv
=
1.0
f
/
step
;
// 计算舍入值
// 计算舍入值
const
auto
rounding
=
2.0
f
-
__builtin_log
2
f
((
1.0
f
+
__builtin_exp2f
(
step
))
*
0.5
f
)
*
step_inv
;
const
auto
rounding
=
2.0
f
-
__builtin_
amdgcn_
logf
((
1.0
f
+
__builtin_
amdgcn_
exp2f
(
step
))
*
0.5
f
)
*
step_inv
;
const
auto
fused_rounding
=
rounding
-
log_amin
*
step_inv
;
const
auto
fused_rounding
=
rounding
-
log_amin
*
step_inv
;
// 用于存储编码后的值
// 用于存储编码后的值
...
@@ -111,7 +110,7 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
...
@@ -111,7 +110,7 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
kNumElemsPerInt4
;
++
k
)
{
// 8
for
(
int
k
=
0
;
k
<
kNumElemsPerInt4
;
++
k
)
{
// 8
// 将 bfloat162 转换为 float2
// 将 bfloat162 转换为 float2
const
auto
&
fp
16
2_fvalue
=
__bfloat1622float2
(
bf162_values
[
k
]);
const
auto
&
fp
32
2_fvalue
=
__bfloat1622float2
(
bf162_values
[
k
]);
/*
/*
实际进行压缩的公式为:
实际进行压缩的公式为:
...
@@ -124,37 +123,19 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
...
@@ -124,37 +123,19 @@ __forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4*
K: 压缩后的整数的最大值(即,K为2的幂)
K: 压缩后的整数的最大值(即,K为2的幂)
*/
*/
// 对 float 值进行编码
// 对 float 值进行编码
encoded
[
k
*
2
+
0
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_log
2
f
(
fp
16
2_fvalue
.
x
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
0
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_
amdgcn_
logf
(
fp
32
2_fvalue
.
x
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
1
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_log
2
f
(
fp
16
2_fvalue
.
y
)
*
step_inv
+
fused_rounding
,
0
));
encoded
[
k
*
2
+
1
]
=
__float2uint_rd
(
__builtin_fmaxf
(
__builtin_
amdgcn_
logf
(
fp
32
2_fvalue
.
y
)
*
step_inv
+
fused_rounding
,
0
));
}
}
// 批量打包编码后的值到 st_buffer
// 批量打包编码后的值到 st_buffer
st_u32_values
[
0
]
=
(
encoded
[
0
]
>>
0
)
|
(
encoded
[
1
]
<<
9
)
|
(
encoded
[
2
]
<<
18
)
|
(
encoded
[
3
]
<<
27
);
st_buffer
[
0
]
=
(
encoded
[
0
]
>>
0
)
|
(
encoded
[
1
]
<<
9
)
|
(
encoded
[
2
]
<<
18
)
|
(
encoded
[
3
]
<<
27
);
st_u32_values
[
1
]
=
(
encoded
[
3
]
>>
5
)
|
(
encoded
[
4
]
<<
4
)
|
(
encoded
[
5
]
<<
13
)
|
(
encoded
[
6
]
<<
22
)
|
(
encoded
[
7
]
<<
31
);
st_buffer
[
1
]
=
(
encoded
[
3
]
>>
5
)
|
(
encoded
[
4
]
<<
4
)
|
(
encoded
[
5
]
<<
13
)
|
(
encoded
[
6
]
<<
22
)
|
(
encoded
[
7
]
<<
31
);
st_u32_values
[
2
]
=
(
encoded
[
7
]
>>
1
)
|
(
encoded
[
8
]
<<
8
)
|
(
encoded
[
9
]
<<
17
)
|
(
encoded
[
10
]
<<
26
);
st_buffer
[
2
]
=
(
encoded
[
7
]
>>
1
)
|
(
encoded
[
8
]
<<
8
)
|
(
encoded
[
9
]
<<
17
)
|
(
encoded
[
10
]
<<
26
);
st_u32_values
[
3
]
=
(
encoded
[
10
]
>>
6
)
|
(
encoded
[
11
]
<<
3
)
|
(
encoded
[
12
]
<<
12
)
|
(
encoded
[
13
]
<<
21
)
|
(
encoded
[
14
]
<<
30
);
st_buffer
[
3
]
=
(
encoded
[
10
]
>>
6
)
|
(
encoded
[
11
]
<<
3
)
|
(
encoded
[
12
]
<<
12
)
|
(
encoded
[
13
]
<<
21
)
|
(
encoded
[
14
]
<<
30
);
st_u32_values
[
4
]
=
(
encoded
[
14
]
>>
2
)
|
(
encoded
[
15
]
<<
7
)
|
(
local_signs
<<
16
);
st_buffer
[
4
]
=
(
encoded
[
14
]
>>
2
)
|
(
encoded
[
15
]
<<
7
)
|
(
local_signs
<<
16
);
}
// 保存160bit的数据到st_buffer
st_buffer
[
0
]
=
st_u32_values
[
0
];
*
(
reinterpret_cast
<
int4
*>
(
st_buffer
+
1
))
=
*
(
reinterpret_cast
<
int4
*>
(
st_u32_values
+
1
));
}
else
{
// 准备收发数据
using
vec_type
=
int4
;
const
auto
&
ld_buffer_vec
=
reinterpret_cast
<
const
vec_type
*>
(
ld_buffer
);
auto
st_buffer_vec
=
reinterpret_cast
<
vec_type
*>
(
reinterpret_cast
<
uint8_t
*>
(
dst_buffer
)
+
lane_id
*
kSendValueBytes
);
constexpr
int
kLoopIter
=
kSendValueBytes
/
sizeof
(
vec_type
);
#pragma unroll
for
(
int
k
=
0
;
k
<
kLoopIter
;
++
k
)
{
st_buffer_vec
[
k
]
=
ld_nc_global
(
ld_buffer_vec
+
k
);
}
}
}
}
// 确保 warp 内的所有线程都完成打包操作
syncwarp
();
// 计算量化成功和失败时的数据量
// 计算量化成功和失败时的数据量
constexpr
int
unable_cast_num_bytes
=
kWarpSize
*
kSendValueBytes
;
// = 64*2*16 = 2048
constexpr
int
unable_cast_num_bytes
=
kWarpSize
*
kSendValueBytes
;
// = 64*2*16 = 2048
constexpr
int
enable_cast_num_bytes
=
unable_cast_num_bytes
*
10
/
16
;
// = 2048/16*10=1280
constexpr
int
enable_cast_num_bytes
=
unable_cast_num_bytes
*
10
/
16
;
// = 2048/16*10=1280
...
@@ -191,8 +172,8 @@ __forceinline__ __device__ void logfmt_check_amaxmin(
...
@@ -191,8 +172,8 @@ __forceinline__ __device__ void logfmt_check_amaxmin(
for
(
int
i
=
0
;
i
<
kNumQuantGroupsPerWarp
;
++
i
)
{
// sizeof(uint64_t) / sizeof(__hip_bfloat162) = 2
for
(
int
i
=
0
;
i
<
kNumQuantGroupsPerWarp
;
++
i
)
{
// sizeof(uint64_t) / sizeof(__hip_bfloat162) = 2
auto
amax
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
x
);
auto
amax
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
x
);
auto
amin
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
y
);
auto
amin
=
static_cast
<
float
>
(
bf162_amaxmin
[
i
].
y
);
log_amax
[
i
]
=
__builtin_log
2
f
(
amax
);
log_amax
[
i
]
=
__builtin_
amdgcn_
logf
(
amax
);
log_amin
[
i
]
=
amin
==
0
?
log_amax
[
i
]
-
kMinClip
:
__builtin_fmaxf
(
__builtin_log
2
f
(
amin
),
log_amax
[
i
]
-
kMinClip
);
log_amin
[
i
]
=
amin
==
0
?
log_amax
[
i
]
-
kMinClip
:
__builtin_fmaxf
(
__builtin_
amdgcn_
logf
(
amin
),
log_amax
[
i
]
-
kMinClip
);
enable_cast
=
enable_cast
and
log_amax
[
i
]
<
kLogThreshold
and
log_amin
[
i
]
<
log_amax
[
i
];
enable_cast
=
enable_cast
and
log_amax
[
i
]
<
kLogThreshold
and
log_amin
[
i
]
<
log_amax
[
i
];
}
}
...
@@ -229,7 +210,7 @@ __forceinline__ __device__ void decode_and_accumulate(
...
@@ -229,7 +210,7 @@ __forceinline__ __device__ void decode_and_accumulate(
const
auto
&
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
const
auto
&
step
=
(
log_amax
-
log_amin
)
/
static_cast
<
float
>
(
kNumValues
-
2
);
auto
decode
=
[
=
](
const
uint32_t
&
encoded
,
const
uint32_t
&
sign
)
{
auto
decode
=
[
=
](
const
uint32_t
&
encoded
,
const
uint32_t
&
sign
)
{
const
auto
decoded
=
encoded
==
0
?
.0
f
:
__builtin_exp2f
((
encoded
-
1
)
*
step
+
log_amin
);
const
auto
decoded
=
encoded
==
0
?
.0
f
:
__builtin_
amdgcn_
exp2f
((
encoded
-
1
)
*
step
+
log_amin
);
return
sign
?
-
decoded
:
decoded
;
return
sign
?
-
decoded
:
decoded
;
};
};
...
...
csrc/kernels/utils.cuh
View file @
33bafa16
...
@@ -240,6 +240,12 @@ template <typename dtype_t> __device__ __forceinline__ dtype_t ld_nc_global(cons
...
@@ -240,6 +240,12 @@ template <typename dtype_t> __device__ __forceinline__ dtype_t ld_nc_global(cons
return
*
reinterpret_cast
<
dtype_t
*>
(
&
ret
);
return
*
reinterpret_cast
<
dtype_t
*>
(
&
ret
);
}
}
template
<
typename
dtype_t
>
__device__
__forceinline__
dtype_t
ld_direct_global
(
const
dtype_t
*
ptr
)
{
using
T
=
typename
VecInt
<
sizeof
(
dtype_t
)
>::
vec_t
;
auto
ret
=
*
(
reinterpret_cast
<
const
T
*>
(
ptr
));
return
*
reinterpret_cast
<
dtype_t
*>
(
&
ret
);
}
////////////////// used in ibgda
////////////////// used in ibgda
__device__
__forceinline__
void
st_na_relaxed
(
const
uint8_t
*
ptr
,
uint8_t
val
)
{
__device__
__forceinline__
void
st_na_relaxed
(
const
uint8_t
*
ptr
,
uint8_t
val
)
{
uint8_t
*
non_const_ptr
=
const_cast
<
uint8_t
*>
(
ptr
);
uint8_t
*
non_const_ptr
=
const_cast
<
uint8_t
*>
(
ptr
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment