Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
c1af9783
Commit
c1af9783
authored
Nov 28, 2025
by
zhangyue
Browse files
issue/676: format
parent
5584035d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
78 additions
and
68 deletions
+78
-68
src/infiniop/ops/topkrouter/kunlun/kernel.h
src/infiniop/ops/topkrouter/kunlun/kernel.h
+5
-6
src/infiniop/ops/topkrouter/kunlun/topkrouter_kunlun.xpu
src/infiniop/ops/topkrouter/kunlun/topkrouter_kunlun.xpu
+14
-3
src/infiniop/sort/kunlun/heap.h
src/infiniop/sort/kunlun/heap.h
+58
-58
test/infiniop/topkrouter.py
test/infiniop/topkrouter.py
+1
-1
No files found.
src/infiniop/ops/topkrouter/kunlun/kernel.h
View file @
c1af9783
...
...
@@ -3,7 +3,6 @@
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../sort/kunlun/heap.h"
#include <xpu/kernel/xtdk_io.h>
#include <float.h>
using
namespace
device
::
kunlun
::
kernel
;
...
...
@@ -34,8 +33,8 @@ inline __device__ void descending_sort(T *x, TID *idx, int32_t n) {
mfence_lm
();
}
template
<
typename
T
,
int32_t
BLOCK_THREADS
=
64
,
int32_t
MAX_EXPERTS
=
256
,
int32_t
N_GROUPS
=
8
,
int32_t
TOPK_GROUP
=
4
,
int32_t
TOPK_PER_GROUP
=
2
>
template
<
typename
T
,
int32_t
BLOCK_THREADS
=
64
,
int32_t
MAX_EXPERTS
=
256
,
int32_t
N_GROUPS
=
8
,
int32_t
TOPK_GROUP
=
4
,
int32_t
TOPK_PER_GROUP
=
2
>
__global__
void
topkrouter_kernel
(
float
*
values_topk
,
// 输出数据, 形状[N, topk]
int32_t
*
indices_topk
,
// 输出索引, 形状[N, topk]
...
...
src/infiniop/ops/topkrouter/kunlun/topkrouter_kunlun.xpu
View file @
c1af9783
...
...
@@ -64,6 +64,17 @@ infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const
N,
width,
topk);
} else if (xtype == INFINI_DTYPE_BF16) {
topkrouter_kernel<bfloat16_t, BLOCK_SIZE, 256, 8, 4, 2>
<<<N, BLOCK_SIZE, stream>>>(
d_values_out,
d_indices_out,
(bfloat16_t *)d_input,
(const float *)d_correction_bias,
routed_scaling_factor,
N,
width,
topk);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
...
...
src/infiniop/sort/kunlun/heap.h
View file @
c1af9783
...
...
@@ -3,8 +3,8 @@
#include "xpu/kernel/xtdk_simd_xpu2.h"
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
sm_swap_kv
(
_shared_ptr_
TK
*
k0
,
_shared_ptr_
TV
*
v0
,
_shared_ptr_
TK
*
k1
,
_shared_ptr_
TV
*
v1
)
{
static
__device__
inline
void
sm_swap_kv
(
_shared_ptr_
TK
*
k0
,
_shared_ptr_
TV
*
v0
,
_shared_ptr_
TK
*
k1
,
_shared_ptr_
TV
*
v1
)
{
TK
tmpk
=
*
k0
;
TV
tmpv
=
*
v0
;
*
k0
=
*
k1
;
...
...
@@ -13,9 +13,9 @@ static __device__ inline void sm_swap_kv(_shared_ptr_ TK* k0, _shared_ptr_ TV* v
*
v1
=
tmpv
;
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
update_sm_min_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
idx
,
int
heap_capacity
)
{
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
update_sm_min_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
idx
,
int
heap_capacity
)
{
while
(
idx
<
heap_capacity
)
{
int
child_l
=
idx
*
2
+
1
;
int
child_r
=
idx
*
2
+
2
;
...
...
@@ -23,10 +23,10 @@ static __device__ inline void update_sm_min_heap(_shared_ptr_ TK* heap_key,
if
(
child_r
>=
heap_capacity
)
{
if
(
child_l
>=
heap_capacity
)
{
// idx is leaf node, shift finished
break
;
}
else
{
// if child_r does not exist while child_l does, choose child_l
}
else
{
// if child_r does not exist while child_l does, choose child_l
child_min
=
child_l
;
}
}
else
{
// both child L & R exists
}
else
{
// both child L & R exists
child_min
=
child_l
+
(
heap_key
[
child_l
]
>
heap_key
[
child_r
]);
}
if
(
heap_key
[
idx
]
<=
heap_key
[
child_min
])
{
...
...
@@ -37,26 +37,26 @@ static __device__ inline void update_sm_min_heap(_shared_ptr_ TK* heap_key,
}
}
template
<
typename
TK
,
typename
TV
>
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
make_sm_min_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
size
)
{
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
size
)
{
for
(
int
i
=
size
/
2
-
1
;
i
>=
0
;
i
--
)
{
update_sm_min_heap
(
heap_key
,
heap_value
,
i
,
size
);
}
}
template
<
typename
TK
,
typename
TV
>
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
sort_sm_min_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
heap_capacity
)
{
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
heap_capacity
)
{
for
(
int
i
=
heap_capacity
-
1
;
i
>
0
;
i
--
)
{
sm_swap_kv
(
&
heap_key
[
0
],
&
heap_value
[
0
],
&
heap_key
[
i
],
&
heap_value
[
i
]);
update_sm_min_heap
(
heap_key
,
heap_value
,
0
,
i
);
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
update_sm_max_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
idx
,
int
heap_capacity
)
{
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
update_sm_max_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
idx
,
int
heap_capacity
)
{
while
(
idx
<
heap_capacity
)
{
int
child_l
=
idx
*
2
+
1
;
int
child_r
=
idx
*
2
+
2
;
...
...
@@ -64,10 +64,10 @@ static __device__ inline void update_sm_max_heap(_shared_ptr_ TK* heap_key,
if
(
child_r
>=
heap_capacity
)
{
if
(
child_l
>=
heap_capacity
)
{
// idx is leaf node, shift finished
break
;
}
else
{
// if child_r does not exist while child_l does, choose child_l
}
else
{
// if child_r does not exist while child_l does, choose child_l
child_max
=
child_l
;
}
}
else
{
// both child L & R exists
}
else
{
// both child L & R exists
child_max
=
child_l
+
(
heap_key
[
child_l
]
<
heap_key
[
child_r
]);
}
if
(
heap_key
[
idx
]
>=
heap_key
[
child_max
])
{
...
...
@@ -78,17 +78,17 @@ static __device__ inline void update_sm_max_heap(_shared_ptr_ TK* heap_key,
}
}
template
<
typename
TK
,
typename
TV
>
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
make_sm_max_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
size
)
{
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
size
)
{
for
(
int
i
=
size
/
2
-
1
;
i
>=
0
;
i
--
)
{
update_sm_max_heap
(
heap_key
,
heap_value
,
i
,
size
);
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
sort_sm_max_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
heap_capacity
)
{
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
sort_sm_max_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
heap_capacity
)
{
for
(
int
i
=
heap_capacity
-
1
;
i
>
0
;
i
--
)
{
sm_swap_kv
(
&
heap_key
[
0
],
&
heap_value
[
0
],
&
heap_key
[
i
],
&
heap_value
[
i
]);
update_sm_max_heap
(
heap_key
,
heap_value
,
0
,
i
);
...
...
@@ -96,8 +96,8 @@ static __device__ inline void sort_sm_max_heap(_shared_ptr_ TK* heap_key,
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
lm_swap_kv
(
TK
*
k0
,
TV
*
v0
,
TK
*
k1
,
TV
*
v1
)
{
static
__device__
inline
void
lm_swap_kv
(
TK
*
k0
,
TV
*
v0
,
TK
*
k1
,
TV
*
v1
)
{
TK
tmpk
=
*
k0
;
TV
tmpv
=
*
v0
;
*
k0
=
*
k1
;
...
...
@@ -107,7 +107,7 @@ static __device__ inline void lm_swap_kv(TK* k0, TV* v0,
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
update_lm_min_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
idx
,
int
heap_capacity
)
{
static
__device__
inline
void
update_lm_min_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
idx
,
int
heap_capacity
)
{
while
(
idx
<
heap_capacity
)
{
int
child_l
=
idx
*
2
+
1
;
int
child_r
=
idx
*
2
+
2
;
...
...
@@ -115,10 +115,10 @@ static __device__ inline void update_lm_min_heap(TK* heap_key, TV* heap_value, i
if
(
child_r
>=
heap_capacity
)
{
if
(
child_l
>=
heap_capacity
)
{
// idx is leaf node, shift finished
break
;
}
else
{
// if child_r does not exist while child_l does, choose child_l
}
else
{
// if child_r does not exist while child_l does, choose child_l
child_min
=
child_l
;
}
}
else
{
// both child L & R exists
}
else
{
// both child L & R exists
child_min
=
child_l
+
(
heap_key
[
child_l
]
>
heap_key
[
child_r
]);
}
if
(
heap_key
[
idx
]
<=
heap_key
[
child_min
])
{
...
...
@@ -129,24 +129,24 @@ static __device__ inline void update_lm_min_heap(TK* heap_key, TV* heap_value, i
}
}
template
<
typename
TK
,
typename
TV
>
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
make_lm_min_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
size
)
{
TK
*
heap_key
,
TV
*
heap_value
,
int
size
)
{
for
(
int
i
=
size
/
2
-
1
;
i
>=
0
;
i
--
)
{
update_lm_min_heap
(
heap_key
,
heap_value
,
i
,
size
);
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
sort_lm_min_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
heap_capacity
)
{
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
sort_lm_min_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
heap_capacity
)
{
for
(
int
i
=
heap_capacity
-
1
;
i
>
0
;
i
--
)
{
lm_swap_kv
(
&
heap_key
[
0
],
&
heap_value
[
0
],
&
heap_key
[
i
],
&
heap_value
[
i
]);
update_lm_min_heap
(
heap_key
,
heap_value
,
0
,
i
);
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
update_lm_max_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
idx
,
int
heap_capacity
)
{
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
update_lm_max_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
idx
,
int
heap_capacity
)
{
while
(
idx
<
heap_capacity
)
{
int
child_l
=
idx
*
2
+
1
;
int
child_r
=
idx
*
2
+
2
;
...
...
@@ -154,10 +154,10 @@ static __device__ inline void update_lm_max_heap(TK* heap_key, TV* heap_value, i
if
(
child_r
>=
heap_capacity
)
{
if
(
child_l
>=
heap_capacity
)
{
// idx is leaf node, shift finished
break
;
}
else
{
// if child_r does not exist while child_l does, choose child_l
}
else
{
// if child_r does not exist while child_l does, choose child_l
child_max
=
child_l
;
}
}
else
{
// both child L & R exists
}
else
{
// both child L & R exists
child_max
=
child_l
+
(
heap_key
[
child_l
]
<
heap_key
[
child_r
]);
}
if
(
heap_key
[
idx
]
>=
heap_key
[
child_max
])
{
...
...
@@ -168,34 +168,34 @@ static __device__ inline void update_lm_max_heap(TK* heap_key, TV* heap_value, i
}
}
template
<
typename
TK
,
typename
TV
>
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
make_lm_max_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
size
)
{
TK
*
heap_key
,
TV
*
heap_value
,
int
size
)
{
for
(
int
i
=
size
/
2
-
1
;
i
>=
0
;
i
--
)
{
update_lm_max_heap
(
heap_key
,
heap_value
,
i
,
size
);
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
sort_lm_max_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
heap_capacity
)
{
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
sort_lm_max_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
heap_capacity
)
{
for
(
int
i
=
heap_capacity
-
1
;
i
>
0
;
i
--
)
{
lm_swap_kv
(
&
heap_key
[
0
],
&
heap_value
[
0
],
&
heap_key
[
i
],
&
heap_value
[
i
]);
update_lm_max_heap
(
heap_key
,
heap_value
,
0
,
i
);
}
}
template
<
typename
TID
>
template
<
typename
TID
>
__device__
TID
roundup_div_p
(
TID
a
,
TID
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
template
<
typename
T
>
__device__
T
min_p
(
T
a
,
T
b
){
template
<
typename
T
>
__device__
T
min_p
(
T
a
,
T
b
)
{
return
a
<
b
?
a
:
b
;
}
template
<
typename
TID
>
static
__device__
inline
void
partition
(
int
tid
,
int
nthreads
,
TID
len
,
int
align
,
TID
*
start
,
TID
*
end
)
{
template
<
typename
TID
>
static
__device__
inline
void
partition
(
int
tid
,
int
nthreads
,
TID
len
,
int
align
,
TID
*
start
,
TID
*
end
)
{
TID
block_cnt
=
roundup_div_p
<
TID
>
(
len
,
align
);
TID
remain_block
=
block_cnt
%
nthreads
;
TID
start_block
=
block_cnt
/
nthreads
*
static_cast
<
TID
>
(
tid
)
+
min_p
<
TID
>
(
tid
,
remain_block
);
...
...
@@ -204,48 +204,48 @@ static __device__ inline void partition(int tid, int nthreads, TID len, int alig
*
end
=
min_p
<
TID
>
(
end_block
*
align
,
len
);
}
template
<
typename
TX
,
typename
TY
>
static
__device__
void
primitive_cast
(
const
TX
*
x
,
TY
*
y
,
int
len
)
{
template
<
typename
TX
,
typename
TY
>
static
__device__
void
primitive_cast
(
const
TX
*
x
,
TY
*
y
,
int
len
)
{
return
;
}
template
<
>
__device__
void
primitive_cast
(
const
float
*
x
,
int
*
y
,
int
len
)
{
template
<
>
__device__
void
primitive_cast
(
const
float
*
x
,
int
*
y
,
int
len
)
{
for
(
int
i
=
0
;
i
<
len
;
i
+=
16
)
{
float32x16_t
Y
=
vload_lm_float32x16
(
x
);
__asm__
__volatile__
(
"vfloat2fix.rz vr0, %0
\t\n
"
"vstore_mask16.mz vr0{mr1}, 0(%1)"
::
"v"
(
Y
),
"r"
(
y
)
:
"vr0"
);
"vstore_mask16.mz vr0{mr1}, 0(%1)"
::
"v"
(
Y
),
"r"
(
y
)
:
"vr0"
);
x
+=
16
;
y
+=
16
;
}
mfence_lm
();
}
template
<
>
__device__
void
primitive_cast
(
const
int
*
x
,
float
*
y
,
int
len
)
{
template
<
>
__device__
void
primitive_cast
(
const
int
*
x
,
float
*
y
,
int
len
)
{
for
(
int
i
=
0
;
i
<
len
;
i
+=
16
)
{
int32x16_t
Y
=
vload_lm_int32x16
(
x
);
__asm__
__volatile__
(
"vfix2float.rn vr0, %0
\t\n
"
"vstore_mask16.mz vr0{mr1}, 0(%1)"
::
"v"
(
Y
),
"r"
(
y
)
:
"vr0"
);
"vstore_mask16.mz vr0{mr1}, 0(%1)"
::
"v"
(
Y
),
"r"
(
y
)
:
"vr0"
);
x
+=
16
;
y
+=
16
;
}
mfence_lm
();
}
static
__device__
inline
void
vload2_lm
(
const
float
*
ptr
,
float32x16_t
&
vl
,
float32x16_t
&
vh
)
{
static
__device__
inline
void
vload2_lm
(
const
float
*
ptr
,
float32x16_t
&
vl
,
float32x16_t
&
vh
)
{
vl
=
__builtin_xpu2_vload_mask16_mr1
(
ptr
,
0
);
vh
=
__builtin_xpu2_vload_mask16_mr1
(
ptr
+
16
,
0
);
}
static
__device__
inline
void
vstore2_lm
(
float
*
ptr
,
float32x16_t
&
vl
,
float32x16_t
&
vh
)
{
static
__device__
inline
void
vstore2_lm
(
float
*
ptr
,
float32x16_t
&
vl
,
float32x16_t
&
vh
)
{
vstore_lm_float32x16
(
ptr
,
vl
);
vstore_lm_float32x16
(
ptr
+
16
,
vh
);
}
template
<
>
__device__
void
primitive_cast
(
const
float
*
x
,
float
*
y
,
int
len
)
{
template
<
>
__device__
void
primitive_cast
(
const
float
*
x
,
float
*
y
,
int
len
)
{
if
(
x
==
y
)
{
return
;
}
else
{
// just copy
...
...
test/infiniop/topkrouter.py
View file @
c1af9783
...
...
@@ -33,7 +33,7 @@ _TEST_CASES_ = [
# w (weight) types
# Note: 'None' means the same as input dtype
_X_DTYPES
=
[
InfiniDtype
.
F32
,
InfiniDtype
.
F16
]
# [InfiniDtype.F32,
InfiniDtype.BF16, InfiniDtype.F16]
_X_DTYPES
=
[
InfiniDtype
.
F32
,
InfiniDtype
.
BF16
,
InfiniDtype
.
F16
]
# x types used for testing
_VALUE_DTYPES
=
[
InfiniDtype
.
F32
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment