Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
77190058
"docs/en/git@developer.sourcefind.cn:OpenDAS/opencompass.git" did not exist on "3871188c89e7841e54f40363a5bb7dc62afa2510"
Commit
77190058
authored
Jun 07, 2023
by
Alan Turner
Browse files
Formatting
parent
421734ae
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
366 additions
and
330 deletions
+366
-330
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+14
-10
library/src/jit_library/include/ck/host/common.hpp
library/src/jit_library/include/ck/host/common.hpp
+3
-2
library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp
...rc/jit_library/include/ck/host/device_gemm_multiple_d.hpp
+25
-26
library/src/jit_library/src/common.cpp
library/src/jit_library/src/common.cpp
+7
-6
library/src/jit_library/src/device_gemm_multiple_d.cpp
library/src/jit_library/src/device_gemm_multiple_d.cpp
+39
-37
test/jit_library/jit_library.cpp
test/jit_library/jit_library.cpp
+278
-249
No files found.
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
77190058
...
@@ -120,23 +120,25 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -120,23 +120,25 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
default
;
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
default
;
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
{
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
__host__
index_t
M01
=
8
)
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
BlockToCTileMap_M00_N0_M01Adapt
(
:
BlockToCTileMap_M00_N0_M01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
{
{
...
@@ -151,13 +153,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -151,13 +153,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
__device__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
__host__
__device__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
{
return
true
;
return
true
;
}
}
...
@@ -231,7 +235,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -231,7 +235,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
constexpr
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
__host__
__device__
constexpr
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
const
CTileDim
&
/* c_tile_dim */
)
const
{
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
}
...
...
library/src/jit_library/include/ck/host/common.hpp
View file @
77190058
...
@@ -17,7 +17,8 @@ struct Solution
...
@@ -17,7 +17,8 @@ struct Solution
std
::
size_t
grid_size
;
std
::
size_t
grid_size
;
};
};
enum
class
DataType
{
enum
class
DataType
{
Half
,
Half
,
Float
,
Float
,
Int8
,
Int8
,
...
@@ -26,7 +27,7 @@ enum class DataType {
...
@@ -26,7 +27,7 @@ enum class DataType {
std
::
string
ToString
(
DataType
dt
);
std
::
string
ToString
(
DataType
dt
);
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
();
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
();
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
);
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
);
...
...
library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp
View file @
77190058
...
@@ -11,45 +11,44 @@
...
@@ -11,45 +11,44 @@
#include <numeric>
#include <numeric>
#include "ck/host/common.hpp"
#include "ck/host/common.hpp"
namespace
ck
{
namespace
ck
{
namespace
host
{
namespace
host
{
namespace
device_gemm_multiple_d
{
namespace
device_gemm_multiple_d
{
struct
Problem
struct
Problem
{
{
std
::
size_t
M
=
0
;
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
K
=
0
;
bool
TransA
=
false
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransB
=
false
;
bool
TransE
=
false
;
bool
TransE
=
false
;
std
::
vector
<
bool
>
DsTrans
=
{};
std
::
vector
<
bool
>
DsTrans
=
{};
DataType
ADataType
=
DataType
::
Half
;
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
EDataType
=
DataType
::
Half
;
DataType
EDataType
=
DataType
::
Half
;
std
::
vector
<
DataType
>
DsDataType
=
{};
std
::
vector
<
DataType
>
DsDataType
=
{};
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CDEElementOp
=
"ck::Tuple<>"
;
std
::
string
CDEElementOp
=
"ck::Tuple<>"
;
static
const
std
::
size_t
ds_layout_idx
=
3
;
static
const
std
::
size_t
ds_layout_idx
=
3
;
static
const
std
::
size_t
ds_data_type_idx
=
9
;
static
const
std
::
size_t
ds_data_type_idx
=
9
;
static
const
std
::
size_t
e_data_type_idx
=
10
;
static
const
std
::
size_t
e_data_type_idx
=
10
;
static
const
std
::
size_t
a_elementwise_op_idx
=
11
;
static
const
std
::
size_t
a_elementwise_op_idx
=
11
;
static
const
std
::
size_t
b_elementwise_op_idx
=
12
;
static
const
std
::
size_t
b_elementwise_op_idx
=
12
;
static
const
std
::
size_t
ds_elementwise_op_idx
=
13
;
static
const
std
::
size_t
ds_elementwise_op_idx
=
13
;
static
const
std
::
size_t
gemm_spec_idx
=
14
;
static
const
std
::
size_t
gemm_spec_idx
=
14
;
static
const
std
::
size_t
block_size_idx
=
16
;
static
const
std
::
size_t
block_size_idx
=
16
;
static
const
std
::
size_t
m_per_block_idx
=
17
;
static
const
std
::
size_t
m_per_block_idx
=
17
;
static
const
std
::
size_t
n_per_block_idx
=
18
;
static
const
std
::
size_t
n_per_block_idx
=
18
;
static
const
std
::
size_t
k_per_block_idx
=
19
;
static
const
std
::
size_t
k_per_block_idx
=
19
;
std
::
string
GetIncludeHeader
()
const
;
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
private:
private:
std
::
vector
<
std
::
string
>
GetInstances
(
const
std
::
string
&
arch
)
const
;
std
::
vector
<
std
::
string
>
GetInstances
(
const
std
::
string
&
arch
)
const
;
Solution
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
;
Solution
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
;
...
...
library/src/jit_library/src/common.cpp
View file @
77190058
...
@@ -8,16 +8,17 @@ namespace host {
...
@@ -8,16 +8,17 @@ namespace host {
std
::
string
ToString
(
DataType
dt
)
std
::
string
ToString
(
DataType
dt
)
{
{
switch
(
dt
)
{
switch
(
dt
)
case
DataType
::
Float
:
return
"float"
;
{
case
DataType
::
Half
:
return
"ck::half_t"
;
case
DataType
::
Float
:
return
"float"
;
case
DataType
::
Int8
:
return
"int8_t"
;
case
DataType
::
Half
:
return
"ck::half_t"
;
case
DataType
::
Int32
:
return
"int32_t"
;
case
DataType
::
Int8
:
return
"int8_t"
;
case
DataType
::
Int32
:
return
"int32_t"
;
}
}
throw
std
::
runtime_error
(
"Incorrect data type"
);
throw
std
::
runtime_error
(
"Incorrect data type"
);
}
}
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
()
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
()
{
{
return
ck_headers
();
return
ck_headers
();
}
}
...
...
library/src/jit_library/src/device_gemm_multiple_d.cpp
View file @
77190058
...
@@ -8,12 +8,12 @@ namespace ck {
...
@@ -8,12 +8,12 @@ namespace ck {
namespace
host
{
namespace
host
{
namespace
device_gemm_multiple_d
{
namespace
device_gemm_multiple_d
{
std
::
string
GetGemmSpec
(
const
std
::
size_t
m
,
std
::
string
GetGemmSpec
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
n
,
const
std
::
size_t
k
,
const
std
::
size_t
k
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
,
const
std
::
size_t
n_per_block
,
const
std
::
size_t
k_per_block
)
const
std
::
size_t
k_per_block
)
{
{
std
::
string
spec
=
""
;
std
::
string
spec
=
""
;
if
(
integer_divide_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
if
(
integer_divide_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
...
@@ -28,13 +28,12 @@ std::string GetGemmSpec(const std::size_t m,
...
@@ -28,13 +28,12 @@ std::string GetGemmSpec(const std::size_t m,
return
"ck::tensor_operation::device::GemmSpecialization::"
+
spec
+
"Padding"
;
return
"ck::tensor_operation::device::GemmSpecialization::"
+
spec
+
"Padding"
;
}
}
std
::
size_t
GetGridSize
(
const
std
::
size_t
m
,
std
::
size_t
GetGridSize
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
n
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
)
const
std
::
size_t
n_per_block
)
{
{
return
integer_divide_ceil
(
m
,
m_per_block
)
*
return
integer_divide_ceil
(
m
,
m_per_block
)
*
integer_divide_ceil
(
n
,
n_per_block
);
integer_divide_ceil
(
n
,
n_per_block
);
}
}
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
()
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
()
...
@@ -47,7 +46,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
...
@@ -47,7 +46,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
{
{
std
::
vector
<
std
::
string
>
instances
;
std
::
vector
<
std
::
string
>
instances
;
const
bool
quantize
=
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
;
const
bool
quantize
=
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
;
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
{
{
ck
::
host
::
instance
::
gemm_add_add_fastgelu_instances
all_instances
{};
ck
::
host
::
instance
::
gemm_add_add_fastgelu_instances
all_instances
{};
if
(
TransA
and
TransB
)
if
(
TransA
and
TransB
)
...
@@ -65,27 +64,28 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
...
@@ -65,27 +64,28 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
std
::
string
MakeLayoutTuple
(
const
std
::
vector
<
bool
>&
layouts
)
std
::
string
MakeLayoutTuple
(
const
std
::
vector
<
bool
>&
layouts
)
{
{
std
::
string
layout_tuple
=
"ck::Tuple<"
;
std
::
string
layout_tuple
=
"ck::Tuple<"
;
auto
it
=
layouts
.
begin
();
auto
it
=
layouts
.
begin
();
while
(
it
!=
layouts
.
end
())
while
(
it
!=
layouts
.
end
())
{
{
layout_tuple
+=
*
it
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
layout_tuple
+=
*
it
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
it
=
std
::
next
(
it
);
it
=
std
::
next
(
it
);
if
(
it
!=
layouts
.
end
())
if
(
it
!=
layouts
.
end
())
layout_tuple
+=
", "
;
layout_tuple
+=
", "
;
}
}
return
layout_tuple
+
">"
;
return
layout_tuple
+
">"
;
}
}
std
::
string
MakeTypeTuple
(
const
std
::
vector
<
DataType
>&
types
)
std
::
string
MakeTypeTuple
(
const
std
::
vector
<
DataType
>&
types
)
{
{
std
::
string
type_tuple
=
"ck::Tuple<"
;
std
::
string
type_tuple
=
"ck::Tuple<"
;
auto
it
=
types
.
begin
();
auto
it
=
types
.
begin
();
while
(
it
!=
types
.
end
())
while
(
it
!=
types
.
end
())
{
{
type_tuple
+=
ToString
(
*
it
);
type_tuple
+=
ToString
(
*
it
);
it
=
std
::
next
(
it
);
it
=
std
::
next
(
it
);
if
(
it
!=
types
.
end
())
if
(
it
!=
types
.
end
())
type_tuple
+=
", "
;
type_tuple
+=
", "
;
}
}
return
type_tuple
+
">"
;
return
type_tuple
+
">"
;
...
@@ -97,43 +97,46 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
...
@@ -97,43 +97,46 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std
::
istringstream
iss
(
template_str
);
std
::
istringstream
iss
(
template_str
);
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
istream_iterator
<
std
::
string
>
());
std
::
istream_iterator
<
std
::
string
>
());
if
(
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
)
if
(
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
)
{
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Half
;
}))
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Half
;
}))
{
{
params
[
params
.
size
()
-
3
]
=
"8"
;
params
[
params
.
size
()
-
3
]
=
"8"
;
}
}
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Float
;
}))
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Float
;
}))
{
{
params
[
params
.
size
()
-
3
]
=
"4"
;
params
[
params
.
size
()
-
3
]
=
"4"
;
}
}
}
}
params
[
a_elementwise_op_idx
]
=
AElementOp
;
params
[
a_elementwise_op_idx
]
=
AElementOp
;
params
[
b_elementwise_op_idx
]
=
BElementOp
;
params
[
b_elementwise_op_idx
]
=
BElementOp
;
params
[
ds_layout_idx
]
=
MakeLayoutTuple
(
DsTrans
);
params
[
ds_layout_idx
]
=
MakeLayoutTuple
(
DsTrans
);
params
[
ds_data_type_idx
]
=
MakeTypeTuple
(
DsDataType
);
params
[
ds_data_type_idx
]
=
MakeTypeTuple
(
DsDataType
);
params
[
ds_elementwise_op_idx
]
=
CDEElementOp
;
params
[
ds_elementwise_op_idx
]
=
CDEElementOp
;
params
[
e_data_type_idx
]
=
ToString
(
EDataType
);
params
[
e_data_type_idx
]
=
ToString
(
EDataType
);
auto
block_size_str
=
params
[
block_size_idx
];
auto
block_size_str
=
params
[
block_size_idx
];
auto
m_per_block_str
=
params
[
m_per_block_idx
];
auto
m_per_block_str
=
params
[
m_per_block_idx
];
auto
n_per_block_str
=
params
[
n_per_block_idx
];
auto
n_per_block_str
=
params
[
n_per_block_idx
];
auto
k_per_block_str
=
params
[
k_per_block_idx
];
auto
k_per_block_str
=
params
[
k_per_block_idx
];
const
std
::
size_t
block_size
=
std
::
stoi
(
block_size_str
);
const
std
::
size_t
block_size
=
std
::
stoi
(
block_size_str
);
const
std
::
size_t
m_per_block
=
std
::
stoi
(
m_per_block_str
);
const
std
::
size_t
m_per_block
=
std
::
stoi
(
m_per_block_str
);
const
std
::
size_t
n_per_block
=
std
::
stoi
(
n_per_block_str
);
const
std
::
size_t
n_per_block
=
std
::
stoi
(
n_per_block_str
);
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
N
,
m_per_block
,
n_per_block
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
N
,
m_per_block
,
n_per_block
);
params
[
gemm_spec_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
m_per_block
,
n_per_block
,
k_per_block
);
params
[
gemm_spec_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
m_per_block
,
n_per_block
,
k_per_block
);
std
::
string
str
=
std
::
accumulate
(
params
.
begin
()
+
1
,
params
.
end
(),
std
::
string
{},
std
::
string
str
=
std
::
accumulate
(
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
params
.
begin
()
+
1
,
return
a
.
empty
()
?
b
:
a
+
", "
+
b
;
params
.
end
(),
});
std
::
string
{},
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
return
a
.
empty
()
?
b
:
a
+
", "
+
b
;
});
str
=
params
.
front
()
+
"< "
+
str
+
">"
;
str
=
params
.
front
()
+
"< "
+
str
+
">"
;
return
Solution
{
str
,
block_size
,
grid_size
};
return
Solution
{
str
,
block_size
,
grid_size
};
}
}
...
@@ -146,7 +149,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
...
@@ -146,7 +149,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
{
{
std
::
vector
<
Solution
>
solutions
;
std
::
vector
<
Solution
>
solutions
;
const
std
::
size_t
num_instances
=
GetInstances
(
arch
).
size
();
const
std
::
size_t
num_instances
=
GetInstances
(
arch
).
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_instances
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
num_instances
;
++
i
)
{
{
solutions
.
push_back
(
MakeSolution
(
i
,
arch
));
solutions
.
push_back
(
MakeSolution
(
i
,
arch
));
}
}
...
@@ -154,7 +157,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
...
@@ -154,7 +157,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
return
solutions
;
return
solutions
;
}
}
}
// namespace device_gemm_multiple_d
}
// namespace device_gemm_multiple_d
}
// namespace host
}
// namespace host
}
// namespace ck
}
// namespace ck
test/jit_library/jit_library.cpp
View file @
77190058
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment