Commit 0bf217c1 authored by PanZezhong's avatar PanZezhong
Browse files

fix: 修复寒武纪编译cpu时的警告

parent 7a833987
#include "./common_cpu.h" #include "./common_cpu.h"
float f16_to_f32(uint16_t h) { float f16_to_f32(uint16_t h) {
uint32_t sign = (h & 0x8000) << 16; // Extract the sign bit uint32_t sign = (h & 0x8000) << 16;
int32_t exponent = (h >> 10) & 0x1F;// Extract the exponent int32_t exponent = (h >> 10) & 0x1F;
uint32_t mantissa = h & 0x3FF; // Extract the mantissa (fraction part) uint32_t mantissa = h & 0x3FF;
if (exponent == 31) {// Special case for Inf and NaN uint32_t f32;
if (exponent == 31) {
if (mantissa != 0) { if (mantissa != 0) {
// NaN: Set float32 NaN f32 = sign | 0x7F800000 | (mantissa << 13);
uint32_t f32 = sign | 0x7F800000 | (mantissa << 13);
return *(float *) &f32;
} else { } else {
// Infinity f32 = sign | 0x7F800000;
uint32_t f32 = sign | 0x7F800000;
return *(float *) &f32;
} }
} else if (exponent == 0) {// Subnormal float16 or zero } else if (exponent == 0) {
if (mantissa == 0) { if (mantissa == 0) {
// Zero (positive or negative) f32 = sign;
uint32_t f32 = sign;// Just return signed zero
return *(float *) &f32;
} else { } else {
// Subnormal: Convert to normalized float32 exponent = -14;
exponent = -14; // Set exponent for subnormal numbers while ((mantissa & 0x400) == 0) {
while ((mantissa & 0x400) == 0) {// Normalize mantissa
mantissa <<= 1; mantissa <<= 1;
exponent--; exponent--;
} }
mantissa &= 0x3FF;// Clear the leading 1 bit mantissa &= 0x3FF;
uint32_t f32 = sign | ((exponent + 127) << 23) | (mantissa << 13); f32 = sign | ((exponent + 127) << 23) | (mantissa << 13);
return *(float *) &f32;
} }
} else { } else {
// Normalized float16 f32 = sign | ((exponent + 127 - 15) << 23) | (mantissa << 13);
uint32_t f32 = sign | ((exponent + 127 - 15) << 23) | (mantissa << 13);
return *(float *) &f32;
} }
float result;
memcpy(&result, &f32, sizeof(result));
return result;
} }
uint16_t f32_to_f16(float val) { uint16_t f32_to_f16(float val) {
uint32_t f32 = *(uint32_t *) &val; // Read the bits of the float32 uint32_t f32;
memcpy(&f32, &val, sizeof(f32)); // Read the bits of the float32
uint16_t sign = (f32 >> 16) & 0x8000; // Extract the sign bit uint16_t sign = (f32 >> 16) & 0x8000; // Extract the sign bit
int32_t exponent = ((f32 >> 23) & 0xFF) - 127;// Extract and de-bias the exponent int32_t exponent =
((f32 >> 23) & 0xFF) - 127; // Extract and de-bias the exponent
uint32_t mantissa = f32 & 0x7FFFFF; // Extract the mantissa (fraction part) uint32_t mantissa = f32 & 0x7FFFFF; // Extract the mantissa (fraction part)
if (exponent >= 31) {// Special cases for Inf and NaN if (exponent >= 31) { // Special cases for Inf and NaN
// NaN // NaN
if (exponent == 128 && mantissa != 0) { if (exponent == 128 && mantissa != 0) {
return sign | 0x7E00; return sign | 0x7E00;
} }
// Infinity // Infinity
return sign | 0x7C00; return sign | 0x7C00;
} else if (exponent >= -14) {// Normalized case } else if (exponent >= -14) { // Normalized case
return (uint16_t)(sign | ((exponent + 15) << 10) | (mantissa >> 13)); return (uint16_t)(sign | ((exponent + 15) << 10) | (mantissa >> 13));
} else if (exponent >= -24) { } else if (exponent >= -24) {
mantissa |= 0x800000;// Add implicit leading 1 mantissa |= 0x800000; // Add implicit leading 1
mantissa >>= (-14 - exponent); mantissa >>= (-14 - exponent);
return (uint16_t)(sign | (mantissa >> 13)); return (uint16_t)(sign | (mantissa >> 13));
} else { } else {
...@@ -63,7 +60,9 @@ uint16_t f32_to_f16(float val) { ...@@ -63,7 +60,9 @@ uint16_t f32_to_f16(float val) {
} }
} }
size_t indexToReducedOffset(size_t flat_index, size_t ndim, int64_t const *broadcasted_strides, int64_t const *target_strides) { size_t indexToReducedOffset(size_t flat_index, size_t ndim,
int64_t const *broadcasted_strides,
int64_t const *target_strides) {
size_t res = 0; size_t res = 0;
for (size_t i = 0; i < ndim; ++i) { for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i] * target_strides[i]; res += flat_index / broadcasted_strides[i] * target_strides[i];
...@@ -72,7 +71,8 @@ size_t indexToReducedOffset(size_t flat_index, size_t ndim, int64_t const *broad ...@@ -72,7 +71,8 @@ size_t indexToReducedOffset(size_t flat_index, size_t ndim, int64_t const *broad
return res; return res;
} }
size_t indexToOffset(size_t flat_index, size_t ndim, size_t const *shape, int64_t const *strides) { size_t indexToOffset(size_t flat_index, size_t ndim, size_t const *shape,
int64_t const *strides) {
size_t res = 0; size_t res = 0;
for (size_t i = ndim; i-- >= 0;) { for (size_t i = ndim; i-- >= 0;) {
res += (flat_index % shape[i]) * strides[i]; res += (flat_index % shape[i]) * strides[i];
...@@ -89,11 +89,12 @@ size_t getPaddedSize(size_t ndim, size_t *shape, size_t const *pads) { ...@@ -89,11 +89,12 @@ size_t getPaddedSize(size_t ndim, size_t *shape, size_t const *pads) {
return total_size; return total_size;
} }
std::vector<size_t> getPaddedShape(size_t ndim, size_t const *shape, size_t const *pads) { std::vector<size_t> getPaddedShape(size_t ndim, size_t const *shape,
size_t const *pads) {
std::vector<size_t> padded_shape(ndim); std::vector<size_t> padded_shape(ndim);
memcpy(padded_shape.data(), shape, ndim * sizeof(size_t)); memcpy(padded_shape.data(), shape, ndim * sizeof(size_t));
for (size_t i = 2; i < ndim; ++i) { for (size_t i = 2; i < ndim; ++i) {
padded_shape[i] += 2 * pads[i - 2]; padded_shape[i] += 2 * pads[i - 2];
} }
return std::move(padded_shape); return padded_shape;
} }
...@@ -46,15 +46,15 @@ infiniopStatus_t matmul_cpu(infiniopMatmulCpuDescriptor_t desc, void *c, ...@@ -46,15 +46,15 @@ infiniopStatus_t matmul_cpu(infiniopMatmulCpuDescriptor_t desc, void *c,
std::swap(a, b); std::swap(a, b);
} }
for (int i = 0; i < info.batch; ++i) { for (size_t i = 0; i < info.batch; ++i) {
for (int m_ = 0; m_ < info.m; ++m_) { for (size_t m_ = 0; m_ < info.m; ++m_) {
for (int n_ = 0; n_ < info.n; ++n_) { for (size_t n_ = 0; n_ < info.n; ++n_) {
auto c_ = reinterpret_cast<Tdata *>(c) + auto c_ = reinterpret_cast<Tdata *>(c) +
i * info.c_matrix.stride + i * info.c_matrix.stride +
m_ * info.c_matrix.row_stride + m_ * info.c_matrix.row_stride +
n_ * info.c_matrix.col_stride; n_ * info.c_matrix.col_stride;
float sum = 0; float sum = 0;
for (int k_ = 0; k_ < info.k; ++k_) { for (size_t k_ = 0; k_ < info.k; ++k_) {
auto a_ = reinterpret_cast<Tdata const *>(a) + auto a_ = reinterpret_cast<Tdata const *>(a) +
i * info.a_matrix.stride + i * info.a_matrix.stride +
m_ * info.a_matrix.row_stride + m_ * info.a_matrix.row_stride +
......
...@@ -46,7 +46,7 @@ __C infiniopStatus_t infiniopCreateMatmulDescriptor( ...@@ -46,7 +46,7 @@ __C infiniopStatus_t infiniopCreateMatmulDescriptor(
} }
#endif #endif
} }
return INFINIOP_STATUS_BAD_DEVICE; return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniopStatus_t __C infiniopStatus_t
...@@ -76,7 +76,7 @@ infiniopGetMatmulWorkspaceSize(infiniopMatmulDescriptor_t desc, size_t *size) { ...@@ -76,7 +76,7 @@ infiniopGetMatmulWorkspaceSize(infiniopMatmulDescriptor_t desc, size_t *size) {
} }
#endif #endif
} }
return INFINIOP_STATUS_BAD_DEVICE; return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniopStatus_t infiniopMatmul(infiniopMatmulDescriptor_t desc, __C infiniopStatus_t infiniopMatmul(infiniopMatmulDescriptor_t desc,
...@@ -106,7 +106,7 @@ __C infiniopStatus_t infiniopMatmul(infiniopMatmulDescriptor_t desc, ...@@ -106,7 +106,7 @@ __C infiniopStatus_t infiniopMatmul(infiniopMatmulDescriptor_t desc,
workspace_size, c, a, b, alpha, beta, stream); workspace_size, c, a, b, alpha, beta, stream);
#endif #endif
} }
return INFINIOP_STATUS_BAD_DEVICE; return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniopStatus_t __C infiniopStatus_t
...@@ -134,5 +134,5 @@ infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t desc) { ...@@ -134,5 +134,5 @@ infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t desc) {
} }
#endif #endif
} }
return INFINIOP_STATUS_BAD_DEVICE; return INFINIOP_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment