Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
304adaad
"docs/vscode:/vscode.git/clone" did not exist on "5d7ea6616fc127469f43605464803d8521fcc51d"
Commit
304adaad
authored
May 09, 2023
by
Bartlomiej Kocot
Browse files
Allow to use any elementwise operator for ref_contraction
parent
93ce856f
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
50 additions
and
39 deletions
+50
-39
example/26_contraction/contraction_bilinear_xdl_fp32.cpp
example/26_contraction/contraction_bilinear_xdl_fp32.cpp
+1
-0
example/26_contraction/contraction_bilinear_xdl_fp64.cpp
example/26_contraction/contraction_bilinear_xdl_fp64.cpp
+1
-0
example/26_contraction/contraction_scale_xdl_fp32.cpp
example/26_contraction/contraction_scale_xdl_fp32.cpp
+2
-1
example/26_contraction/contraction_scale_xdl_fp64.cpp
example/26_contraction/contraction_scale_xdl_fp64.cpp
+2
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp
.../reference_tensor_operation/cpu/reference_contraction.hpp
+30
-24
profiler/include/profiler/profile_contraction_impl.hpp
profiler/include/profiler/profile_contraction_impl.hpp
+13
-12
test/contraction/test_contraction.cpp
test/contraction/test_contraction.cpp
+1
-1
No files found.
example/26_contraction/contraction_bilinear_xdl_fp32.cpp
View file @
304adaad
...
...
@@ -260,6 +260,7 @@ int main(int argc, char* argv[])
AElementOp
,
BElementOp
,
CDEElementOp
,
true
,
DDataType
>
;
auto
ref_gemm
=
ReferenceOpInstance
{};
...
...
example/26_contraction/contraction_bilinear_xdl_fp64.cpp
View file @
304adaad
...
...
@@ -260,6 +260,7 @@ int main(int argc, char* argv[])
AElementOp
,
BElementOp
,
CDEElementOp
,
true
,
DDataType
>
;
auto
ref_gemm
=
ReferenceOpInstance
{};
...
...
example/26_contraction/contraction_scale_xdl_fp32.cpp
View file @
304adaad
...
...
@@ -242,7 +242,8 @@ int main(int argc, char* argv[])
AccDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
CDEElementOp
,
false
>
;
auto
ref_gemm
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
example/26_contraction/contraction_scale_xdl_fp64.cpp
View file @
304adaad
...
...
@@ -242,7 +242,8 @@ int main(int argc, char* argv[])
AccDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
>
;
CDEElementOp
,
false
>
;
auto
ref_gemm
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp
View file @
304adaad
...
...
@@ -23,11 +23,12 @@ template <ck::index_t NumDimM,
ck
::
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
C
DataType
,
typename
E
DataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
bool
UseDToBinaryOp
,
typename
DDataType
=
float
,
ck
::
enable_if_t
<
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
,
bool
>
=
false
>
struct
ReferenceContraction_M2_N2_K2
:
public
ck
::
tensor_operation
::
device
::
BaseOperator
...
...
@@ -38,14 +39,14 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
Argument
(
const
Tensor
<
ADataType
>&
a_ms_ks
,
const
Tensor
<
BDataType
>&
b_ns_ks
,
const
Tensor
<
DDataType
>&
d_ms_ns
,
Tensor
<
C
DataType
>&
c
_ms_ns
,
Tensor
<
E
DataType
>&
e
_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
a_ms_ks_
{
a_ms_ks
},
b_ns_ks_
{
b_ns_ks
},
d_ms_ns_
{
d_ms_ns
},
c
_ms_ns_
{
c
_ms_ns
},
e
_ms_ns_
{
e
_ms_ns
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
...
...
@@ -55,7 +56,7 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
const
Tensor
<
ADataType
>&
a_ms_ks_
;
const
Tensor
<
BDataType
>&
b_ns_ks_
;
const
Tensor
<
DDataType
>&
d_ms_ns_
;
Tensor
<
C
DataType
>&
c
_ms_ns_
;
Tensor
<
E
DataType
>&
e
_ms_ns_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
...
...
@@ -67,19 +68,17 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{
using
Argument
=
ReferenceContraction_M2_N2_K2
::
Argument
;
template
<
typename
Op
>
void
apply_op
(
Op
&
op
,
DDataType
&
v_d
,
CDataType
&
v_c
,
AccDataType
&
v_acc
)
void
apply_unary_op
(
const
CDEElementwiseOperation
&
op
,
EDataType
&
v_e
,
AccDataType
&
v_acc
)
{
op
(
v_
c
,
static_cast
<
AccDataType
>
(
v_d
+
v_acc
)
)
;
op
(
v_
e
,
v_acc
);
}
template
<
>
void
apply_op
<
const
Bilinear
>
(
const
Bilinear
&
bilinear
,
DDataType
&
v_d
,
CDataType
&
v_c
,
AccDataType
&
v_acc
)
void
apply_binary_op
(
const
CDEElementwiseOperation
&
op
,
EDataType
&
v_e
,
AccDataType
&
v_acc
,
DDataType
&
v_d
)
{
bilinear
(
v_c
,
v_d
,
v_acc
);
op
(
v_e
,
v_acc
,
v_d
);
}
float
Run
(
const
Argument
&
arg
)
...
...
@@ -106,19 +105,26 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
}
}
AccDataType
v_
c
;
AccDataType
v_
e
;
DDataType
v_d
=
arg
.
d_ms_ns_
.
GetNumOfDimension
()
==
0
?
0
:
arg
.
d_ms_ns_
(
m0
,
m1
,
n0
,
n1
);
apply_op
(
arg
.
cde_element_op_
,
v_d
,
v_c
,
v_acc
);
if
constexpr
(
UseDToBinaryOp
)
{
apply_binary_op
(
arg
.
cde_element_op_
,
v_e
,
v_acc
,
v_d
);
}
else
{
apply_unary_op
(
arg
.
cde_element_op_
,
v_e
,
v_acc
);
}
arg
.
c
_ms_ns_
(
m0
,
m1
,
n0
,
n1
)
=
v_
c
;
arg
.
e
_ms_ns_
(
m0
,
m1
,
n0
,
n1
)
=
v_
e
;
};
make_ParallelTensorFunctor
(
f_ms_ns
,
arg
.
c
_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
c
_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
c
_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
c
_ms_ns_
.
mDesc
.
GetLengths
()[
3
])(
arg
.
e
_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e
_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e
_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e
_ms_ns_
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
@@ -145,23 +151,23 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_ms_ks
,
const
Tensor
<
BDataType
>&
b_ns_ks
,
const
Tensor
<
DDataType
>&
d_ms_ns
,
Tensor
<
C
DataType
>&
c
_ms_ns
,
Tensor
<
E
DataType
>&
e
_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
a_ms_ks
,
b_ns_ks
,
d_ms_ns
,
c
_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
a_ms_ks
,
b_ns_ks
,
d_ms_ns
,
e
_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_ms_ks
,
const
Tensor
<
BDataType
>&
b_ns_ks
,
Tensor
<
C
DataType
>&
c
_ms_ns
,
Tensor
<
E
DataType
>&
e
_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
a_ms_ks
,
b_ns_ks
,
c
_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
return
Argument
{
a_ms_ks
,
b_ns_ks
,
e
_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
profiler/include/profiler/profile_contraction_impl.hpp
View file @
304adaad
...
...
@@ -129,8 +129,8 @@ int profile_contraction_impl(ck::index_t do_verification,
// Run reference op
if
(
do_verification
)
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceContraction_M2_N2_K2
<
NumDim
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceContraction_M2_N2_K2
<
NumDim
,
NumDim
,
NumDim
,
DataType
,
...
...
@@ -140,6 +140,7 @@ int profile_contraction_impl(ck::index_t do_verification,
AElementOp
,
BElementOp
,
CDElementOp
,
std
::
is_same
<
CDElementOp
,
Bilinear
>::
value
,
DataType
>
;
auto
ref_op
=
ReferenceGemmInstance
{};
...
...
test/contraction/test_contraction.cpp
View file @
304adaad
...
...
@@ -125,7 +125,7 @@ TYPED_TEST(TestContractionBilinear, bilinear)
{
this
->
p_cd_element_op
=
std
::
make_unique
<
Bilinear
>
(
1.
f
,
1.
f
);
this
->
Run
();
this
->
p_cd_element_op
=
std
::
make_unique
<
Bilinear
>
(
0.5
f
,
0.5
f
);
this
->
p_cd_element_op
=
std
::
make_unique
<
Bilinear
>
(
-
0.5
f
,
0.5
f
);
this
->
Run
();
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment