Commit 006fb46e authored by YdrMaster's avatar YdrMaster
Browse files

issue/121/refactor: 用 Result 修改 CausalSoftmax 构造流程,并与 Gemm 保持风格一致


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent fab5ed70
...@@ -2,84 +2,43 @@ ...@@ -2,84 +2,43 @@
#define CAUSAL_SOFTMAX_H #define CAUSAL_SOFTMAX_H
#include "../../operator.h" #include "../../operator.h"
#include "../../tensor.h" #include "info.h"
#include <iostream>
#include <vector> #define DESCRIPTOR(NAMESPACE) \
\
struct CausalSoftmaxInfo { namespace op::causal_softmax::NAMESPACE { \
infiniDtype_t dtype; class Descriptor final : public InfiniopDescriptor { \
size_t batch_size; struct Opaque; \
ptrdiff_t stride_b; Opaque *_opaque; \
size_t seq_len; CausalSoftmaxInfo _info; \
ptrdiff_t stride_i; size_t _workspace_size; \
size_t total_seq_len; \
ptrdiff_t stride_j; Descriptor( \
}; Opaque *opaque, \
CausalSoftmaxInfo info, \
inline infiniStatus_t createCausalSoftmaxInfo(CausalSoftmaxInfo *info, infiniopTensorDescriptor_t y_desc) { size_t workspace_size, \
auto dtype = y_desc->dtype(); infiniDevice_t device_type, \
if (y_desc->dtype() != INFINI_DTYPE_F16 && y_desc->dtype() != INFINI_DTYPE_F32) { int device_id) \
return INFINI_STATUS_BAD_TENSOR_DTYPE; : InfiniopDescriptor{device_type, device_id}, \
} _opaque(opaque), \
info->dtype = dtype; _info(info), \
_workspace_size(workspace_size) {} \
if (y_desc->ndim() != 2 && y_desc->ndim() != 3) { \
return INFINI_STATUS_BAD_TENSOR_SHAPE; public: \
} ~Descriptor(); \
\
if (y_desc->shape()[y_desc->ndim() - 1] < y_desc->shape()[y_desc->ndim() - 2]) { size_t workspaceSize() const { return _workspace_size; } \
return INFINI_STATUS_BAD_TENSOR_SHAPE; \
} static infiniStatus_t create( \
infiniopHandle_t handle, \
size_t batch_size = 1; Descriptor **desc_ptr, \
ptrdiff_t stride_b = 0; infiniopTensorDescriptor_t y_desc); \
size_t seq_len = y_desc->shape()[y_desc->ndim() - 2]; \
ptrdiff_t stride_i = y_desc->strides()[y_desc->ndim() - 2]; infiniStatus_t calculate( \
size_t total_seq_len = y_desc->shape()[y_desc->ndim() - 1]; void *workspace, size_t workspace_size, \
ptrdiff_t stride_j = y_desc->strides()[y_desc->ndim() - 1]; void *data, \
if (y_desc->ndim() == 3) { void *stream) const; \
stride_b = y_desc->strides()[0]; }; \
batch_size = y_desc->shape()[0];
}
info->batch_size = batch_size;
info->stride_b = stride_b;
info->seq_len = seq_len;
info->stride_i = stride_i;
info->total_seq_len = total_seq_len;
info->stride_j = stride_j;
return INFINI_STATUS_SUCCESS;
}
#define DESCRIPTOR(NAMESPACE) \
namespace op::causal_softmax::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
CausalSoftmaxInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
CausalSoftmaxInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
size_t workspaceSize() const { return _workspace_size; } \
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \
void *data, void *stream); \
}; \
} }
#endif // CAUSAL_SOFTMAX_H #endif // CAUSAL_SOFTMAX_H
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
#include "../../../reduce/cpu/reduce.h" #include "../../../reduce/cpu/reduce.h"
namespace op::causal_softmax::cpu { namespace op::causal_softmax::cpu {
Descriptor::~Descriptor() {} Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create( infiniStatus_t Descriptor::create(
infiniopHandle_t handle, infiniopHandle_t handle,
Descriptor **desc_ptr, Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc) { infiniopTensorDescriptor_t y_desc) {
CausalSoftmaxInfo info; auto result = CausalSoftmaxInfo::create(y_desc);
CHECK_STATUS(createCausalSoftmaxInfo(&info, y_desc)); CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, info, 0, handle->device, handle->device_id); *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -53,9 +54,11 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) { ...@@ -53,9 +54,11 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, infiniStatus_t Descriptor::calculate(
void *data, void *workspace, size_t workspace_size,
void *stream) { void *data,
void *stream) const {
if (_info.dtype == INFINI_DTYPE_F16) { if (_info.dtype == INFINI_DTYPE_F16) {
CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)data)); CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)data));
} else if (_info.dtype == INFINI_DTYPE_F32) { } else if (_info.dtype == INFINI_DTYPE_F32) {
......
#ifndef __CAUSAL_SOFTMAX_INFO_H__
#define __CAUSAL_SOFTMAX_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::causal_softmax {
class CausalSoftmaxInfo {
CausalSoftmaxInfo() = default;
public:
infiniDtype_t dtype;
size_t batch_size;
ptrdiff_t stride_b;
size_t seq_len;
ptrdiff_t stride_i;
size_t total_seq_len;
ptrdiff_t stride_j;
static utils::Result<CausalSoftmaxInfo> create(infiniopTensorDescriptor_t y_desc) {
auto dtype = y_desc->dtype();
if (y_desc->dtype() != INFINI_DTYPE_F16 && y_desc->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (y_desc->ndim() != 2 && y_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (y_desc->shape()[y_desc->ndim() - 1] < y_desc->shape()[y_desc->ndim() - 2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch_size = 1;
ptrdiff_t stride_b = 0;
size_t seq_len = y_desc->shape()[y_desc->ndim() - 2];
ptrdiff_t stride_i = y_desc->strides()[y_desc->ndim() - 2];
size_t total_seq_len = y_desc->shape()[y_desc->ndim() - 1];
ptrdiff_t stride_j = y_desc->strides()[y_desc->ndim() - 1];
if (y_desc->ndim() == 3) {
stride_b = y_desc->strides()[0];
batch_size = y_desc->shape()[0];
}
return utils::Result<CausalSoftmaxInfo>(CausalSoftmaxInfo{
dtype,
batch_size,
stride_b,
seq_len,
stride_i,
total_seq_len,
stride_j});
}
};
} // namespace op::causal_softmax
#endif // __CAUSAL_SOFTMAX_INFO_H__
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define __GEMM_H__ #define __GEMM_H__
#include "../../operator.h" #include "../../operator.h"
#include "matmul_info.h" #include "info.h"
/** /**
* # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明 * # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明
...@@ -44,50 +44,50 @@ ...@@ -44,50 +44,50 @@
* 这个宏仅适用于矩阵乘,但这种模式很容易复制到其他算子,以简化和规范算子的声明。 * 这个宏仅适用于矩阵乘,但这种模式很容易复制到其他算子,以简化和规范算子的声明。
*/ */
#define DESCRIPTOR(NAMESPACE) \ #define DESCRIPTOR(NAMESPACE) \
\ \
namespace op::gemm::NAMESPACE { \ namespace op::gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \ class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \ struct Opaque; \
Opaque *_opaque; \ Opaque *_opaque; \
infiniDtype_t _dtype; \ infiniDtype_t _dtype; \
MatmulInfo _info; \ MatmulInfo _info; \
\ size_t _workspace_size; \
Descriptor( \ \
infiniDtype_t dtype, \ Descriptor( \
MatmulInfo info, \ infiniDtype_t dtype, \
size_t workspace_size_, \ MatmulInfo info, \
Opaque *opaque, \ size_t workspace_size_, \
infiniDevice_t device_type, \ Opaque *opaque, \
int device_id) \ infiniDevice_t device_type, \
: InfiniopDescriptor{device_type, device_id}, \ int device_id) \
_opaque(opaque), \ : InfiniopDescriptor{device_type, device_id}, \
_dtype(dtype), \ _opaque(opaque), \
_info(info), \ _dtype(dtype), \
workspace_size(workspace_size_) {} \ _info(info), \
\ _workspace_size(workspace_size_) {} \
public: \ \
size_t workspace_size; \ public: \
\ ~Descriptor(); \
~Descriptor(); \ \
\ size_t workspaceSize() const { return _workspace_size; } \
static infiniStatus_t create( \ \
infiniopHandle_t handle, \ static infiniStatus_t create( \
Descriptor **desc_ptr, \ infiniopHandle_t handle, \
infiniopTensorDescriptor_t c_desc, \ Descriptor **desc_ptr, \
infiniopTensorDescriptor_t a_desc, \ infiniopTensorDescriptor_t c_desc, \
infiniopTensorDescriptor_t b_desc); \ infiniopTensorDescriptor_t a_desc, \
\ infiniopTensorDescriptor_t b_desc); \
infiniStatus_t calculate( \ \
void *workspace, \ infiniStatus_t calculate( \
size_t workspace_size, \ void *workspace, size_t workspace_size, \
void *c, \ void *c, \
float beta, \ float beta, \
const void *a, \ const void *a, \
const void *b, \ const void *b, \
float alpha, \ float alpha, \
void *stream) const; \ void *stream) const; \
}; \ }; \
} }
#endif // __GEMM_H__ #endif // __GEMM_H__
#ifndef __BLAS_H__ #ifndef __GEMM_INFO_H__
#define __BLAS_H__ #define __GEMM_INFO_H__
#include "../../../utils.h" #include "../../../utils.h"
#include "../../operator.h" #include "../../operator.h"
...@@ -132,4 +132,4 @@ public: ...@@ -132,4 +132,4 @@ public:
} // namespace op::gemm } // namespace op::gemm
#endif // __BLAS_H__ #endif // __GEMM_INFO_H__
...@@ -70,9 +70,9 @@ infiniopGetGemmWorkspaceSize( ...@@ -70,9 +70,9 @@ infiniopGetGemmWorkspaceSize(
infiniopGemmDescriptor_t desc, infiniopGemmDescriptor_t desc,
size_t *size) { size_t *size) {
#define GET(CASE, NAMESPACE) \ #define GET(CASE, NAMESPACE) \
case CASE: \ case CASE: \
*size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspace_size; \ *size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS return INFINI_STATUS_SUCCESS
switch (desc->device_type) { switch (desc->device_type) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment