Commit 006fb46e authored by YdrMaster's avatar YdrMaster
Browse files

issue/121/refactor: 用 Result 修改 CausalSoftmax 构造流程,并与 Gemm 保持风格一致


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent fab5ed70
......@@ -2,84 +2,43 @@
#define CAUSAL_SOFTMAX_H
#include "../../operator.h"
#include "../../tensor.h"
#include <iostream>
#include <vector>
struct CausalSoftmaxInfo {
infiniDtype_t dtype;
size_t batch_size;
ptrdiff_t stride_b;
size_t seq_len;
ptrdiff_t stride_i;
size_t total_seq_len;
ptrdiff_t stride_j;
};
inline infiniStatus_t createCausalSoftmaxInfo(CausalSoftmaxInfo *info, infiniopTensorDescriptor_t y_desc) {
auto dtype = y_desc->dtype();
if (y_desc->dtype() != INFINI_DTYPE_F16 && y_desc->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
info->dtype = dtype;
if (y_desc->ndim() != 2 && y_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (y_desc->shape()[y_desc->ndim() - 1] < y_desc->shape()[y_desc->ndim() - 2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch_size = 1;
ptrdiff_t stride_b = 0;
size_t seq_len = y_desc->shape()[y_desc->ndim() - 2];
ptrdiff_t stride_i = y_desc->strides()[y_desc->ndim() - 2];
size_t total_seq_len = y_desc->shape()[y_desc->ndim() - 1];
ptrdiff_t stride_j = y_desc->strides()[y_desc->ndim() - 1];
if (y_desc->ndim() == 3) {
stride_b = y_desc->strides()[0];
batch_size = y_desc->shape()[0];
}
info->batch_size = batch_size;
info->stride_b = stride_b;
info->seq_len = seq_len;
info->stride_i = stride_i;
info->total_seq_len = total_seq_len;
info->stride_j = stride_j;
return INFINI_STATUS_SUCCESS;
}
#define DESCRIPTOR(NAMESPACE) \
namespace op::causal_softmax::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
CausalSoftmaxInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
CausalSoftmaxInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
size_t workspaceSize() const { return _workspace_size; } \
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \
void *data, void *stream); \
}; \
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::causal_softmax::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
CausalSoftmaxInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
CausalSoftmaxInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *data, \
void *stream) const; \
}; \
}
#endif // CAUSAL_SOFTMAX_H
......@@ -3,15 +3,16 @@
#include "../../../reduce/cpu/reduce.h"
namespace op::causal_softmax::cpu {
Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc) {
CausalSoftmaxInfo info;
CHECK_STATUS(createCausalSoftmaxInfo(&info, y_desc));
*desc_ptr = new Descriptor(nullptr, info, 0, handle->device, handle->device_id);
auto result = CausalSoftmaxInfo::create(y_desc);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
......@@ -53,9 +54,11 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) {
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *data,
void *stream) {
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *data,
void *stream) const {
if (_info.dtype == INFINI_DTYPE_F16) {
CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)data));
} else if (_info.dtype == INFINI_DTYPE_F32) {
......
#ifndef __CAUSAL_SOFTMAX_INFO_H__
#define __CAUSAL_SOFTMAX_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::causal_softmax {
class CausalSoftmaxInfo {
CausalSoftmaxInfo() = default;
public:
infiniDtype_t dtype;
size_t batch_size;
ptrdiff_t stride_b;
size_t seq_len;
ptrdiff_t stride_i;
size_t total_seq_len;
ptrdiff_t stride_j;
static utils::Result<CausalSoftmaxInfo> create(infiniopTensorDescriptor_t y_desc) {
auto dtype = y_desc->dtype();
if (y_desc->dtype() != INFINI_DTYPE_F16 && y_desc->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (y_desc->ndim() != 2 && y_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (y_desc->shape()[y_desc->ndim() - 1] < y_desc->shape()[y_desc->ndim() - 2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch_size = 1;
ptrdiff_t stride_b = 0;
size_t seq_len = y_desc->shape()[y_desc->ndim() - 2];
ptrdiff_t stride_i = y_desc->strides()[y_desc->ndim() - 2];
size_t total_seq_len = y_desc->shape()[y_desc->ndim() - 1];
ptrdiff_t stride_j = y_desc->strides()[y_desc->ndim() - 1];
if (y_desc->ndim() == 3) {
stride_b = y_desc->strides()[0];
batch_size = y_desc->shape()[0];
}
return utils::Result<CausalSoftmaxInfo>(CausalSoftmaxInfo{
dtype,
batch_size,
stride_b,
seq_len,
stride_i,
total_seq_len,
stride_j});
}
};
} // namespace op::causal_softmax
#endif // __CAUSAL_SOFTMAX_INFO_H__
......@@ -2,7 +2,7 @@
#define __GEMM_H__
#include "../../operator.h"
#include "matmul_info.h"
#include "info.h"
/**
* # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明
......@@ -44,50 +44,50 @@
* 这个宏仅适用于矩阵乘,但这种模式很容易复制到其他算子,以简化和规范算子的声明。
*/
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
infiniDtype_t _dtype; \
MatmulInfo _info; \
\
Descriptor( \
infiniDtype_t dtype, \
MatmulInfo info, \
size_t workspace_size_, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_dtype(dtype), \
_info(info), \
workspace_size(workspace_size_) {} \
\
public: \
size_t workspace_size; \
\
~Descriptor(); \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t c_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc); \
\
infiniStatus_t calculate( \
void *workspace, \
size_t workspace_size, \
void *c, \
float beta, \
const void *a, \
const void *b, \
float alpha, \
void *stream) const; \
}; \
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
infiniDtype_t _dtype; \
MatmulInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
infiniDtype_t dtype, \
MatmulInfo info, \
size_t workspace_size_, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_dtype(dtype), \
_info(info), \
_workspace_size(workspace_size_) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t c_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *c, \
float beta, \
const void *a, \
const void *b, \
float alpha, \
void *stream) const; \
}; \
}
#endif // __GEMM_H__
#ifndef __BLAS_H__
#define __BLAS_H__
#ifndef __GEMM_INFO_H__
#define __GEMM_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
......@@ -132,4 +132,4 @@ public:
} // namespace op::gemm
#endif // __BLAS_H__
#endif // __GEMM_INFO_H__
......@@ -70,9 +70,9 @@ infiniopGetGemmWorkspaceSize(
infiniopGemmDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspace_size; \
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment