Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8354aad7
Commit
8354aad7
authored
May 10, 2023
by
Bartlomiej Kocot
Browse files
Make ref_contraction generic and extend interface tests
parent
1864dfe1
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
216 additions
and
167 deletions
+216
-167
example/26_contraction/contraction_bilinear_xdl_fp32.cpp
example/26_contraction/contraction_bilinear_xdl_fp32.cpp
+21
-11
example/26_contraction/contraction_bilinear_xdl_fp64.cpp
example/26_contraction/contraction_bilinear_xdl_fp64.cpp
+21
-11
example/26_contraction/contraction_scale_xdl_fp32.cpp
example/26_contraction/contraction_scale_xdl_fp32.cpp
+20
-10
example/26_contraction/contraction_scale_xdl_fp64.cpp
example/26_contraction/contraction_scale_xdl_fp64.cpp
+20
-10
library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp
.../reference_tensor_operation/cpu/reference_contraction.hpp
+14
-61
profiler/include/profiler/profile_contraction_impl.hpp
profiler/include/profiler/profile_contraction_impl.hpp
+64
-44
test/contraction/test_contraction_interface.cpp
test/contraction/test_contraction_interface.cpp
+56
-20
No files found.
example/26_contraction/contraction_bilinear_xdl_fp32.cpp
View file @
8354aad7
...
...
@@ -249,6 +249,8 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
using
ReferenceOpInstance
=
ck
::
tensor_operation
::
host
::
ReferenceContraction_M2_N2_K2
<
NumDimM
,
NumDimN
,
...
...
@@ -258,24 +260,32 @@ int main(int argc, char* argv[])
CShuffleDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
true
,
DDataType
>
;
BElementOp
>
;
auto
ref_gemm
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_ms_ks
,
b_ns_ks
,
d_ms_ns
,
e_ms_ns_host_result
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_ms_ks
,
b_ns_ks
,
c_ms_ns_host_result
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
),
d_ms_ns
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
)
?
0
:
1
;
}
...
...
example/26_contraction/contraction_bilinear_xdl_fp64.cpp
View file @
8354aad7
...
...
@@ -249,6 +249,8 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
using
ReferenceOpInstance
=
ck
::
tensor_operation
::
host
::
ReferenceContraction_M2_N2_K2
<
NumDimM
,
NumDimN
,
...
...
@@ -258,24 +260,32 @@ int main(int argc, char* argv[])
CShuffleDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
true
,
DDataType
>
;
BElementOp
>
;
auto
ref_gemm
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_ms_ks
,
b_ns_ks
,
d_ms_ns
,
e_ms_ns_host_result
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_ms_ks
,
b_ns_ks
,
c_ms_ns_host_result
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
),
d_ms_ns
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
)
?
0
:
1
;
}
...
...
example/26_contraction/contraction_scale_xdl_fp32.cpp
View file @
8354aad7
...
...
@@ -232,6 +232,8 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
using
ReferenceOpInstance
=
ck
::
tensor_operation
::
host
::
ReferenceContraction_M2_N2_K2
<
NumDimM
,
NumDimN
,
...
...
@@ -241,24 +243,32 @@ int main(int argc, char* argv[])
CShuffleDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
false
>
;
BElementOp
>
;
auto
ref_gemm
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
Tensor
<
float
>
empty_tensor
(
std
::
vector
<
ck
::
index_t
>
{},
std
::
vector
<
ck
::
index_t
>
{});
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_ms_ks
,
b_ns_ks
,
empty_tensor
,
e_ms_ns_host_result
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_ms_ks
,
b_ns_ks
,
c_ms_ns_host_result
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
)
?
0
:
1
;
}
...
...
example/26_contraction/contraction_scale_xdl_fp64.cpp
View file @
8354aad7
...
...
@@ -232,6 +232,8 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
using
ReferenceOpInstance
=
ck
::
tensor_operation
::
host
::
ReferenceContraction_M2_N2_K2
<
NumDimM
,
NumDimN
,
...
...
@@ -241,24 +243,32 @@ int main(int argc, char* argv[])
CShuffleDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
false
>
;
BElementOp
>
;
auto
ref_gemm
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
Tensor
<
float
>
empty_tensor
(
std
::
vector
<
ck
::
index_t
>
{},
std
::
vector
<
ck
::
index_t
>
{});
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_ms_ks
,
b_ns_ks
,
empty_tensor
,
e_ms_ns_host_result
,
a_element_op
,
b_element_op
,
cde_element_op
);
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_ms_ks
,
b_ns_ks
,
c_ms_ns_host_result
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_ms_ns_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
cde_element_op
(
e_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
),
c_ms_ns_host_result
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
)
?
0
:
1
;
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp
View file @
8354aad7
...
...
@@ -23,13 +23,10 @@ template <ck::index_t NumDimM,
ck
::
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
E
DataType
,
typename
C
DataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
bool
UseDToBinaryOp
,
typename
DDataType
=
float
,
ck
::
enable_if_t
<
NumDimM
==
2
&&
NumDimN
==
2
&&
NumDimK
==
2
,
bool
>
=
false
>
struct
ReferenceContraction_M2_N2_K2
:
public
ck
::
tensor_operation
::
device
::
BaseOperator
{
...
...
@@ -38,29 +35,23 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{
Argument
(
const
Tensor
<
ADataType
>&
a_ms_ks
,
const
Tensor
<
BDataType
>&
b_ns_ks
,
const
Tensor
<
DDataType
>&
d_ms_ns
,
Tensor
<
EDataType
>&
e_ms_ns
,
Tensor
<
CDataType
>&
c_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
BElementwiseOperation
b_element_op
)
:
a_ms_ks_
{
a_ms_ks
},
b_ns_ks_
{
b_ns_ks
},
d_ms_ns_
{
d_ms_ns
},
e_ms_ns_
{
e_ms_ns
},
c_ms_ns_
{
c_ms_ns
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
b_element_op_
{
b_element_op
}
{
}
const
Tensor
<
ADataType
>&
a_ms_ks_
;
const
Tensor
<
BDataType
>&
b_ns_ks_
;
const
Tensor
<
DDataType
>&
d_ms_ns_
;
Tensor
<
EDataType
>&
e_ms_ns_
;
Tensor
<
CDataType
>&
c_ms_ns_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
};
// Invoker
...
...
@@ -68,19 +59,6 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{
using
Argument
=
ReferenceContraction_M2_N2_K2
::
Argument
;
void
apply_unary_op
(
const
CDEElementwiseOperation
&
op
,
EDataType
&
v_e
,
AccDataType
&
v_acc
)
{
op
(
v_e
,
v_acc
);
}
void
apply_binary_op
(
const
CDEElementwiseOperation
&
op
,
EDataType
&
v_e
,
AccDataType
&
v_acc
,
DDataType
&
v_d
)
{
op
(
v_e
,
v_acc
,
v_d
);
}
float
Run
(
const
Argument
&
arg
)
{
auto
f_ms_ns
=
[
&
](
auto
m0
,
auto
m1
,
auto
n0
,
auto
n1
)
{
...
...
@@ -105,26 +83,14 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
}
}
AccDataType
v_e
;
DDataType
v_d
=
arg
.
d_ms_ns_
.
GetNumOfDimension
()
==
0
?
0
:
arg
.
d_ms_ns_
(
m0
,
m1
,
n0
,
n1
);
if
constexpr
(
UseDToBinaryOp
)
{
apply_binary_op
(
arg
.
cde_element_op_
,
v_e
,
v_acc
,
v_d
);
}
else
{
apply_unary_op
(
arg
.
cde_element_op_
,
v_e
,
v_acc
);
}
arg
.
e_ms_ns_
(
m0
,
m1
,
n0
,
n1
)
=
v_e
;
arg
.
c_ms_ns_
(
m0
,
m1
,
n0
,
n1
)
=
v_acc
;
};
make_ParallelTensorFunctor
(
f_ms_ns
,
arg
.
e
_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
e
_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
e
_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
e
_ms_ns_
.
mDesc
.
GetLengths
()[
3
])(
arg
.
c
_ms_ns_
.
mDesc
.
GetLengths
()[
0
],
arg
.
c
_ms_ns_
.
mDesc
.
GetLengths
()[
1
],
arg
.
c
_ms_ns_
.
mDesc
.
GetLengths
()[
2
],
arg
.
c
_ms_ns_
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
@@ -150,24 +116,11 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_ms_ks
,
const
Tensor
<
BDataType
>&
b_ns_ks
,
const
Tensor
<
DDataType
>&
d_ms_ns
,
Tensor
<
EDataType
>&
e_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
a_ms_ks
,
b_ns_ks
,
d_ms_ns
,
e_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeArgument
(
const
Tensor
<
ADataType
>&
a_ms_ks
,
const
Tensor
<
BDataType
>&
b_ns_ks
,
Tensor
<
EDataType
>&
e_ms_ns
,
Tensor
<
CDataType
>&
c_ms_ns
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
BElementwiseOperation
b_element_op
)
{
return
Argument
{
a_ms_ks
,
b_ns_ks
,
e
_ms_ns
,
a_element_op
,
b_element_op
,
cde_element_op
};
return
Argument
{
a_ms_ks
,
b_ns_ks
,
c
_ms_ns
,
a_element_op
,
b_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
profiler/include/profiler/profile_contraction_impl.hpp
View file @
8354aad7
...
...
@@ -41,7 +41,7 @@ int profile_contraction_impl(ck::index_t do_verification,
ck
::
index_t
init_method
,
bool
do_log
,
bool
time_kernel
,
CDElementOp
cd_element_op
,
CDElementOp
cd
e
_element_op
,
const
std
::
vector
<
ck
::
index_t
>&
M
,
const
std
::
vector
<
ck
::
index_t
>&
N
,
const
std
::
vector
<
ck
::
index_t
>&
K
,
...
...
@@ -64,14 +64,14 @@ int profile_contraction_impl(ck::index_t do_verification,
Tensor
<
DataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StridesA
));
Tensor
<
DataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StridesB
));
Tensor
<
DataType
>
c
_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StridesC
));
Tensor
<
DataType
>
c
_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StridesC
));
Tensor
<
DataType
>
e
_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StridesC
));
Tensor
<
DataType
>
e
_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StridesC
));
Tensor
<
DataType
>
d_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StridesD
));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_m_n: "
<<
d_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
c
_m_n: "
<<
c
_m_n_device_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
e
_m_n: "
<<
e
_m_n_device_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -92,18 +92,18 @@ int profile_contraction_impl(ck::index_t do_verification,
DeviceMem
a_device_buf
(
sizeof
(
DataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
DataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c
_device_buf
(
sizeof
(
DataType
)
*
c
_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e
_device_buf
(
sizeof
(
DataType
)
*
e
_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DataType
)
*
d_m_n
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c
_device_buf
.
SetZero
();
e
_device_buf
.
SetZero
();
d_device_buf
.
ToDevice
(
d_m_n
.
mData
.
data
());
const
std
::
vector
<
index_t
>
a_m
s
_k
s
_lengths
=
{
M
[
0
],
M
[
1
],
K
[
0
],
K
[
1
]};
const
std
::
vector
<
index_t
>
b_n
s
_k
s
_lengths
=
{
N
[
0
],
N
[
1
],
K
[
0
],
K
[
1
]};
const
std
::
vector
<
index_t
>
c_m
s
_n
s
_lengths
=
{
M
[
0
],
M
[
1
],
N
[
0
],
N
[
1
]};
const
std
::
vector
<
index_t
>
d_m
s
_n
s
_lengths
=
{
M
[
0
],
M
[
1
],
N
[
0
],
N
[
1
]};
const
std
::
vector
<
index_t
>
a_m_k_lengths
=
{
M
[
0
],
M
[
1
],
K
[
0
],
K
[
1
]};
const
std
::
vector
<
index_t
>
b_n_k_lengths
=
{
N
[
0
],
N
[
1
],
K
[
0
],
K
[
1
]};
const
std
::
vector
<
index_t
>
c_m_n_lengths
=
{
M
[
0
],
M
[
1
],
N
[
0
],
N
[
1
]};
const
std
::
vector
<
index_t
>
d_m_n_lengths
=
{
M
[
0
],
M
[
1
],
N
[
0
],
N
[
1
]};
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
...
...
@@ -129,8 +129,8 @@ int profile_contraction_impl(ck::index_t do_verification,
// Run reference op
if
(
do_verification
)
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceContraction_M2_N2_K2
<
NumDim
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceContraction_M2_N2_K2
<
NumDim
,
NumDim
,
NumDim
,
DataType
,
...
...
@@ -138,21 +138,41 @@ int profile_contraction_impl(ck::index_t do_verification,
DataType
,
DataType
,
AElementOp
,
BElementOp
,
CDElementOp
,
std
::
is_same
<
CDElementOp
,
Bilinear
>::
value
,
DataType
>
;
BElementOp
>
;
auto
ref_op
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_op
.
MakeInvoker
();
if
constexpr
(
std
::
is_same
<
CDElementOp
,
Scale
>::
value
)
d_m_n
=
Tensor
<
DataType
>
(
std
::
vector
<
ck
::
index_t
>
{},
std
::
vector
<
ck
::
index_t
>
{});
Tensor
<
DataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StridesC
));
auto
ref_argument
=
ref_op
.
MakeArgument
(
a_m_k
,
b_k_n
,
d_m_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
,
cd_element_op
);
auto
ref_argument
=
ref_op
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument
);
for
(
size_t
m0
=
0
;
m0
<
e_m_n_host_result
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_m_n_host_result
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
{
for
(
size_t
n0
=
0
;
n0
<
e_m_n_host_result
.
mDesc
.
GetLengths
()[
2
];
++
n0
)
{
for
(
size_t
n1
=
0
;
n1
<
e_m_n_host_result
.
mDesc
.
GetLengths
()[
3
];
++
n1
)
{
if
constexpr
(
is_same
<
CDElementOp
,
Bilinear
>::
value
)
{
cde_element_op
(
e_m_n_host_result
(
m0
,
m1
,
n0
,
n1
),
c_m_n_host_result
(
m0
,
m1
,
n0
,
n1
),
d_m_n
(
m0
,
m1
,
n0
,
n1
));
}
else
if
constexpr
(
is_same
<
CDElementOp
,
Scale
>::
value
)
{
cde_element_op
(
e_m_n_host_result
(
m0
,
m1
,
n0
,
n1
),
c_m_n_host_result
(
m0
,
m1
,
n0
,
n1
));
}
}
}
}
}
}
std
::
string
best_op_name
;
...
...
@@ -170,18 +190,18 @@ int profile_contraction_impl(ck::index_t do_verification,
static_cast
<
DataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()},
static_cast
<
DataType
*>
(
c
_device_buf
.
GetDeviceBuffer
()),
a_m
s
_k
s
_lengths
,
static_cast
<
DataType
*>
(
e
_device_buf
.
GetDeviceBuffer
()),
a_m_k_lengths
,
StridesA
,
b_n
s
_k
s
_lengths
,
b_n_k_lengths
,
StridesB
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_m
s
_n
s
_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_m_n_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
StridesD
},
c_m
s
_n
s
_lengths
,
c_m_n_lengths
,
StridesC
,
a_element_op
,
b_element_op
,
cd_element_op
);
cd
e
_element_op
);
}
else
{
...
...
@@ -189,18 +209,18 @@ int profile_contraction_impl(ck::index_t do_verification,
op_ptr
->
MakeArgumentPointer
(
static_cast
<
DataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
std
::
array
<
const
void
*
,
0
>
{},
static_cast
<
DataType
*>
(
c
_device_buf
.
GetDeviceBuffer
()),
a_m
s
_k
s
_lengths
,
static_cast
<
DataType
*>
(
e
_device_buf
.
GetDeviceBuffer
()),
a_m_k_lengths
,
StridesA
,
b_n
s
_k
s
_lengths
,
b_n_k_lengths
,
StridesB
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
0
>
{},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
0
>
{},
c_m
s
_n
s
_lengths
,
c_m_n_lengths
,
StridesC
,
a_element_op
,
b_element_op
,
cd_element_op
);
cd
e
_element_op
);
}
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
...
@@ -212,7 +232,7 @@ int profile_contraction_impl(ck::index_t do_verification,
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
// re-init C to zero before profiling next kernel
c
_device_buf
.
SetZero
();
e
_device_buf
.
SetZero
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
...
...
@@ -242,12 +262,12 @@ int profile_contraction_impl(ck::index_t do_verification,
if
(
do_verification
)
{
c
_device_buf
.
FromDevice
(
c
_m_n_device_result
.
mData
.
data
());
e
_device_buf
.
FromDevice
(
e
_m_n_device_result
.
mData
.
data
());
float
threshold
=
static_cast
<
DataType
>
(
nelems_k
)
*
std
::
numeric_limits
<
DataType
>::
epsilon
();
pass
=
pass
&
ck
::
utils
::
check_err
(
c
_m_n_device_result
,
c
_m_n_host_result
,
pass
=
pass
&
ck
::
utils
::
check_err
(
e
_m_n_device_result
,
e
_m_n_host_result
,
"Error: incorrect results!"
,
threshold
,
threshold
);
...
...
@@ -256,9 +276,9 @@ int profile_contraction_impl(ck::index_t do_verification,
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
c
_m_n_host_result
.
mData
,
","
)
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
e
_m_n_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c
_m_n_device_result
.
mData
,
","
)
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
e
_m_n_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
...
...
test/contraction/test_contraction_interface.cpp
View file @
8354aad7
...
...
@@ -23,7 +23,7 @@ template <typename DataTypeA,
typename
DataTypeB
,
typename
DataTypeC
,
typename
DataTypeD
,
in
t
NumDim
>
ck
::
index_
t
NumDim
>
class
ContractionDeviceWrapper
{
...
...
@@ -40,10 +40,27 @@ class ContractionDeviceWrapper
Bilinear
>
;
public:
ContractionDeviceWrapper
(
std
::
vector
<
ck
::
index_t
>&
Dims
,
std
::
vector
<
ck
::
index_t
>&
Strides
)
:
InputDims_
(
Dims
),
OutputDims_
(
Dims
),
InputStrides_
(
Strides
),
OutputStrides_
(
Strides
)
{
}
ContractionDeviceWrapper
(
std
::
vector
<
ck
::
index_t
>&
InDims
,
std
::
vector
<
ck
::
index_t
>&
OutDims
,
std
::
vector
<
ck
::
index_t
>&
InStrides
,
std
::
vector
<
ck
::
index_t
>&
OutStrides
)
:
InputDims_
(
InDims
),
OutputDims_
(
OutDims
),
InputStrides_
(
InStrides
),
OutputStrides_
(
OutStrides
)
{
}
std
::
vector
<
ck
::
index_t
>&
InputDims_
;
std
::
vector
<
ck
::
index_t
>&
OutputDims_
;
std
::
vector
<
ck
::
index_t
>&
InputStrides_
;
std
::
vector
<
ck
::
index_t
>&
OutputStrides_
;
bool
IsSupported
()
const
{
std
::
vector
<
ck
::
index_t
>
dummy_dims
(
NumDim
*
2
,
4
);
std
::
vector
<
ck
::
index_t
>
dummy_strides
(
NumDim
*
2
,
1
);
bool
supported
=
false
;
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
...
@@ -56,14 +73,14 @@ class ContractionDeviceWrapper
nullptr
,
std
::
array
<
const
void
*
,
1
>
{
nullptr
},
nullptr
,
dummy_dims
,
dummy_s
trides
,
dummy_dims
,
dummy_s
trides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
dummy_dims
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
dummy_s
trides
},
dummy_d
ims
,
dummy_s
trides
,
InputStrides_
,
InputS
trides
_
,
InputStrides_
,
InputS
trides
_
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
InputStrides_
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
InputS
trides
_
},
OutputD
ims
_
,
OutputS
trides
_
,
Pass
{},
Pass
{},
Bilinear
{
1.
f
,
1.
f
});
...
...
@@ -76,9 +93,11 @@ class ContractionDeviceWrapper
TEST
(
TestContractionInterface
,
IncorrectNumDims
)
{
ContractionDeviceWrapper
<
F32
,
F32
,
F32
,
F32
,
1
>
wrapper_1d
;
ContractionDeviceWrapper
<
F32
,
F32
,
F32
,
F32
,
2
>
wrapper_2d
;
ContractionDeviceWrapper
<
F32
,
F32
,
F32
,
F32
,
3
>
wrapper_3d
;
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Dims
=
{{
4
,
4
},
{
4
,
4
,
4
,
4
},
{
4
,
4
,
4
,
4
,
4
,
4
}};
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
Strides
=
{{
1
,
1
},
{
1
,
1
,
1
,
1
},
{
1
,
1
,
1
,
1
,
1
,
1
}};
ContractionDeviceWrapper
<
F32
,
F32
,
F32
,
F32
,
1
>
wrapper_1d
(
Dims
[
0
],
Strides
[
0
]);
ContractionDeviceWrapper
<
F32
,
F32
,
F32
,
F32
,
2
>
wrapper_2d
(
Dims
[
1
],
Strides
[
1
]);
ContractionDeviceWrapper
<
F32
,
F32
,
F32
,
F32
,
3
>
wrapper_3d
(
Dims
[
2
],
Strides
[
2
]);
EXPECT_FALSE
(
wrapper_1d
.
IsSupported
());
EXPECT_TRUE
(
wrapper_2d
.
IsSupported
());
EXPECT_FALSE
(
wrapper_3d
.
IsSupported
());
...
...
@@ -86,13 +105,30 @@ TEST(TestContractionInterface, IncorrectNumDims)
TEST
(
TestContractionInterface
,
IncorrectDataTypes
)
{
ContractionDeviceWrapper
<
F32
,
F32
,
F64
,
F64
,
2
>
wrapper_1
;
ContractionDeviceWrapper
<
F64
,
F64
,
F32
,
F32
,
2
>
wrapper_2
;
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Strides
=
{
64
,
16
,
4
,
1
};
ContractionDeviceWrapper
<
F32
,
F32
,
F64
,
F64
,
2
>
wrapper_1
(
Dims
,
Strides
);
ContractionDeviceWrapper
<
F64
,
F64
,
F32
,
F32
,
2
>
wrapper_2
(
Dims
,
Strides
);
EXPECT_FALSE
(
wrapper_1
.
IsSupported
());
EXPECT_FALSE
(
wrapper_2
.
IsSupported
());
}
// TEST(TestContractionInterface, CornerCases)
// {
// EXPECT_FALSE()
// }
TEST
(
TestContractionInterface
,
GridwiseGemm
)
{
std
::
vector
<
ck
::
index_t
>
InDims
=
{
1
,
2
,
3
,
4
};
std
::
vector
<
ck
::
index_t
>
InStrides
=
{
24
,
12
,
4
,
1
};
std
::
vector
<
ck
::
index_t
>
OutDims
=
{
4
,
3
,
2
,
1
};
std
::
vector
<
ck
::
index_t
>
OutStrides
=
{
6
,
2
,
1
,
1
};
ContractionDeviceWrapper
<
F32
,
F32
,
F32
,
F32
,
2
>
wrapper
(
InDims
,
OutDims
,
InStrides
,
OutStrides
);
EXPECT_FALSE
(
wrapper
.
IsSupported
());
}
TEST
(
TestContractionInterface
,
MemoryAccess
)
{
std
::
vector
<
ck
::
index_t
>
Dims
=
{
4
,
4
,
4
,
4
};
std
::
vector
<
ck
::
index_t
>
Strides
=
{
4
,
16
,
64
,
256
};
ContractionDeviceWrapper
<
F32
,
F32
,
F32
,
F32
,
2
>
wrapper
(
Dims
,
Strides
);
EXPECT_FALSE
(
wrapper
.
IsSupported
());
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment