Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
adc64a23
Commit
adc64a23
authored
Sep 07, 2023
by
Jing Zhang
Browse files
fixed reference gemm
parent
9212f569
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
4 deletions
+5
-4
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+5
-4
No files found.
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
adc64a23
...
@@ -21,7 +21,7 @@ template <typename ADataType,
...
@@ -21,7 +21,7 @@ template <typename ADataType,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Comput
e
Type
=
ADataType
>
typename
ComputType
=
ADataType
>
struct
ReferenceGemm
:
public
device
::
BaseOperator
struct
ReferenceGemm
:
public
device
::
BaseOperator
{
{
// Argument
// Argument
...
@@ -65,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -65,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
Comput
e
Type
v_a
;
ComputType
v_a
;
Comput
e
Type
v_b
;
ComputType
v_b
;
// use PassThrough instead of ConvertBF16RTN for reference calculation
// use PassThrough instead of ConvertBF16RTN for reference calculation
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
...
@@ -89,7 +89,8 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -89,7 +89,8 @@ struct ReferenceGemm : public device::BaseOperator
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
}
v_acc
+=
type_convert
<
AccDataType
>
(
v_a
*
v_b
);
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
CDataType
v_c
;
CDataType
v_c
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment