Commit 006fb46e authored by YdrMaster's avatar YdrMaster
Browse files

issue/121/refactor: 用 Result 修改 CausalSoftmax 构造流程,并与 Gemm 保持风格一致


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent fab5ed70
...@@ -2,57 +2,10 @@ ...@@ -2,57 +2,10 @@
#define CAUSAL_SOFTMAX_H #define CAUSAL_SOFTMAX_H
#include "../../operator.h" #include "../../operator.h"
#include "../../tensor.h" #include "info.h"
#include <iostream>
#include <vector>
struct CausalSoftmaxInfo {
infiniDtype_t dtype;
size_t batch_size;
ptrdiff_t stride_b;
size_t seq_len;
ptrdiff_t stride_i;
size_t total_seq_len;
ptrdiff_t stride_j;
};
inline infiniStatus_t createCausalSoftmaxInfo(CausalSoftmaxInfo *info, infiniopTensorDescriptor_t y_desc) {
auto dtype = y_desc->dtype();
if (y_desc->dtype() != INFINI_DTYPE_F16 && y_desc->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
info->dtype = dtype;
if (y_desc->ndim() != 2 && y_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (y_desc->shape()[y_desc->ndim() - 1] < y_desc->shape()[y_desc->ndim() - 2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch_size = 1;
ptrdiff_t stride_b = 0;
size_t seq_len = y_desc->shape()[y_desc->ndim() - 2];
ptrdiff_t stride_i = y_desc->strides()[y_desc->ndim() - 2];
size_t total_seq_len = y_desc->shape()[y_desc->ndim() - 1];
ptrdiff_t stride_j = y_desc->strides()[y_desc->ndim() - 1];
if (y_desc->ndim() == 3) {
stride_b = y_desc->strides()[0];
batch_size = y_desc->shape()[0];
}
info->batch_size = batch_size;
info->stride_b = stride_b;
info->seq_len = seq_len;
info->stride_i = stride_i;
info->total_seq_len = total_seq_len;
info->stride_j = stride_j;
return INFINI_STATUS_SUCCESS;
}
#define DESCRIPTOR(NAMESPACE) \ #define DESCRIPTOR(NAMESPACE) \
\
namespace op::causal_softmax::NAMESPACE { \ namespace op::causal_softmax::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \ class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \ struct Opaque; \
...@@ -65,20 +18,26 @@ inline infiniStatus_t createCausalSoftmaxInfo(CausalSoftmaxInfo *info, infiniopT ...@@ -65,20 +18,26 @@ inline infiniStatus_t createCausalSoftmaxInfo(CausalSoftmaxInfo *info, infiniopT
CausalSoftmaxInfo info, \ CausalSoftmaxInfo info, \
size_t workspace_size, \ size_t workspace_size, \
infiniDevice_t device_type, \ infiniDevice_t device_type, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \ int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \ _opaque(opaque), \
_info(info), \ _info(info), \
_workspace_size(workspace_size) {} \ _workspace_size(workspace_size) {} \
\ \
public: \ public: \
~Descriptor(); \ ~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \ size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \ static infiniStatus_t create( \
infiniopHandle_t handle, \ infiniopHandle_t handle, \
Descriptor **desc_ptr, \ Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc); \ infiniopTensorDescriptor_t y_desc); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \ \
void *data, void *stream); \ infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *data, \
void *stream) const; \
}; \ }; \
} }
......
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
#include "../../../reduce/cpu/reduce.h" #include "../../../reduce/cpu/reduce.h"
namespace op::causal_softmax::cpu { namespace op::causal_softmax::cpu {
Descriptor::~Descriptor() {} Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create( infiniStatus_t Descriptor::create(
infiniopHandle_t handle, infiniopHandle_t handle,
Descriptor **desc_ptr, Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc) { infiniopTensorDescriptor_t y_desc) {
CausalSoftmaxInfo info; auto result = CausalSoftmaxInfo::create(y_desc);
CHECK_STATUS(createCausalSoftmaxInfo(&info, y_desc)); CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, info, 0, handle->device, handle->device_id); *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -53,9 +54,11 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) { ...@@ -53,9 +54,11 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *data, void *data,
void *stream) { void *stream) const {
if (_info.dtype == INFINI_DTYPE_F16) { if (_info.dtype == INFINI_DTYPE_F16) {
CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)data)); CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)data));
} else if (_info.dtype == INFINI_DTYPE_F32) { } else if (_info.dtype == INFINI_DTYPE_F32) {
......
#ifndef __CAUSAL_SOFTMAX_INFO_H__
#define __CAUSAL_SOFTMAX_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::causal_softmax {
class CausalSoftmaxInfo {
CausalSoftmaxInfo() = default;
public:
infiniDtype_t dtype;
size_t batch_size;
ptrdiff_t stride_b;
size_t seq_len;
ptrdiff_t stride_i;
size_t total_seq_len;
ptrdiff_t stride_j;
static utils::Result<CausalSoftmaxInfo> create(infiniopTensorDescriptor_t y_desc) {
auto dtype = y_desc->dtype();
if (y_desc->dtype() != INFINI_DTYPE_F16 && y_desc->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (y_desc->ndim() != 2 && y_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (y_desc->shape()[y_desc->ndim() - 1] < y_desc->shape()[y_desc->ndim() - 2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch_size = 1;
ptrdiff_t stride_b = 0;
size_t seq_len = y_desc->shape()[y_desc->ndim() - 2];
ptrdiff_t stride_i = y_desc->strides()[y_desc->ndim() - 2];
size_t total_seq_len = y_desc->shape()[y_desc->ndim() - 1];
ptrdiff_t stride_j = y_desc->strides()[y_desc->ndim() - 1];
if (y_desc->ndim() == 3) {
stride_b = y_desc->strides()[0];
batch_size = y_desc->shape()[0];
}
return utils::Result<CausalSoftmaxInfo>(CausalSoftmaxInfo{
dtype,
batch_size,
stride_b,
seq_len,
stride_i,
total_seq_len,
stride_j});
}
};
} // namespace op::causal_softmax
#endif // __CAUSAL_SOFTMAX_INFO_H__
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define __GEMM_H__ #define __GEMM_H__
#include "../../operator.h" #include "../../operator.h"
#include "matmul_info.h" #include "info.h"
/** /**
* # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明 * # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
Opaque *_opaque; \ Opaque *_opaque; \
infiniDtype_t _dtype; \ infiniDtype_t _dtype; \
MatmulInfo _info; \ MatmulInfo _info; \
size_t _workspace_size; \
\ \
Descriptor( \ Descriptor( \
infiniDtype_t dtype, \ infiniDtype_t dtype, \
...@@ -64,13 +65,13 @@ ...@@ -64,13 +65,13 @@
_opaque(opaque), \ _opaque(opaque), \
_dtype(dtype), \ _dtype(dtype), \
_info(info), \ _info(info), \
workspace_size(workspace_size_) {} \ _workspace_size(workspace_size_) {} \
\ \
public: \ public: \
size_t workspace_size; \
\
~Descriptor(); \ ~Descriptor(); \
\ \
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \ static infiniStatus_t create( \
infiniopHandle_t handle, \ infiniopHandle_t handle, \
Descriptor **desc_ptr, \ Descriptor **desc_ptr, \
...@@ -79,8 +80,7 @@ ...@@ -79,8 +80,7 @@
infiniopTensorDescriptor_t b_desc); \ infiniopTensorDescriptor_t b_desc); \
\ \
infiniStatus_t calculate( \ infiniStatus_t calculate( \
void *workspace, \ void *workspace, size_t workspace_size, \
size_t workspace_size, \
void *c, \ void *c, \
float beta, \ float beta, \
const void *a, \ const void *a, \
......
#ifndef __BLAS_H__ #ifndef __GEMM_INFO_H__
#define __BLAS_H__ #define __GEMM_INFO_H__
#include "../../../utils.h" #include "../../../utils.h"
#include "../../operator.h" #include "../../operator.h"
...@@ -132,4 +132,4 @@ public: ...@@ -132,4 +132,4 @@ public:
} // namespace op::gemm } // namespace op::gemm
#endif // __BLAS_H__ #endif // __GEMM_INFO_H__
...@@ -72,7 +72,7 @@ infiniopGetGemmWorkspaceSize( ...@@ -72,7 +72,7 @@ infiniopGetGemmWorkspaceSize(
#define GET(CASE, NAMESPACE) \ #define GET(CASE, NAMESPACE) \
case CASE: \ case CASE: \
*size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspace_size; \ *size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS return INFINI_STATUS_SUCCESS
switch (desc->device_type) { switch (desc->device_type) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment