Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
77190058
"docs/vscode:/vscode.git/clone" did not exist on "666743302ff5bd1e02c204b81a80e566648d60de"
Commit
77190058
authored
Jun 07, 2023
by
Alan Turner
Browse files
Formatting
parent
421734ae
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
366 additions
and
330 deletions
+366
-330
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+14
-10
library/src/jit_library/include/ck/host/common.hpp
library/src/jit_library/include/ck/host/common.hpp
+3
-2
library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp
...rc/jit_library/include/ck/host/device_gemm_multiple_d.hpp
+25
-26
library/src/jit_library/src/common.cpp
library/src/jit_library/src/common.cpp
+7
-6
library/src/jit_library/src/device_gemm_multiple_d.cpp
library/src/jit_library/src/device_gemm_multiple_d.cpp
+39
-37
test/jit_library/jit_library.cpp
test/jit_library/jit_library.cpp
+278
-249
No files found.
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
77190058
...
@@ -120,23 +120,25 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -120,23 +120,25 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
default
;
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
default
;
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
{
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
__host__
index_t
M01
=
8
)
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
BlockToCTileMap_M00_N0_M01Adapt
(
:
BlockToCTileMap_M00_N0_M01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
{
{
...
@@ -151,13 +153,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -151,13 +153,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
__device__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
__host__
__device__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
{
return
true
;
return
true
;
}
}
...
@@ -231,7 +235,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -231,7 +235,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
constexpr
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
__host__
__device__
constexpr
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
const
CTileDim
&
/* c_tile_dim */
)
const
{
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
}
...
...
library/src/jit_library/include/ck/host/common.hpp
View file @
77190058
...
@@ -17,7 +17,8 @@ struct Solution
...
@@ -17,7 +17,8 @@ struct Solution
std
::
size_t
grid_size
;
std
::
size_t
grid_size
;
};
};
enum
class
DataType
{
enum
class
DataType
{
Half
,
Half
,
Float
,
Float
,
Int8
,
Int8
,
...
@@ -26,7 +27,7 @@ enum class DataType {
...
@@ -26,7 +27,7 @@ enum class DataType {
std
::
string
ToString
(
DataType
dt
);
std
::
string
ToString
(
DataType
dt
);
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
();
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
();
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
);
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
);
...
...
library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp
View file @
77190058
...
@@ -11,45 +11,44 @@
...
@@ -11,45 +11,44 @@
#include <numeric>
#include <numeric>
#include "ck/host/common.hpp"
#include "ck/host/common.hpp"
namespace
ck
{
namespace
ck
{
namespace
host
{
namespace
host
{
namespace
device_gemm_multiple_d
{
namespace
device_gemm_multiple_d
{
struct
Problem
struct
Problem
{
{
std
::
size_t
M
=
0
;
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
K
=
0
;
bool
TransA
=
false
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransB
=
false
;
bool
TransE
=
false
;
bool
TransE
=
false
;
std
::
vector
<
bool
>
DsTrans
=
{};
std
::
vector
<
bool
>
DsTrans
=
{};
DataType
ADataType
=
DataType
::
Half
;
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
EDataType
=
DataType
::
Half
;
DataType
EDataType
=
DataType
::
Half
;
std
::
vector
<
DataType
>
DsDataType
=
{};
std
::
vector
<
DataType
>
DsDataType
=
{};
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CDEElementOp
=
"ck::Tuple<>"
;
std
::
string
CDEElementOp
=
"ck::Tuple<>"
;
static
const
std
::
size_t
ds_layout_idx
=
3
;
static
const
std
::
size_t
ds_layout_idx
=
3
;
static
const
std
::
size_t
ds_data_type_idx
=
9
;
static
const
std
::
size_t
ds_data_type_idx
=
9
;
static
const
std
::
size_t
e_data_type_idx
=
10
;
static
const
std
::
size_t
e_data_type_idx
=
10
;
static
const
std
::
size_t
a_elementwise_op_idx
=
11
;
static
const
std
::
size_t
a_elementwise_op_idx
=
11
;
static
const
std
::
size_t
b_elementwise_op_idx
=
12
;
static
const
std
::
size_t
b_elementwise_op_idx
=
12
;
static
const
std
::
size_t
ds_elementwise_op_idx
=
13
;
static
const
std
::
size_t
ds_elementwise_op_idx
=
13
;
static
const
std
::
size_t
gemm_spec_idx
=
14
;
static
const
std
::
size_t
gemm_spec_idx
=
14
;
static
const
std
::
size_t
block_size_idx
=
16
;
static
const
std
::
size_t
block_size_idx
=
16
;
static
const
std
::
size_t
m_per_block_idx
=
17
;
static
const
std
::
size_t
m_per_block_idx
=
17
;
static
const
std
::
size_t
n_per_block_idx
=
18
;
static
const
std
::
size_t
n_per_block_idx
=
18
;
static
const
std
::
size_t
k_per_block_idx
=
19
;
static
const
std
::
size_t
k_per_block_idx
=
19
;
std
::
string
GetIncludeHeader
()
const
;
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
private:
private:
std
::
vector
<
std
::
string
>
GetInstances
(
const
std
::
string
&
arch
)
const
;
std
::
vector
<
std
::
string
>
GetInstances
(
const
std
::
string
&
arch
)
const
;
Solution
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
;
Solution
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
;
...
...
library/src/jit_library/src/common.cpp
View file @
77190058
...
@@ -8,16 +8,17 @@ namespace host {
...
@@ -8,16 +8,17 @@ namespace host {
std
::
string
ToString
(
DataType
dt
)
std
::
string
ToString
(
DataType
dt
)
{
{
switch
(
dt
)
{
switch
(
dt
)
case
DataType
::
Float
:
return
"float"
;
{
case
DataType
::
Half
:
return
"ck::half_t"
;
case
DataType
::
Float
:
return
"float"
;
case
DataType
::
Int8
:
return
"int8_t"
;
case
DataType
::
Half
:
return
"ck::half_t"
;
case
DataType
::
Int32
:
return
"int32_t"
;
case
DataType
::
Int8
:
return
"int8_t"
;
case
DataType
::
Int32
:
return
"int32_t"
;
}
}
throw
std
::
runtime_error
(
"Incorrect data type"
);
throw
std
::
runtime_error
(
"Incorrect data type"
);
}
}
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
()
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
()
{
{
return
ck_headers
();
return
ck_headers
();
}
}
...
...
library/src/jit_library/src/device_gemm_multiple_d.cpp
View file @
77190058
...
@@ -8,12 +8,12 @@ namespace ck {
...
@@ -8,12 +8,12 @@ namespace ck {
namespace
host
{
namespace
host
{
namespace
device_gemm_multiple_d
{
namespace
device_gemm_multiple_d
{
std
::
string
GetGemmSpec
(
const
std
::
size_t
m
,
std
::
string
GetGemmSpec
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
n
,
const
std
::
size_t
k
,
const
std
::
size_t
k
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
,
const
std
::
size_t
n_per_block
,
const
std
::
size_t
k_per_block
)
const
std
::
size_t
k_per_block
)
{
{
std
::
string
spec
=
""
;
std
::
string
spec
=
""
;
if
(
integer_divide_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
if
(
integer_divide_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
...
@@ -28,13 +28,12 @@ std::string GetGemmSpec(const std::size_t m,
...
@@ -28,13 +28,12 @@ std::string GetGemmSpec(const std::size_t m,
return
"ck::tensor_operation::device::GemmSpecialization::"
+
spec
+
"Padding"
;
return
"ck::tensor_operation::device::GemmSpecialization::"
+
spec
+
"Padding"
;
}
}
std
::
size_t
GetGridSize
(
const
std
::
size_t
m
,
std
::
size_t
GetGridSize
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
n
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
)
const
std
::
size_t
n_per_block
)
{
{
return
integer_divide_ceil
(
m
,
m_per_block
)
*
return
integer_divide_ceil
(
m
,
m_per_block
)
*
integer_divide_ceil
(
n
,
n_per_block
);
integer_divide_ceil
(
n
,
n_per_block
);
}
}
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
()
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
()
...
@@ -47,7 +46,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
...
@@ -47,7 +46,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
{
{
std
::
vector
<
std
::
string
>
instances
;
std
::
vector
<
std
::
string
>
instances
;
const
bool
quantize
=
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
;
const
bool
quantize
=
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
;
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
{
{
ck
::
host
::
instance
::
gemm_add_add_fastgelu_instances
all_instances
{};
ck
::
host
::
instance
::
gemm_add_add_fastgelu_instances
all_instances
{};
if
(
TransA
and
TransB
)
if
(
TransA
and
TransB
)
...
@@ -65,27 +64,28 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
...
@@ -65,27 +64,28 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
std
::
string
MakeLayoutTuple
(
const
std
::
vector
<
bool
>&
layouts
)
std
::
string
MakeLayoutTuple
(
const
std
::
vector
<
bool
>&
layouts
)
{
{
std
::
string
layout_tuple
=
"ck::Tuple<"
;
std
::
string
layout_tuple
=
"ck::Tuple<"
;
auto
it
=
layouts
.
begin
();
auto
it
=
layouts
.
begin
();
while
(
it
!=
layouts
.
end
())
while
(
it
!=
layouts
.
end
())
{
{
layout_tuple
+=
*
it
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
layout_tuple
+=
*
it
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
it
=
std
::
next
(
it
);
it
=
std
::
next
(
it
);
if
(
it
!=
layouts
.
end
())
if
(
it
!=
layouts
.
end
())
layout_tuple
+=
", "
;
layout_tuple
+=
", "
;
}
}
return
layout_tuple
+
">"
;
return
layout_tuple
+
">"
;
}
}
std
::
string
MakeTypeTuple
(
const
std
::
vector
<
DataType
>&
types
)
std
::
string
MakeTypeTuple
(
const
std
::
vector
<
DataType
>&
types
)
{
{
std
::
string
type_tuple
=
"ck::Tuple<"
;
std
::
string
type_tuple
=
"ck::Tuple<"
;
auto
it
=
types
.
begin
();
auto
it
=
types
.
begin
();
while
(
it
!=
types
.
end
())
while
(
it
!=
types
.
end
())
{
{
type_tuple
+=
ToString
(
*
it
);
type_tuple
+=
ToString
(
*
it
);
it
=
std
::
next
(
it
);
it
=
std
::
next
(
it
);
if
(
it
!=
types
.
end
())
if
(
it
!=
types
.
end
())
type_tuple
+=
", "
;
type_tuple
+=
", "
;
}
}
return
type_tuple
+
">"
;
return
type_tuple
+
">"
;
...
@@ -97,43 +97,46 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
...
@@ -97,43 +97,46 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std
::
istringstream
iss
(
template_str
);
std
::
istringstream
iss
(
template_str
);
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
istream_iterator
<
std
::
string
>
());
std
::
istream_iterator
<
std
::
string
>
());
if
(
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
)
if
(
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
)
{
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Half
;
}))
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Half
;
}))
{
{
params
[
params
.
size
()
-
3
]
=
"8"
;
params
[
params
.
size
()
-
3
]
=
"8"
;
}
}
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Float
;
}))
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Float
;
}))
{
{
params
[
params
.
size
()
-
3
]
=
"4"
;
params
[
params
.
size
()
-
3
]
=
"4"
;
}
}
}
}
params
[
a_elementwise_op_idx
]
=
AElementOp
;
params
[
a_elementwise_op_idx
]
=
AElementOp
;
params
[
b_elementwise_op_idx
]
=
BElementOp
;
params
[
b_elementwise_op_idx
]
=
BElementOp
;
params
[
ds_layout_idx
]
=
MakeLayoutTuple
(
DsTrans
);
params
[
ds_layout_idx
]
=
MakeLayoutTuple
(
DsTrans
);
params
[
ds_data_type_idx
]
=
MakeTypeTuple
(
DsDataType
);
params
[
ds_data_type_idx
]
=
MakeTypeTuple
(
DsDataType
);
params
[
ds_elementwise_op_idx
]
=
CDEElementOp
;
params
[
ds_elementwise_op_idx
]
=
CDEElementOp
;
params
[
e_data_type_idx
]
=
ToString
(
EDataType
);
params
[
e_data_type_idx
]
=
ToString
(
EDataType
);
auto
block_size_str
=
params
[
block_size_idx
];
auto
block_size_str
=
params
[
block_size_idx
];
auto
m_per_block_str
=
params
[
m_per_block_idx
];
auto
m_per_block_str
=
params
[
m_per_block_idx
];
auto
n_per_block_str
=
params
[
n_per_block_idx
];
auto
n_per_block_str
=
params
[
n_per_block_idx
];
auto
k_per_block_str
=
params
[
k_per_block_idx
];
auto
k_per_block_str
=
params
[
k_per_block_idx
];
const
std
::
size_t
block_size
=
std
::
stoi
(
block_size_str
);
const
std
::
size_t
block_size
=
std
::
stoi
(
block_size_str
);
const
std
::
size_t
m_per_block
=
std
::
stoi
(
m_per_block_str
);
const
std
::
size_t
m_per_block
=
std
::
stoi
(
m_per_block_str
);
const
std
::
size_t
n_per_block
=
std
::
stoi
(
n_per_block_str
);
const
std
::
size_t
n_per_block
=
std
::
stoi
(
n_per_block_str
);
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
N
,
m_per_block
,
n_per_block
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
N
,
m_per_block
,
n_per_block
);
params
[
gemm_spec_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
m_per_block
,
n_per_block
,
k_per_block
);
params
[
gemm_spec_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
m_per_block
,
n_per_block
,
k_per_block
);
std
::
string
str
=
std
::
accumulate
(
params
.
begin
()
+
1
,
params
.
end
(),
std
::
string
{},
std
::
string
str
=
std
::
accumulate
(
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
params
.
begin
()
+
1
,
return
a
.
empty
()
?
b
:
a
+
", "
+
b
;
params
.
end
(),
});
std
::
string
{},
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
return
a
.
empty
()
?
b
:
a
+
", "
+
b
;
});
str
=
params
.
front
()
+
"< "
+
str
+
">"
;
str
=
params
.
front
()
+
"< "
+
str
+
">"
;
return
Solution
{
str
,
block_size
,
grid_size
};
return
Solution
{
str
,
block_size
,
grid_size
};
}
}
...
@@ -146,7 +149,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
...
@@ -146,7 +149,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
{
{
std
::
vector
<
Solution
>
solutions
;
std
::
vector
<
Solution
>
solutions
;
const
std
::
size_t
num_instances
=
GetInstances
(
arch
).
size
();
const
std
::
size_t
num_instances
=
GetInstances
(
arch
).
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_instances
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
num_instances
;
++
i
)
{
{
solutions
.
push_back
(
MakeSolution
(
i
,
arch
));
solutions
.
push_back
(
MakeSolution
(
i
,
arch
));
}
}
...
@@ -154,7 +157,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
...
@@ -154,7 +157,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
return
solutions
;
return
solutions
;
}
}
}
// namespace device_gemm_multiple_d
}
// namespace device_gemm_multiple_d
}
// namespace host
}
// namespace host
}
// namespace ck
}
// namespace ck
test/jit_library/jit_library.cpp
View file @
77190058
...
@@ -3,36 +3,48 @@
...
@@ -3,36 +3,48 @@
bool
test_Problem
()
bool
test_Problem
()
{
{
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
false
,
256
,
true
,
false
,
false
,
true
,
{},
false
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
include_header
=
problem
.
GetIncludeHeader
();
const
auto
include_header
=
problem
.
GetIncludeHeader
();
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
const
auto
template_str
=
solution
.
template_str
;
const
auto
grid_size
=
solution
.
grid_size
;
const
auto
grid_size
=
solution
.
grid_size
;
const
auto
block_size
=
solution
.
block_size
;
const
auto
block_size
=
solution
.
block_size
;
bool
pass
=
true
;
bool
pass
=
true
;
pass
&=
include_header
==
"ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
;
pass
&=
include_header
==
"ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
;
pass
&=
solutions
.
size
()
==
42
;
pass
&=
solutions
.
size
()
==
42
;
pass
&=
template_str
==
"ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, ck::half_t, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, 1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>"
;
pass
&=
template_str
==
pass
&=
grid_size
==
2
;
"ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< "
"ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, "
"ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, "
"ck::half_t, ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, "
"8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, "
"8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, "
"1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>"
;
pass
&=
grid_size
==
2
;
pass
&=
block_size
==
256
;
pass
&=
block_size
==
256
;
return
pass
;
return
pass
;
}
}
...
@@ -40,46 +52,48 @@ bool test_GetGemmSpec()
...
@@ -40,46 +52,48 @@ bool test_GetGemmSpec()
{
{
bool
pass
=
true
;
bool
pass
=
true
;
{
{
//PadMNK
// PadMNK
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
255
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
255
,
255
,
255
,
255
,
false
,
255
,
true
,
false
,
false
,
true
,
{},
false
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
template_str
=
solution
.
template_str
;
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
pass
&=
template_str
.
find
(
"GemmSpecialization::MNKPadding"
)
!=
std
::
string
::
npos
;
pass
&=
template_str
.
find
(
"GemmSpecialization::MNKPadding"
)
!=
std
::
string
::
npos
;
}
}
{
{
//Default
// Default
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
false
,
256
,
true
,
false
,
false
,
true
,
{},
false
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
template_str
=
solution
.
template_str
;
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
pass
&=
template_str
.
find
(
"GemmSpecialization::Default"
)
!=
std
::
string
::
npos
;
pass
&=
template_str
.
find
(
"GemmSpecialization::Default"
)
!=
std
::
string
::
npos
;
}
}
...
@@ -91,147 +105,155 @@ bool test_GetInstances()
...
@@ -91,147 +105,155 @@ bool test_GetInstances()
{
{
bool
pass
=
true
;
bool
pass
=
true
;
{
{
//Col Col Fp16
// Col Col Fp16
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
true
,
256
,
true
,
true
,
false
,
true
,
{},
false
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
51
;
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
51
;
}
}
{
{
//Col Row Fp16
// Col Row Fp16
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
true
,
256
,
false
,
true
,
false
,
false
,
{},
false
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
51
;
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
51
;
}
}
{
{
//Row Col Fp16
// Row Col Fp16
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
false
,
256
,
true
,
false
,
false
,
true
,
{},
false
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
42
;
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
42
;
}
}
{
{
//Row Row Int8
// Row Row Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
false
,
256
,
false
,
false
,
false
,
false
,
{},
false
,
ck
::
host
::
DataType
::
Int8
,
{},
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Int8
,
{},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
}
}
{
{
//Col Col Int8
// Col Col Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
true
,
256
,
true
,
true
,
false
,
true
,
{},
false
,
ck
::
host
::
DataType
::
Int8
,
{},
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Int8
,
{},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
}
}
{
{
//Col Row Int8
// Col Row Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
true
,
256
,
false
,
true
,
false
,
false
,
{},
false
,
ck
::
host
::
DataType
::
Int8
,
{},
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Int8
,
{},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
}
}
{
{
//Row Col Int8
// Row Col Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
false
,
256
,
true
,
false
,
false
,
true
,
{},
false
,
ck
::
host
::
DataType
::
Int8
,
{},
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Int8
,
{},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
39
;
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
39
;
}
}
{
{
//Row Row Int8
// Row Row Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
false
,
256
,
false
,
false
,
false
,
false
,
{},
false
,
ck
::
host
::
DataType
::
Int8
,
{},
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Int8
,
{},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
}
}
...
@@ -243,45 +265,50 @@ bool test_MakeLayoutsTuple()
...
@@ -243,45 +265,50 @@ bool test_MakeLayoutsTuple()
bool
pass
=
true
;
bool
pass
=
true
;
{
{
// Empty Tuple
// Empty Tuple
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
false
,
256
,
false
,
false
,
false
,
false
,
{},
false
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{
ck
::
host
::
DataType
::
Half
},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{
ck
::
host
::
DataType
::
Half
},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
template_str
=
solution
.
template_str
;
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
pass
&=
template_str
.
find
(
"ck::Tuple<>"
)
!=
std
::
string
::
npos
;
pass
&=
template_str
.
find
(
"ck::Tuple<>"
)
!=
std
::
string
::
npos
;
}
}
{
{
// RowColRow Tuple
// RowColRow Tuple
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
false
,
256
,
false
,
false
,
false
,
false
,
{
false
,
true
,
false
},
false
,
ck
::
host
::
DataType
::
Half
,
{
false
,
true
,
false
},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{
ck
::
host
::
DataType
::
Half
},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{
ck
::
host
::
DataType
::
Half
},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
template_str
=
solution
.
template_str
;
const
auto
&
solution
=
solutions
.
at
(
0
);
pass
&=
template_str
.
find
(
"ck::Tuple<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor>"
)
!=
std
::
string
::
npos
;
const
auto
template_str
=
solution
.
template_str
;
pass
&=
template_str
.
find
(
"ck::Tuple<ck::tensor_layout::gemm::RowMajor, "
"ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor>"
)
!=
std
::
string
::
npos
;
}
}
return
pass
;
return
pass
;
...
@@ -292,44 +319,46 @@ bool test_MakeTypeTuple()
...
@@ -292,44 +319,46 @@ bool test_MakeTypeTuple()
bool
pass
=
true
;
bool
pass
=
true
;
{
{
// Empty Tuple
// Empty Tuple
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
false
,
256
,
false
,
false
,
false
,
false
,
{
true
},
false
,
ck
::
host
::
DataType
::
Half
,
{
true
},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
template_str
=
solution
.
template_str
;
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
pass
&=
template_str
.
find
(
"ck::Tuple<>"
)
!=
std
::
string
::
npos
;
pass
&=
template_str
.
find
(
"ck::Tuple<>"
)
!=
std
::
string
::
npos
;
}
}
{
{
// Half Int8 Tuple
// Half Int8 Tuple
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
256
,
false
,
256
,
false
,
false
,
false
,
false
,
{},
false
,
ck
::
host
::
DataType
::
Half
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Int8
},
ck
::
host
::
DataType
::
Half
,
"ck::tensor_operation::element_wise::Passthrough"
,
{
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Int8
},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
"ck::tensor_operation::element_wise::Passthrough"
,
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
template_str
=
solution
.
template_str
;
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
pass
&=
template_str
.
find
(
"ck::Tuple<ck::half_t, int8_t>"
)
!=
std
::
string
::
npos
;
pass
&=
template_str
.
find
(
"ck::Tuple<ck::half_t, int8_t>"
)
!=
std
::
string
::
npos
;
}
}
return
pass
;
return
pass
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment