Commit 86c701a8 authored by PanZezhong's avatar PanZezhong
Browse files

issue/68/fix 使用utils中的cast函数

parent eab61cb3
...@@ -7,43 +7,10 @@ ...@@ -7,43 +7,10 @@
#define CHECK_OR(cmd, action) CHECK_API_OR(cmd, INFINI_STATUS_SUCCESS, action) #define CHECK_OR(cmd, action) CHECK_API_OR(cmd, INFINI_STATUS_SUCCESS, action)
inline float f16_to_f32(uint16_t h) {
uint32_t sign = (h & 0x8000) << 16;
int32_t exponent = (h >> 10) & 0x1F;
uint32_t mantissa = h & 0x3FF;
uint32_t f32;
if (exponent == 31) {
if (mantissa != 0) {
f32 = sign | 0x7F800000 | (mantissa << 13);
} else {
f32 = sign | 0x7F800000;
}
} else if (exponent == 0) {
if (mantissa == 0) {
f32 = sign;
} else {
exponent = -14;
while ((mantissa & 0x400) == 0) {
mantissa <<= 1;
exponent--;
}
mantissa &= 0x3FF;
f32 = sign | ((exponent + 127) << 23) | (mantissa << 13);
}
} else {
f32 = sign | ((exponent + 127 - 15) << 23) | (mantissa << 13);
}
float result;
memcpy(&result, &f32, sizeof(result));
return result;
}
inline double getVal(void *ptr, GGML_TYPE ggml_type) { inline double getVal(void *ptr, GGML_TYPE ggml_type) {
switch (ggml_type) { switch (ggml_type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
return f16_to_f32(*(uint16_t *)ptr); return utils::cast<double>(*(fp16_t *)ptr);
case GGML_TYPE_F32: case GGML_TYPE_F32:
return *(float *)ptr; return *(float *)ptr;
case GGML_TYPE_F64: case GGML_TYPE_F64:
......
...@@ -20,11 +20,11 @@ void printData(const T *data, const std::vector<size_t> &shape, const std::vecto ...@@ -20,11 +20,11 @@ void printData(const T *data, const std::vector<size_t> &shape, const std::vecto
} }
template <> template <>
void printData(const uint16_t *data, const std::vector<size_t> &shape, void printData(const fp16_t *data, const std::vector<size_t> &shape,
const std::vector<ptrdiff_t> &strides, size_t dim) { const std::vector<ptrdiff_t> &strides, size_t dim) {
if (dim == shape.size() - 1) { if (dim == shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) { for (size_t i = 0; i < shape[dim]; i++) {
std::cout << f16_to_f32(*(data + i * strides[dim])) << " "; std::cout << utils::cast<float>(*(data + i * strides[dim])) << " ";
} }
} else if (dim < shape.size() - 1) { } else if (dim < shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) { for (size_t i = 0; i < shape[dim]; i++) {
...@@ -177,7 +177,7 @@ void Tensor::debug() const { ...@@ -177,7 +177,7 @@ void Tensor::debug() const {
std::cout << "Tensor: " << tensor->info() << std::endl; std::cout << "Tensor: " << tensor->info() << std::endl;
switch (_ggml_type) { switch (_ggml_type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
printData((uint16_t *)(tensor->data()), _shape, _strides, 0); printData((fp16_t *)(tensor->data()), _shape, _strides, 0);
break; break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
printData((float *)(tensor->data()), _shape, _strides, 0); printData((float *)(tensor->data()), _shape, _strides, 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment