Commit 137adc72 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/3: 实现causal_softmax cpu算子

parent bd8ae651
#ifndef CAUSAL_SOFTMAX_H
#define CAUSAL_SOFTMAX_H
#include "../../operator.h"
#include "../../tensor.h"
#include <iostream>
#include <vector>
struct CausalSoftmaxInfo {
infiniDtype_t dtype;
size_t batch_size;
ptrdiff_t stride_b;
size_t seq_len;
ptrdiff_t stride_i;
size_t total_seq_len;
ptrdiff_t stride_j;
};
inline infiniStatus_t createCausalSoftmaxInfo(CausalSoftmaxInfo *info, infiniopTensorDescriptor_t y_desc) {
auto dtype = y_desc->dtype();
if (y_desc->dtype() != INFINI_DTYPE_F16 && y_desc->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
info->dtype = dtype;
if (y_desc->ndim() != 2 && y_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (y_desc->shape()[y_desc->ndim() - 1] < y_desc->shape()[y_desc->ndim() - 2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch_size = 1;
ptrdiff_t stride_b = 0;
size_t seq_len = y_desc->shape()[y_desc->ndim() - 2];
ptrdiff_t stride_i = y_desc->strides()[y_desc->ndim() - 2];
size_t total_seq_len = y_desc->shape()[y_desc->ndim() - 1];
ptrdiff_t stride_j = y_desc->strides()[y_desc->ndim() - 1];
if (y_desc->ndim() == 3) {
stride_b = y_desc->strides()[0];
batch_size = y_desc->shape()[0];
}
info->batch_size = batch_size;
info->stride_b = stride_b;
info->seq_len = seq_len;
info->stride_i = stride_i;
info->total_seq_len = total_seq_len;
info->stride_j = stride_j;
return INFINI_STATUS_SUCCESS;
}
#define DESCRIPTOR(NAMESPACE) \
namespace op::causal_softmax::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
CausalSoftmaxInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
CausalSoftmaxInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
size_t workspaceSize() const { return _workspace_size; } \
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \
void *data, void *stream); \
}; \
}
#endif // CAUSAL_SOFTMAX_H
#include "causal_softmax_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../reduce/cpu/reduce.h"
namespace op::causal_softmax::cpu {
Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc) {
CausalSoftmaxInfo info;
CHECK_STATUS(createCausalSoftmaxInfo(&info, y_desc));
*desc_ptr = new Descriptor(nullptr, info, 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename T>
infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) {
#pragma omp parallel for
for (ptrdiff_t index = 0; index < ptrdiff_t(info->batch_size * info->seq_len); index++) {
size_t ind = index;
size_t offset = 0;
size_t i = (ind % info->seq_len);
offset += (ind % info->seq_len) * info->stride_i;
ind /= info->seq_len;
offset += (ind % info->batch_size) * info->stride_b;
for (size_t j = info->total_seq_len - info->seq_len + i + 1; j < info->total_seq_len; j++) {
if constexpr (std::is_same<T, fp16_t>::value) {
data[offset + j * info->stride_j] = utils::cast<fp16_t>(0.0f);
} else {
data[offset + j * info->stride_j] = 0.0f;
}
}
float val = op::common_cpu::reduce_op::max(&data[offset], info->total_seq_len - info->seq_len + i + 1, info->stride_j);
for (size_t j = 0; j <= info->total_seq_len - info->seq_len + i; j++) {
if constexpr (std::is_same<T, fp16_t>::value) {
data[offset + j * info->stride_j] = utils::cast<fp16_t>(std::exp(utils::cast<float>(data[offset + j * info->stride_j]) - val));
} else {
data[offset + j * info->stride_j] = std::exp(data[offset + j * info->stride_j] - val);
}
}
float sum = op::common_cpu::reduce_op::sum(&data[offset], info->total_seq_len - info->seq_len + i + 1, info->stride_j);
for (size_t j = 0; j <= info->total_seq_len - info->seq_len + i; j++) {
if constexpr (std::is_same<T, fp16_t>::value) {
data[offset + j * info->stride_j] = utils::cast<fp16_t>(utils::cast<float>(data[offset + j * info->stride_j]) / sum);
} else {
data[offset + j * info->stride_j] = data[offset + j * info->stride_j] / sum;
}
}
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *data,
void *stream) {
if (_info.dtype == INFINI_DTYPE_F16) {
CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)data));
} else if (_info.dtype == INFINI_DTYPE_F32) {
CHECK_STATUS(causal_softmax<float>(&_info, (float *)data));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::causal_softmax::cpu
#ifndef __CAUSAL_SOFTMAX_CPU_H__
#define __CAUSAL_SOFTMAX_CPU_H__
#include "../causal_softmax.h"
DESCRIPTOR(cpu)
#endif
...@@ -2,14 +2,25 @@ ...@@ -2,14 +2,25 @@
#include "../../handle.h" #include "../../handle.h"
#include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/causal_softmax.h"
#ifdef ENABLE_CPU_API
#include "cpu/causal_softmax_cpu.h"
#endif
__C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor( __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopCausalSoftmaxDescriptor_t *desc_ptr, infiniopCausalSoftmaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc) { infiniopTensorDescriptor_t y_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::causal_softmax::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc);
switch (handle->device) { switch (handle->device) {
#ifdef ENABLE_CPU #ifdef ENABLE_CPU_API
case DevCpu: CREATE(INFINI_DEVICE_CPU, cpu)
return cpuCreateCausalSoftmaxDescriptor(handle, (CausalSoftmaxCpuDescriptor_t *)desc_ptr, y_desc);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_NV_GPU
case DevNvGpu: { case DevNvGpu: {
...@@ -43,10 +54,15 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor( ...@@ -43,10 +54,15 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
} }
__C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDescriptor_t desc, size_t *size) { __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU #ifdef ENABLE_CPU_API
case DevCpu: GET(INFINI_DEVICE_CPU, cpu)
return cpuGetCausalSoftmaxWorkspaceSize((CausalSoftmaxCpuDescriptor_t)desc, size);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_NV_GPU
case DevNvGpu: { case DevNvGpu: {
...@@ -81,10 +97,15 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe ...@@ -81,10 +97,15 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
} }
__C infiniStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t desc, void *workspace, size_t workspace_size, void *data, void *stream) { __C infiniStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t desc, void *workspace, size_t workspace_size, void *data, void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, data, stream);
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU #ifdef ENABLE_CPU_API
case DevCpu: CALCULATE(INFINI_DEVICE_CPU, cpu)
return cpuCausalSoftmax((CausalSoftmaxCpuDescriptor_t)desc, workspace, workspace_size, data, stream);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_NV_GPU
case DevNvGpu: { case DevNvGpu: {
...@@ -118,10 +139,15 @@ __C infiniStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t desc, ...@@ -118,10 +139,15 @@ __C infiniStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t desc,
} }
__C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxDescriptor_t desc) { __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU #ifdef ENABLE_CPU_API
case DevCpu: DESTROY(INFINI_DEVICE_CPU, cpu)
return cpuDestroyCausalSoftmaxDescriptor((CausalSoftmaxCpuDescriptor_t)desc);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_NV_GPU
case DevNvGpu: { case DevNvGpu: {
......
#include "reduce.h"
namespace op::common_cpu::reduce_op {
float sum(const fp16_t *data, size_t len, ptrdiff_t stride) {
float result = 0;
for (size_t i = 0; i < len; i++) {
result += utils::cast<float>(data[i * stride]);
}
return result;
}
float max(const fp16_t *data, size_t len, ptrdiff_t stride) {
float result = utils::cast<float>(data[0]);
for (size_t i = 1; i < len; i++) {
result = std::max(result, utils::cast<float>(data[i * stride]));
}
return result;
}
float sumSquared(const fp16_t *data, size_t len, ptrdiff_t stride) {
float result = 0;
for (size_t i = 0; i < len; i++) {
float val = utils::cast<float>(data[i * stride]);
result += val * val;
}
return result;
}
} // namespace op::common_cpu::reduce_op
...@@ -36,15 +36,20 @@ T sum(const T *data, size_t len, ptrdiff_t stride = 1) { ...@@ -36,15 +36,20 @@ T sum(const T *data, size_t len, ptrdiff_t stride = 1) {
return result; return result;
} }
float sum(const fp16_t *data, size_t len, ptrdiff_t stride = 1) { float sum(const fp16_t *data, size_t len, ptrdiff_t stride = 1);
float result = 0;
for (size_t i = 0; i < len; i++) { template <typename T, typename = std::enable_if_t<ReduceToSame<T>::value>>
result += utils::cast<float>(data[i * stride]); T max(const T *data, size_t len, ptrdiff_t stride = 1) {
T result = data[0];
for (size_t i = 1; i < len; i++) {
result = std::max(result, data[i * stride]);
} }
return result; return result;
} }
float max(const fp16_t *data, size_t len, ptrdiff_t stride = 1);
template <typename T, typename = std::enable_if_t<ReduceToSame<T>::value>> template <typename T, typename = std::enable_if_t<ReduceToSame<T>::value>>
T sumSquared(const T *data, size_t len, ptrdiff_t stride = 1) { T sumSquared(const T *data, size_t len, ptrdiff_t stride = 1) {
T result = 0; T result = 0;
...@@ -56,15 +61,7 @@ T sumSquared(const T *data, size_t len, ptrdiff_t stride = 1) { ...@@ -56,15 +61,7 @@ T sumSquared(const T *data, size_t len, ptrdiff_t stride = 1) {
return result; return result;
} }
float sumSquared(const fp16_t *data, size_t len, ptrdiff_t stride = 1) { float sumSquared(const fp16_t *data, size_t len, ptrdiff_t stride = 1);
float result = 0;
for (size_t i = 0; i < len; i++) {
float val = utils::cast<float>(data[i * stride]);
result += val * val;
}
return result;
}
} // namespace reduce_op } // namespace reduce_op
......
...@@ -80,7 +80,7 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16) ...@@ -80,7 +80,7 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16)
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate() x_tensor.destroyDesc(lib)
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
......
...@@ -19,7 +19,7 @@ target("infiniop-cpu") ...@@ -19,7 +19,7 @@ target("infiniop-cpu")
end end
set_languages("cxx17") set_languages("cxx17")
add_files("../src/infiniop/devices/cpu/*.cc", "../src/infiniop/ops/*/cpu/*.cc") add_files("../src/infiniop/devices/cpu/*.cc", "../src/infiniop/ops/*/cpu/*.cc", "../src/infiniop/reduce/cpu/*.cc")
target_end() target_end()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment