Commit c1af9783 authored by zhangyue's avatar zhangyue
Browse files

issue/676: format

parent 5584035d
......@@ -3,7 +3,6 @@
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../sort/kunlun/heap.h"
#include <xpu/kernel/xtdk_io.h>
#include <float.h>
using namespace device::kunlun::kernel;
......@@ -34,8 +33,8 @@ inline __device__ void descending_sort(T *x, TID *idx, int32_t n) {
mfence_lm();
}
template <typename T, int32_t BLOCK_THREADS=64, int32_t MAX_EXPERTS=256,
int32_t N_GROUPS=8, int32_t TOPK_GROUP=4, int32_t TOPK_PER_GROUP=2>
template <typename T, int32_t BLOCK_THREADS = 64, int32_t MAX_EXPERTS = 256,
int32_t N_GROUPS = 8, int32_t TOPK_GROUP = 4, int32_t TOPK_PER_GROUP = 2>
__global__ void topkrouter_kernel(
float *values_topk, // 输出数据, 形状[N, topk]
int32_t *indices_topk, // 输出索引, 形状[N, topk]
......
......@@ -64,6 +64,17 @@ infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const
N,
width,
topk);
} else if (xtype == INFINI_DTYPE_BF16) {
topkrouter_kernel<bfloat16_t, BLOCK_SIZE, 256, 8, 4, 2>
<<<N, BLOCK_SIZE, stream>>>(
d_values_out,
d_indices_out,
(bfloat16_t *)d_input,
(const float *)d_correction_bias,
routed_scaling_factor,
N,
width,
topk);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -3,8 +3,8 @@
#include "xpu/kernel/xtdk_simd_xpu2.h"
template <typename TK, typename TV>
static __device__ inline void sm_swap_kv(_shared_ptr_ TK* k0, _shared_ptr_ TV* v0,
_shared_ptr_ TK* k1, _shared_ptr_ TV* v1) {
static __device__ inline void sm_swap_kv(_shared_ptr_ TK *k0, _shared_ptr_ TV *v0,
_shared_ptr_ TK *k1, _shared_ptr_ TV *v1) {
TK tmpk = *k0;
TV tmpv = *v0;
*k0 = *k1;
......@@ -13,9 +13,9 @@ static __device__ inline void sm_swap_kv(_shared_ptr_ TK* k0, _shared_ptr_ TV* v
*v1 = tmpv;
}
template<typename TK, typename TV>
static __device__ inline void update_sm_min_heap(_shared_ptr_ TK* heap_key,
_shared_ptr_ TV* heap_value, int idx, int heap_capacity) {
template <typename TK, typename TV>
static __device__ inline void update_sm_min_heap(_shared_ptr_ TK *heap_key,
_shared_ptr_ TV *heap_value, int idx, int heap_capacity) {
while (idx < heap_capacity) {
int child_l = idx * 2 + 1;
int child_r = idx * 2 + 2;
......@@ -23,10 +23,10 @@ static __device__ inline void update_sm_min_heap(_shared_ptr_ TK* heap_key,
if (child_r >= heap_capacity) {
if (child_l >= heap_capacity) { // idx is leaf node, shift finished
break;
} else {// if child_r does not exist while child_l does, choose child_l
} else { // if child_r does not exist while child_l does, choose child_l
child_min = child_l;
}
} else {// both child L & R exists
} else { // both child L & R exists
child_min = child_l + (heap_key[child_l] > heap_key[child_r]);
}
if (heap_key[idx] <= heap_key[child_min]) {
......@@ -37,26 +37,26 @@ static __device__ inline void update_sm_min_heap(_shared_ptr_ TK* heap_key,
}
}
template<typename TK, typename TV>
template <typename TK, typename TV>
static __device__ inline void make_sm_min_heap(
_shared_ptr_ TK* heap_key, _shared_ptr_ TV* heap_value, int size) {
_shared_ptr_ TK *heap_key, _shared_ptr_ TV *heap_value, int size) {
for (int i = size / 2 - 1; i >= 0; i--) {
update_sm_min_heap(heap_key, heap_value, i, size);
}
}
template<typename TK, typename TV>
template <typename TK, typename TV>
static __device__ inline void sort_sm_min_heap(
_shared_ptr_ TK* heap_key, _shared_ptr_ TV* heap_value, int heap_capacity) {
_shared_ptr_ TK *heap_key, _shared_ptr_ TV *heap_value, int heap_capacity) {
for (int i = heap_capacity - 1; i > 0; i--) {
sm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]);
update_sm_min_heap(heap_key, heap_value, 0, i);
}
}
template<typename TK, typename TV>
static __device__ inline void update_sm_max_heap(_shared_ptr_ TK* heap_key,
_shared_ptr_ TV* heap_value, int idx, int heap_capacity) {
template <typename TK, typename TV>
static __device__ inline void update_sm_max_heap(_shared_ptr_ TK *heap_key,
_shared_ptr_ TV *heap_value, int idx, int heap_capacity) {
while (idx < heap_capacity) {
int child_l = idx * 2 + 1;
int child_r = idx * 2 + 2;
......@@ -64,10 +64,10 @@ static __device__ inline void update_sm_max_heap(_shared_ptr_ TK* heap_key,
if (child_r >= heap_capacity) {
if (child_l >= heap_capacity) { // idx is leaf node, shift finished
break;
} else {// if child_r does not exist while child_l does, choose child_l
} else { // if child_r does not exist while child_l does, choose child_l
child_max = child_l;
}
} else {// both child L & R exists
} else { // both child L & R exists
child_max = child_l + (heap_key[child_l] < heap_key[child_r]);
}
if (heap_key[idx] >= heap_key[child_max]) {
......@@ -78,17 +78,17 @@ static __device__ inline void update_sm_max_heap(_shared_ptr_ TK* heap_key,
}
}
template<typename TK, typename TV>
template <typename TK, typename TV>
static __device__ inline void make_sm_max_heap(
_shared_ptr_ TK* heap_key, _shared_ptr_ TV* heap_value, int size) {
_shared_ptr_ TK *heap_key, _shared_ptr_ TV *heap_value, int size) {
for (int i = size / 2 - 1; i >= 0; i--) {
update_sm_max_heap(heap_key, heap_value, i, size);
}
}
template<typename TK, typename TV>
static __device__ inline void sort_sm_max_heap(_shared_ptr_ TK* heap_key,
_shared_ptr_ TV* heap_value, int heap_capacity) {
template <typename TK, typename TV>
static __device__ inline void sort_sm_max_heap(_shared_ptr_ TK *heap_key,
_shared_ptr_ TV *heap_value, int heap_capacity) {
for (int i = heap_capacity - 1; i > 0; i--) {
sm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]);
update_sm_max_heap(heap_key, heap_value, 0, i);
......@@ -96,8 +96,8 @@ static __device__ inline void sort_sm_max_heap(_shared_ptr_ TK* heap_key,
}
template <typename TK, typename TV>
static __device__ inline void lm_swap_kv(TK* k0, TV* v0,
TK* k1, TV* v1) {
static __device__ inline void lm_swap_kv(TK *k0, TV *v0,
TK *k1, TV *v1) {
TK tmpk = *k0;
TV tmpv = *v0;
*k0 = *k1;
......@@ -107,7 +107,7 @@ static __device__ inline void lm_swap_kv(TK* k0, TV* v0,
}
template <typename TK, typename TV>
static __device__ inline void update_lm_min_heap(TK* heap_key, TV* heap_value, int idx, int heap_capacity) {
static __device__ inline void update_lm_min_heap(TK *heap_key, TV *heap_value, int idx, int heap_capacity) {
while (idx < heap_capacity) {
int child_l = idx * 2 + 1;
int child_r = idx * 2 + 2;
......@@ -115,10 +115,10 @@ static __device__ inline void update_lm_min_heap(TK* heap_key, TV* heap_value, i
if (child_r >= heap_capacity) {
if (child_l >= heap_capacity) { // idx is leaf node, shift finished
break;
} else {// if child_r does not exist while child_l does, choose child_l
} else { // if child_r does not exist while child_l does, choose child_l
child_min = child_l;
}
} else {// both child L & R exists
} else { // both child L & R exists
child_min = child_l + (heap_key[child_l] > heap_key[child_r]);
}
if (heap_key[idx] <= heap_key[child_min]) {
......@@ -129,24 +129,24 @@ static __device__ inline void update_lm_min_heap(TK* heap_key, TV* heap_value, i
}
}
template<typename TK, typename TV>
template <typename TK, typename TV>
static __device__ inline void make_lm_min_heap(
TK* heap_key, TV* heap_value, int size) {
TK *heap_key, TV *heap_value, int size) {
for (int i = size / 2 - 1; i >= 0; i--) {
update_lm_min_heap(heap_key, heap_value, i, size);
}
}
template<typename TK, typename TV>
static __device__ inline void sort_lm_min_heap(TK* heap_key, TV* heap_value, int heap_capacity) {
template <typename TK, typename TV>
static __device__ inline void sort_lm_min_heap(TK *heap_key, TV *heap_value, int heap_capacity) {
for (int i = heap_capacity - 1; i > 0; i--) {
lm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]);
update_lm_min_heap(heap_key, heap_value, 0, i);
}
}
template<typename TK, typename TV>
static __device__ inline void update_lm_max_heap(TK* heap_key, TV* heap_value, int idx, int heap_capacity) {
template <typename TK, typename TV>
static __device__ inline void update_lm_max_heap(TK *heap_key, TV *heap_value, int idx, int heap_capacity) {
while (idx < heap_capacity) {
int child_l = idx * 2 + 1;
int child_r = idx * 2 + 2;
......@@ -154,10 +154,10 @@ static __device__ inline void update_lm_max_heap(TK* heap_key, TV* heap_value, i
if (child_r >= heap_capacity) {
if (child_l >= heap_capacity) { // idx is leaf node, shift finished
break;
} else {// if child_r does not exist while child_l does, choose child_l
} else { // if child_r does not exist while child_l does, choose child_l
child_max = child_l;
}
} else {// both child L & R exists
} else { // both child L & R exists
child_max = child_l + (heap_key[child_l] < heap_key[child_r]);
}
if (heap_key[idx] >= heap_key[child_max]) {
......@@ -168,34 +168,34 @@ static __device__ inline void update_lm_max_heap(TK* heap_key, TV* heap_value, i
}
}
template<typename TK, typename TV>
template <typename TK, typename TV>
static __device__ inline void make_lm_max_heap(
TK* heap_key, TV* heap_value, int size) {
TK *heap_key, TV *heap_value, int size) {
for (int i = size / 2 - 1; i >= 0; i--) {
update_lm_max_heap(heap_key, heap_value, i, size);
}
}
template<typename TK, typename TV>
static __device__ inline void sort_lm_max_heap(TK* heap_key, TV* heap_value, int heap_capacity) {
template <typename TK, typename TV>
static __device__ inline void sort_lm_max_heap(TK *heap_key, TV *heap_value, int heap_capacity) {
for (int i = heap_capacity - 1; i > 0; i--) {
lm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]);
update_lm_max_heap(heap_key, heap_value, 0, i);
}
}
template<typename TID>
template <typename TID>
__device__ TID roundup_div_p(TID a, TID b) {
return (a + b - 1) / b;
}
template<typename T>
__device__ T min_p(T a, T b){
template <typename T>
__device__ T min_p(T a, T b) {
return a < b ? a : b;
}
template<typename TID>
static __device__ inline void partition(int tid, int nthreads, TID len, int align, TID* start, TID* end) {
template <typename TID>
static __device__ inline void partition(int tid, int nthreads, TID len, int align, TID *start, TID *end) {
TID block_cnt = roundup_div_p<TID>(len, align);
TID remain_block = block_cnt % nthreads;
TID start_block = block_cnt / nthreads * static_cast<TID>(tid) + min_p<TID>(tid, remain_block);
......@@ -204,48 +204,48 @@ static __device__ inline void partition(int tid, int nthreads, TID len, int alig
*end = min_p<TID>(end_block * align, len);
}
template<typename TX, typename TY>
static __device__ void primitive_cast(const TX* x, TY* y, int len) {
template <typename TX, typename TY>
static __device__ void primitive_cast(const TX *x, TY *y, int len) {
return;
}
template<>
__device__ void primitive_cast(const float* x, int* y, int len) {
template <>
__device__ void primitive_cast(const float *x, int *y, int len) {
for (int i = 0; i < len; i += 16) {
float32x16_t Y = vload_lm_float32x16(x);
__asm__ __volatile__("vfloat2fix.rz vr0, %0\t\n"
"vstore_mask16.mz vr0{mr1}, 0(%1)"
::"v"(Y), "r"(y):"vr0");
"vstore_mask16.mz vr0{mr1}, 0(%1)" ::"v"(Y),
"r"(y) : "vr0");
x += 16;
y += 16;
}
mfence_lm();
}
template<>
__device__ void primitive_cast(const int* x, float* y, int len) {
template <>
__device__ void primitive_cast(const int *x, float *y, int len) {
for (int i = 0; i < len; i += 16) {
int32x16_t Y = vload_lm_int32x16(x);
__asm__ __volatile__("vfix2float.rn vr0, %0\t\n"
"vstore_mask16.mz vr0{mr1}, 0(%1)"
::"v"(Y), "r"(y):"vr0");
"vstore_mask16.mz vr0{mr1}, 0(%1)" ::"v"(Y),
"r"(y) : "vr0");
x += 16;
y += 16;
}
mfence_lm();
}
static __device__ inline void vload2_lm(const float* ptr, float32x16_t& vl, float32x16_t& vh) {
static __device__ inline void vload2_lm(const float *ptr, float32x16_t &vl, float32x16_t &vh) {
vl = __builtin_xpu2_vload_mask16_mr1(ptr, 0);
vh = __builtin_xpu2_vload_mask16_mr1(ptr + 16, 0);
}
static __device__ inline void vstore2_lm(float* ptr, float32x16_t& vl, float32x16_t& vh) {
static __device__ inline void vstore2_lm(float *ptr, float32x16_t &vl, float32x16_t &vh) {
vstore_lm_float32x16(ptr, vl);
vstore_lm_float32x16(ptr + 16, vh);
}
template<>
__device__ void primitive_cast(const float* x, float* y, int len) {
template <>
__device__ void primitive_cast(const float *x, float *y, int len) {
if (x == y) {
return;
} else { // just copy
......
......@@ -33,7 +33,7 @@ _TEST_CASES_ = [
# w (weight) types
# Note: 'None' means the same as input dtype
_X_DTYPES = [InfiniDtype.F32, InfiniDtype.F16] # [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
_X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
# x types used for testing
_VALUE_DTYPES = [InfiniDtype.F32]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment