Unverified Commit 1887c3f1 authored by zhangyue's avatar zhangyue Committed by GitHub
Browse files

Merge pull request #677 from InfiniTensor/issue/676

Issue/676 昆仑芯 topkrouter
parents e290ff22 56ec08ae
#ifndef __TOPKROUTER_KUNLUN_KERNEL_H__
#define __TOPKROUTER_KUNLUN_KERNEL_H__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../sort/kunlun/heap.h"
#include <float.h>
using namespace device::kunlun::kernel;
template <typename T>
inline __device__ float expf_(T x) {
float data;
if constexpr (std::is_same_v<T, float>) {
data = x;
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
data = __bfloat162float(x);
} else if constexpr (std::is_same_v<T, half>) {
data = __half2float(x);
}
return exp(data);
}
template <typename T>
inline __device__ float sigmoidf_(T x) {
return 1.0f / (1.0f + expf_<T>(-x));
}
template <typename T, typename TID>
inline __device__ void descending_sort(T *x, TID *idx, int32_t n) {
make_lm_min_heap(x, idx, n);
mfence_lm();
sort_lm_min_heap(x, idx, n);
mfence_lm();
}
template <typename T, int32_t BLOCK_THREADS = 64, int32_t MAX_EXPERTS = 256,
int32_t N_GROUPS = 8, int32_t TOPK_GROUP = 4, int32_t TOPK_PER_GROUP = 2>
__global__ void topkrouter_kernel(
float *values_topk, // 输出数据, 形状[N, topk]
int32_t *indices_topk, // 输出索引, 形状[N, topk]
const T *input, // 输入数据 [N, n_experts]
const float *d_correction_bias, // 输入数据 [n_experts]
const float routed_scaling_factor,
const int32_t N, // N tokens
const int32_t n_experts, // n_experts <= MAX_EXPERTS
const int32_t topk) {
const int32_t block_idx = cluster_id();
if (block_idx >= N) {
return;
}
const int32_t thread_idx = core_id();
const int32_t GROUP_SIZE = n_experts / N_GROUPS; // 32 in DeepSeek-V3
__shared__ T input_shm[MAX_EXPERTS]; // input shm for i-th token, total N
__shared__ float correction_bias_sm[MAX_EXPERTS];
// Copy data into SM
if (thread_idx == 0) {
GM2SM_ASYNC(input + block_idx * n_experts, input_shm, n_experts * sizeof(T));
GM2SM_ASYNC(d_correction_bias, correction_bias_sm, n_experts * sizeof(float));
}
sync_cluster();
// Calculate sigmoid scores and add bias
__shared__ float scores[MAX_EXPERTS];
__shared__ float scores_with_bias_shm[MAX_EXPERTS];
for (int32_t i = thread_idx; i < n_experts; i += BLOCK_THREADS) {
float v = sigmoidf_<T>(input_shm[i]);
scores[i] = v;
scores_with_bias_shm[i] = v + correction_bias_sm[i];
}
sync_cluster();
// 按N_GROUPS分组,每组统计TOPK_PER_GROUP最大分数和
__shared__ float values_grouped_topk_shm[N_GROUPS];
if (thread_idx < N_GROUPS) {
int32_t base = thread_idx * GROUP_SIZE;
float tmp[TOPK_PER_GROUP];
// 初始化为负无穷,便于找topk
#pragma unroll
for (int32_t k = 0; k < TOPK_PER_GROUP; ++k) {
tmp[k] = -FLT_MAX;
}
// 维护一个TOPK_PER_GROUP大小的降序队列
for (int32_t i = 0; i < GROUP_SIZE; ++i) {
float val = scores_with_bias_shm[base + i];
// 插入到队列
if (val > tmp[TOPK_PER_GROUP - 1]) {
int pos = TOPK_PER_GROUP - 1;
while (pos > 0 && val > tmp[pos - 1]) {
tmp[pos] = tmp[pos - 1];
--pos;
}
tmp[pos] = val;
}
}
float group_sum = 0.f;
for (int32_t k = 0; k < TOPK_PER_GROUP; ++k) {
group_sum += tmp[k];
}
values_grouped_topk_shm[thread_idx] = group_sum;
}
sync_cluster();
// Select TOPK_GROUP in N_GROUPS according to sum of TOPK_PER_GROUP values in each group
__shared__ int32_t indices_group[TOPK_GROUP];
if (thread_idx == 0) {
float values_group[TOPK_GROUP];
int32_t indices_tmp[TOPK_GROUP];
// 初始化为负无穷和-1
#pragma unroll
for (int32_t k = 0; k < TOPK_GROUP; ++k) {
values_group[k] = -FLT_MAX;
indices_tmp[k] = -1;
}
for (int32_t i = 0; i < N_GROUPS; i++) {
float val = values_grouped_topk_shm[i];
if (val > values_group[TOPK_GROUP - 1]) {
int32_t pos = TOPK_GROUP - 1;
while (pos > 0 && val > values_group[pos - 1]) {
values_group[pos] = values_group[pos - 1];
indices_tmp[pos] = indices_tmp[pos - 1];
pos--;
}
values_group[pos] = val;
indices_tmp[pos] = i;
}
}
// 写入共享内存
#pragma unroll
for (int32_t k = 0; k < TOPK_GROUP; ++k) {
indices_group[k] = indices_tmp[k];
}
}
sync_cluster();
// 拷贝被选中的group的数据 values_group_select和 indices_group_select
__shared__ float values_group_select[MAX_EXPERTS];
__shared__ int32_t indices_group_select[MAX_EXPERTS];
if (thread_idx < TOPK_GROUP) {
int32_t group_id = indices_group[thread_idx];
// 用于本线程复制group数据的临时buffer
float local_buffer[GROUP_SIZE];
// 拷贝选中group的所有分数到local_buffer
__builtin_memcpy(local_buffer, scores_with_bias_shm + group_id * GROUP_SIZE, GROUP_SIZE * sizeof(float));
mfence_lm();
// 写回到共享内存选取buffer,对齐排列
__builtin_memcpy(values_group_select + thread_idx * GROUP_SIZE, local_buffer, GROUP_SIZE * sizeof(float));
// 记录原始索引
for (int32_t i = 0; i < GROUP_SIZE; i++) {
indices_group_select[thread_idx * GROUP_SIZE + i] = group_id * GROUP_SIZE + i;
}
}
sync_cluster();
// Global topk and copy to GM
if (thread_idx == 0) {
int32_t len = GROUP_SIZE * TOPK_GROUP;
float values[len];
int32_t indices[len];
// COPY to LM
__builtin_memcpy(values, values_group_select, len * sizeof(float));
__builtin_memcpy(indices, indices_group_select, len * sizeof(int32_t));
mfence_lm();
// Sort
descending_sort<float, int32_t>(values, indices, len);
// Last scaling
float sum = 1e-9f;
for (int32_t k = 0; k < topk; k++) {
int32_t idx = indices[k];
sum += scores[idx];
}
for (int32_t k = 0; k < topk; k++) {
int32_t idx = indices[k];
values[k] = routed_scaling_factor * scores[idx] / sum;
}
mfence_lm();
// COPY to GM
LM2GM_ASYNC(values, values_topk, topk * sizeof(float));
LM2GM_ASYNC(indices, indices_topk, topk * sizeof(int32_t));
}
sync_cluster();
}
#endif // __TOPKROUTER_KUNLUN_KERNEL_H__
#ifndef __TOPKROUTER_KUNLUN_H__
#define __TOPKROUTER_KUNLUN_H__
#include "../topkrouter.h"
DESCRIPTOR(kunlun)
#endif
#include "../../../devices/kunlun/kunlun_common.h"
#include "../../../devices/kunlun/kunlun_handle.h"
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "kernel.h"
#include "topkrouter_kunlun.h"
#include <memory>
#include <stdint.h>
namespace op::topkrouter::kunlun {
struct Descriptor::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t correction_bias_desc) {
auto result = TopkrouterInfo::create(x_desc);
CHECK_RESULT(result);
auto info = result.take();
if (info.x_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::kunlun::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <int BLOCK_SIZE = 64>
infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias,
const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype,
kunlunStream_t stream) {
if (xtype == INFINI_DTYPE_F32) {
topkrouter_kernel<float, BLOCK_SIZE, 256, 8, 4, 2>
<<<N, BLOCK_SIZE, stream>>>(
d_values_out,
d_indices_out,
(float *)d_input,
(const float *)d_correction_bias,
routed_scaling_factor,
N,
width,
topk);
} else if (xtype == INFINI_DTYPE_F16) {
topkrouter_kernel<half, BLOCK_SIZE, 256, 8, 4, 2>
<<<N, BLOCK_SIZE, stream>>>(
d_values_out,
d_indices_out,
(half *)d_input,
(const float *)d_correction_bias,
routed_scaling_factor,
N,
width,
topk);
} else if (xtype == INFINI_DTYPE_BF16) {
topkrouter_kernel<bfloat16_t, BLOCK_SIZE, 256, 8, 4, 2>
<<<N, BLOCK_SIZE, stream>>>(
d_values_out,
d_indices_out,
(bfloat16_t *)d_input,
(const float *)d_correction_bias,
routed_scaling_factor,
N,
width,
topk);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
float *values,
int *indices,
const void *x,
const float *correction_bias,
const float routed_scaling_factor,
const size_t topk,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
size_t N = _info.N;
size_t width = _info.width;
auto kunlun_stream = reinterpret_cast<kunlunStream_t>(stream);
launch_topkrouter<64>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, kunlun_stream);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::topkrouter::kunlun
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/topkrouter_nvidia.cuh" #include "nvidia/topkrouter_nvidia.cuh"
#endif #endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/topkrouter_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, infiniopTopkrouterDescriptor_t *desc_ptr, __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, infiniopTopkrouterDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t x_desc,
...@@ -26,6 +29,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i ...@@ -26,6 +29,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
#endif #endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia); CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
} }
...@@ -49,6 +55,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript ...@@ -49,6 +55,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
#endif #endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia); GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
} }
...@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void ...@@ -75,6 +84,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
#endif #endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia); CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
} }
...@@ -98,6 +110,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip ...@@ -98,6 +110,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
#endif #endif
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia); DESTROY(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
} }
......
#ifndef __INFINIOP_HEAP_KUNLUN_H__
#define __INFINIOP_HEAP_KUNLUN_H__
#include "xpu/kernel/xtdk_simd_xpu2.h"
template <typename TK, typename TV>
static __device__ inline void sm_swap_kv(_shared_ptr_ TK *k0, _shared_ptr_ TV *v0,
_shared_ptr_ TK *k1, _shared_ptr_ TV *v1) {
TK tmpk = *k0;
TV tmpv = *v0;
*k0 = *k1;
*v0 = *v1;
*k1 = tmpk;
*v1 = tmpv;
}
template <typename TK, typename TV>
static __device__ inline void update_sm_min_heap(_shared_ptr_ TK *heap_key,
_shared_ptr_ TV *heap_value, int idx, int heap_capacity) {
while (idx < heap_capacity) {
int child_l = idx * 2 + 1;
int child_r = idx * 2 + 2;
int child_min = child_l;
if (child_r >= heap_capacity) {
if (child_l >= heap_capacity) { // idx is leaf node, shift finished
break;
} else { // if child_r does not exist while child_l does, choose child_l
child_min = child_l;
}
} else { // both child L & R exists
child_min = child_l + (heap_key[child_l] > heap_key[child_r]);
}
if (heap_key[idx] <= heap_key[child_min]) {
break;
}
sm_swap_kv(&heap_key[idx], &heap_value[idx], &heap_key[child_min], &heap_value[child_min]);
idx = child_min;
}
}
template <typename TK, typename TV>
static __device__ inline void make_sm_min_heap(
_shared_ptr_ TK *heap_key, _shared_ptr_ TV *heap_value, int size) {
for (int i = size / 2 - 1; i >= 0; i--) {
update_sm_min_heap(heap_key, heap_value, i, size);
}
}
template <typename TK, typename TV>
static __device__ inline void sort_sm_min_heap(
_shared_ptr_ TK *heap_key, _shared_ptr_ TV *heap_value, int heap_capacity) {
for (int i = heap_capacity - 1; i > 0; i--) {
sm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]);
update_sm_min_heap(heap_key, heap_value, 0, i);
}
}
template <typename TK, typename TV>
static __device__ inline void update_sm_max_heap(_shared_ptr_ TK *heap_key,
_shared_ptr_ TV *heap_value, int idx, int heap_capacity) {
while (idx < heap_capacity) {
int child_l = idx * 2 + 1;
int child_r = idx * 2 + 2;
int child_max = child_l;
if (child_r >= heap_capacity) {
if (child_l >= heap_capacity) { // idx is leaf node, shift finished
break;
} else { // if child_r does not exist while child_l does, choose child_l
child_max = child_l;
}
} else { // both child L & R exists
child_max = child_l + (heap_key[child_l] < heap_key[child_r]);
}
if (heap_key[idx] >= heap_key[child_max]) {
break;
}
sm_swap_kv(&heap_key[idx], &heap_value[idx], &heap_key[child_max], &heap_value[child_max]);
idx = child_max;
}
}
template <typename TK, typename TV>
static __device__ inline void make_sm_max_heap(
_shared_ptr_ TK *heap_key, _shared_ptr_ TV *heap_value, int size) {
for (int i = size / 2 - 1; i >= 0; i--) {
update_sm_max_heap(heap_key, heap_value, i, size);
}
}
template <typename TK, typename TV>
static __device__ inline void sort_sm_max_heap(_shared_ptr_ TK *heap_key,
_shared_ptr_ TV *heap_value, int heap_capacity) {
for (int i = heap_capacity - 1; i > 0; i--) {
sm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]);
update_sm_max_heap(heap_key, heap_value, 0, i);
}
}
template <typename TK, typename TV>
static __device__ inline void lm_swap_kv(TK *k0, TV *v0,
TK *k1, TV *v1) {
TK tmpk = *k0;
TV tmpv = *v0;
*k0 = *k1;
*v0 = *v1;
*k1 = tmpk;
*v1 = tmpv;
}
template <typename TK, typename TV>
static __device__ inline void update_lm_min_heap(TK *heap_key, TV *heap_value, int idx, int heap_capacity) {
while (idx < heap_capacity) {
int child_l = idx * 2 + 1;
int child_r = idx * 2 + 2;
int child_min = child_l;
if (child_r >= heap_capacity) {
if (child_l >= heap_capacity) { // idx is leaf node, shift finished
break;
} else { // if child_r does not exist while child_l does, choose child_l
child_min = child_l;
}
} else { // both child L & R exists
child_min = child_l + (heap_key[child_l] > heap_key[child_r]);
}
if (heap_key[idx] <= heap_key[child_min]) {
break;
}
lm_swap_kv(&heap_key[idx], &heap_value[idx], &heap_key[child_min], &heap_value[child_min]);
idx = child_min;
}
}
template <typename TK, typename TV>
static __device__ inline void make_lm_min_heap(
TK *heap_key, TV *heap_value, int size) {
for (int i = size / 2 - 1; i >= 0; i--) {
update_lm_min_heap(heap_key, heap_value, i, size);
}
}
template <typename TK, typename TV>
static __device__ inline void sort_lm_min_heap(TK *heap_key, TV *heap_value, int heap_capacity) {
for (int i = heap_capacity - 1; i > 0; i--) {
lm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]);
update_lm_min_heap(heap_key, heap_value, 0, i);
}
}
template <typename TK, typename TV>
static __device__ inline void update_lm_max_heap(TK *heap_key, TV *heap_value, int idx, int heap_capacity) {
while (idx < heap_capacity) {
int child_l = idx * 2 + 1;
int child_r = idx * 2 + 2;
int child_max = child_l;
if (child_r >= heap_capacity) {
if (child_l >= heap_capacity) { // idx is leaf node, shift finished
break;
} else { // if child_r does not exist while child_l does, choose child_l
child_max = child_l;
}
} else { // both child L & R exists
child_max = child_l + (heap_key[child_l] < heap_key[child_r]);
}
if (heap_key[idx] >= heap_key[child_max]) {
break;
}
lm_swap_kv(&heap_key[idx], &heap_value[idx], &heap_key[child_max], &heap_value[child_max]);
idx = child_max;
}
}
template <typename TK, typename TV>
static __device__ inline void make_lm_max_heap(
TK *heap_key, TV *heap_value, int size) {
for (int i = size / 2 - 1; i >= 0; i--) {
update_lm_max_heap(heap_key, heap_value, i, size);
}
}
template <typename TK, typename TV>
static __device__ inline void sort_lm_max_heap(TK *heap_key, TV *heap_value, int heap_capacity) {
for (int i = heap_capacity - 1; i > 0; i--) {
lm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]);
update_lm_max_heap(heap_key, heap_value, 0, i);
}
}
template <typename TID>
__device__ TID roundup_div_p(TID a, TID b) {
return (a + b - 1) / b;
}
template <typename T>
__device__ T min_p(T a, T b) {
return a < b ? a : b;
}
template <typename TID>
static __device__ inline void partition(int tid, int nthreads, TID len, int align, TID *start, TID *end) {
TID block_cnt = roundup_div_p<TID>(len, align);
TID remain_block = block_cnt % nthreads;
TID start_block = block_cnt / nthreads * static_cast<TID>(tid) + min_p<TID>(tid, remain_block);
TID end_block = start_block + block_cnt / nthreads + (tid < remain_block);
*start = min_p<TID>(start_block * align, len);
*end = min_p<TID>(end_block * align, len);
}
template <typename TX, typename TY>
static __device__ void primitive_cast(const TX *x, TY *y, int len) {
return;
}
template <>
__device__ void primitive_cast(const float *x, int *y, int len) {
for (int i = 0; i < len; i += 16) {
float32x16_t Y = vload_lm_float32x16(x);
__asm__ __volatile__("vfloat2fix.rz vr0, %0\t\n"
"vstore_mask16.mz vr0{mr1}, 0(%1)" ::"v"(Y),
"r"(y)
: "vr0");
x += 16;
y += 16;
}
mfence_lm();
}
template <>
__device__ void primitive_cast(const int *x, float *y, int len) {
for (int i = 0; i < len; i += 16) {
int32x16_t Y = vload_lm_int32x16(x);
__asm__ __volatile__("vfix2float.rn vr0, %0\t\n"
"vstore_mask16.mz vr0{mr1}, 0(%1)" ::"v"(Y),
"r"(y)
: "vr0");
x += 16;
y += 16;
}
mfence_lm();
}
static __device__ inline void vload2_lm(const float *ptr, float32x16_t &vl, float32x16_t &vh) {
vl = __builtin_xpu2_vload_mask16_mr1(ptr, 0);
vh = __builtin_xpu2_vload_mask16_mr1(ptr + 16, 0);
}
static __device__ inline void vstore2_lm(float *ptr, float32x16_t &vl, float32x16_t &vh) {
vstore_lm_float32x16(ptr, vl);
vstore_lm_float32x16(ptr + 16, vh);
}
template <>
__device__ void primitive_cast(const float *x, float *y, int len) {
if (x == y) {
return;
} else { // just copy
float32x16_t vec_x_0;
float32x16_t vec_x_1;
for (int i = 0; i < len; i += 32) {
vload2_lm(x + i, vec_x_0, vec_x_1);
vstore2_lm(y + i, vec_x_0, vec_x_1);
}
mfence_lm();
}
}
#endif
...@@ -23,7 +23,7 @@ __C infiniStatus_t infinirtGetAllDeviceCount(int *count_array) { ...@@ -23,7 +23,7 @@ __C infiniStatus_t infinirtGetAllDeviceCount(int *count_array) {
return INFINI_STATUS_NULL_POINTER; return INFINI_STATUS_NULL_POINTER;
} }
for (size_t i = 0; i < INFINI_DEVICE_TYPE_COUNT; i++) { for (size_t i = 0; i < INFINI_DEVICE_TYPE_COUNT; i++) {
if (i == INFINI_DEVICE_ILUVATAR || i == INFINI_DEVICE_QY || i == INFINI_DEVICE_KUNLUN || i == INFINI_DEVICE_HYGON) { if (i == INFINI_DEVICE_ILUVATAR || i == INFINI_DEVICE_HYGON || i == INFINI_DEVICE_QY) {
count_array[i] = 0; count_array[i] = 0;
continue; continue;
} }
......
...@@ -33,7 +33,8 @@ _TEST_CASES_ = [ ...@@ -33,7 +33,8 @@ _TEST_CASES_ = [
# w (weight) types # w (weight) types
# Note: 'None' means the same as input dtype # Note: 'None' means the same as input dtype
_X_DTYPES = [] # [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16] # _X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
_X_DTYPES = [] # CPU CI
# x types used for testing # x types used for testing
_VALUE_DTYPES = [InfiniDtype.F32] _VALUE_DTYPES = [InfiniDtype.F32]
...@@ -194,6 +195,7 @@ def test( ...@@ -194,6 +195,7 @@ def test(
lib_topkrouter() lib_topkrouter()
lable_values, lable_indices = torch_topkrouter(x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk) lable_values, lable_indices = torch_topkrouter(x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk)
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment