Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b57c3879
Commit
b57c3879
authored
Jul 27, 2022
by
Anthony Chang
Browse files
harmonize interface between ref_gemm and ref_batched_gemm
parent
237371ad
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
22 additions
and
9 deletions
+22
-9
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
...e/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
+8
-2
example/24_batched_gemm_e_permute/batched_gemm_e_permute_xdl_fp16.cpp
...atched_gemm_e_permute/batched_gemm_e_permute_xdl_fp16.cpp
+7
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
...reference_tensor_operation/cpu/reference_batched_gemm.hpp
+5
-5
profiler/include/profile_batched_gemm_impl.hpp
profiler/include/profile_batched_gemm_impl.hpp
+1
-0
profiler/include/profile_batched_gemm_reduce_impl.hpp
profiler/include/profile_batched_gemm_reduce_impl.hpp
+1
-0
No files found.
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
View file @
b57c3879
...
...
@@ -66,8 +66,14 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
ReducePtrsGlobal
,
AElementOp
,
BElementOp
,
CElementOp
,
ReduceOps
,
ReduceInElementOps
,
ReduceOutElementOps
,
ReduceGlobalMemOps
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
// clang-format on
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
BDataType
,
CDataType
,
ReduceAccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
example/24_batched_gemm_e_permute/batched_gemm_e_permute_xdl_fp16.cpp
View file @
b57c3879
...
...
@@ -51,8 +51,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermu
<
ALayout
,
BLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
BDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
View file @
b57c3879
...
...
@@ -16,6 +16,7 @@ namespace host {
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
...
...
@@ -58,7 +59,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator
auto
f_gmk_gkn_gmn
=
[
&
](
auto
g
,
auto
m
,
auto
n
)
{
const
int
K
=
arg
.
a_g_m_k_
.
mDesc
.
GetLengths
()[
2
];
float
v_acc
=
0
;
AccDataType
v_acc
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
...
...
@@ -68,10 +69,10 @@ struct ReferenceBatchedGemm : public device::BaseOperator
arg
.
a_element_op_
(
v_a
,
arg
.
a_g_m_k_
(
g
,
m
,
k
));
arg
.
b_element_op_
(
v_b
,
arg
.
b_g_k_n_
(
g
,
k
,
n
));
v_acc
+=
ck
::
type_convert
<
float
>
(
v_a
)
*
ck
::
type_convert
<
float
>
(
v_b
);
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
float
v_c
;
AccDataType
v_c
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
...
...
@@ -81,8 +82,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator
make_ParallelTensorFunctor
(
f_gmk_gkn_gmn
,
arg
.
c_g_m_n_
.
mDesc
.
GetLengths
()[
0
],
arg
.
c_g_m_n_
.
mDesc
.
GetLengths
()[
1
],
arg
.
c_g_m_n_
.
mDesc
.
GetLengths
()[
2
])(
std
::
thread
::
hardware_concurrency
());
arg
.
c_g_m_n_
.
mDesc
.
GetLengths
()[
2
])();
return
0
;
}
...
...
profiler/include/profile_batched_gemm_impl.hpp
View file @
b57c3879
...
...
@@ -101,6 +101,7 @@ bool profile_batched_gemm_impl(int do_verification,
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
BDataType
,
CDataType
,
float
,
AElementOp
,
BElementOp
,
CElementOp
>
;
...
...
profiler/include/profile_batched_gemm_reduce_impl.hpp
View file @
b57c3879
...
...
@@ -155,6 +155,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
BDataType
,
CDataType
,
float
,
AElementOp
,
BElementOp
,
CElementOp
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment