Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e89422a8
Commit
e89422a8
authored
Dec 05, 2022
by
rocking
Browse files
use reference layernorm
parent
180290ba
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
44 deletions
+27
-44
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
...ple/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
+27
-44
No files found.
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
View file @
e89422a8
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
...
@@ -98,63 +99,45 @@ void host_gemm_layernorm(Tensor<HDataType>& e_m_n,
...
@@ -98,63 +99,45 @@ void host_gemm_layernorm(Tensor<HDataType>& e_m_n,
CDEElementOp
cde_element_op
,
CDEElementOp
cde_element_op
,
int
M
,
int
M
,
int
N
,
int
N
,
float
epsilon
=
1e-5
)
AccDataType
epsilon
=
1e-5
)
{
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemm
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
PassThrough
>
;
PassThrough
>
;
using
NormalizeFunctor
=
ck
::
tensor_operation
::
element_wise
::
Normalize
;
using
ReferenceLayernorm
=
ck
::
tensor_operation
::
host
::
ReferenceLayernorm
<
HDataType
,
GammaDataType
,
BetaDataType
,
HDataType
,
AccDataType
,
HElementOp
,
2
,
1
>
;
Tensor
<
AccDataType
>
c_m_n
(
HostTensorDescriptor
{
M
,
N
});
Tensor
<
AccDataType
>
c_m_n
(
HostTensorDescriptor
{
M
,
N
});
auto
ref_gemm
=
ReferenceGemm
Instance
{};
auto
ref_gemm
=
ReferenceGemm
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_
gemm_
invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
auto
ref_
gemm_
argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n
,
a_element_op
,
b_element_op
,
PassThrough
{});
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n
,
a_element_op
,
b_element_op
,
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
ref_
gemm_
invoker
.
Run
(
ref_
gemm_
argument
);
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_m_n
(
m
,
n
),
c_m_n
(
m
,
n
),
bias_n
(
n
),
d1_m_n
(
m
,
n
));
cde_element_op
(
e_m_n
(
m
,
n
),
c_m_n
(
m
,
n
),
bias_n
(
n
),
d1_m_n
(
m
,
n
));
}
// LayerNorm
Tensor
<
AccDataType
>
mean_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
AccDataType
>
meanSquare_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
auto
layerNormInst
=
NormalizeFunctor
{
epsilon
};
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
AccDataType
mean
=
0
;
AccDataType
meanSquare
=
0
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
auto
e_val
=
ck
::
type_convert
<
AccDataType
>
(
e_m_n
(
m
,
n
));
mean
+=
e_val
;
meanSquare
+=
e_val
*
e_val
;
}
mean
/=
N
;
ReferenceLayernorm
ref_layernorm
;
meanSquare
/=
N
;
auto
ref_layernorm_invoker
=
ref_layernorm
.
MakeInvoker
()
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
auto
ref_layernorm_argument
=
ref_layernorm
.
MakeArgument
(
{
e_m_n
,
gamma_n
,
beta_n
,
h_m_n
,
HElementOp
{},
{
M
,
N
},
{
1
},
epsilon
);
AccDataType
h_val
=
0
;
ref_layernorm_invoker
.
Run
(
ref_layernorm_argument
);
AccDataType
e_val
=
ck
::
type_convert
<
AccDataType
>
(
e_m_n
(
m
,
n
));
AccDataType
gamma_val
=
ck
::
type_convert
<
AccDataType
>
(
gamma_n
(
n
));
AccDataType
beta_val
=
ck
::
type_convert
<
AccDataType
>
(
beta_n
(
n
));
layerNormInst
(
h_val
,
e_val
,
mean
,
meanSquare
,
gamma_val
,
beta_val
);
h_m_n
(
m
,
n
)
=
ck
::
type_convert
<
HDataType
>
(
h_val
);
}
}
}
}
int
main
()
int
main
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment