Commit 67d81412 authored by PanZezhong's avatar PanZezhong
Browse files

issue/5 添加reduce类通用代码,实现rms norm cpu算子

parent fd0242ed
...@@ -46,3 +46,4 @@ jobs: ...@@ -46,3 +46,4 @@ jobs:
run: | run: |
pip install torch pip install torch
LD_LIBRARY_PATH=$HOME/.infini/lib python test/infiniop/matmul.py --cpu LD_LIBRARY_PATH=$HOME/.infini/lib python test/infiniop/matmul.py --cpu
LD_LIBRARY_PATH=$HOME/.infini/lib python test/infiniop/rms_norm.py --cpu
...@@ -16,7 +16,7 @@ __C __export infiniStatus_t infiniopCreateRMSNormDescriptor( ...@@ -16,7 +16,7 @@ __C __export infiniStatus_t infiniopCreateRMSNormDescriptor(
__C __export infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t desc, size_t *size); __C __export infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *workspace, size_t workspace_size, __C __export infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *workspace, size_t workspace_size,
void *y, void const *x, void const *w, void *stream); void *y, const void *x, const void *w, void *stream);
__C __export infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t desc); __C __export infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t desc);
......
#include "common_cpu.h" #include "common_cpu.h"
float f16_to_f32(uint16_t h) {
uint32_t sign = (h & 0x8000) << 16;
int32_t exponent = (h >> 10) & 0x1F;
uint32_t mantissa = h & 0x3FF;
uint32_t f32;
if (exponent == 31) {
if (mantissa != 0) {
f32 = sign | 0x7F800000 | (mantissa << 13);
} else {
f32 = sign | 0x7F800000;
}
} else if (exponent == 0) {
if (mantissa == 0) {
f32 = sign;
} else {
exponent = -14;
while ((mantissa & 0x400) == 0) {
mantissa <<= 1;
exponent--;
}
mantissa &= 0x3FF;
f32 = sign | ((exponent + 127) << 23) | (mantissa << 13);
}
} else {
f32 = sign | ((exponent + 127 - 15) << 23) | (mantissa << 13);
}
float result;
memcpy(&result, &f32, sizeof(result));
return result;
}
uint16_t f32_to_f16(float val) {
uint32_t f32;
memcpy(&f32, &val, sizeof(f32)); // Read the bits of the float32
uint16_t sign = (f32 >> 16) & 0x8000; // Extract the sign bit
int32_t exponent = ((f32 >> 23) & 0xFF) - 127; // Extract and de-bias the exponent
uint32_t mantissa = f32 & 0x7FFFFF; // Extract the mantissa (fraction part)
if (exponent >= 31) { // Special cases for Inf and NaN
// NaN
if (exponent == 128 && mantissa != 0) {
return sign | 0x7E00;
}
// Infinity
return sign | 0x7C00;
} else if (exponent >= -14) { // Normalized case
return (uint16_t)(sign | ((exponent + 15) << 10) | (mantissa >> 13));
} else if (exponent >= -24) {
mantissa |= 0x800000; // Add implicit leading 1
mantissa >>= (-14 - exponent);
return (uint16_t)(sign | (mantissa >> 13));
} else {
// Too small for subnormal: return signed zero
return (uint16_t)sign;
}
}
size_t indexToReducedOffset( size_t indexToReducedOffset(
size_t flat_index, size_t flat_index,
size_t ndim, size_t ndim,
......
...@@ -2,17 +2,16 @@ ...@@ -2,17 +2,16 @@
#define __INFINIOP_COMMON_CPU_H__ #define __INFINIOP_COMMON_CPU_H__
#include "../../../utils.h" #include "../../../utils.h"
#include "cpu_handle.h"
#include <cmath> #include <cmath>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
// convert half-precision float to single-precision float #ifdef ENABLE_OMP
float f16_to_f32(uint16_t code); #include <omp.h>
#endif
// convert single-precision float to half-precision float
uint16_t f32_to_f16(float val);
// return the memory offset of original tensor, given the flattened index of broadcasted tensor // return the memory offset of original tensor, given the flattened index of broadcasted tensor
size_t indexToReducedOffset(size_t flat_index, size_t ndim, const ptrdiff_t *broadcasted_strides, const ptrdiff_t *target_strides); size_t indexToReducedOffset(size_t flat_index, size_t ndim, const ptrdiff_t *broadcasted_strides, const ptrdiff_t *target_strides);
......
#include "matmul_cpu.h" #include "matmul_cpu.h"
#include "../../../devices/cpu/common_cpu.h" #include "../../../devices/cpu/common_cpu.h"
#include "../../../devices/cpu/cpu_handle.h"
namespace op::matmul::cpu { namespace op::matmul::cpu {
...@@ -52,17 +51,17 @@ void calculate( ...@@ -52,17 +51,17 @@ void calculate(
for (size_t k_ = 0; k_ < info.k; ++k_) { for (size_t k_ = 0; k_ < info.k; ++k_) {
auto a_ = reinterpret_cast<const Tdata *>(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride; auto a_ = reinterpret_cast<const Tdata *>(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride;
auto b_ = reinterpret_cast<const Tdata *>(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride; auto b_ = reinterpret_cast<const Tdata *>(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride;
if constexpr (std::is_same<Tdata, uint16_t>::value) { if constexpr (std::is_same<Tdata, fp16_t>::value) {
sum += f16_to_f32(*a_) * f16_to_f32(*b_); sum += utils::cast<float>(*a_) * utils::cast<float>(*b_);
} else { } else {
sum += *a_ * (*b_); sum += *a_ * (*b_);
} }
} }
if constexpr (std::is_same<Tdata, uint16_t>::value) { if constexpr (std::is_same<Tdata, fp16_t>::value) {
if (beta == 0) { if (beta == 0) {
*c_ = f32_to_f16(alpha * sum); *c_ = utils::cast<fp16_t>(alpha * sum);
} else { } else {
*c_ = f32_to_f16(beta * f16_to_f32(*c_) + alpha * sum); *c_ = utils::cast<fp16_t>(beta * utils::cast<float>(*c_) + alpha * sum);
} }
} else { } else {
*c_ = beta * (*c_) + alpha * sum; *c_ = beta * (*c_) + alpha * sum;
...@@ -84,7 +83,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -84,7 +83,7 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
cpu::calculate<uint16_t>(_info, c, beta, a, b, alpha); cpu::calculate<fp16_t>(_info, c, beta, a, b, alpha);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
......
#include "rms_norm_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../reduce/cpu/reduce.h"
namespace op::rms_norm::cpu {
Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
RMSNormInfo info;
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon));
*desc_ptr = new Descriptor(nullptr, info, 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename T>
infiniStatus_t rmsnorm(const RMSNormInfo *info, T *y, const T *x, const T *w) {
#pragma omp parallel for
for (ptrdiff_t i = 0; i < ptrdiff_t(info->shape[0]); i++) {
T *x_ = (T *)(x + i * info->x_strides[0]);
T *y_ = (T *)(y + i * info->y_strides[0]);
// [Reduce] sum of x^2 on last dimension
T ss = op::common_cpu::reduce_op::sumSquared(x_, info->shape[1], info->x_strides[1]);
// 1 / (sqrt(sum/dim + eps))
T rms = (T)1 / std::sqrt(ss / (T)(info->shape[1]) + (T)(info->epsilon));
for (size_t j = 0; j < info->shape[1]; j++) {
y_[j * info->y_strides[1]] = x_[j * info->x_strides[1]] * w[j] * rms;
}
}
return INFINI_STATUS_SUCCESS;
}
template <typename Tw>
infiniStatus_t rmsnormF16(const RMSNormInfo *info, fp16_t *y, const fp16_t *x, const Tw *w) {
#pragma omp parallel for
for (ptrdiff_t i = 0; i < ptrdiff_t(info->shape[0]); i++) {
fp16_t *x_ = (fp16_t *)(x + i * info->x_strides[0]);
fp16_t *y_ = (fp16_t *)(y + i * info->y_strides[0]);
// [Reduce] sum of x^2 on last dimension
float ss = op::common_cpu::reduce_op::sumSquared(x_, info->shape[1], info->x_strides[1]);
// 1 / (sqrt(sum/dim + eps))
float rms = 1.f / std::sqrt(ss / (float)(info->shape[1]) + info->epsilon);
for (size_t j = 0; j < info->shape[1]; j++) {
if constexpr (std::is_same<Tw, float>::value) {
float val = utils::cast<float>(x_[j * info->x_strides[1]]) * w[j] * rms;
y_[j * info->y_strides[1]] = utils::cast<fp16_t>(val);
} else if constexpr (std::is_same<Tw, fp16_t>::value) {
float val = utils::cast<float>(x_[j * info->x_strides[1]]) * utils::cast<float>(w[j]) * rms;
y_[j * info->y_strides[1]] = utils::cast<fp16_t>(val);
} else {
std::abort();
}
}
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) {
if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(rmsnormF16(&_info, (fp16_t *)y, (const fp16_t *)x, (const fp16_t *)w));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(rmsnormF16(&_info, (fp16_t *)y, (const fp16_t *)x, (const float *)w));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_F32) {
CHECK_STATUS(rmsnorm(&_info, (float *)y, (float *)x, (float *)w));
} else if (_info.atype == INFINI_DTYPE_F64) {
CHECK_STATUS(rmsnorm(&_info, (double *)y, (double *)x, (double *)w));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::cpu
#ifndef __RMS_NORM_CPU_H__
#define __RMS_NORM_CPU_H__
#include "../rms_norm.h"
DESCRIPTOR(cpu)
#endif
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
#include "../../handle.h" #include "../../handle.h"
#include "infiniop/ops/rms_norm.h" #include "infiniop/ops/rms_norm.h"
#ifdef ENABLE_CPU_API
#include "cpu/rms_norm_cpu.h"
#endif
__C infiniStatus_t infiniopCreateRMSNormDescriptor( __C infiniStatus_t infiniopCreateRMSNormDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopRMSNormDescriptor_t *desc_ptr, infiniopRMSNormDescriptor_t *desc_ptr,
...@@ -9,10 +13,20 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor( ...@@ -9,10 +13,20 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc, infiniopTensorDescriptor_t w_desc,
float epsilon) { float epsilon) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::rms_norm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::rms_norm::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
x_desc, \
w_desc, \
epsilon);
switch (handle->device) { switch (handle->device) {
#ifdef ENABLE_CPU #ifdef ENABLE_CPU_API
case DevCpu: CREATE(INFINI_DEVICE_CPU, cpu)
return cpuCreateRMSNormDescriptor(handle, (RMSNormCpuDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_NV_GPU
case DevNvGpu: { case DevNvGpu: {
...@@ -45,14 +59,22 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor( ...@@ -45,14 +59,22 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
} }
#endif #endif
} }
#undef CREATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t desc, size_t *size) { __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::rms_norm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU #ifdef ENABLE_CPU_API
case DevCpu: GET(INFINI_DEVICE_CPU, cpu)
return cpuGetRMSNormWorkspaceSize((RMSNormCpuDescriptor_t)desc, size);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_NV_GPU
case DevNvGpu: { case DevNvGpu: {
...@@ -82,15 +104,23 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d ...@@ -82,15 +104,23 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
} }
#endif #endif
} }
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *workspace, size_t workspace_size, __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *workspace, size_t workspace_size,
void *y, const void *x, const void *w, void *stream) { void *y, const void *x, const void *w, void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::rms_norm::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, y, x, w, stream);
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU #ifdef ENABLE_CPU_API
case DevCpu: CALCULATE(INFINI_DEVICE_CPU, cpu)
return cpuRMSNorm((RMSNormCpuDescriptor_t)desc, workspace, workspace_size, y, x, w, stream);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_NV_GPU
case DevNvGpu: { case DevNvGpu: {
...@@ -125,14 +155,22 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works ...@@ -125,14 +155,22 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
} }
#endif #endif
} }
#undef CALCULATE
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t desc) { __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::rms_norm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU #ifdef ENABLE_CPU_API
case DevCpu: DESTROY(INFINI_DEVICE_CPU, cpu)
return cpuDestroyRMSNormDescriptor((RMSNormCpuDescriptor_t)desc);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_NV_GPU
case DevNvGpu: { case DevNvGpu: {
...@@ -161,5 +199,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t ...@@ -161,5 +199,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
} }
#endif #endif
} }
#undef DESTROY
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#ifndef RMS_NORM_H
#define RMS_NORM_H
#include "../../operator.h"
#include "../../tensor.h"
#include <vector>
struct RMSNormInfo {
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> x_strides;
size_t ndim() { return shape.size(); }
size_t dim() { return shape[ndim() - 1]; }
};
inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = w_desc->dtype();
if (x_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16) {
if (wtype != INFINI_DTYPE_F16 && wtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
info->wtype = wtype;
info->atype = atype;
info->epsilon = epsilon;
if (y_desc->ndim() != 2 || x_desc->ndim() != 2 || w_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = y_desc->shape()[0];
size_t dim = y_desc->shape()[1];
if (x_desc->shape()[0] != batch || x_desc->shape()[1] != dim || w_desc->shape()[0] != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
info->shape = std::move(y_desc->shape());
info->y_strides = std::move(y_desc->strides());
info->x_strides = std::move(x_desc->strides());
return INFINI_STATUS_SUCCESS;
}
#define DESCRIPTOR(NAMESPACE) \
namespace op::rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
RMSNormInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
RMSNormInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
size_t workspaceSize() const { return _workspace_size; } \
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t w_desc, \
float epsilon); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \
void *y, const void *x, const void *w, void *stream); \
}; \
}
#endif // RMS_NORM_H
#ifndef __INFINIOP_REDUCE_CPU_H__
#define __INFINIOP_REDUCE_CPU_H__
#include "../../../utils.h"
#include <cstddef>
#ifdef ENABLE_OMP
#include <omp.h>
#endif
#include <type_traits>
namespace op::common_cpu {
namespace reduce_op {
template <typename T>
using ReduceToSame = std::disjunction<
std::is_same<T, float>,
std::is_same<T, double>,
std::is_same<T, uint8_t>,
std::is_same<T, int8_t>,
std::is_same<T, uint16_t>,
std::is_same<T, int16_t>,
std::is_same<T, uint32_t>,
std::is_same<T, int32_t>,
std::is_same<T, uint64_t>,
std::is_same<T, int64_t>>;
template <typename T, typename = std::enable_if_t<ReduceToSame<T>::value>>
T sum(const T *data, size_t len, ptrdiff_t stride = 1) {
T result = 0;
for (size_t i = 0; i < len; i++) {
result += data[i * stride];
}
return result;
}
float sum(const fp16_t *data, size_t len, ptrdiff_t stride = 1) {
float result = 0;
for (size_t i = 0; i < len; i++) {
result += utils::cast<float>(data[i * stride]);
}
return result;
}
template <typename T, typename = std::enable_if_t<ReduceToSame<T>::value>>
T sumSquared(const T *data, size_t len, ptrdiff_t stride = 1) {
T result = 0;
for (size_t i = 0; i < len; i++) {
T val = data[i * stride];
result += val * val;
}
return result;
}
float sumSquared(const fp16_t *data, size_t len, ptrdiff_t stride = 1) {
float result = 0;
for (size_t i = 0; i < len; i++) {
float val = utils::cast<float>(data[i * stride]);
result += val * val;
}
return result;
}
} // namespace reduce_op
} // namespace op::common_cpu
#endif //__INFINIOP_REDUCE_CPU_H__
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "infinicore.h" #include "infinicore.h"
#include "utils/check.h" #include "utils/check.h"
#include "utils/custom_types.h"
#include "utils/rearrange.h" #include "utils/rearrange.h"
inline size_t infiniSizeOf(infiniDtype_t dtype) { inline size_t infiniSizeOf(infiniDtype_t dtype) {
......
#include "custom_types.h"
#include <cstdint>
#include <cstring>
float _f16_to_f32(fp16_t val) {
uint16_t h = val._v;
uint32_t sign = (h & 0x8000) << 16;
int32_t exponent = (h >> 10) & 0x1F;
uint32_t mantissa = h & 0x3FF;
uint32_t f32;
if (exponent == 31) {
if (mantissa != 0) {
f32 = sign | 0x7F800000 | (mantissa << 13);
} else {
f32 = sign | 0x7F800000;
}
} else if (exponent == 0) {
if (mantissa == 0) {
f32 = sign;
} else {
exponent = -14;
while ((mantissa & 0x400) == 0) {
mantissa <<= 1;
exponent--;
}
mantissa &= 0x3FF;
f32 = sign | ((exponent + 127) << 23) | (mantissa << 13);
}
} else {
f32 = sign | ((exponent + 127 - 15) << 23) | (mantissa << 13);
}
float result;
memcpy(&result, &f32, sizeof(result));
return result;
}
fp16_t _f32_to_f16(float val) {
uint32_t f32;
memcpy(&f32, &val, sizeof(f32)); // Read the bits of the float32
uint16_t sign = (f32 >> 16) & 0x8000; // Extract the sign bit
int32_t exponent = ((f32 >> 23) & 0xFF) - 127; // Extract and de-bias the exponent
uint32_t mantissa = f32 & 0x7FFFFF; // Extract the mantissa (fraction part)
if (exponent >= 31) { // Special cases for Inf and NaN
// NaN
if (exponent == 128 && mantissa != 0) {
return fp16_t{static_cast<uint16_t>(sign | 0x7E00)};
}
// Infinity
return fp16_t{static_cast<uint16_t>(sign | 0x7C00)};
} else if (exponent >= -14) { // Normalized case
return fp16_t{(uint16_t)(sign | ((exponent + 15) << 10) | (mantissa >> 13))};
} else if (exponent >= -24) {
mantissa |= 0x800000; // Add implicit leading 1
mantissa >>= (-14 - exponent);
return fp16_t{(uint16_t)(sign | (mantissa >> 13))};
} else {
// Too small for subnormal: return signed zero
return fp16_t{(uint16_t)sign};
}
}
#ifndef __INFINIUTILS_CUSTOM_TYPES_H__
#define __INFINIUTILS_CUSTOM_TYPES_H__
#include <stdint.h>
#include <type_traits>
struct CustomFloat16 {
uint16_t _v;
};
typedef struct CustomFloat16 fp16_t;
struct CustomBFloat16 {
uint16_t _v;
};
typedef struct CustomBFloat16 bf16_t;
float _f16_to_f32(fp16_t val);
fp16_t _f32_to_f16(float val);
namespace utils {
// General template for non-fp16_t conversions
template <typename TypeTo, typename TypeFrom>
TypeTo cast(TypeFrom val) {
if constexpr (std::is_same<TypeTo, TypeFrom>::value) {
return val;
} else if constexpr (std::is_same<TypeTo, fp16_t>::value && std::is_same<TypeFrom, float>::value) {
return _f32_to_f16(val);
} else if constexpr (std::is_same<TypeTo, fp16_t>::value && !std::is_same<TypeFrom, float>::value) {
return _f32_to_f16(static_cast<TypeTo>(val));
} else if constexpr (std::is_same<TypeFrom, fp16_t>::value && std::is_same<TypeTo, float>::value) {
return _f16_to_f32(val);
} else if constexpr (std::is_same<TypeFrom, fp16_t>::value && !std::is_same<TypeTo, float>::value) {
return static_cast<TypeTo>(_f16_to_f32(val));
} else {
return static_cast<TypeTo>(val);
}
}
} // namespace utils
#endif
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
#include <cstring> #include <cstring>
#include <vector> #include <vector>
#ifdef ENABLE_OMP
#include <omp.h>
#endif
namespace utils { namespace utils {
RearrangeMeta::RearrangeMeta(std::vector<ptrdiff_t> meta) RearrangeMeta::RearrangeMeta(std::vector<ptrdiff_t> meta)
...@@ -98,7 +102,8 @@ void RearrangeMeta::launch(void *dst_, const void *src_) const { ...@@ -98,7 +102,8 @@ void RearrangeMeta::launch(void *dst_, const void *src_) const {
if (count_ == 1) { if (count_ == 1) {
std::memcpy(dst_, src_, unit_); std::memcpy(dst_, src_, unit_);
} else { } else {
for (size_t i = 0; i < count_; ++i) { #pragma omp parallel for
for (ptrdiff_t i = 0; i < (ptrdiff_t)count_; ++i) {
auto dst = reinterpret_cast<char *>(dst_); auto dst = reinterpret_cast<char *>(dst_);
auto src = reinterpret_cast<const char *>(src_); auto src = reinterpret_cast<const char *>(src_);
auto rem = i; auto rem = i;
......
...@@ -25,6 +25,7 @@ from libinfiniop import ( ...@@ -25,6 +25,7 @@ from libinfiniop import (
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES = [
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype # y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
((1, 4), (1, 4), (4,), None, None, torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float32), ((16, 2048), (16, 2048), (2048,), None, None, torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float16), ((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32), ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32),
...@@ -57,7 +58,7 @@ def rms_norm(x, w, eps): ...@@ -57,7 +58,7 @@ def rms_norm(x, w, eps):
hidden_states = x.to(torch.float32) hidden_states = x.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True) variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps) hidden_states = hidden_states * torch.rsqrt(variance + eps)
return w * hidden_states.to(input_dtype) return (w * hidden_states).to(input_dtype)
def test( def test(
...@@ -79,7 +80,7 @@ def test( ...@@ -79,7 +80,7 @@ def test(
y = torch.zeros(y_shape, dtype=dtype).to(torch_device) y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device) x = torch.rand(x_shape, dtype=dtype).to(torch_device)
w = torch.ones(w_shape, dtype=w_dtype).to(torch_device) w = torch.rand(w_shape, dtype=w_dtype).to(torch_device)
eps = 1e-5 eps = 1e-5
ans = rms_norm(x, w, eps) ans = rms_norm(x, w, eps)
...@@ -106,7 +107,7 @@ def test( ...@@ -106,7 +107,7 @@ def test(
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [x_tensor, y_tensor, w_tensor]: for tensor in [x_tensor, y_tensor, w_tensor]:
tensor.descriptor.contents.invalidate() tensor.destroyDesc(lib)
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
......
...@@ -8,7 +8,6 @@ add_includedirs("include") ...@@ -8,7 +8,6 @@ add_includedirs("include")
set_encodings("utf-8") set_encodings("utf-8")
if is_mode("debug") then if is_mode("debug") then
add_cxflags("-g -O0")
add_defines("DEBUG_MODE") add_defines("DEBUG_MODE")
end end
...@@ -20,7 +19,7 @@ option("cpu") ...@@ -20,7 +19,7 @@ option("cpu")
option_end() option_end()
option("omp") option("omp")
set_default(false) set_default(true)
set_showmenu(true) set_showmenu(true)
set_description("Enable or disable OpenMP support for cpu kernel") set_description("Enable or disable OpenMP support for cpu kernel")
option_end() option_end()
...@@ -30,6 +29,10 @@ if has_config("cpu") then ...@@ -30,6 +29,10 @@ if has_config("cpu") then
add_defines("ENABLE_CPU_API") add_defines("ENABLE_CPU_API")
end end
if has_config("omp") then
add_defines("ENABLE_OMP")
end
-- 英伟达 -- 英伟达
option("nv-gpu") option("nv-gpu")
set_default(false) set_default(false)
......
...@@ -5,16 +5,21 @@ target("infiniop-cpu") ...@@ -5,16 +5,21 @@ target("infiniop-cpu")
set_warnings("all", "error") set_warnings("all", "error")
if not is_plat("windows") then if is_plat("windows") then
if has_config("omp") then
add_cxflags("/openmp")
end
else
add_cxflags("-fPIC") add_cxflags("-fPIC")
if has_config("omp") then
add_cxflags("-fopenmp")
add_ldflags("-fopenmp")
end
end end
set_languages("cxx17") set_languages("cxx17")
add_files("../src/infiniop/devices/cpu/*.cc", "../src/infiniop/ops/*/cpu/*.cc") add_files("../src/infiniop/devices/cpu/*.cc", "../src/infiniop/ops/*/cpu/*.cc")
if has_config("omp") then
add_cxflags("-fopenmp")
add_ldflags("-fopenmp")
end
target_end() target_end()
target("infinirt-cpu") target("infinirt-cpu")
......
...@@ -21,6 +21,7 @@ target("infiniop-cuda") ...@@ -21,6 +21,7 @@ target("infiniop-cuda")
if is_plat("windows") then if is_plat("windows") then
add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler") add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler")
add_cuflags("-Xcompiler=/W3", "-Xcompiler=/WX") add_cuflags("-Xcompiler=/W3", "-Xcompiler=/WX")
add_cxxflags("/FS")
if CUDNN_ROOT ~= nil then if CUDNN_ROOT ~= nil then
add_linkdirs(CUDNN_ROOT .. "\\lib\\x64") add_linkdirs(CUDNN_ROOT .. "\\lib\\x64")
end end
...@@ -46,6 +47,7 @@ target("infinirt-cuda") ...@@ -46,6 +47,7 @@ target("infinirt-cuda")
if is_plat("windows") then if is_plat("windows") then
add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler") add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler")
add_cxxflags("/FS")
else else
add_cuflags("-Xcompiler=-fPIC") add_cuflags("-Xcompiler=-fPIC")
add_culdflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment