Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
006fb46e
Commit
006fb46e
authored
Apr 02, 2025
by
YdrMaster
Browse files
issue/121/refactor: 用 Result 修改 CausalSoftmax 构造流程,并与 Gemm 保持风格一致
Signed-off-by:
YdrMaster
<
ydrml@hotmail.com
>
parent
fab5ed70
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
157 additions
and
135 deletions
+157
-135
src/infiniop/ops/causal_softmax/causal_softmax.h
src/infiniop/ops/causal_softmax/causal_softmax.h
+37
-78
src/infiniop/ops/causal_softmax/cpu/causal_softmax_cpu.cc
src/infiniop/ops/causal_softmax/cpu/causal_softmax_cpu.cc
+9
-6
src/infiniop/ops/causal_softmax/info.h
src/infiniop/ops/causal_softmax/info.h
+60
-0
src/infiniop/ops/gemm/gemm.h
src/infiniop/ops/gemm/gemm.h
+45
-45
src/infiniop/ops/gemm/info.h
src/infiniop/ops/gemm/info.h
+3
-3
src/infiniop/ops/gemm/operator.cc
src/infiniop/ops/gemm/operator.cc
+3
-3
No files found.
src/infiniop/ops/causal_softmax/causal_softmax.h
View file @
006fb46e
...
@@ -2,57 +2,10 @@
...
@@ -2,57 +2,10 @@
#define CAUSAL_SOFTMAX_H
#define CAUSAL_SOFTMAX_H
#include "../../operator.h"
#include "../../operator.h"
#include "../../tensor.h"
#include "info.h"
#include <iostream>
#include <vector>
struct
CausalSoftmaxInfo
{
infiniDtype_t
dtype
;
size_t
batch_size
;
ptrdiff_t
stride_b
;
size_t
seq_len
;
ptrdiff_t
stride_i
;
size_t
total_seq_len
;
ptrdiff_t
stride_j
;
};
inline
infiniStatus_t
createCausalSoftmaxInfo
(
CausalSoftmaxInfo
*
info
,
infiniopTensorDescriptor_t
y_desc
)
{
auto
dtype
=
y_desc
->
dtype
();
if
(
y_desc
->
dtype
()
!=
INFINI_DTYPE_F16
&&
y_desc
->
dtype
()
!=
INFINI_DTYPE_F32
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
info
->
dtype
=
dtype
;
if
(
y_desc
->
ndim
()
!=
2
&&
y_desc
->
ndim
()
!=
3
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
y_desc
->
shape
()[
y_desc
->
ndim
()
-
1
]
<
y_desc
->
shape
()[
y_desc
->
ndim
()
-
2
])
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
batch_size
=
1
;
ptrdiff_t
stride_b
=
0
;
size_t
seq_len
=
y_desc
->
shape
()[
y_desc
->
ndim
()
-
2
];
ptrdiff_t
stride_i
=
y_desc
->
strides
()[
y_desc
->
ndim
()
-
2
];
size_t
total_seq_len
=
y_desc
->
shape
()[
y_desc
->
ndim
()
-
1
];
ptrdiff_t
stride_j
=
y_desc
->
strides
()[
y_desc
->
ndim
()
-
1
];
if
(
y_desc
->
ndim
()
==
3
)
{
stride_b
=
y_desc
->
strides
()[
0
];
batch_size
=
y_desc
->
shape
()[
0
];
}
info
->
batch_size
=
batch_size
;
info
->
stride_b
=
stride_b
;
info
->
seq_len
=
seq_len
;
info
->
stride_i
=
stride_i
;
info
->
total_seq_len
=
total_seq_len
;
info
->
stride_j
=
stride_j
;
return
INFINI_STATUS_SUCCESS
;
}
#define DESCRIPTOR(NAMESPACE) \
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::causal_softmax::NAMESPACE { \
namespace op::causal_softmax::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
struct Opaque; \
...
@@ -65,20 +18,26 @@ inline infiniStatus_t createCausalSoftmaxInfo(CausalSoftmaxInfo *info, infiniopT
...
@@ -65,20 +18,26 @@ inline infiniStatus_t createCausalSoftmaxInfo(CausalSoftmaxInfo *info, infiniopT
CausalSoftmaxInfo info, \
CausalSoftmaxInfo info, \
size_t workspace_size, \
size_t workspace_size, \
infiniDevice_t device_type, \
infiniDevice_t device_type, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_opaque(opaque), \
_info(info), \
_info(info), \
_workspace_size(workspace_size) {} \
_workspace_size(workspace_size) {} \
\
\
public: \
public: \
~Descriptor(); \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
static infiniStatus_t create( \
infiniopHandle_t handle, \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc); \
infiniopTensorDescriptor_t y_desc); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \
\
void *data, void *stream); \
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *data, \
void *stream) const; \
}; \
}; \
}
}
...
...
src/infiniop/ops/causal_softmax/cpu/causal_softmax_cpu.cc
View file @
006fb46e
...
@@ -3,15 +3,16 @@
...
@@ -3,15 +3,16 @@
#include "../../../reduce/cpu/reduce.h"
#include "../../../reduce/cpu/reduce.h"
namespace
op
::
causal_softmax
::
cpu
{
namespace
op
::
causal_softmax
::
cpu
{
Descriptor
::~
Descriptor
()
{}
Descriptor
::~
Descriptor
()
{}
infiniStatus_t
Descriptor
::
create
(
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
)
{
infiniopTensorDescriptor_t
y_desc
)
{
CausalSoftmaxInfo
info
;
auto
result
=
CausalSoftmaxInfo
::
create
(
y_desc
)
;
CHECK_
STATUS
(
createCausalSoftmaxInfo
(
&
info
,
y_desc
)
);
CHECK_
RESULT
(
result
);
*
desc_ptr
=
new
Descriptor
(
nullptr
,
info
,
0
,
handle
->
device
,
handle
->
device_id
);
*
desc_ptr
=
new
Descriptor
(
nullptr
,
result
.
take
()
,
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
...
@@ -53,9 +54,11 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) {
...
@@ -53,9 +54,11 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) {
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
data
,
void
*
data
,
void
*
stream
)
{
void
*
stream
)
const
{
if
(
_info
.
dtype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
dtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
causal_softmax
<
fp16_t
>
(
&
_info
,
(
fp16_t
*
)
data
));
CHECK_STATUS
(
causal_softmax
<
fp16_t
>
(
&
_info
,
(
fp16_t
*
)
data
));
}
else
if
(
_info
.
dtype
==
INFINI_DTYPE_F32
)
{
}
else
if
(
_info
.
dtype
==
INFINI_DTYPE_F32
)
{
...
...
src/infiniop/ops/causal_softmax/info.h
0 → 100644
View file @
006fb46e
#
ifndef
__CAUSAL_SOFTMAX_INFO_H__
#define __CAUSAL_SOFTMAX_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace
op
::
causal_softmax
{
class
CausalSoftmaxInfo
{
CausalSoftmaxInfo
()
=
default
;
public:
infiniDtype_t
dtype
;
size_t
batch_size
;
ptrdiff_t
stride_b
;
size_t
seq_len
;
ptrdiff_t
stride_i
;
size_t
total_seq_len
;
ptrdiff_t
stride_j
;
static
utils
::
Result
<
CausalSoftmaxInfo
>
create
(
infiniopTensorDescriptor_t
y_desc
)
{
auto
dtype
=
y_desc
->
dtype
();
if
(
y_desc
->
dtype
()
!=
INFINI_DTYPE_F16
&&
y_desc
->
dtype
()
!=
INFINI_DTYPE_F32
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
y_desc
->
ndim
()
!=
2
&&
y_desc
->
ndim
()
!=
3
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
y_desc
->
shape
()[
y_desc
->
ndim
()
-
1
]
<
y_desc
->
shape
()[
y_desc
->
ndim
()
-
2
])
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
batch_size
=
1
;
ptrdiff_t
stride_b
=
0
;
size_t
seq_len
=
y_desc
->
shape
()[
y_desc
->
ndim
()
-
2
];
ptrdiff_t
stride_i
=
y_desc
->
strides
()[
y_desc
->
ndim
()
-
2
];
size_t
total_seq_len
=
y_desc
->
shape
()[
y_desc
->
ndim
()
-
1
];
ptrdiff_t
stride_j
=
y_desc
->
strides
()[
y_desc
->
ndim
()
-
1
];
if
(
y_desc
->
ndim
()
==
3
)
{
stride_b
=
y_desc
->
strides
()[
0
];
batch_size
=
y_desc
->
shape
()[
0
];
}
return
utils
::
Result
<
CausalSoftmaxInfo
>
(
CausalSoftmaxInfo
{
dtype
,
batch_size
,
stride_b
,
seq_len
,
stride_i
,
total_seq_len
,
stride_j
});
}
};
}
// namespace op::causal_softmax
#endif // __CAUSAL_SOFTMAX_INFO_H__
src/infiniop/ops/gemm/gemm.h
View file @
006fb46e
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#define __GEMM_H__
#define __GEMM_H__
#include "../../operator.h"
#include "../../operator.h"
#include "
matmul_
info.h"
#include "info.h"
/**
/**
* # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明
* # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明
...
@@ -52,6 +52,7 @@
...
@@ -52,6 +52,7 @@
Opaque *_opaque; \
Opaque *_opaque; \
infiniDtype_t _dtype; \
infiniDtype_t _dtype; \
MatmulInfo _info; \
MatmulInfo _info; \
size_t _workspace_size; \
\
\
Descriptor( \
Descriptor( \
infiniDtype_t dtype, \
infiniDtype_t dtype, \
...
@@ -64,13 +65,13 @@
...
@@ -64,13 +65,13 @@
_opaque(opaque), \
_opaque(opaque), \
_dtype(dtype), \
_dtype(dtype), \
_info(info), \
_info(info), \
workspace_size(workspace_size_) {} \
_
workspace_size(workspace_size_) {}
\
\
\
public: \
public: \
size_t workspace_size; \
\
~Descriptor(); \
~Descriptor(); \
\
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
static infiniStatus_t create( \
infiniopHandle_t handle, \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
Descriptor **desc_ptr, \
...
@@ -79,8 +80,7 @@
...
@@ -79,8 +80,7 @@
infiniopTensorDescriptor_t b_desc); \
infiniopTensorDescriptor_t b_desc); \
\
\
infiniStatus_t calculate( \
infiniStatus_t calculate( \
void *workspace, \
void *workspace, size_t workspace_size, \
size_t workspace_size, \
void *c, \
void *c, \
float beta, \
float beta, \
const void *a, \
const void *a, \
...
...
src/infiniop/ops/gemm/
matmul_
info.h
→
src/infiniop/ops/gemm/info.h
View file @
006fb46e
#ifndef __
BLAS
_H__
#ifndef __
GEMM_INFO
_H__
#define __
BLAS
_H__
#define __
GEMM_INFO
_H__
#include "../../../utils.h"
#include "../../../utils.h"
#include "../../operator.h"
#include "../../operator.h"
...
@@ -132,4 +132,4 @@ public:
...
@@ -132,4 +132,4 @@ public:
}
// namespace op::gemm
}
// namespace op::gemm
#endif // __
BLAS
_H__
#endif // __
GEMM_INFO
_H__
src/infiniop/ops/gemm/operator.cc
View file @
006fb46e
...
@@ -72,7 +72,7 @@ infiniopGetGemmWorkspaceSize(
...
@@ -72,7 +72,7 @@ infiniopGetGemmWorkspaceSize(
#define GET(CASE, NAMESPACE) \
#define GET(CASE, NAMESPACE) \
case CASE: \
case CASE: \
*size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspace
_s
ize; \
*size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspace
S
ize
()
; \
return INFINI_STATUS_SUCCESS
return INFINI_STATUS_SUCCESS
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment