Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3421b74c
Commit
3421b74c
authored
Jul 06, 2022
by
Chao Liu
Browse files
use type_convert
parent
1deb01b2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
14 deletions
+14
-14
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp
..._gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp
+10
-10
example/24_contraction/contraction_bilinear_xdl_fp32.cpp
example/24_contraction/contraction_bilinear_xdl_fp32.cpp
+2
-2
example/24_contraction/contraction_scale_xdl_fp32.cpp
example/24_contraction/contraction_scale_xdl_fp32.cpp
+2
-2
No files found.
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp
View file @
3421b74c
...
...
@@ -166,15 +166,15 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
c_m_n
(
m
,
n
))
+
static_cas
t
<
AccDataType
>
(
bias_n
(
n
));
AccDataType
acc
=
ck
::
type_convert
<
AccDataType
>
(
c_m_n
(
m
,
n
))
+
ck
::
type_conver
t
<
AccDataType
>
(
bias_n
(
n
));
AccDataType
c1
=
static_cas
t
<
AccDataType
>
(
c1_m_n
(
m
,
n
));
AccDataType
c1
=
ck
::
type_conver
t
<
AccDataType
>
(
c1_m_n
(
m
,
n
));
c_element_op
(
acc
,
acc
);
c1_element_op
(
c1
,
c1
);
acc
+=
c1
;
c_m_n
(
m
,
n
)
=
static_cas
t
<
CDataType
>
(
acc
);
c_m_n
(
m
,
n
)
=
ck
::
type_conver
t
<
CDataType
>
(
acc
);
}
// reduce_mean and reduce_square_mean
...
...
@@ -208,12 +208,12 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
{
AccDataType
out_acc
=
0
;
layerNormInst
(
out_acc
,
static_cas
t
<
AccDataType
>
(
c_m_n
(
m
,
n
)),
static_cas
t
<
AccDataType
>
(
mean_m
(
m
)),
static_cas
t
<
AccDataType
>
(
meanSquare_m
(
m
)),
static_cas
t
<
AccDataType
>
(
gamma_n
(
n
)),
static_cas
t
<
AccDataType
>
(
beta_n
(
n
)));
out_m_n
(
m
,
n
)
=
static_cas
t
<
ReduceDataType
>
(
out_acc
);
ck
::
type_conver
t
<
AccDataType
>
(
c_m_n
(
m
,
n
)),
ck
::
type_conver
t
<
AccDataType
>
(
mean_m
(
m
)),
ck
::
type_conver
t
<
AccDataType
>
(
meanSquare_m
(
m
)),
ck
::
type_conver
t
<
AccDataType
>
(
gamma_n
(
n
)),
ck
::
type_conver
t
<
AccDataType
>
(
beta_n
(
n
)));
out_m_n
(
m
,
n
)
=
ck
::
type_conver
t
<
ReduceDataType
>
(
out_acc
);
}
}
}
...
...
example/24_contraction/contraction_bilinear_xdl_fp32.cpp
View file @
3421b74c
...
...
@@ -135,9 +135,9 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
AccDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
static_cas
t
<
const
AccDataType
>
(
arg
.
a_ms_ks_
(
m0
,
m1
,
k0
,
k1
)));
v_a
,
ck
::
type_conver
t
<
const
AccDataType
>
(
arg
.
a_ms_ks_
(
m0
,
m1
,
k0
,
k1
)));
arg
.
b_element_op_
(
v_b
,
static_cas
t
<
const
AccDataType
>
(
arg
.
b_ns_ks_
(
n0
,
n1
,
k0
,
k1
)));
v_b
,
ck
::
type_conver
t
<
const
AccDataType
>
(
arg
.
b_ns_ks_
(
n0
,
n1
,
k0
,
k1
)));
v_acc
+=
v_a
*
v_b
;
}
...
...
example/24_contraction/contraction_scale_xdl_fp32.cpp
View file @
3421b74c
...
...
@@ -134,9 +134,9 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
AccDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
static_cas
t
<
const
AccDataType
>
(
arg
.
a_ms_ks_
(
m0
,
m1
,
k0
,
k1
)));
v_a
,
ck
::
type_conver
t
<
const
AccDataType
>
(
arg
.
a_ms_ks_
(
m0
,
m1
,
k0
,
k1
)));
arg
.
b_element_op_
(
v_b
,
static_cas
t
<
const
AccDataType
>
(
arg
.
b_ns_ks_
(
n0
,
n1
,
k0
,
k1
)));
v_b
,
ck
::
type_conver
t
<
const
AccDataType
>
(
arg
.
b_ns_ks_
(
n0
,
n1
,
k0
,
k1
)));
v_acc
+=
v_a
*
v_b
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment