Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
5584035d
Commit
5584035d
authored
Nov 28, 2025
by
zhangyue
Browse files
issue/676: kunlun topkrouter
parent
5f0f80d6
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
575 additions
and
1 deletion
+575
-1
src/infiniop/ops/topkrouter/kunlun/kernel.h
src/infiniop/ops/topkrouter/kunlun/kernel.h
+190
-0
src/infiniop/ops/topkrouter/kunlun/topkrouter_kunlun.h
src/infiniop/ops/topkrouter/kunlun/topkrouter_kunlun.h
+8
-0
src/infiniop/ops/topkrouter/kunlun/topkrouter_kunlun.xpu
src/infiniop/ops/topkrouter/kunlun/topkrouter_kunlun.xpu
+97
-0
src/infiniop/ops/topkrouter/operator.cc
src/infiniop/ops/topkrouter/operator.cc
+15
-0
src/infiniop/sort/kunlun/heap.h
src/infiniop/sort/kunlun/heap.h
+262
-0
test/infiniop/topkrouter.py
test/infiniop/topkrouter.py
+3
-1
No files found.
src/infiniop/ops/topkrouter/kunlun/kernel.h
0 → 100644
View file @
5584035d
#ifndef __TOPKROUTER_KUNLUN_KERNEL_H__
#define __TOPKROUTER_KUNLUN_KERNEL_H__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../sort/kunlun/heap.h"
#include <xpu/kernel/xtdk_io.h>
#include <float.h>
using
namespace
device
::
kunlun
::
kernel
;
template
<
typename
T
>
inline
__device__
float
expf_
(
T
x
)
{
float
data
;
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
data
=
x
;
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
bfloat16_t
>
)
{
data
=
__bfloat162float
(
x
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
data
=
__half2float
(
x
);
}
return
exp
(
data
);
}
template
<
typename
T
>
inline
__device__
float
sigmoidf_
(
T
x
)
{
return
1.0
f
/
(
1.0
f
+
expf_
<
T
>
(
-
x
));
}
template
<
typename
T
,
typename
TID
>
inline
__device__
void
descending_sort
(
T
*
x
,
TID
*
idx
,
int32_t
n
)
{
make_lm_min_heap
(
x
,
idx
,
n
);
mfence_lm
();
sort_lm_min_heap
(
x
,
idx
,
n
);
mfence_lm
();
}
template
<
typename
T
,
int32_t
BLOCK_THREADS
=
64
,
int32_t
MAX_EXPERTS
=
256
,
int32_t
N_GROUPS
=
8
,
int32_t
TOPK_GROUP
=
4
,
int32_t
TOPK_PER_GROUP
=
2
>
__global__
void
topkrouter_kernel
(
float
*
values_topk
,
// 输出数据, 形状[N, topk]
int32_t
*
indices_topk
,
// 输出索引, 形状[N, topk]
const
T
*
input
,
// 输入数据 [N, n_experts]
const
float
*
d_correction_bias
,
// 输入数据 [n_experts]
const
float
routed_scaling_factor
,
const
int32_t
N
,
// N tokens
const
int32_t
n_experts
,
// n_experts <= MAX_EXPERTS
const
int32_t
topk
)
{
const
int32_t
block_idx
=
cluster_id
();
if
(
block_idx
>=
N
)
{
return
;
}
const
int32_t
thread_idx
=
core_id
();
const
int32_t
GROUP_SIZE
=
n_experts
/
N_GROUPS
;
// 32 in DeepSeek-V3
__shared__
T
input_shm
[
MAX_EXPERTS
];
// input shm for i-th token, total N
__shared__
float
correction_bias_sm
[
MAX_EXPERTS
];
// Copy data into SM
if
(
thread_idx
==
0
)
{
GM2SM_ASYNC
(
input
+
block_idx
*
n_experts
,
input_shm
,
n_experts
*
sizeof
(
T
));
GM2SM_ASYNC
(
d_correction_bias
,
correction_bias_sm
,
n_experts
*
sizeof
(
float
));
}
sync_cluster
();
// Calculate sigmoid scores and add bias
__shared__
float
scores
[
MAX_EXPERTS
];
__shared__
float
scores_with_bias_shm
[
MAX_EXPERTS
];
for
(
int32_t
i
=
thread_idx
;
i
<
n_experts
;
i
+=
BLOCK_THREADS
)
{
float
v
=
sigmoidf_
<
T
>
(
input_shm
[
i
]);
scores
[
i
]
=
v
;
scores_with_bias_shm
[
i
]
=
v
+
correction_bias_sm
[
i
];
}
sync_cluster
();
// 按N_GROUPS分组,每组统计TOPK_PER_GROUP最大分数和
__shared__
float
values_grouped_topk_shm
[
N_GROUPS
];
if
(
thread_idx
<
N_GROUPS
)
{
int32_t
base
=
thread_idx
*
GROUP_SIZE
;
float
tmp
[
TOPK_PER_GROUP
];
// 初始化为负无穷,便于找topk
#pragma unroll
for
(
int32_t
k
=
0
;
k
<
TOPK_PER_GROUP
;
++
k
)
{
tmp
[
k
]
=
-
FLT_MAX
;
}
// 维护一个TOPK_PER_GROUP大小的降序队列
for
(
int32_t
i
=
0
;
i
<
GROUP_SIZE
;
++
i
)
{
float
val
=
scores_with_bias_shm
[
base
+
i
];
// 插入到队列
if
(
val
>
tmp
[
TOPK_PER_GROUP
-
1
])
{
int
pos
=
TOPK_PER_GROUP
-
1
;
while
(
pos
>
0
&&
val
>
tmp
[
pos
-
1
])
{
tmp
[
pos
]
=
tmp
[
pos
-
1
];
--
pos
;
}
tmp
[
pos
]
=
val
;
}
}
float
group_sum
=
0.
f
;
for
(
int32_t
k
=
0
;
k
<
TOPK_PER_GROUP
;
++
k
)
{
group_sum
+=
tmp
[
k
];
}
values_grouped_topk_shm
[
thread_idx
]
=
group_sum
;
}
sync_cluster
();
// Select TOPK_GROUP in N_GROUPS according to sum of TOPK_PER_GROUP values in each group
__shared__
int32_t
indices_group
[
TOPK_GROUP
];
if
(
thread_idx
==
0
)
{
float
values_group
[
TOPK_GROUP
];
int32_t
indices_tmp
[
TOPK_GROUP
];
// 初始化为负无穷和-1
#pragma unroll
for
(
int32_t
k
=
0
;
k
<
TOPK_GROUP
;
++
k
)
{
values_group
[
k
]
=
-
FLT_MAX
;
indices_tmp
[
k
]
=
-
1
;
}
for
(
int32_t
i
=
0
;
i
<
N_GROUPS
;
i
++
)
{
float
val
=
values_grouped_topk_shm
[
i
];
if
(
val
>
values_group
[
TOPK_GROUP
-
1
])
{
int32_t
pos
=
TOPK_GROUP
-
1
;
while
(
pos
>
0
&&
val
>
values_group
[
pos
-
1
])
{
values_group
[
pos
]
=
values_group
[
pos
-
1
];
indices_tmp
[
pos
]
=
indices_tmp
[
pos
-
1
];
pos
--
;
}
values_group
[
pos
]
=
val
;
indices_tmp
[
pos
]
=
i
;
}
}
// 写入共享内存
#pragma unroll
for
(
int32_t
k
=
0
;
k
<
TOPK_GROUP
;
++
k
)
{
indices_group
[
k
]
=
indices_tmp
[
k
];
}
}
sync_cluster
();
// 拷贝被选中的group的数据 values_group_select和 indices_group_select
__shared__
float
values_group_select
[
MAX_EXPERTS
];
__shared__
int32_t
indices_group_select
[
MAX_EXPERTS
];
if
(
thread_idx
<
TOPK_GROUP
)
{
int32_t
group_id
=
indices_group
[
thread_idx
];
// 用于本线程复制group数据的临时buffer
float
local_buffer
[
GROUP_SIZE
];
// 拷贝选中group的所有分数到local_buffer
__builtin_memcpy
(
local_buffer
,
scores_with_bias_shm
+
group_id
*
GROUP_SIZE
,
GROUP_SIZE
*
sizeof
(
float
));
mfence_lm
();
// 写回到共享内存选取buffer,对齐排列
__builtin_memcpy
(
values_group_select
+
thread_idx
*
GROUP_SIZE
,
local_buffer
,
GROUP_SIZE
*
sizeof
(
float
));
// 记录原始索引
for
(
int32_t
i
=
0
;
i
<
GROUP_SIZE
;
i
++
)
{
indices_group_select
[
thread_idx
*
GROUP_SIZE
+
i
]
=
group_id
*
GROUP_SIZE
+
i
;
}
}
sync_cluster
();
// Global topk and copy to GM
if
(
thread_idx
==
0
)
{
int32_t
len
=
GROUP_SIZE
*
TOPK_GROUP
;
float
values
[
len
];
int32_t
indices
[
len
];
// COPY to LM
__builtin_memcpy
(
values
,
values_group_select
,
len
*
sizeof
(
float
));
__builtin_memcpy
(
indices
,
indices_group_select
,
len
*
sizeof
(
int32_t
));
mfence_lm
();
// Sort
descending_sort
<
float
,
int32_t
>
(
values
,
indices
,
len
);
// Last scaling
float
sum
=
1e-9
f
;
for
(
int32_t
k
=
0
;
k
<
topk
;
k
++
)
{
int32_t
idx
=
indices
[
k
];
sum
+=
scores
[
idx
];
}
for
(
int32_t
k
=
0
;
k
<
topk
;
k
++
)
{
int32_t
idx
=
indices
[
k
];
values
[
k
]
=
routed_scaling_factor
*
scores
[
idx
]
/
sum
;
}
mfence_lm
();
// COPY to GM
LM2GM_ASYNC
(
values
,
values_topk
,
topk
*
sizeof
(
float
));
LM2GM_ASYNC
(
indices
,
indices_topk
,
topk
*
sizeof
(
int32_t
));
}
sync_cluster
();
}
#endif // __TOPKROUTER_KUNLUN_KERNEL_H__
src/infiniop/ops/topkrouter/kunlun/topkrouter_kunlun.h
0 → 100644
View file @
5584035d
#ifndef __TOPKROUTER_KUNLUN_H__
#define __TOPKROUTER_KUNLUN_H__
#include "../topkrouter.h"
DESCRIPTOR
(
kunlun
)
#endif
src/infiniop/ops/topkrouter/kunlun/topkrouter_kunlun.xpu
0 → 100644
View file @
5584035d
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "kernel.h"
#include "topkrouter_kunlun.h"
#include <memory>
#include <stdint.h>
namespace op::topkrouter::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t correction_bias_desc) {
auto result = TopkrouterInfo::create(x_desc);
CHECK_RESULT(result);
auto info = result.take();
if (info.x_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::kunlun::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <int BLOCK_SIZE = 64>
infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias,
const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype,
kunlunStream_t stream) {
if (xtype == INFINI_DTYPE_F32) {
topkrouter_kernel<float, BLOCK_SIZE, 256, 8, 4, 2>
<<<N, BLOCK_SIZE, stream>>>(
d_values_out,
d_indices_out,
(float *)d_input,
(const float *)d_correction_bias,
routed_scaling_factor,
N,
width,
topk);
} else if (xtype == INFINI_DTYPE_F16) {
topkrouter_kernel<half, BLOCK_SIZE, 256, 8, 4, 2>
<<<N, BLOCK_SIZE, stream>>>(
d_values_out,
d_indices_out,
(half *)d_input,
(const float *)d_correction_bias,
routed_scaling_factor,
N,
width,
topk);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
float *values,
int *indices,
const void *x,
const float *correction_bias,
const float routed_scaling_factor,
const size_t topk,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
size_t N = _info.N;
size_t width = _info.width;
auto kunlun_stream = reinterpret_cast<kunlunStream_t>(stream);
launch_topkrouter<64>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, kunlun_stream);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::topkrouter::kunlun
src/infiniop/ops/topkrouter/operator.cc
View file @
5584035d
...
...
@@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/topkrouter_nvidia.cuh"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/topkrouter_kunlun.h"
#endif
__C
infiniStatus_t
infiniopCreateTopkrouterDescriptor
(
infiniopHandle_t
handle
,
infiniopTopkrouterDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
x_desc
,
...
...
@@ -26,6 +29,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
#endif
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
}
...
...
@@ -49,6 +55,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
#endif
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
}
...
...
@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
#endif
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
}
...
...
@@ -98,6 +110,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
#endif
#ifdef ENABLE_QY_API
DESTROY
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_KUNLUN_API
DESTROY
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
}
...
...
src/infiniop/sort/kunlun/heap.h
0 → 100644
View file @
5584035d
#ifndef __INFINIOP_HEAP_KUNLUN_H__
#define __INFINIOP_HEAP_KUNLUN_H__
#include "xpu/kernel/xtdk_simd_xpu2.h"
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
sm_swap_kv
(
_shared_ptr_
TK
*
k0
,
_shared_ptr_
TV
*
v0
,
_shared_ptr_
TK
*
k1
,
_shared_ptr_
TV
*
v1
)
{
TK
tmpk
=
*
k0
;
TV
tmpv
=
*
v0
;
*
k0
=
*
k1
;
*
v0
=
*
v1
;
*
k1
=
tmpk
;
*
v1
=
tmpv
;
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
update_sm_min_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
idx
,
int
heap_capacity
)
{
while
(
idx
<
heap_capacity
)
{
int
child_l
=
idx
*
2
+
1
;
int
child_r
=
idx
*
2
+
2
;
int
child_min
=
child_l
;
if
(
child_r
>=
heap_capacity
)
{
if
(
child_l
>=
heap_capacity
)
{
// idx is leaf node, shift finished
break
;
}
else
{
// if child_r does not exist while child_l does, choose child_l
child_min
=
child_l
;
}
}
else
{
// both child L & R exists
child_min
=
child_l
+
(
heap_key
[
child_l
]
>
heap_key
[
child_r
]);
}
if
(
heap_key
[
idx
]
<=
heap_key
[
child_min
])
{
break
;
}
sm_swap_kv
(
&
heap_key
[
idx
],
&
heap_value
[
idx
],
&
heap_key
[
child_min
],
&
heap_value
[
child_min
]);
idx
=
child_min
;
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
make_sm_min_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
size
)
{
for
(
int
i
=
size
/
2
-
1
;
i
>=
0
;
i
--
)
{
update_sm_min_heap
(
heap_key
,
heap_value
,
i
,
size
);
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
sort_sm_min_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
heap_capacity
)
{
for
(
int
i
=
heap_capacity
-
1
;
i
>
0
;
i
--
)
{
sm_swap_kv
(
&
heap_key
[
0
],
&
heap_value
[
0
],
&
heap_key
[
i
],
&
heap_value
[
i
]);
update_sm_min_heap
(
heap_key
,
heap_value
,
0
,
i
);
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
update_sm_max_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
idx
,
int
heap_capacity
)
{
while
(
idx
<
heap_capacity
)
{
int
child_l
=
idx
*
2
+
1
;
int
child_r
=
idx
*
2
+
2
;
int
child_max
=
child_l
;
if
(
child_r
>=
heap_capacity
)
{
if
(
child_l
>=
heap_capacity
)
{
// idx is leaf node, shift finished
break
;
}
else
{
// if child_r does not exist while child_l does, choose child_l
child_max
=
child_l
;
}
}
else
{
// both child L & R exists
child_max
=
child_l
+
(
heap_key
[
child_l
]
<
heap_key
[
child_r
]);
}
if
(
heap_key
[
idx
]
>=
heap_key
[
child_max
])
{
break
;
}
sm_swap_kv
(
&
heap_key
[
idx
],
&
heap_value
[
idx
],
&
heap_key
[
child_max
],
&
heap_value
[
child_max
]);
idx
=
child_max
;
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
make_sm_max_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
size
)
{
for
(
int
i
=
size
/
2
-
1
;
i
>=
0
;
i
--
)
{
update_sm_max_heap
(
heap_key
,
heap_value
,
i
,
size
);
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
sort_sm_max_heap
(
_shared_ptr_
TK
*
heap_key
,
_shared_ptr_
TV
*
heap_value
,
int
heap_capacity
)
{
for
(
int
i
=
heap_capacity
-
1
;
i
>
0
;
i
--
)
{
sm_swap_kv
(
&
heap_key
[
0
],
&
heap_value
[
0
],
&
heap_key
[
i
],
&
heap_value
[
i
]);
update_sm_max_heap
(
heap_key
,
heap_value
,
0
,
i
);
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
lm_swap_kv
(
TK
*
k0
,
TV
*
v0
,
TK
*
k1
,
TV
*
v1
)
{
TK
tmpk
=
*
k0
;
TV
tmpv
=
*
v0
;
*
k0
=
*
k1
;
*
v0
=
*
v1
;
*
k1
=
tmpk
;
*
v1
=
tmpv
;
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
update_lm_min_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
idx
,
int
heap_capacity
)
{
while
(
idx
<
heap_capacity
)
{
int
child_l
=
idx
*
2
+
1
;
int
child_r
=
idx
*
2
+
2
;
int
child_min
=
child_l
;
if
(
child_r
>=
heap_capacity
)
{
if
(
child_l
>=
heap_capacity
)
{
// idx is leaf node, shift finished
break
;
}
else
{
// if child_r does not exist while child_l does, choose child_l
child_min
=
child_l
;
}
}
else
{
// both child L & R exists
child_min
=
child_l
+
(
heap_key
[
child_l
]
>
heap_key
[
child_r
]);
}
if
(
heap_key
[
idx
]
<=
heap_key
[
child_min
])
{
break
;
}
lm_swap_kv
(
&
heap_key
[
idx
],
&
heap_value
[
idx
],
&
heap_key
[
child_min
],
&
heap_value
[
child_min
]);
idx
=
child_min
;
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
make_lm_min_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
size
)
{
for
(
int
i
=
size
/
2
-
1
;
i
>=
0
;
i
--
)
{
update_lm_min_heap
(
heap_key
,
heap_value
,
i
,
size
);
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
sort_lm_min_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
heap_capacity
)
{
for
(
int
i
=
heap_capacity
-
1
;
i
>
0
;
i
--
)
{
lm_swap_kv
(
&
heap_key
[
0
],
&
heap_value
[
0
],
&
heap_key
[
i
],
&
heap_value
[
i
]);
update_lm_min_heap
(
heap_key
,
heap_value
,
0
,
i
);
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
update_lm_max_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
idx
,
int
heap_capacity
)
{
while
(
idx
<
heap_capacity
)
{
int
child_l
=
idx
*
2
+
1
;
int
child_r
=
idx
*
2
+
2
;
int
child_max
=
child_l
;
if
(
child_r
>=
heap_capacity
)
{
if
(
child_l
>=
heap_capacity
)
{
// idx is leaf node, shift finished
break
;
}
else
{
// if child_r does not exist while child_l does, choose child_l
child_max
=
child_l
;
}
}
else
{
// both child L & R exists
child_max
=
child_l
+
(
heap_key
[
child_l
]
<
heap_key
[
child_r
]);
}
if
(
heap_key
[
idx
]
>=
heap_key
[
child_max
])
{
break
;
}
lm_swap_kv
(
&
heap_key
[
idx
],
&
heap_value
[
idx
],
&
heap_key
[
child_max
],
&
heap_value
[
child_max
]);
idx
=
child_max
;
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
make_lm_max_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
size
)
{
for
(
int
i
=
size
/
2
-
1
;
i
>=
0
;
i
--
)
{
update_lm_max_heap
(
heap_key
,
heap_value
,
i
,
size
);
}
}
template
<
typename
TK
,
typename
TV
>
static
__device__
inline
void
sort_lm_max_heap
(
TK
*
heap_key
,
TV
*
heap_value
,
int
heap_capacity
)
{
for
(
int
i
=
heap_capacity
-
1
;
i
>
0
;
i
--
)
{
lm_swap_kv
(
&
heap_key
[
0
],
&
heap_value
[
0
],
&
heap_key
[
i
],
&
heap_value
[
i
]);
update_lm_max_heap
(
heap_key
,
heap_value
,
0
,
i
);
}
}
template
<
typename
TID
>
__device__
TID
roundup_div_p
(
TID
a
,
TID
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
template
<
typename
T
>
__device__
T
min_p
(
T
a
,
T
b
){
return
a
<
b
?
a
:
b
;
}
template
<
typename
TID
>
static
__device__
inline
void
partition
(
int
tid
,
int
nthreads
,
TID
len
,
int
align
,
TID
*
start
,
TID
*
end
)
{
TID
block_cnt
=
roundup_div_p
<
TID
>
(
len
,
align
);
TID
remain_block
=
block_cnt
%
nthreads
;
TID
start_block
=
block_cnt
/
nthreads
*
static_cast
<
TID
>
(
tid
)
+
min_p
<
TID
>
(
tid
,
remain_block
);
TID
end_block
=
start_block
+
block_cnt
/
nthreads
+
(
tid
<
remain_block
);
*
start
=
min_p
<
TID
>
(
start_block
*
align
,
len
);
*
end
=
min_p
<
TID
>
(
end_block
*
align
,
len
);
}
template
<
typename
TX
,
typename
TY
>
static
__device__
void
primitive_cast
(
const
TX
*
x
,
TY
*
y
,
int
len
)
{
return
;
}
template
<
>
__device__
void
primitive_cast
(
const
float
*
x
,
int
*
y
,
int
len
)
{
for
(
int
i
=
0
;
i
<
len
;
i
+=
16
)
{
float32x16_t
Y
=
vload_lm_float32x16
(
x
);
__asm__
__volatile__
(
"vfloat2fix.rz vr0, %0
\t\n
"
"vstore_mask16.mz vr0{mr1}, 0(%1)"
::
"v"
(
Y
),
"r"
(
y
)
:
"vr0"
);
x
+=
16
;
y
+=
16
;
}
mfence_lm
();
}
template
<
>
__device__
void
primitive_cast
(
const
int
*
x
,
float
*
y
,
int
len
)
{
for
(
int
i
=
0
;
i
<
len
;
i
+=
16
)
{
int32x16_t
Y
=
vload_lm_int32x16
(
x
);
__asm__
__volatile__
(
"vfix2float.rn vr0, %0
\t\n
"
"vstore_mask16.mz vr0{mr1}, 0(%1)"
::
"v"
(
Y
),
"r"
(
y
)
:
"vr0"
);
x
+=
16
;
y
+=
16
;
}
mfence_lm
();
}
static
__device__
inline
void
vload2_lm
(
const
float
*
ptr
,
float32x16_t
&
vl
,
float32x16_t
&
vh
)
{
vl
=
__builtin_xpu2_vload_mask16_mr1
(
ptr
,
0
);
vh
=
__builtin_xpu2_vload_mask16_mr1
(
ptr
+
16
,
0
);
}
static
__device__
inline
void
vstore2_lm
(
float
*
ptr
,
float32x16_t
&
vl
,
float32x16_t
&
vh
)
{
vstore_lm_float32x16
(
ptr
,
vl
);
vstore_lm_float32x16
(
ptr
+
16
,
vh
);
}
template
<
>
__device__
void
primitive_cast
(
const
float
*
x
,
float
*
y
,
int
len
)
{
if
(
x
==
y
)
{
return
;
}
else
{
// just copy
float32x16_t
vec_x_0
;
float32x16_t
vec_x_1
;
for
(
int
i
=
0
;
i
<
len
;
i
+=
32
)
{
vload2_lm
(
x
+
i
,
vec_x_0
,
vec_x_1
);
vstore2_lm
(
y
+
i
,
vec_x_0
,
vec_x_1
);
}
mfence_lm
();
}
}
#endif
test/infiniop/topkrouter.py
View file @
5584035d
...
...
@@ -33,7 +33,7 @@ _TEST_CASES_ = [
# w (weight) types
# Note: 'None' means the same as input dtype
_X_DTYPES
=
[]
# [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
_X_DTYPES
=
[
InfiniDtype
.
F32
,
InfiniDtype
.
F16
]
# [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
# x types used for testing
_VALUE_DTYPES
=
[
InfiniDtype
.
F32
]
...
...
@@ -194,6 +194,8 @@ def test(
lib_topkrouter
()
torch
.
cuda
.
synchronize
()
lable_values
,
lable_indices
=
torch_topkrouter
(
x
.
actual_tensor
(),
correction_bias
.
actual_tensor
(),
routed_scaling_factor
,
topk
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment