"...gmock/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "e0d051ea64dd5f32d5b6af9831747d1acb2a9c40"
Commit c1af9783 authored by zhangyue's avatar zhangyue
Browse files

issue/676: format

parent 5584035d
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include "../../../devices/kunlun/kunlun_kernel_common.h" #include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../sort/kunlun/heap.h" #include "../../../sort/kunlun/heap.h"
#include <xpu/kernel/xtdk_io.h>
#include <float.h> #include <float.h>
using namespace device::kunlun::kernel; using namespace device::kunlun::kernel;
...@@ -34,11 +33,11 @@ inline __device__ void descending_sort(T *x, TID *idx, int32_t n) { ...@@ -34,11 +33,11 @@ inline __device__ void descending_sort(T *x, TID *idx, int32_t n) {
mfence_lm(); mfence_lm();
} }
template <typename T, int32_t BLOCK_THREADS=64, int32_t MAX_EXPERTS=256, template <typename T, int32_t BLOCK_THREADS = 64, int32_t MAX_EXPERTS = 256,
int32_t N_GROUPS=8, int32_t TOPK_GROUP=4, int32_t TOPK_PER_GROUP=2> int32_t N_GROUPS = 8, int32_t TOPK_GROUP = 4, int32_t TOPK_PER_GROUP = 2>
__global__ void topkrouter_kernel( __global__ void topkrouter_kernel(
float *values_topk, // 输出数据, 形状[N, topk] float *values_topk, // 输出数据, 形状[N, topk]
int32_t *indices_topk, // 输出索引, 形状[N, topk] int32_t *indices_topk, // 输出索引, 形状[N, topk]
const T *input, // 输入数据 [N, n_experts] const T *input, // 输入数据 [N, n_experts]
const float *d_correction_bias, // 输入数据 [n_experts] const float *d_correction_bias, // 输入数据 [n_experts]
const float routed_scaling_factor, const float routed_scaling_factor,
...@@ -56,14 +55,14 @@ __global__ void topkrouter_kernel( ...@@ -56,14 +55,14 @@ __global__ void topkrouter_kernel(
__shared__ T input_shm[MAX_EXPERTS]; // input shm for i-th token, total N __shared__ T input_shm[MAX_EXPERTS]; // input shm for i-th token, total N
__shared__ float correction_bias_sm[MAX_EXPERTS]; __shared__ float correction_bias_sm[MAX_EXPERTS];
// Copy data into SM // Copy data into SM
if (thread_idx == 0) { if (thread_idx == 0) {
GM2SM_ASYNC(input + block_idx * n_experts, input_shm, n_experts * sizeof(T)); GM2SM_ASYNC(input + block_idx * n_experts, input_shm, n_experts * sizeof(T));
GM2SM_ASYNC(d_correction_bias, correction_bias_sm, n_experts * sizeof(float)); GM2SM_ASYNC(d_correction_bias, correction_bias_sm, n_experts * sizeof(float));
} }
sync_cluster(); sync_cluster();
// Calculate sigmoid scores and add bias // Calculate sigmoid scores and add bias
__shared__ float scores[MAX_EXPERTS]; __shared__ float scores[MAX_EXPERTS];
__shared__ float scores_with_bias_shm[MAX_EXPERTS]; __shared__ float scores_with_bias_shm[MAX_EXPERTS];
......
...@@ -45,8 +45,8 @@ infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const ...@@ -45,8 +45,8 @@ infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const
if (xtype == INFINI_DTYPE_F32) { if (xtype == INFINI_DTYPE_F32) {
topkrouter_kernel<float, BLOCK_SIZE, 256, 8, 4, 2> topkrouter_kernel<float, BLOCK_SIZE, 256, 8, 4, 2>
<<<N, BLOCK_SIZE, stream>>>( <<<N, BLOCK_SIZE, stream>>>(
d_values_out, d_values_out,
d_indices_out, d_indices_out,
(float *)d_input, (float *)d_input,
(const float *)d_correction_bias, (const float *)d_correction_bias,
routed_scaling_factor, routed_scaling_factor,
...@@ -57,13 +57,24 @@ infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const ...@@ -57,13 +57,24 @@ infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const
topkrouter_kernel<half, BLOCK_SIZE, 256, 8, 4, 2> topkrouter_kernel<half, BLOCK_SIZE, 256, 8, 4, 2>
<<<N, BLOCK_SIZE, stream>>>( <<<N, BLOCK_SIZE, stream>>>(
d_values_out, d_values_out,
d_indices_out, d_indices_out,
(half *)d_input, (half *)d_input,
(const float *)d_correction_bias, (const float *)d_correction_bias,
routed_scaling_factor, routed_scaling_factor,
N, N,
width, width,
topk); topk);
} else if (xtype == INFINI_DTYPE_BF16) {
topkrouter_kernel<bfloat16_t, BLOCK_SIZE, 256, 8, 4, 2>
<<<N, BLOCK_SIZE, stream>>>(
d_values_out,
d_indices_out,
(bfloat16_t *)d_input,
(const float *)d_correction_bias,
routed_scaling_factor,
N,
width,
topk);
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#include "xpu/kernel/xtdk_simd_xpu2.h" #include "xpu/kernel/xtdk_simd_xpu2.h"
template <typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void sm_swap_kv(_shared_ptr_ TK* k0, _shared_ptr_ TV* v0, static __device__ inline void sm_swap_kv(_shared_ptr_ TK *k0, _shared_ptr_ TV *v0,
_shared_ptr_ TK* k1, _shared_ptr_ TV* v1) { _shared_ptr_ TK *k1, _shared_ptr_ TV *v1) {
TK tmpk = *k0; TK tmpk = *k0;
TV tmpv = *v0; TV tmpv = *v0;
*k0 = *k1; *k0 = *k1;
...@@ -13,9 +13,9 @@ static __device__ inline void sm_swap_kv(_shared_ptr_ TK* k0, _shared_ptr_ TV* v ...@@ -13,9 +13,9 @@ static __device__ inline void sm_swap_kv(_shared_ptr_ TK* k0, _shared_ptr_ TV* v
*v1 = tmpv; *v1 = tmpv;
} }
template<typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void update_sm_min_heap(_shared_ptr_ TK* heap_key, static __device__ inline void update_sm_min_heap(_shared_ptr_ TK *heap_key,
_shared_ptr_ TV* heap_value, int idx, int heap_capacity) { _shared_ptr_ TV *heap_value, int idx, int heap_capacity) {
while (idx < heap_capacity) { while (idx < heap_capacity) {
int child_l = idx * 2 + 1; int child_l = idx * 2 + 1;
int child_r = idx * 2 + 2; int child_r = idx * 2 + 2;
...@@ -23,10 +23,10 @@ static __device__ inline void update_sm_min_heap(_shared_ptr_ TK* heap_key, ...@@ -23,10 +23,10 @@ static __device__ inline void update_sm_min_heap(_shared_ptr_ TK* heap_key,
if (child_r >= heap_capacity) { if (child_r >= heap_capacity) {
if (child_l >= heap_capacity) { // idx is leaf node, shift finished if (child_l >= heap_capacity) { // idx is leaf node, shift finished
break; break;
} else {// if child_r does not exist while child_l does, choose child_l } else { // if child_r does not exist while child_l does, choose child_l
child_min = child_l; child_min = child_l;
} }
} else {// both child L & R exists } else { // both child L & R exists
child_min = child_l + (heap_key[child_l] > heap_key[child_r]); child_min = child_l + (heap_key[child_l] > heap_key[child_r]);
} }
if (heap_key[idx] <= heap_key[child_min]) { if (heap_key[idx] <= heap_key[child_min]) {
...@@ -37,26 +37,26 @@ static __device__ inline void update_sm_min_heap(_shared_ptr_ TK* heap_key, ...@@ -37,26 +37,26 @@ static __device__ inline void update_sm_min_heap(_shared_ptr_ TK* heap_key,
} }
} }
template<typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void make_sm_min_heap( static __device__ inline void make_sm_min_heap(
_shared_ptr_ TK* heap_key, _shared_ptr_ TV* heap_value, int size) { _shared_ptr_ TK *heap_key, _shared_ptr_ TV *heap_value, int size) {
for (int i = size / 2 - 1; i >= 0; i--) { for (int i = size / 2 - 1; i >= 0; i--) {
update_sm_min_heap(heap_key, heap_value, i, size); update_sm_min_heap(heap_key, heap_value, i, size);
} }
} }
template<typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void sort_sm_min_heap( static __device__ inline void sort_sm_min_heap(
_shared_ptr_ TK* heap_key, _shared_ptr_ TV* heap_value, int heap_capacity) { _shared_ptr_ TK *heap_key, _shared_ptr_ TV *heap_value, int heap_capacity) {
for (int i = heap_capacity - 1; i > 0; i--) { for (int i = heap_capacity - 1; i > 0; i--) {
sm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]); sm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]);
update_sm_min_heap(heap_key, heap_value, 0, i); update_sm_min_heap(heap_key, heap_value, 0, i);
} }
} }
template<typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void update_sm_max_heap(_shared_ptr_ TK* heap_key, static __device__ inline void update_sm_max_heap(_shared_ptr_ TK *heap_key,
_shared_ptr_ TV* heap_value, int idx, int heap_capacity) { _shared_ptr_ TV *heap_value, int idx, int heap_capacity) {
while (idx < heap_capacity) { while (idx < heap_capacity) {
int child_l = idx * 2 + 1; int child_l = idx * 2 + 1;
int child_r = idx * 2 + 2; int child_r = idx * 2 + 2;
...@@ -64,10 +64,10 @@ static __device__ inline void update_sm_max_heap(_shared_ptr_ TK* heap_key, ...@@ -64,10 +64,10 @@ static __device__ inline void update_sm_max_heap(_shared_ptr_ TK* heap_key,
if (child_r >= heap_capacity) { if (child_r >= heap_capacity) {
if (child_l >= heap_capacity) { // idx is leaf node, shift finished if (child_l >= heap_capacity) { // idx is leaf node, shift finished
break; break;
} else {// if child_r does not exist while child_l does, choose child_l } else { // if child_r does not exist while child_l does, choose child_l
child_max = child_l; child_max = child_l;
} }
} else {// both child L & R exists } else { // both child L & R exists
child_max = child_l + (heap_key[child_l] < heap_key[child_r]); child_max = child_l + (heap_key[child_l] < heap_key[child_r]);
} }
if (heap_key[idx] >= heap_key[child_max]) { if (heap_key[idx] >= heap_key[child_max]) {
...@@ -78,17 +78,17 @@ static __device__ inline void update_sm_max_heap(_shared_ptr_ TK* heap_key, ...@@ -78,17 +78,17 @@ static __device__ inline void update_sm_max_heap(_shared_ptr_ TK* heap_key,
} }
} }
template<typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void make_sm_max_heap( static __device__ inline void make_sm_max_heap(
_shared_ptr_ TK* heap_key, _shared_ptr_ TV* heap_value, int size) { _shared_ptr_ TK *heap_key, _shared_ptr_ TV *heap_value, int size) {
for (int i = size / 2 - 1; i >= 0; i--) { for (int i = size / 2 - 1; i >= 0; i--) {
update_sm_max_heap(heap_key, heap_value, i, size); update_sm_max_heap(heap_key, heap_value, i, size);
} }
} }
template<typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void sort_sm_max_heap(_shared_ptr_ TK* heap_key, static __device__ inline void sort_sm_max_heap(_shared_ptr_ TK *heap_key,
_shared_ptr_ TV* heap_value, int heap_capacity) { _shared_ptr_ TV *heap_value, int heap_capacity) {
for (int i = heap_capacity - 1; i > 0; i--) { for (int i = heap_capacity - 1; i > 0; i--) {
sm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]); sm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]);
update_sm_max_heap(heap_key, heap_value, 0, i); update_sm_max_heap(heap_key, heap_value, 0, i);
...@@ -96,8 +96,8 @@ static __device__ inline void sort_sm_max_heap(_shared_ptr_ TK* heap_key, ...@@ -96,8 +96,8 @@ static __device__ inline void sort_sm_max_heap(_shared_ptr_ TK* heap_key,
} }
template <typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void lm_swap_kv(TK* k0, TV* v0, static __device__ inline void lm_swap_kv(TK *k0, TV *v0,
TK* k1, TV* v1) { TK *k1, TV *v1) {
TK tmpk = *k0; TK tmpk = *k0;
TV tmpv = *v0; TV tmpv = *v0;
*k0 = *k1; *k0 = *k1;
...@@ -107,7 +107,7 @@ static __device__ inline void lm_swap_kv(TK* k0, TV* v0, ...@@ -107,7 +107,7 @@ static __device__ inline void lm_swap_kv(TK* k0, TV* v0,
} }
template <typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void update_lm_min_heap(TK* heap_key, TV* heap_value, int idx, int heap_capacity) { static __device__ inline void update_lm_min_heap(TK *heap_key, TV *heap_value, int idx, int heap_capacity) {
while (idx < heap_capacity) { while (idx < heap_capacity) {
int child_l = idx * 2 + 1; int child_l = idx * 2 + 1;
int child_r = idx * 2 + 2; int child_r = idx * 2 + 2;
...@@ -115,10 +115,10 @@ static __device__ inline void update_lm_min_heap(TK* heap_key, TV* heap_value, i ...@@ -115,10 +115,10 @@ static __device__ inline void update_lm_min_heap(TK* heap_key, TV* heap_value, i
if (child_r >= heap_capacity) { if (child_r >= heap_capacity) {
if (child_l >= heap_capacity) { // idx is leaf node, shift finished if (child_l >= heap_capacity) { // idx is leaf node, shift finished
break; break;
} else {// if child_r does not exist while child_l does, choose child_l } else { // if child_r does not exist while child_l does, choose child_l
child_min = child_l; child_min = child_l;
} }
} else {// both child L & R exists } else { // both child L & R exists
child_min = child_l + (heap_key[child_l] > heap_key[child_r]); child_min = child_l + (heap_key[child_l] > heap_key[child_r]);
} }
if (heap_key[idx] <= heap_key[child_min]) { if (heap_key[idx] <= heap_key[child_min]) {
...@@ -129,24 +129,24 @@ static __device__ inline void update_lm_min_heap(TK* heap_key, TV* heap_value, i ...@@ -129,24 +129,24 @@ static __device__ inline void update_lm_min_heap(TK* heap_key, TV* heap_value, i
} }
} }
template<typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void make_lm_min_heap( static __device__ inline void make_lm_min_heap(
TK* heap_key, TV* heap_value, int size) { TK *heap_key, TV *heap_value, int size) {
for (int i = size / 2 - 1; i >= 0; i--) { for (int i = size / 2 - 1; i >= 0; i--) {
update_lm_min_heap(heap_key, heap_value, i, size); update_lm_min_heap(heap_key, heap_value, i, size);
} }
} }
template<typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void sort_lm_min_heap(TK* heap_key, TV* heap_value, int heap_capacity) { static __device__ inline void sort_lm_min_heap(TK *heap_key, TV *heap_value, int heap_capacity) {
for (int i = heap_capacity - 1; i > 0; i--) { for (int i = heap_capacity - 1; i > 0; i--) {
lm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]); lm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]);
update_lm_min_heap(heap_key, heap_value, 0, i); update_lm_min_heap(heap_key, heap_value, 0, i);
} }
} }
template<typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void update_lm_max_heap(TK* heap_key, TV* heap_value, int idx, int heap_capacity) { static __device__ inline void update_lm_max_heap(TK *heap_key, TV *heap_value, int idx, int heap_capacity) {
while (idx < heap_capacity) { while (idx < heap_capacity) {
int child_l = idx * 2 + 1; int child_l = idx * 2 + 1;
int child_r = idx * 2 + 2; int child_r = idx * 2 + 2;
...@@ -154,10 +154,10 @@ static __device__ inline void update_lm_max_heap(TK* heap_key, TV* heap_value, i ...@@ -154,10 +154,10 @@ static __device__ inline void update_lm_max_heap(TK* heap_key, TV* heap_value, i
if (child_r >= heap_capacity) { if (child_r >= heap_capacity) {
if (child_l >= heap_capacity) { // idx is leaf node, shift finished if (child_l >= heap_capacity) { // idx is leaf node, shift finished
break; break;
} else {// if child_r does not exist while child_l does, choose child_l } else { // if child_r does not exist while child_l does, choose child_l
child_max = child_l; child_max = child_l;
} }
} else {// both child L & R exists } else { // both child L & R exists
child_max = child_l + (heap_key[child_l] < heap_key[child_r]); child_max = child_l + (heap_key[child_l] < heap_key[child_r]);
} }
if (heap_key[idx] >= heap_key[child_max]) { if (heap_key[idx] >= heap_key[child_max]) {
...@@ -168,34 +168,34 @@ static __device__ inline void update_lm_max_heap(TK* heap_key, TV* heap_value, i ...@@ -168,34 +168,34 @@ static __device__ inline void update_lm_max_heap(TK* heap_key, TV* heap_value, i
} }
} }
template<typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void make_lm_max_heap( static __device__ inline void make_lm_max_heap(
TK* heap_key, TV* heap_value, int size) { TK *heap_key, TV *heap_value, int size) {
for (int i = size / 2 - 1; i >= 0; i--) { for (int i = size / 2 - 1; i >= 0; i--) {
update_lm_max_heap(heap_key, heap_value, i, size); update_lm_max_heap(heap_key, heap_value, i, size);
} }
} }
template<typename TK, typename TV> template <typename TK, typename TV>
static __device__ inline void sort_lm_max_heap(TK* heap_key, TV* heap_value, int heap_capacity) { static __device__ inline void sort_lm_max_heap(TK *heap_key, TV *heap_value, int heap_capacity) {
for (int i = heap_capacity - 1; i > 0; i--) { for (int i = heap_capacity - 1; i > 0; i--) {
lm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]); lm_swap_kv(&heap_key[0], &heap_value[0], &heap_key[i], &heap_value[i]);
update_lm_max_heap(heap_key, heap_value, 0, i); update_lm_max_heap(heap_key, heap_value, 0, i);
} }
} }
template<typename TID> template <typename TID>
__device__ TID roundup_div_p(TID a, TID b) { __device__ TID roundup_div_p(TID a, TID b) {
return (a + b - 1) / b; return (a + b - 1) / b;
} }
template<typename T> template <typename T>
__device__ T min_p(T a, T b){ __device__ T min_p(T a, T b) {
return a < b ? a : b; return a < b ? a : b;
} }
template<typename TID> template <typename TID>
static __device__ inline void partition(int tid, int nthreads, TID len, int align, TID* start, TID* end) { static __device__ inline void partition(int tid, int nthreads, TID len, int align, TID *start, TID *end) {
TID block_cnt = roundup_div_p<TID>(len, align); TID block_cnt = roundup_div_p<TID>(len, align);
TID remain_block = block_cnt % nthreads; TID remain_block = block_cnt % nthreads;
TID start_block = block_cnt / nthreads * static_cast<TID>(tid) + min_p<TID>(tid, remain_block); TID start_block = block_cnt / nthreads * static_cast<TID>(tid) + min_p<TID>(tid, remain_block);
...@@ -204,51 +204,51 @@ static __device__ inline void partition(int tid, int nthreads, TID len, int alig ...@@ -204,51 +204,51 @@ static __device__ inline void partition(int tid, int nthreads, TID len, int alig
*end = min_p<TID>(end_block * align, len); *end = min_p<TID>(end_block * align, len);
} }
template<typename TX, typename TY> template <typename TX, typename TY>
static __device__ void primitive_cast(const TX* x, TY* y, int len) { static __device__ void primitive_cast(const TX *x, TY *y, int len) {
return; return;
} }
template<> template <>
__device__ void primitive_cast(const float* x, int* y, int len) { __device__ void primitive_cast(const float *x, int *y, int len) {
for (int i = 0; i < len; i += 16) { for (int i = 0; i < len; i += 16) {
float32x16_t Y = vload_lm_float32x16(x); float32x16_t Y = vload_lm_float32x16(x);
__asm__ __volatile__("vfloat2fix.rz vr0, %0\t\n" __asm__ __volatile__("vfloat2fix.rz vr0, %0\t\n"
"vstore_mask16.mz vr0{mr1}, 0(%1)" "vstore_mask16.mz vr0{mr1}, 0(%1)" ::"v"(Y),
::"v"(Y), "r"(y):"vr0"); "r"(y) : "vr0");
x += 16; x += 16;
y += 16; y += 16;
} }
mfence_lm(); mfence_lm();
} }
template<> template <>
__device__ void primitive_cast(const int* x, float* y, int len) { __device__ void primitive_cast(const int *x, float *y, int len) {
for (int i = 0; i < len; i += 16) { for (int i = 0; i < len; i += 16) {
int32x16_t Y = vload_lm_int32x16(x); int32x16_t Y = vload_lm_int32x16(x);
__asm__ __volatile__("vfix2float.rn vr0, %0\t\n" __asm__ __volatile__("vfix2float.rn vr0, %0\t\n"
"vstore_mask16.mz vr0{mr1}, 0(%1)" "vstore_mask16.mz vr0{mr1}, 0(%1)" ::"v"(Y),
::"v"(Y), "r"(y):"vr0"); "r"(y) : "vr0");
x += 16; x += 16;
y += 16; y += 16;
} }
mfence_lm(); mfence_lm();
} }
static __device__ inline void vload2_lm(const float* ptr, float32x16_t& vl, float32x16_t& vh) { static __device__ inline void vload2_lm(const float *ptr, float32x16_t &vl, float32x16_t &vh) {
vl = __builtin_xpu2_vload_mask16_mr1(ptr, 0); vl = __builtin_xpu2_vload_mask16_mr1(ptr, 0);
vh = __builtin_xpu2_vload_mask16_mr1(ptr + 16, 0); vh = __builtin_xpu2_vload_mask16_mr1(ptr + 16, 0);
} }
static __device__ inline void vstore2_lm(float* ptr, float32x16_t& vl, float32x16_t& vh) { static __device__ inline void vstore2_lm(float *ptr, float32x16_t &vl, float32x16_t &vh) {
vstore_lm_float32x16(ptr, vl); vstore_lm_float32x16(ptr, vl);
vstore_lm_float32x16(ptr + 16, vh); vstore_lm_float32x16(ptr + 16, vh);
} }
template<> template <>
__device__ void primitive_cast(const float* x, float* y, int len) { __device__ void primitive_cast(const float *x, float *y, int len) {
if (x == y) { if (x == y) {
return; return;
} else { // just copy } else { // just copy
float32x16_t vec_x_0; float32x16_t vec_x_0;
float32x16_t vec_x_1; float32x16_t vec_x_1;
for (int i = 0; i < len; i += 32) { for (int i = 0; i < len; i += 32) {
......
...@@ -33,7 +33,7 @@ _TEST_CASES_ = [ ...@@ -33,7 +33,7 @@ _TEST_CASES_ = [
# w (weight) types # w (weight) types
# Note: 'None' means the same as input dtype # Note: 'None' means the same as input dtype
_X_DTYPES = [InfiniDtype.F32, InfiniDtype.F16] # [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16] _X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
# x types used for testing # x types used for testing
_VALUE_DTYPES = [InfiniDtype.F32] _VALUE_DTYPES = [InfiniDtype.F32]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment