Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
74ffbff5
Commit
74ffbff5
authored
Apr 02, 2025
by
YdrMaster
Browse files
issue/121/refactor: 用 Result 修改 RmsNormInfo 构造流程,并与 Gemm 保持风格一致
Signed-off-by:
YdrMaster
<
ydrml@hotmail.com
>
parent
006fb46e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
129 additions
and
102 deletions
+129
-102
src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.cc
src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.cc
+8
-6
src/infiniop/ops/rms_norm/info.h
src/infiniop/ops/rms_norm/info.h
+78
-0
src/infiniop/ops/rms_norm/rms_norm.h
src/infiniop/ops/rms_norm/rms_norm.h
+43
-96
No files found.
src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.cc
View file @
74ffbff5
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "../../../reduce/cpu/reduce.h"
#include "../../../reduce/cpu/reduce.h"
namespace
op
::
rms_norm
::
cpu
{
namespace
op
::
rms_norm
::
cpu
{
Descriptor
::~
Descriptor
()
{}
Descriptor
::~
Descriptor
()
{}
infiniStatus_t
Descriptor
::
create
(
infiniStatus_t
Descriptor
::
create
(
...
@@ -12,9 +13,9 @@ infiniStatus_t Descriptor::create(
...
@@ -12,9 +13,9 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
w_desc
,
infiniopTensorDescriptor_t
w_desc
,
float
epsilon
)
{
float
epsilon
)
{
RMSNormInfo
info
;
auto
result
=
RMSNormInfo
::
create
(
y_desc
,
x_desc
,
w_desc
,
epsilon
)
;
CHECK_
STATUS
(
createRMSNormInfo
(
&
info
,
y_desc
,
x_desc
,
w_desc
,
epsilon
)
);
CHECK_
RESULT
(
result
);
*
desc_ptr
=
new
Descriptor
(
nullptr
,
info
,
0
,
handle
->
device
,
handle
->
device_id
);
*
desc_ptr
=
new
Descriptor
(
nullptr
,
result
.
take
()
,
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
...
@@ -68,9 +69,10 @@ infiniStatus_t rmsnormF16(const RMSNormInfo *info, fp16_t *y, const fp16_t *x, c
...
@@ -68,9 +69,10 @@ infiniStatus_t rmsnormF16(const RMSNormInfo *info, fp16_t *y, const fp16_t *x, c
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
}
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
infiniStatus_t
Descriptor
::
calculate
(
void
*
y
,
const
void
*
x
,
const
void
*
w
,
void
*
workspace
,
size_t
workspace_size
,
void
*
stream
)
{
void
*
y
,
const
void
*
x
,
const
void
*
w
,
void
*
stream
)
const
{
if
(
_info
.
atype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
atype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
rmsnormF16
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
x
,
(
const
fp16_t
*
)
w
));
CHECK_STATUS
(
rmsnormF16
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
x
,
(
const
fp16_t
*
)
w
));
...
...
src/infiniop/ops/rms_norm/info.h
0 → 100644
View file @
74ffbff5
#
ifndef
__RMS_NORM_INFO_H__
#define __RMS_NORM_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace
op
::
rms_norm
{
class
RMSNormInfo
{
RMSNormInfo
()
=
default
;
public:
infiniDtype_t
wtype
;
infiniDtype_t
atype
;
float
epsilon
;
std
::
vector
<
size_t
>
shape
;
std
::
vector
<
ptrdiff_t
>
y_strides
;
std
::
vector
<
ptrdiff_t
>
x_strides
;
size_t
ndim
()
const
{
return
shape
.
size
();
}
size_t
dim
()
const
{
return
shape
[
ndim
()
-
1
];
}
static
utils
::
Result
<
RMSNormInfo
>
create
(
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
w_desc
,
float
epsilon
)
{
auto
atype
=
y_desc
->
dtype
();
auto
wtype
=
w_desc
->
dtype
();
if
(
x_desc
->
dtype
()
!=
atype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
atype
==
INFINI_DTYPE_F16
)
{
if
(
wtype
!=
INFINI_DTYPE_F16
&&
wtype
!=
INFINI_DTYPE_F32
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
atype
==
INFINI_DTYPE_F32
||
atype
==
INFINI_DTYPE_F64
)
{
if
(
atype
!=
wtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
y_desc
->
ndim
()
!=
2
||
x_desc
->
ndim
()
!=
2
||
w_desc
->
ndim
()
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
batch
=
y_desc
->
shape
()[
0
];
size_t
dim
=
y_desc
->
shape
()[
1
];
if
(
x_desc
->
shape
()[
0
]
!=
batch
||
x_desc
->
shape
()[
1
]
!=
dim
||
w_desc
->
shape
()[
0
]
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
w_desc
->
stride
(
0
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
if
(
x_desc
->
stride
(
1
)
!=
1
||
y_desc
->
stride
(
1
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
return
utils
::
Result
<
RMSNormInfo
>
(
RMSNormInfo
{
wtype
,
atype
,
epsilon
,
y_desc
->
shape
(),
y_desc
->
strides
(),
x_desc
->
strides
(),
});
}
};
}
// namespace op::rms_norm
#endif // __RMS_NORM_INFO_H__
src/infiniop/ops/rms_norm/rms_norm.h
View file @
74ffbff5
#ifndef RMS_NORM_H
#ifndef RMS_NORM_H
#define RMS_NORM_H
#define RMS_NORM_H
#include "../../operator.h"
#include "../../tensor.h"
#include <vector>
struct
RMSNormInfo
{
infiniDtype_t
wtype
;
infiniDtype_t
atype
;
float
epsilon
;
std
::
vector
<
size_t
>
shape
;
std
::
vector
<
ptrdiff_t
>
y_strides
;
std
::
vector
<
ptrdiff_t
>
x_strides
;
size_t
ndim
()
{
return
shape
.
size
();
}
size_t
dim
()
{
return
shape
[
ndim
()
-
1
];
}
};
inline
infiniStatus_t
createRMSNormInfo
(
RMSNormInfo
*
info
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
w_desc
,
float
epsilon
)
{
auto
atype
=
y_desc
->
dtype
();
auto
wtype
=
w_desc
->
dtype
();
if
(
x_desc
->
dtype
()
!=
atype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
atype
==
INFINI_DTYPE_F16
)
{
if
(
wtype
!=
INFINI_DTYPE_F16
&&
wtype
!=
INFINI_DTYPE_F32
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
atype
==
INFINI_DTYPE_F32
||
atype
==
INFINI_DTYPE_F64
)
{
if
(
atype
!=
wtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
info
->
wtype
=
wtype
;
info
->
atype
=
atype
;
info
->
epsilon
=
epsilon
;
if
(
y_desc
->
ndim
()
!=
2
||
x_desc
->
ndim
()
!=
2
||
w_desc
->
ndim
()
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
batch
=
y_desc
->
shape
()[
0
];
#include "../../operator.h"
size_t
dim
=
y_desc
->
shape
()[
1
];
#include "info.h"
if
(
x_desc
->
shape
()[
0
]
!=
batch
||
x_desc
->
shape
()[
1
]
!=
dim
||
w_desc
->
shape
()[
0
]
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
#define DESCRIPTOR(NAMESPACE) \
}
\
namespace op::rms_norm::NAMESPACE { \
if
(
w_desc
->
stride
(
0
)
!=
1
)
{
class Descriptor final : public InfiniopDescriptor { \
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
struct Opaque; \
}
Opaque *_opaque; \
RMSNormInfo _info; \
if
(
x_desc
->
stride
(
1
)
!=
1
||
y_desc
->
stride
(
1
)
!=
1
)
{
size_t _workspace_size; \
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
\
}
Descriptor( \
Opaque *opaque, \
info
->
shape
=
std
::
move
(
y_desc
->
shape
());
RMSNormInfo info, \
info
->
y_strides
=
std
::
move
(
y_desc
->
strides
());
size_t workspace_size, \
info
->
x_strides
=
std
::
move
(
x_desc
->
strides
());
infiniDevice_t device_type, \
int device_id) \
return
INFINI_STATUS_SUCCESS
;
: InfiniopDescriptor{device_type, device_id}, \
}
_opaque(opaque), \
_info(info), \
#define DESCRIPTOR(NAMESPACE) \
_workspace_size(workspace_size) {} \
namespace op::rms_norm::NAMESPACE { \
\
class Descriptor final : public InfiniopDescriptor { \
public: \
struct Opaque; \
~Descriptor(); \
Opaque *_opaque; \
\
RMSNormInfo _info; \
size_t workspaceSize() const { return _workspace_size; } \
size_t _workspace_size; \
\
\
static infiniStatus_t create( \
Descriptor( \
infiniopHandle_t handle, \
Opaque *opaque, \
Descriptor **desc_ptr, \
RMSNormInfo info, \
infiniopTensorDescriptor_t y_desc, \
size_t workspace_size, \
infiniopTensorDescriptor_t x_desc, \
infiniDevice_t device_type, \
infiniopTensorDescriptor_t w_desc, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \
float epsilon); \
_opaque(opaque), \
\
_info(info), \
infiniStatus_t calculate( \
_workspace_size(workspace_size) {} \
void *workspace, size_t workspace_size, \
\
void *y, \
public: \
const void *x, \
~Descriptor(); \
const void *w, \
size_t workspaceSize() const { return _workspace_size; } \
void *stream) const; \
static infiniStatus_t create( \
}; \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t w_desc, \
float epsilon); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \
void *y, const void *x, const void *w, void *stream); \
}; \
}
}
#endif // RMS_NORM_H
#endif // RMS_NORM_H
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment