Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
546a764e
Commit
546a764e
authored
Oct 24, 2023
by
Artur Wojcik
Browse files
Merge branch 'migraphx' into uif2-migraphx
parents
8da3dfff
57cdd70b
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1259 additions
and
32 deletions
+1259
-32
library/src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
.../src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
+115
-0
library/src/jit_library/src/device_gemm_multiple_d.cpp
library/src/jit_library/src/device_gemm_multiple_d.cpp
+174
-0
library/src/jit_library/util/file_templates.py
library/src/jit_library/util/file_templates.py
+177
-0
library/src/jit_library/util/make_instance_strings.py
library/src/jit_library/util/make_instance_strings.py
+367
-0
test/CMakeLists.txt
test/CMakeLists.txt
+36
-32
test/jit_library/CMakeLists.txt
test/jit_library/CMakeLists.txt
+4
-0
test/jit_library/jit_library.cpp
test/jit_library/jit_library.cpp
+386
-0
No files found.
library/src/jit_library/src/device_batched_gemm_softmax_gemm.cpp
0 → 100644
View file @
546a764e
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
#include "ck/host/common.hpp"
#include "batched_gemm_softmax_gemm_instances.hpp"
#include <algorithm>
#include <unordered_set>
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
std
::
string
GetGemmSpec
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
k
,
const
std
::
size_t
n1
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
,
const
std
::
size_t
k_per_block
,
const
std
::
size_t
n1_per_block
)
{
std
::
string
spec
=
""
;
if
(
integer_divide_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
spec
+=
"M"
;
if
(
integer_divide_ceil
(
n
,
n_per_block
)
*
n_per_block
-
n
!=
0
)
spec
+=
"N"
;
if
(
integer_divide_ceil
(
k
,
k_per_block
)
*
k_per_block
-
k
!=
0
)
spec
+=
"K"
;
if
(
integer_divide_ceil
(
n1
,
n1_per_block
)
*
n1_per_block
-
n1
!=
0
)
spec
+=
"O"
;
if
(
spec
==
""
)
return
"ck::tensor_operation::device::GemmSpecialization::Default"
;
return
"ck::tensor_operation::device::GemmSpecialization::"
+
spec
+
"Padding"
;
}
std
::
size_t
GetGridSize
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
)
{
return
integer_divide_ceil
(
m
,
m_per_block
)
*
integer_divide_ceil
(
n
,
n_per_block
);
}
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
()
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx90a"
,
"gfx908"
,
"gfx940"
,
"gfx942"
};
return
supported_archs
;
}
std
::
vector
<
std
::
string
>
Problem
::
GetInstances
(
const
std
::
string
&
arch
)
const
{
std
::
vector
<
std
::
string
>
instances
;
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
{
ck
::
host
::
instance
::
batched_gemm_softmax_gemm_instances
all_instances
{};
instances
=
all_instances
.
get_instances
();
}
return
instances
;
}
Solution
Problem
::
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
{
auto
template_str
=
GetInstances
(
arch
).
at
(
idx
);
std
::
istringstream
iss
(
template_str
);
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
istream_iterator
<
std
::
string
>
());
params
[
AElementwiseOperation_idx
]
=
AElementOp
;
params
[
B0ElementwiseOperation_idx
]
=
BElementOp
;
params
[
B1ElementwiseOperation_idx
]
=
BElementOp
;
params
[
CElementwiseOperation_idx
]
=
CElementOp
;
params
[
Acc0ElementwiseOperation_idx
]
=
AccElementOp
;
auto
block_size_str
=
params
[
BlockSize_idx
];
auto
m_per_block_str
=
params
[
Gemm01MPerBlock_idx
];
auto
n_per_block_str
=
params
[
Gemm0NPerBlock_idx
];
auto
k_per_block_str
=
params
[
Gemm0KPerBlock_idx
];
auto
n1_per_block_str
=
params
[
Gemm1NPerBlock_idx
];
const
std
::
size_t
block_size
=
std
::
stoi
(
block_size_str
);
const
std
::
size_t
m_per_block
=
std
::
stoi
(
m_per_block_str
);
const
std
::
size_t
n_per_block
=
std
::
stoi
(
n_per_block_str
);
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
n1_per_block
=
std
::
stoi
(
n1_per_block_str
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
O
,
m_per_block
,
n1_per_block
);
params
[
GEMMSpecialization_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
O
,
m_per_block
,
n_per_block
,
k_per_block
,
n1_per_block
);
std
::
string
str
=
std
::
accumulate
(
params
.
begin
()
+
1
,
params
.
end
(),
std
::
string
{},
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
return
a
.
empty
()
?
b
:
a
+
", "
+
b
;
});
str
=
params
.
front
()
+
"< "
+
str
+
">"
;
return
Solution
{
str
,
block_size
,
grid_size
};
}
std
::
string
Problem
::
GetIncludeHeader
()
const
{
return
ck
::
host
::
instance
::
batched_gemm_softmax_gemm_instances
{}.
get_include_header
();
}
std
::
vector
<
Solution
>
Problem
::
GetSolutions
(
const
std
::
string
&
arch
)
const
{
std
::
vector
<
Solution
>
solutions
;
const
std
::
size_t
num_instances
=
GetInstances
(
arch
).
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_instances
;
++
i
)
{
solutions
.
push_back
(
MakeSolution
(
i
,
arch
));
}
return
solutions
;
}
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
library/src/jit_library/src/device_gemm_multiple_d.cpp
0 → 100644
View file @
546a764e
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/common.hpp"
#include "gemm_add_add_fastgelu_instances.hpp"
#include <algorithm>
#include <unordered_set>
namespace
ck
{
namespace
host
{
namespace
device_gemm_multiple_d
{
std
::
string
GetGemmSpec
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
k
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
,
const
std
::
size_t
k_per_block
)
{
std
::
string
spec
=
""
;
if
(
integer_divide_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
spec
+=
"M"
;
if
(
integer_divide_ceil
(
n
,
n_per_block
)
*
n_per_block
-
n
!=
0
)
spec
+=
"N"
;
if
(
integer_divide_ceil
(
k
,
k_per_block
)
*
k_per_block
-
k
!=
0
)
spec
+=
"K"
;
if
(
spec
==
""
)
return
"ck::tensor_operation::device::GemmSpecialization::Default"
;
return
"ck::tensor_operation::device::GemmSpecialization::"
+
spec
+
"Padding"
;
}
std
::
size_t
GetGridSize
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
)
{
return
integer_divide_ceil
(
m
,
m_per_block
)
*
integer_divide_ceil
(
n
,
n_per_block
);
}
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
()
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx90a"
,
"gfx908"
,
"gfx940"
,
"gfx942"
};
return
supported_archs
;
}
std
::
vector
<
std
::
string
>
Problem
::
GetInstances
(
const
std
::
string
&
arch
)
const
{
std
::
vector
<
std
::
string
>
instances
;
const
bool
quantize
=
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
;
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
{
ck
::
host
::
instance
::
gemm_add_add_fastgelu_instances
all_instances
{};
if
(
TransA
and
TransB
)
instances
=
all_instances
.
get_col_col_instances
(
quantize
);
else
if
(
TransA
and
not
TransB
)
instances
=
all_instances
.
get_col_row_instances
(
quantize
);
else
if
(
not
TransA
and
not
TransB
)
instances
=
all_instances
.
get_row_row_instances
(
quantize
);
else
instances
=
all_instances
.
get_row_col_instances
(
quantize
);
}
return
instances
;
}
std
::
string
MakeLayoutTuple
(
const
std
::
vector
<
bool
>&
layouts
)
{
std
::
string
layout_tuple
=
"ck::Tuple<"
;
auto
it
=
layouts
.
begin
();
while
(
it
!=
layouts
.
end
())
{
layout_tuple
+=
*
it
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
it
=
std
::
next
(
it
);
if
(
it
!=
layouts
.
end
())
layout_tuple
+=
", "
;
}
return
layout_tuple
+
">"
;
}
std
::
string
MakeTypeTuple
(
const
std
::
vector
<
DataType
>&
types
)
{
std
::
string
type_tuple
=
"ck::Tuple<"
;
auto
it
=
types
.
begin
();
while
(
it
!=
types
.
end
())
{
type_tuple
+=
ToString
(
*
it
);
it
=
std
::
next
(
it
);
if
(
it
!=
types
.
end
())
type_tuple
+=
", "
;
}
return
type_tuple
+
">"
;
}
Solution
Problem
::
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
{
auto
template_str
=
GetInstances
(
arch
).
at
(
idx
);
std
::
istringstream
iss
(
template_str
);
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
istream_iterator
<
std
::
string
>
());
if
(
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
)
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if
(
EDataType
==
DataType
::
Half
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Half
;
}))
{
params
[
params
.
size
()
-
3
]
=
"8"
;
}
if
(
EDataType
==
DataType
::
Float
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Float
;
}))
{
params
[
params
.
size
()
-
3
]
=
"4"
;
}
if
(
EDataType
==
DataType
::
Int32
or
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Int32
;
}))
{
params
[
params
.
size
()
-
3
]
=
"4"
;
}
}
params
[
a_elementwise_op_idx
]
=
AElementOp
;
params
[
b_elementwise_op_idx
]
=
BElementOp
;
params
[
ds_layout_idx
]
=
MakeLayoutTuple
(
DsTrans
);
params
[
ds_data_type_idx
]
=
MakeTypeTuple
(
DsDataType
);
params
[
ds_elementwise_op_idx
]
=
CDEElementOp
;
params
[
e_data_type_idx
]
=
ToString
(
EDataType
);
auto
block_size_str
=
params
[
block_size_idx
];
auto
m_per_block_str
=
params
[
m_per_block_idx
];
auto
n_per_block_str
=
params
[
n_per_block_idx
];
auto
k_per_block_str
=
params
[
k_per_block_idx
];
const
std
::
size_t
block_size
=
std
::
stoi
(
block_size_str
);
const
std
::
size_t
m_per_block
=
std
::
stoi
(
m_per_block_str
);
const
std
::
size_t
n_per_block
=
std
::
stoi
(
n_per_block_str
);
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
N
,
m_per_block
,
n_per_block
);
params
[
gemm_spec_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
m_per_block
,
n_per_block
,
k_per_block
);
std
::
string
str
=
std
::
accumulate
(
params
.
begin
()
+
1
,
params
.
end
(),
std
::
string
{},
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
return
a
.
empty
()
?
b
:
a
+
", "
+
b
;
});
str
=
params
.
front
()
+
"< "
+
str
+
">"
;
if
(
params
.
back
().
find
(
"v2"
)
!=
std
::
string
::
npos
and
K
%
k_per_block
!=
0
)
str
=
""
;
return
Solution
{
str
,
block_size
,
grid_size
};
}
std
::
string
Problem
::
GetIncludeHeader
()
const
{
return
ck
::
host
::
instance
::
gemm_add_add_fastgelu_instances
{}.
get_include_header
();
}
std
::
vector
<
Solution
>
Problem
::
GetSolutions
(
const
std
::
string
&
arch
)
const
{
std
::
vector
<
Solution
>
solutions
;
const
std
::
size_t
num_instances
=
GetInstances
(
arch
).
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_instances
;
++
i
)
{
auto
solution
=
MakeSolution
(
i
,
arch
);
if
(
solution
.
template_str
!=
""
)
solutions
.
push_back
(
solution
);
}
return
solutions
;
}
}
// namespace device_gemm_multiple_d
}
// namespace host
}
// namespace ck
library/src/jit_library/util/file_templates.py
0 → 100644
View file @
546a764e
out_file_with_quant
=
"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
namespace ck {{
namespace host {{
namespace instance {{
struct {op_name}_instances
{{
static inline std::vector<std::string> {col_row_name} =
{{
{col_row_instances}
}};
static inline std::vector<std::string> {col_col_name} =
{{
{col_col_instances}
}};
static inline std::vector<std::string> {row_row_name} =
{{
{row_row_instances}
}};
static inline std::vector<std::string> {row_col_name} =
{{
{row_col_instances}
}};
static inline std::vector<std::string> {int8_col_row_name} =
{{
{int8_col_row_instances}
}};
static inline std::vector<std::string> {int8_col_col_name} =
{{
{int8_col_col_instances}
}};
static inline std::vector<std::string> {int8_row_row_name} =
{{
{int8_row_row_instances}
}};
static inline std::vector<std::string> {int8_row_col_name} =
{{
{int8_row_col_instances}
}};
static auto get_col_row_instances(const bool quantize)
{{
return quantize ? {int8_col_row_name} :
{col_row_name};
}}
static auto get_col_col_instances(const bool quantize)
{{
return quantize ? {int8_col_col_name} :
{col_col_name};
}}
static auto get_row_row_instances(const bool quantize)
{{
return quantize ? {int8_row_row_name} :
{row_row_name};
}}
static auto get_row_col_instances(const bool quantize)
{{
return quantize ? {int8_row_col_name} :
{row_col_name};
}}
static auto get_include_header()
{{
return "{include_header}";
}}
}};
}} // namespace instance
}} // namespace host
}} // namespace ck
"""
out_file_no_quant
=
"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
namespace ck {{
namespace host {{
namespace instance {{
struct {op_name}_instances
{{
static inline std::vector<std::string> {instances_name} =
{{
{instances}
}};
static auto get_instances()
{{
return {instances_name};
}}
static auto get_include_header()
{{
return "{include_header}";
}}
}};
}} // namespace instance
}} // namespace host
}} // namespace ck
"""
def
get_device_gemm_multiple_d_file
(
op_name
,
col_row_name
,
col_row_instances
,
col_col_name
,
col_col_instances
,
row_row_name
,
row_row_instances
,
row_col_name
,
row_col_instances
,
int8_col_row_name
,
int8_col_row_instances
,
int8_col_col_name
,
int8_col_col_instances
,
int8_row_row_name
,
int8_row_row_instances
,
int8_row_col_name
,
int8_row_col_instances
,
include_header
):
return
out_file_with_quant
.
format
(
op_name
=
op_name
,
col_row_name
=
col_row_name
,
col_row_instances
=
col_row_instances
,
col_col_name
=
col_col_name
,
col_col_instances
=
col_col_instances
,
row_row_name
=
row_row_name
,
row_row_instances
=
row_row_instances
,
row_col_name
=
row_col_name
,
row_col_instances
=
row_col_instances
,
int8_col_row_name
=
int8_col_row_name
,
int8_col_row_instances
=
int8_col_row_instances
,
int8_col_col_name
=
int8_col_col_name
,
int8_col_col_instances
=
int8_col_col_instances
,
int8_row_row_name
=
int8_row_row_name
,
int8_row_row_instances
=
int8_row_row_instances
,
int8_row_col_name
=
int8_row_col_name
,
int8_row_col_instances
=
int8_row_col_instances
,
include_header
=
include_header
)
def
get_device_gemm_softmax_gemm_file
(
op_name
,
instances_name
,
instances
,
include_header
):
return
out_file_no_quant
.
format
(
op_name
=
op_name
,
instances_name
=
instances_name
,
instances
=
instances
,
include_header
=
include_header
)
library/src/jit_library/util/make_instance_strings.py
0 → 100644
View file @
546a764e
import
argparse
,
re
,
json
,
os
,
sys
,
file_templates
def
strip_sequences
(
str
):
matches
=
re
.
findall
(
r
'S<\s*\d+(?:,\s*\d+)*>'
,
str
)
for
match
in
matches
:
str
=
str
.
replace
(
match
,
match
.
replace
(
' '
,
''
))
str
=
str
.
replace
(
'S<'
,
"ck::Sequence<"
)
return
str
def
remove_commas_and_brackets
(
string
):
regex_matches
=
re
.
findall
(
r
'ck::Sequence<.*?>'
,
string
)
for
match
in
regex_matches
:
string
=
string
.
replace
(
match
,
match
.
replace
(
','
,
'|'
).
replace
(
'<'
,
'%'
).
replace
(
'>'
,
'$'
))
string
=
string
.
replace
(
','
,
''
).
replace
(
'<'
,
''
).
replace
(
'>'
,
''
)
for
match
in
regex_matches
:
string
=
string
.
replace
(
match
.
replace
(
','
,
'|'
).
replace
(
'<'
,
'%'
).
replace
(
'>'
,
'$'
),
match
)
return
string
def
get_int8_instances
(
src
,
file
,
template_name
):
aliases
=
{
"Empty_Tuple"
:
"ck::Tuple<>"
,
"Row"
:
"ck::tensor_layout::gemm::RowMajor"
,
"Col"
:
"ck::tensor_layout::gemm::ColumnMajor"
,
"OutElementOp"
:
"PassThrough"
}
instances
=
{
"row_row"
:
[],
"row_col"
:
[],
"col_col"
:
[],
"col_row"
:
[],
"row_row_name"
:
[],
"row_col_name"
:
[],
"col_col_name"
:
[],
"col_row_name"
:
[]}
path
=
src
+
file
with
open
(
path
)
as
f
:
for
line
in
f
:
if
"impl"
in
line
:
include_header
=
line
.
replace
(
"#include
\"
"
,
""
).
replace
(
"
\"
"
,
""
).
replace
(
"
\n
"
,
""
)
elif
"using"
in
line
:
if
bool
(
re
.
search
(
".*mk.*kn.*"
,
line
)):
instances
[
"row_row_name"
]
=
re
.
search
(
"device_gemm.*instance"
,
line
).
group
()
elif
bool
(
re
.
search
(
".*mk.*nk.*"
,
line
)):
instances
[
"row_col_name"
]
=
re
.
search
(
"device_gemm.*instance"
,
line
).
group
()
elif
bool
(
re
.
search
(
".*km.*nk.*"
,
line
)):
instances
[
"col_col_name"
]
=
re
.
search
(
"device_gemm.*instance"
,
line
).
group
()
elif
bool
(
re
.
search
(
".*km.*kn.*"
,
line
)):
instances
[
"col_row_name"
]
=
re
.
search
(
"device_gemm.*instance"
,
line
).
group
()
elif
template_name
in
line
:
# Turn all whitespace into single spaces
new_line
=
" "
.
join
(
line
.
split
())
# Remove whitespace from S<*>
new_line
=
strip_sequences
(
new_line
)
new_line
=
remove_commas_and_brackets
(
new_line
)
last_char
=
"
\n
"
if
new_line
[
-
1
]
==
","
:
last_char
=
",
\n
"
new_line
=
new_line
[:
-
1
]
new_line
=
' "ck::tensor_operation::device::'
+
new_line
+
'",'
versions
=
[]
for
key
in
aliases
:
new_line
=
new_line
.
replace
(
key
,
aliases
[
key
])
versions
.
append
(
new_line
.
replace
(
"GemmPipeline"
,
"ck::PipelineVersion::v1"
).
replace
(
"GemmLoopScheduler"
,
"ck::LoopScheduler::Default"
))
versions
.
append
(
new_line
.
replace
(
"GemmPipeline"
,
"ck::PipelineVersion::v1"
).
replace
(
"GemmLoopScheduler"
,
"ck::LoopScheduler::Interwave"
))
versions
.
append
(
new_line
.
replace
(
"GemmPipeline"
,
"ck::PipelineVersion::v2"
).
replace
(
"GemmLoopScheduler"
,
"ck::LoopScheduler::Default"
))
if
"ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::RowMajor"
in
new_line
:
instances
[
"row_row"
].
extend
(
versions
)
elif
"ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::ColumnMajor"
in
new_line
:
instances
[
"row_col"
].
extend
(
versions
)
elif
"ck::tensor_layout::gemm::ColumnMajor ck::tensor_layout::gemm::ColumnMajor"
in
new_line
:
instances
[
"col_col"
].
extend
(
versions
)
elif
"ck::tensor_layout::gemm::ColumnMajor ck::tensor_layout::gemm::RowMajor"
in
new_line
:
instances
[
"col_row"
].
extend
(
versions
)
instances
[
"row_row"
][
-
1
]
=
instances
[
"row_row"
][
-
1
][:
-
1
]
instances
[
"row_col"
][
-
1
]
=
instances
[
"row_col"
][
-
1
][:
-
1
]
instances
[
"col_col"
][
-
1
]
=
instances
[
"col_col"
][
-
1
][:
-
1
]
instances
[
"col_row"
][
-
1
]
=
instances
[
"col_row"
][
-
1
][:
-
1
]
return
instances
def
parse_instances
(
source
,
out_dir
):
aliases
=
{
"F16_F16_Tuple"
:
"ck::Tuple<F16,F16>"
,
"Row_Row_Tuple"
:
"ck::Tuple<Row,Row>"
,
"Empty_Tuple"
:
"ck::Tuple<>"
,
"LoopScheduler"
:
"ck::LoopScheduler"
,
"PipelineVersion"
:
"ck::PipelineVersion"
,
"Row"
:
"ck::tensor_layout::gemm::RowMajor"
,
"Col"
:
"ck::tensor_layout::gemm::ColumnMajor"
,
"F16"
:
"ck::half_t"
,
"F32"
:
"float"
,
"OutElementOp"
:
"PassThrough"
}
device_ops
=
{
"gemm_add_add_fastgelu"
:
"DeviceGemmMultipleD_Xdl_CShuffle"
,
#"batched_gemm_softmax_gemm": "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
for
root_
,
dirs_
,
files_
in
os
.
walk
(
source
):
for
dir
in
dirs_
:
op_name
=
os
.
path
.
split
(
dir
)[
-
1
]
if
op_name
not
in
device_ops
:
continue
col_row_name
=
""
col_col_name
=
""
row_row_name
=
""
row_col_name
=
""
row_row_instances
=
[]
col_row_instances
=
[]
row_col_instances
=
[]
col_col_instances
=
[]
for
root
,
dirs
,
files
in
os
.
walk
(
os
.
path
.
join
(
root_
,
dir
)):
for
file
in
files
:
if
not
file
.
endswith
(
".cpp"
):
continue
;
file_name
=
os
.
path
.
split
(
file
)[
-
1
]
is_row_row
=
bool
(
re
.
search
(
".*mk.*kn.*"
,
file_name
))
is_col_row
=
bool
(
re
.
search
(
".*km.*kn.*"
,
file_name
))
is_row_col
=
bool
(
re
.
search
(
".*mk.*nk.*"
,
file_name
))
is_col_col
=
bool
(
re
.
search
(
".*km.*nk.*"
,
file_name
))
if
is_row_row
:
row_row_name
=
file_name
[:
-
4
]
if
is_col_row
:
col_row_name
=
file_name
[:
-
4
]
if
is_row_col
:
row_col_name
=
file_name
[:
-
4
]
if
is_col_col
:
col_col_name
=
file_name
[:
-
4
]
instances_list
=
[]
template_name
=
device_ops
[
op_name
]
include_header
=
""
with
open
(
os
.
path
.
join
(
root
,
file
))
as
f
:
for
line
in
f
:
if
"impl"
in
line
:
include_header
=
line
.
replace
(
"#include
\"
"
,
""
).
replace
(
"
\"
"
,
""
).
replace
(
"
\n
"
,
""
)
elif
template_name
in
line
:
# Turn all whitespace into single spaces
new_line
=
" "
.
join
(
line
.
split
())
# Remove whitespace from S<*>
new_line
=
strip_sequences
(
new_line
)
new_line
=
remove_commas_and_brackets
(
new_line
)
last_char
=
"
\n
"
if
new_line
[
-
1
]
==
","
:
last_char
=
",
\n
"
new_line
=
new_line
[:
-
1
]
new_line
=
' "ck::tensor_operation::device::'
+
new_line
+
'",'
for
key
in
aliases
:
new_line
=
new_line
.
replace
(
key
,
aliases
[
key
])
instances_list
.
append
(
new_line
)
instances_list
[
-
1
]
=
instances_list
[
-
1
][:
-
1
]
if
is_row_row
:
row_row_instances
=
instances_list
if
is_col_row
:
col_row_instances
=
instances_list
if
is_row_col
:
row_col_instances
=
instances_list
if
is_col_col
:
col_col_instances
=
instances_list
out_file_name
=
op_name
+
"_instances.hpp"
if
not
os
.
path
.
exists
(
out_dir
):
os
.
mkdir
(
out_dir
)
int8_file
=
"/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
int8_instances
=
get_int8_instances
(
source
,
int8_file
,
"DeviceGemmMultipleD_Xdl_CShuffle"
)
with
open
(
os
.
path
.
join
(
out_dir
,
out_file_name
),
"w+"
)
as
f
:
f
.
write
(
file_templates
.
get_device_gemm_multiple_d_file
(
op_name
,
col_row_name
,
"
\n
"
.
join
(
col_row_instances
),
col_col_name
,
"
\n
"
.
join
(
col_col_instances
),
row_row_name
,
"
\n
"
.
join
(
row_row_instances
),
row_col_name
,
"
\n
"
.
join
(
row_col_instances
),
int8_instances
[
"col_row_name"
],
"
\n
"
.
join
(
int8_instances
[
"col_row"
]),
int8_instances
[
"col_col_name"
],
"
\n
"
.
join
(
int8_instances
[
"col_col"
]),
int8_instances
[
"row_row_name"
],
"
\n
"
.
join
(
int8_instances
[
"row_row"
]),
int8_instances
[
"row_col_name"
],
"
\n
"
.
join
(
int8_instances
[
"row_col"
]),
include_header
))
def
parse_device_gemm_multiple_d_instances
(
source
,
out_dir
):
aliases
=
{
"F16_F16_Tuple"
:
"ck::Tuple<F16,F16>"
,
"Row_Row_Tuple"
:
"ck::Tuple<Row,Row>"
,
"Empty_Tuple"
:
"ck::Tuple<>"
,
"LoopScheduler"
:
"ck::LoopScheduler"
,
"PipelineVersion"
:
"ck::PipelineVersion"
,
"Row"
:
"ck::tensor_layout::gemm::RowMajor"
,
"Col"
:
"ck::tensor_layout::gemm::ColumnMajor"
,
"F16"
:
"ck::half_t"
,
"F32"
:
"float"
,
"OutElementOp"
:
"PassThrough"
}
device_ops
=
{
"gemm_add_add_fastgelu"
:
"DeviceGemmMultipleD_Xdl_CShuffle"
,
#"batched_gemm_softmax_gemm": "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
for
root_
,
dirs_
,
files_
in
os
.
walk
(
source
):
for
dir
in
dirs_
:
op_name
=
os
.
path
.
split
(
dir
)[
-
1
]
if
op_name
not
in
device_ops
:
continue
col_row_name
=
""
col_col_name
=
""
row_row_name
=
""
row_col_name
=
""
row_row_instances
=
[]
col_row_instances
=
[]
row_col_instances
=
[]
col_col_instances
=
[]
for
root
,
dirs
,
files
in
os
.
walk
(
os
.
path
.
join
(
root_
,
dir
)):
for
file
in
files
:
if
not
file
.
endswith
(
".cpp"
):
continue
;
file_name
=
os
.
path
.
split
(
file
)[
-
1
]
is_row_row
=
bool
(
re
.
search
(
".*mk.*kn.*"
,
file_name
))
is_col_row
=
bool
(
re
.
search
(
".*km.*kn.*"
,
file_name
))
is_row_col
=
bool
(
re
.
search
(
".*mk.*nk.*"
,
file_name
))
is_col_col
=
bool
(
re
.
search
(
".*km.*nk.*"
,
file_name
))
if
is_row_row
:
row_row_name
=
file_name
[:
-
4
]
if
is_col_row
:
col_row_name
=
file_name
[:
-
4
]
if
is_row_col
:
row_col_name
=
file_name
[:
-
4
]
if
is_col_col
:
col_col_name
=
file_name
[:
-
4
]
instances_list
=
[]
template_name
=
device_ops
[
op_name
]
include_header
=
""
with
open
(
os
.
path
.
join
(
root
,
file
))
as
f
:
for
line
in
f
:
if
"impl"
in
line
:
include_header
=
line
.
replace
(
"#include
\"
"
,
""
).
replace
(
"
\"
"
,
""
).
replace
(
"
\n
"
,
""
)
elif
template_name
in
line
:
# Turn all whitespace into single spaces
new_line
=
" "
.
join
(
line
.
split
())
# Remove whitespace from S<*>
new_line
=
strip_sequences
(
new_line
)
new_line
=
remove_commas_and_brackets
(
new_line
)
last_char
=
"
\n
"
if
new_line
[
-
1
]
==
","
:
last_char
=
",
\n
"
new_line
=
new_line
[:
-
1
]
new_line
=
' "ck::tensor_operation::device::'
+
new_line
+
'",'
for
key
in
aliases
:
new_line
=
new_line
.
replace
(
key
,
aliases
[
key
])
instances_list
.
append
(
new_line
)
instances_list
[
-
1
]
=
instances_list
[
-
1
][:
-
1
]
if
is_row_row
:
row_row_instances
=
instances_list
if
is_col_row
:
col_row_instances
=
instances_list
if
is_row_col
:
row_col_instances
=
instances_list
if
is_col_col
:
col_col_instances
=
instances_list
out_file_name
=
op_name
+
"_instances.hpp"
if
not
os
.
path
.
exists
(
out_dir
):
os
.
mkdir
(
out_dir
)
int8_file
=
"/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
int8_instances
=
get_int8_instances
(
source
,
int8_file
,
"DeviceGemmMultipleD_Xdl_CShuffle"
)
with
open
(
os
.
path
.
join
(
out_dir
,
out_file_name
),
"w+"
)
as
f
:
f
.
write
(
file_templates
.
get_device_gemm_multiple_d_file
(
op_name
,
col_row_name
,
"
\n
"
.
join
(
col_row_instances
),
col_col_name
,
"
\n
"
.
join
(
col_col_instances
),
row_row_name
,
"
\n
"
.
join
(
row_row_instances
),
row_col_name
,
"
\n
"
.
join
(
row_col_instances
),
int8_instances
[
"col_row_name"
],
"
\n
"
.
join
(
int8_instances
[
"col_row"
]),
int8_instances
[
"col_col_name"
],
"
\n
"
.
join
(
int8_instances
[
"col_col"
]),
int8_instances
[
"row_row_name"
],
"
\n
"
.
join
(
int8_instances
[
"row_row"
]),
int8_instances
[
"row_col_name"
],
"
\n
"
.
join
(
int8_instances
[
"row_col"
]),
include_header
))
def
parse_param_names
(
file
):
param_names
=
[]
for
line
in
file
:
if
bool
(
re
.
search
(
r
"\s*//#+"
,
line
)):
names
=
line
.
split
(
'|'
)
names
=
[
n
.
strip
()
for
n
in
names
]
if
not
param_names
:
param_names
=
[
""
]
*
len
(
names
)
param_names
=
[
a
+
b
for
a
,
b
in
zip
(
param_names
,
names
)]
elif
param_names
:
param_names
[
0
]
=
line
.
split
(
'<'
)[
0
].
strip
()
file
.
seek
(
0
)
return
param_names
[:
-
1
]
file
.
seek
(
0
)
return
param_names
[:
-
1
]
def
parse_device_batched_gemm_softmax_gemm_instances
(
source
,
out_dir
):
aliases
=
{
"Row"
:
"ck::tensor_layout::gemm::RowMajor"
,
"Col"
:
"ck::tensor_layout::gemm::ColumnMajor"
,
"F16"
:
"ck::half_t"
,
"F32"
:
"float"
,
"PassThrough"
:
"ck::tensor_operation::element_wise::PassThrough"
,
"Scale"
:
"ck::tensor_operation::element_wise::Scale"
,
"GemmPadded"
:
"ck::tensor_operation::device::GemmSpecialization::MNKOPadding"
,
"GemmDefault"
:
"ck::tensor_operation::device::GemmSpecialization::Default"
}
device_ops
=
{
"batched_gemm_softmax_gemm"
:
"DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
for
root_
,
dirs_
,
files_
in
os
.
walk
(
source
):
for
dir
in
dirs_
:
op_name
=
os
.
path
.
split
(
dir
)[
-
1
]
if
"permute"
in
op_name
or
op_name
not
in
device_ops
:
continue
for
root
,
dirs
,
files
in
os
.
walk
(
os
.
path
.
join
(
root_
,
dir
)):
for
file
in
files
:
if
not
file
.
endswith
(
".cpp"
):
continue
;
file_name
=
os
.
path
.
split
(
file
)[
-
1
]
instances_name
=
file_name
[:
-
4
]
instances_list
=
[]
template_name
=
device_ops
[
op_name
]
include_header
=
""
with
open
(
os
.
path
.
join
(
root
,
file
))
as
f
:
param_names
=
parse_param_names
(
f
)
# for i in range(len(param_names)):
# print(f"{i}: {param_names[i]}")
for
line
in
f
:
if
"impl"
in
line
:
include_header
=
line
.
replace
(
"#include
\"
"
,
""
).
replace
(
"
\"
"
,
""
).
replace
(
"
\n
"
,
""
)
elif
template_name
in
line
:
# Turn all whitespace into single spaces
new_line
=
" "
.
join
(
line
.
split
())
# Remove whitespace from S<*>
new_line
=
strip_sequences
(
new_line
)
new_line
=
remove_commas_and_brackets
(
new_line
)
last_char
=
"
\n
"
if
new_line
[
-
1
]
==
","
:
last_char
=
",
\n
"
new_line
=
new_line
[:
-
1
]
new_line
=
' "ck::tensor_operation::device::'
+
new_line
+
'",'
for
key
in
aliases
:
new_line
=
new_line
.
replace
(
key
,
aliases
[
key
])
masking
=
new_line
.
replace
(
"Masking"
,
"true"
)
no_masking
=
new_line
.
replace
(
"Masking"
,
"false"
)
instances_list
.
append
(
masking
)
instances_list
.
append
(
no_masking
)
out_file_name
=
op_name
+
"_instances.hpp"
if
not
os
.
path
.
exists
(
out_dir
):
os
.
mkdir
(
out_dir
)
with
open
(
os
.
path
.
join
(
out_dir
,
out_file_name
),
"w+"
)
as
f
:
f
.
write
(
file_templates
.
get_device_gemm_softmax_gemm_file
(
op_name
,
instances_name
,
"
\n
"
.
join
(
instances_list
),
include_header
))
def
run
(
args
):
parse_device_gemm_multiple_d_instances
(
args
[
0
],
args
[
1
])
parse_device_batched_gemm_softmax_gemm_instances
(
args
[
0
],
args
[
1
])
if
__name__
==
'__main__'
:
run
(
sys
.
argv
[
1
:])
\ No newline at end of file
test/CMakeLists.txt
View file @
546a764e
...
@@ -120,36 +120,40 @@ function(add_gtest_executable TEST_NAME)
...
@@ -120,36 +120,40 @@ function(add_gtest_executable TEST_NAME)
set
(
result
${
result
}
PARENT_SCOPE
)
set
(
result
${
result
}
PARENT_SCOPE
)
endfunction
()
endfunction
()
add_subdirectory
(
magic_number_division
)
if
(
CK_BUILD_JIT_LIB
)
add_subdirectory
(
space_filling_curve
)
add_subdirectory
(
jit_library
)
add_subdirectory
(
conv_util
)
else
()
add_subdirectory
(
reference_conv_fwd
)
add_subdirectory
(
magic_number_division
)
add_subdirectory
(
gemm
)
add_subdirectory
(
space_filling_curve
)
add_subdirectory
(
gemm_layernorm
)
add_subdirectory
(
conv_util
)
add_subdirectory
(
gemm_split_k
)
add_subdirectory
(
reference_conv_fwd
)
add_subdirectory
(
gemm_reduce
)
add_subdirectory
(
gemm
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
gemm_layernorm
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
gemm_split_k
)
add_subdirectory
(
batched_gemm_gemm
)
add_subdirectory
(
gemm_reduce
)
add_subdirectory
(
batched_gemm_softmax_gemm
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
batched_gemm_softmax_gemm_permute
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
batched_gemm_gemm
)
add_subdirectory
(
reduce
)
add_subdirectory
(
batched_gemm_softmax_gemm
)
add_subdirectory
(
convnd_fwd
)
add_subdirectory
(
batched_gemm_softmax_gemm_permute
)
add_subdirectory
(
convnd_bwd_data
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
grouped_convnd_fwd
)
add_subdirectory
(
reduce
)
add_subdirectory
(
grouped_convnd_bwd_weight
)
add_subdirectory
(
convnd_fwd
)
add_subdirectory
(
block_to_ctile_map
)
add_subdirectory
(
convnd_bwd_data
)
add_subdirectory
(
softmax
)
add_subdirectory
(
grouped_convnd_fwd
)
add_subdirectory
(
normalization
)
add_subdirectory
(
grouped_convnd_bwd_weight
)
add_subdirectory
(
data_type
)
add_subdirectory
(
block_to_ctile_map
)
add_subdirectory
(
elementwise_normalization
)
add_subdirectory
(
softmax
)
add_subdirectory
(
batchnorm
)
add_subdirectory
(
normalization
)
add_subdirectory
(
contraction
)
add_subdirectory
(
data_type
)
add_subdirectory
(
pool
)
add_subdirectory
(
elementwise_normalization
)
add_subdirectory
(
batched_gemm_multi_d
)
add_subdirectory
(
batchnorm
)
add_subdirectory
(
grouped_convnd_bwd_data
)
add_subdirectory
(
contraction
)
add_subdirectory
(
conv_tensor_rearrange
)
add_subdirectory
(
pool
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
batched_gemm_multi_d
)
add_subdirectory
(
wmma_op
)
add_subdirectory
(
grouped_convnd_bwd_data
)
add_subdirectory
(
conv_tensor_rearrange
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
endif
()
endif
()
endif
()
test/jit_library/CMakeLists.txt
0 → 100644
View file @
546a764e
add_test_executable
(
test_jit_library jit_library.cpp
)
add_dependencies
(
test_jit_library jit_library
)
target_include_directories
(
test_jit_library PUBLIC $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../library/src/jit_library/include>
)
target_link_libraries
(
test_jit_library PRIVATE jit_library ck_headers
)
test/jit_library/jit_library.cpp
0 → 100644
View file @
546a764e
#include "ck/host/device_gemm_multiple_d.hpp"
#include <iostream>
bool
test_Problem
()
{
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
true
,
false
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
include_header
=
problem
.
GetIncludeHeader
();
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
const
auto
grid_size
=
solution
.
grid_size
;
const
auto
block_size
=
solution
.
block_size
;
bool
pass
=
true
;
pass
&=
include_header
==
"ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
;
pass
&=
solutions
.
size
()
==
42
;
pass
&=
template_str
==
"ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< "
"ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, "
"ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, "
"ck::half_t, ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, "
"8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, "
"8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, "
"1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>"
;
pass
&=
grid_size
==
2
;
pass
&=
block_size
==
256
;
return
pass
;
}
bool
test_GetGemmSpec
()
{
bool
pass
=
true
;
{
// PadMNK
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
255
,
255
,
255
,
false
,
true
,
false
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
pass
&=
template_str
.
find
(
"GemmSpecialization::MNKPadding"
)
!=
std
::
string
::
npos
;
}
{
// Default
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
true
,
false
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
pass
&=
template_str
.
find
(
"GemmSpecialization::Default"
)
!=
std
::
string
::
npos
;
}
return
pass
;
}
bool
test_GetInstances
()
{
bool
pass
=
true
;
{
// Col Col Fp16
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
true
,
true
,
false
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
51
;
}
{
// Col Row Fp16
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
true
,
false
,
false
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
51
;
}
{
// Row Col Fp16
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
true
,
false
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
42
;
}
{
// Row Row Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
false
,
false
,
{},
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Half
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
}
{
// Col Col Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
true
,
true
,
false
,
{},
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Half
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
}
{
// Col Row Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
true
,
false
,
false
,
{},
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Half
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
}
{
// Row Col Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
true
,
false
,
{},
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Half
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
39
;
}
{
// Row Row Int8
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
false
,
false
,
{},
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Int8
,
ck
::
host
::
DataType
::
Half
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
pass
&=
problem
.
GetSolutions
(
"gfx90a"
).
size
()
==
48
;
}
return
pass
;
}
bool
test_MakeLayoutsTuple
()
{
bool
pass
=
true
;
{
// Empty Tuple
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
false
,
false
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{
ck
::
host
::
DataType
::
Half
},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
pass
&=
template_str
.
find
(
"ck::Tuple<>"
)
!=
std
::
string
::
npos
;
}
{
// RowColRow Tuple
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
false
,
false
,
{
false
,
true
,
false
},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{
ck
::
host
::
DataType
::
Half
},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
pass
&=
template_str
.
find
(
"ck::Tuple<ck::tensor_layout::gemm::RowMajor, "
"ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor>"
)
!=
std
::
string
::
npos
;
}
return
pass
;
}
bool
test_MakeTypeTuple
()
{
bool
pass
=
true
;
{
// Empty Tuple
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
false
,
false
,
{
true
},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
pass
&=
template_str
.
find
(
"ck::Tuple<>"
)
!=
std
::
string
::
npos
;
}
{
// Half Int8 Tuple
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
256
,
256
,
256
,
false
,
false
,
false
,
{},
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Half
,
{
ck
::
host
::
DataType
::
Half
,
ck
::
host
::
DataType
::
Int8
},
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
,
"ck::tensor_operation::element_wise::Passthrough"
};
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
&
solution
=
solutions
.
at
(
0
);
const
auto
template_str
=
solution
.
template_str
;
pass
&=
template_str
.
find
(
"ck::Tuple<ck::half_t, int8_t>"
)
!=
std
::
string
::
npos
;
}
return
pass
;
}
int
main
()
{
bool
pass
=
true
;
pass
&=
test_Problem
();
pass
&=
test_GetGemmSpec
();
pass
&=
test_GetInstances
();
pass
&=
test_MakeLayoutsTuple
();
pass
&=
test_MakeTypeTuple
();
if
(
pass
)
{
std
::
cout
<<
"Test jit library: Pass"
<<
std
::
endl
;
return
0
;
}
else
{
std
::
cout
<<
"Test jit library: Fail"
<<
std
::
endl
;
return
-
1
;
}
}
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment