Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fcca3307
Commit
fcca3307
authored
Jun 01, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/migx-jit-lib2' into migx-jit-lib
parents
7295e38d
9bf51c4c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
311 additions
and
252 deletions
+311
-252
cmake/Embed.cmake
cmake/Embed.cmake
+4
-4
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+0
-1
library/src/jit_library/CMakeLists.txt
library/src/jit_library/CMakeLists.txt
+15
-17
library/src/jit_library/include/ck/host/common.hpp
library/src/jit_library/include/ck/host/common.hpp
+34
-0
library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp
...rc/jit_library/include/ck/host/device_gemm_multiple_d.hpp
+60
-0
library/src/jit_library/include/device_gemm_multiple_d.hpp
library/src/jit_library/include/device_gemm_multiple_d.hpp
+0
-219
library/src/jit_library/src/common.cpp
library/src/jit_library/src/common.cpp
+30
-0
library/src/jit_library/src/device_gemm_multiple_d.cpp
library/src/jit_library/src/device_gemm_multiple_d.cpp
+160
-0
library/src/jit_library/util/make_instance_strings.py
library/src/jit_library/util/make_instance_strings.py
+8
-11
No files found.
cmake/Embed.cmake
View file @
fcca3307
...
...
@@ -27,7 +27,7 @@ find_program(EMBED_OBJCOPY objcopy)
function
(
generate_embed_source EMBED_NAME
)
set
(
options
)
set
(
oneValueArgs SRC HEADER RELATIVE
)
set
(
multiValueArgs OBJECTS SYMBOLS
)
set
(
multiValueArgs OBJECTS SYMBOLS
FILES
)
cmake_parse_arguments
(
PARSE
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
...
...
@@ -44,6 +44,7 @@ function(generate_embed_source EMBED_NAME)
foreach
(
idx RANGE
${
LEN
}
)
list
(
GET PARSE_SYMBOLS
${
idx
}
SYMBOL
)
list
(
GET PARSE_OBJECTS
${
idx
}
OBJECT
)
list
(
GET PARSE_FILES
${
idx
}
FILE
)
set
(
START_SYMBOL
"_binary_
${
SYMBOL
}
_start"
)
set
(
END_SYMBOL
"_binary_
${
SYMBOL
}
_end"
)
string
(
APPEND EXTERNS
"
...
...
@@ -52,8 +53,7 @@ function(generate_embed_source EMBED_NAME)
"
)
file
(
RELATIVE_PATH BASE_NAME
${
PARSE_RELATIVE
}
"
${
OBJECT
}
"
)
string
(
REGEX REPLACE
".[A-Za-z0-9_]$"
""
BASE_NAME
${
BASE_NAME
}
)
file
(
RELATIVE_PATH BASE_NAME
${
PARSE_RELATIVE
}
"
${
FILE
}
"
)
string
(
APPEND INIT_KERNELS
"
{
\"
${
BASE_NAME
}
\"
, {
${
START_SYMBOL
}
,
${
END_SYMBOL
}
} },
...
...
@@ -121,7 +121,7 @@ function(add_embed_library EMBED_NAME)
list
(
APPEND SYMBOLS
${
OUTPUT_SYMBOL
}
)
endforeach
()
message
(
STATUS
"Generating embedding library
${
EMBED_NAME
}
"
)
generate_embed_source
(
${
EMBED_NAME
}
SRC
${
SRC_FILE
}
HEADER
${
HEADER_FILE
}
OBJECTS
${
OUTPUT_FILES
}
SYMBOLS
${
SYMBOLS
}
RELATIVE
${
PARSE_RELATIVE
}
)
generate_embed_source
(
${
EMBED_NAME
}
SRC
${
SRC_FILE
}
HEADER
${
HEADER_FILE
}
OBJECTS
${
OUTPUT_FILES
}
SYMBOLS
${
SYMBOLS
}
RELATIVE
${
PARSE_RELATIVE
}
FILES
${
PARSE_UNPARSED_ARGUMENTS
}
)
add_library
(
${
EMBED_NAME
}
STATIC
${
OUTPUT_FILES
}
"
${
SRC_FILE
}
"
)
target_include_directories
(
${
EMBED_NAME
}
PUBLIC
"$<BUILD_INTERFACE:
${
EMBED_DIR
}
/include>"
)
target_compile_options
(
${
EMBED_NAME
}
PRIVATE -Wno-reserved-identifier
)
...
...
cmake/EnableCompilerWarnings.cmake
View file @
fcca3307
...
...
@@ -66,7 +66,6 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
-Wsign-compare
-Wno-extra-semi-stmt
)
...
...
library/src/jit_library/CMakeLists.txt
View file @
fcca3307
include
(
Embed
)
file
(
GLOB_RECURSE KERNEL_FILES
${
CONFIGURE_DEPENDS
}
file
(
GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
${
PROJECT_SOURCE_DIR
}
/include/ck/*.hpp
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
add_embed_library
(
ck_headers
${
KERNEL_FILES
}
RELATIVE
${
PROJECT_SOURCE_DIR
}
/build/include
)
message
(
STATUS
"RELATIVE:
${
PROJECT_SOURCE_DIR
}
/include"
)
add_embed_library
(
ck_headers
${
KERNEL_FILES
}
RELATIVE
${
PROJECT_SOURCE_DIR
}
/include
)
execute_process
(
COMMAND python3
${
CMAKE_CURRENT_SOURCE_DIR
}
/util/make_instance_strings.py
COMMAND python3
${
CMAKE_CURRENT_SOURCE_DIR
}
/util/make_instance_strings.py
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu
${
CMAKE_CURRENT_BINARY_DIR
}
/solution_instances
WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
/../tensor_operation_instance/gpu/
)
set
(
JIT_LIB_SOURCE
${
CMAKE_CURRENT_SOURCE_DIR
}
/include/device_gemm_multiple_d.h
pp
add_library
(
jit_library STATIC
src/device_gemm_multiple_d.cpp
src/common.c
pp
)
add_library
(
jit_library STATIC
${
JIT_LIB_SOURCE
}
)
add_library
(
composable_kernel::jit_library ALIAS jit_library
)
set_target_properties
(
jit_library PROPERTIES LINKER_LANGUAGE CXX
)
target_include_directories
(
jit_library P
UBLIC
target_include_directories
(
jit_library P
RIVATE
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
$<BUILD_INTERFACE:
${
PROJECT_SOURCE_DIR
}
/library/src/jit_library/solution_instances>
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_BINARY_DIR
}
/solution_instances>
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_BINARY_DIR
}
/embed/ck_headers/include>
)
target_link_libraries
(
jit_library PRIVATE ck_headers
)
...
...
@@ -30,14 +34,8 @@ rocm_install(
EXPORT jit_libraryTargets
)
set
(
INCLUDE_DIRS
${
PROJECT_SOURCE_DIR
}
/include/ck/
${
PROJECT_SOURCE_DIR
}
/library/src/jit_library/include
${
PROJECT_SOURCE_DIR
}
/library/src/jit_library/solution_instances
${
CMAKE_CURRENT_BINARY_DIR
}
/embed/ck_headers/include
)
rocm_install
(
DIRECTORY
${
INCLUDE_DIRS
}
DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck
)
rocm_install
(
DIRECTORY include/ck DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
)
rocm_install
(
DIRECTORY
${
PROJECT_SOURCE_DIR
}
/include/ck DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
)
rocm_install
(
EXPORT jit_libraryTargets
...
...
library/src/jit_library/include/ck/host/common.hpp
0 → 100644
View file @
fcca3307
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <utility>
#include <unordered_map>
namespace
ck
{
namespace
host
{
struct
Solution
{
std
::
string
template_str
;
std
::
size_t
block_size
;
std
::
size_t
grid_size
;
};
enum
class
DataType
{
Half
,
Float
,
Int8
,
Int32
};
std
::
string
ToString
(
DataType
dt
);
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
();
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
);
}
// namespace host
}
// namespace ck
library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp
0 → 100644
View file @
fcca3307
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/common.hpp"
namespace
ck
{
namespace
host
{
namespace
device_gemm_multiple_d
{
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransE
=
false
;
std
::
vector
<
bool
>
DsTrans
=
{};
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
EDataType
=
DataType
::
Half
;
std
::
vector
<
DataType
>
DsDataType
=
{};
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CDEElementOp
=
"ck::Tuple<>"
;
static
const
std
::
size_t
ds_layout_idx
=
3
;
static
const
std
::
size_t
ds_data_type_idx
=
9
;
static
const
std
::
size_t
e_data_type_idx
=
10
;
static
const
std
::
size_t
a_elementwise_op_idx
=
11
;
static
const
std
::
size_t
b_elementwise_op_idx
=
12
;
static
const
std
::
size_t
ds_elementwise_op_idx
=
13
;
static
const
std
::
size_t
gemm_spec_idx
=
14
;
static
const
std
::
size_t
block_size_idx
=
16
;
static
const
std
::
size_t
m_per_block_idx
=
17
;
static
const
std
::
size_t
n_per_block_idx
=
18
;
static
const
std
::
size_t
k_per_block_idx
=
19
;
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
private:
std
::
vector
<
std
::
string
>
GetInstances
(
const
std
::
string
&
arch
)
const
;
Solution
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
;
};
}
// namespace device_gemm_multiple_d
}
// namespace host
}
// namespace ck
library/src/jit_library/include/device_gemm_multiple_d.hpp
deleted
100644 → 0
View file @
7295e38d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/solution_instances/gemm_add_add_fastgelu_instances.hpp"
#include "ck/ck.hpp"
#include "ck/utility/math.hpp"
#include "ck_headers.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_multiple_d
{
struct
Solution
{
std
::
string
template_str
;
index_t
block_size
;
index_t
grid_size
;
};
std
::
string
GetGemmSpec
(
const
index_t
m
,
const
index_t
n
,
const
index_t
k
,
const
index_t
m_per_block
,
const
index_t
n_per_block
,
const
index_t
k_per_block
)
{
std
::
string
spec
=
""
;
if
(
math
::
integer_divide_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
spec
+=
"M"
;
if
(
math
::
integer_divide_ceil
(
n
,
n_per_block
)
*
n_per_block
-
n
!=
0
)
spec
+=
"N"
;
if
(
math
::
integer_divide_ceil
(
k
,
k_per_block
)
*
k_per_block
-
k
!=
0
)
spec
+=
"K"
;
if
(
spec
==
""
)
return
"ck::tensor_operation::device::GemmSpecialization::Default"
;
return
"ck::tensor_operation::device::GemmSpecialization::"
+
spec
+
"Padding"
;
}
index_t
GetGridSize
(
const
index_t
m
,
const
index_t
n
,
const
index_t
m_per_block
,
const
index_t
n_per_block
)
{
return
math
::
integer_divide_ceil
(
m
,
m_per_block
)
*
math
::
integer_divide_ceil
(
n
,
n_per_block
);
}
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
()
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx90a"
};
return
supported_archs
;
}
struct
Problem
{
index_t
M
=
0
;
index_t
N
=
0
;
index_t
K
=
0
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransE
=
false
;
std
::
vector
<
bool
>
DsLayout
=
{};
std
::
string
ADataType
=
"ck::half_t"
;
std
::
string
BDataType
=
"ck::half_t"
;
std
::
string
EDataType
=
"ck::half_t"
;
std
::
vector
<
std
::
string
>
DsDataType
=
{};
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CDEElementOp
=
"ck::Tuple<>"
;
static
const
index_t
ds_layout_idx
=
3
;
static
const
index_t
ds_data_type_idx
=
9
;
static
const
index_t
e_data_type_idx
=
10
;
static
const
index_t
a_elementwise_op_idx
=
11
;
static
const
index_t
b_elementwise_op_idx
=
12
;
static
const
index_t
ds_elementwise_op_idx
=
13
;
static
const
index_t
gemm_spec_idx
=
14
;
static
const
index_t
block_size_idx
=
16
;
static
const
index_t
m_per_block_idx
=
17
;
static
const
index_t
n_per_block_idx
=
18
;
static
const
index_t
k_per_block_idx
=
19
;
private:
auto
GetInstances
(
const
std
::
string
&
arch
)
const
{
std
::
vector
<
std
::
string
>
instances
;
const
bool
quantize
=
ADataType
==
"int8_t"
and
BDataType
==
"int8_t"
;
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
{
instance
::
gemm_add_add_fastgelu_instances
all_instances
{};
if
(
TransA
and
TransB
)
instances
=
all_instances
.
get_col_col_instances
(
quantize
);
else
if
(
TransA
and
not
TransB
)
instances
=
all_instances
.
get_col_row_instances
(
quantize
);
else
if
(
not
TransA
and
not
TransB
)
instances
=
all_instances
.
get_row_row_instances
(
quantize
);
else
instances
=
all_instances
.
get_row_col_instances
(
quantize
);
}
return
instances
;
}
auto
MakeLayoutTuple
(
const
std
::
vector
<
bool
>&
layouts
)
const
{
std
::
string
layout_tuple
=
"ck::Tuple<"
;
auto
it
=
layouts
.
begin
();
while
(
it
!=
layouts
.
end
())
{
layout_tuple
+=
*
it
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
it
=
std
::
next
(
it
);
if
(
it
!=
layouts
.
end
())
layout_tuple
+=
", "
;
}
return
layout_tuple
+
">"
;
}
auto
MakeTypeTuple
(
const
std
::
vector
<
std
::
string
>&
types
)
const
{
std
::
string
type_tuple
=
"ck::Tuple<"
;
auto
it
=
types
.
begin
();
while
(
it
!=
types
.
end
())
{
type_tuple
+=
*
it
;
it
=
std
::
next
(
it
);
if
(
it
!=
types
.
end
())
type_tuple
+=
", "
;
}
return
type_tuple
+
">"
;
}
auto
MakeSolution
(
index_t
idx
,
const
std
::
string
&
arch
)
const
{
auto
template_str
=
GetInstances
(
arch
).
at
(
idx
);
std
::
istringstream
iss
(
template_str
);
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
istream_iterator
<
std
::
string
>
());
if
(
ADataType
==
"int8_t"
and
BDataType
==
"int8_t"
)
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
"ck::half_t"
;
})
or
EDataType
==
"ck::half_t"
)
{
params
[
params
.
size
()
-
3
]
=
"8"
;
}
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
"float"
;
})
or
EDataType
==
"float"
)
{
params
[
params
.
size
()
-
3
]
=
"4"
;
}
}
params
[
a_elementwise_op_idx
]
=
AElementOp
;
params
[
b_elementwise_op_idx
]
=
BElementOp
;
params
[
ds_layout_idx
]
=
MakeLayoutTuple
(
DsLayout
);
params
[
ds_data_type_idx
]
=
MakeTypeTuple
(
DsDataType
);
params
[
ds_elementwise_op_idx
]
=
CDEElementOp
;
params
[
e_data_type_idx
]
=
EDataType
;
auto
block_size_str
=
params
[
block_size_idx
];
auto
m_per_block_str
=
params
[
m_per_block_idx
];
auto
n_per_block_str
=
params
[
n_per_block_idx
];
auto
k_per_block_str
=
params
[
k_per_block_idx
];
const
auto
block_size
=
std
::
stoi
(
block_size_str
);
const
auto
m_per_block
=
std
::
stoi
(
m_per_block_str
);
const
auto
n_per_block
=
std
::
stoi
(
n_per_block_str
);
const
auto
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
auto
grid_size
=
GetGridSize
(
M
,
N
,
m_per_block
,
n_per_block
);
params
[
gemm_spec_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
m_per_block
,
n_per_block
,
k_per_block
);
std
::
string
str
=
std
::
accumulate
(
params
.
begin
()
+
1
,
params
.
end
(),
std
::
string
{},
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
return
a
.
empty
()
?
b
:
a
+
", "
+
b
;
});
str
=
params
.
front
()
+
"< "
+
str
+
">"
;
return
Solution
{
str
,
block_size
,
grid_size
};
}
public:
auto
GetHeaders
()
const
{
return
ck_headers
();
}
auto
GetIncludeHeader
()
const
{
return
instance
::
gemm_add_add_fastgelu_instances
{}.
get_include_header
();
}
auto
GetSolutions
(
const
std
::
string
&
arch
)
const
{
std
::
vector
<
Solution
>
solutions
;
const
auto
num_instances
=
GetInstances
(
arch
).
size
();
for
(
auto
i
=
0
;
i
<
num_instances
;
++
i
)
{
solutions
.
push_back
(
MakeSolution
(
i
,
arch
));
}
return
solutions
;
}
};
}
// namespace device_gemm_multiple_d
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/jit_library/src/common.cpp
0 → 100644
View file @
fcca3307
#include "ck/host/common.hpp"
#include "ck_headers.hpp"
namespace
ck
{
namespace
host
{
std
::
string
ToString
(
DataType
dt
)
{
switch
(
dt
)
{
case
DataType
::
Float
:
return
"float"
;
case
DataType
::
Half
:
return
"ck::half_t"
;
case
DataType
::
Int8
:
return
"int8_t"
;
case
DataType
::
Int32
:
return
"int32_t"
;
}
throw
std
::
runtime_error
(
"Incorrect data type"
);
}
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
()
{
return
ck_headers
();
}
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
)
{
return
(
x
+
y
-
std
::
size_t
{
1
})
/
y
;
}
}
// namespace host
}
// namespace ck
library/src/jit_library/src/device_gemm_multiple_d.cpp
0 → 100644
View file @
fcca3307
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/common.hpp"
#include "gemm_add_add_fastgelu_instances.hpp"
#include <algorithm>
#include <unordered_set>
namespace
ck
{
namespace
host
{
namespace
device_gemm_multiple_d
{
std
::
string
GetGemmSpec
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
k
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
,
const
std
::
size_t
k_per_block
)
{
std
::
string
spec
=
""
;
if
(
integer_divide_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
spec
+=
"M"
;
if
(
integer_divide_ceil
(
n
,
n_per_block
)
*
n_per_block
-
n
!=
0
)
spec
+=
"N"
;
if
(
integer_divide_ceil
(
k
,
k_per_block
)
*
k_per_block
-
k
!=
0
)
spec
+=
"K"
;
if
(
spec
==
""
)
return
"ck::tensor_operation::device::GemmSpecialization::Default"
;
return
"ck::tensor_operation::device::GemmSpecialization::"
+
spec
+
"Padding"
;
}
std
::
size_t
GetGridSize
(
const
std
::
size_t
m
,
const
std
::
size_t
n
,
const
std
::
size_t
m_per_block
,
const
std
::
size_t
n_per_block
)
{
return
integer_divide_ceil
(
m
,
m_per_block
)
*
integer_divide_ceil
(
n
,
n_per_block
);
}
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
()
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx90a"
};
return
supported_archs
;
}
std
::
vector
<
std
::
string
>
Problem
::
GetInstances
(
const
std
::
string
&
arch
)
const
{
std
::
vector
<
std
::
string
>
instances
;
const
bool
quantize
=
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
;
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
{
instance
::
gemm_add_add_fastgelu_instances
all_instances
{};
if
(
TransA
and
TransB
)
instances
=
all_instances
.
get_col_col_instances
(
quantize
);
else
if
(
TransA
and
not
TransB
)
instances
=
all_instances
.
get_col_row_instances
(
quantize
);
else
if
(
not
TransA
and
not
TransB
)
instances
=
all_instances
.
get_row_row_instances
(
quantize
);
else
instances
=
all_instances
.
get_row_col_instances
(
quantize
);
}
return
instances
;
}
std
::
string
MakeLayoutTuple
(
const
std
::
vector
<
bool
>&
layouts
)
{
std
::
string
layout_tuple
=
"ck::Tuple<"
;
auto
it
=
layouts
.
begin
();
while
(
it
!=
layouts
.
end
())
{
layout_tuple
+=
*
it
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
it
=
std
::
next
(
it
);
if
(
it
!=
layouts
.
end
())
layout_tuple
+=
", "
;
}
return
layout_tuple
+
">"
;
}
std
::
string
MakeTypeTuple
(
const
std
::
vector
<
DataType
>&
types
)
{
std
::
string
type_tuple
=
"ck::Tuple<"
;
auto
it
=
types
.
begin
();
while
(
it
!=
types
.
end
())
{
type_tuple
+=
ToString
(
*
it
);
it
=
std
::
next
(
it
);
if
(
it
!=
types
.
end
())
type_tuple
+=
", "
;
}
return
type_tuple
+
">"
;
}
Solution
Problem
::
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
{
auto
template_str
=
GetInstances
(
arch
).
at
(
idx
);
std
::
istringstream
iss
(
template_str
);
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
istream_iterator
<
std
::
string
>
());
if
(
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
)
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Half
;
}))
{
params
[
params
.
size
()
-
3
]
=
"8"
;
}
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
Float
;
}))
{
params
[
params
.
size
()
-
3
]
=
"4"
;
}
}
params
[
a_elementwise_op_idx
]
=
AElementOp
;
params
[
b_elementwise_op_idx
]
=
BElementOp
;
params
[
ds_layout_idx
]
=
MakeLayoutTuple
(
DsTrans
);
params
[
ds_data_type_idx
]
=
MakeTypeTuple
(
DsDataType
);
params
[
ds_elementwise_op_idx
]
=
CDEElementOp
;
params
[
e_data_type_idx
]
=
ToString
(
EDataType
);
auto
block_size_str
=
params
[
block_size_idx
];
auto
m_per_block_str
=
params
[
m_per_block_idx
];
auto
n_per_block_str
=
params
[
n_per_block_idx
];
auto
k_per_block_str
=
params
[
k_per_block_idx
];
const
std
::
size_t
block_size
=
std
::
stoi
(
block_size_str
);
const
std
::
size_t
m_per_block
=
std
::
stoi
(
m_per_block_str
);
const
std
::
size_t
n_per_block
=
std
::
stoi
(
n_per_block_str
);
const
std
::
size_t
k_per_block
=
std
::
stoi
(
k_per_block_str
);
const
std
::
size_t
grid_size
=
GetGridSize
(
M
,
N
,
m_per_block
,
n_per_block
);
params
[
gemm_spec_idx
]
=
GetGemmSpec
(
M
,
N
,
K
,
m_per_block
,
n_per_block
,
k_per_block
);
std
::
string
str
=
std
::
accumulate
(
params
.
begin
()
+
1
,
params
.
end
(),
std
::
string
{},
[](
const
std
::
string
&
a
,
const
std
::
string
&
b
)
{
return
a
.
empty
()
?
b
:
a
+
", "
+
b
;
});
str
=
params
.
front
()
+
"< "
+
str
+
">"
;
return
Solution
{
str
,
block_size
,
grid_size
};
}
std
::
string
Problem
::
GetIncludeHeader
()
const
{
return
instance
::
gemm_add_add_fastgelu_instances
{}.
get_include_header
();
}
std
::
vector
<
Solution
>
Problem
::
GetSolutions
(
const
std
::
string
&
arch
)
const
{
std
::
vector
<
Solution
>
solutions
;
const
std
::
size_t
num_instances
=
GetInstances
(
arch
).
size
();
for
(
std
::
size_t
i
=
0
;
i
<
num_instances
;
++
i
)
{
solutions
.
push_back
(
MakeSolution
(
i
,
arch
));
}
return
solutions
;
}
}
// namespace device_gemm_multiple_d
}
// namespace host
}
// namespace ck
library/src/jit_library/util/make_instance_strings.py
View file @
fcca3307
import
argparse
,
re
,
json
,
os
import
argparse
,
re
,
json
,
os
,
sys
out_file
=
"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
...
...
@@ -10,8 +10,7 @@ out_file = """// SPDX-License-Identifier: MIT
#include <memory>
namespace ck {{
namespace tensor_operation {{
namespace device {{
namespace host {{
namespace instance {{
struct {op_name}_instances
...
...
@@ -87,8 +86,7 @@ struct {op_name}_instances
}};
}} // namespace instance
}} // namespace device
}} // namespace tensor_operation
}} // namespace host
}} // namespace ck
"""
...
...
@@ -172,8 +170,7 @@ def get_int8_instances(src, file, template_name):
instances
[
"col_row"
][
-
1
]
=
instances
[
"col_row"
][
-
1
][:
-
1
]
return
instances
def
parse_instances
(
source
):
out_dir
=
os
.
path
.
join
(
source
,
"../../../src/jit_library/solution_instances"
)
def
parse_instances
(
source
,
out_dir
):
aliases
=
{
"F16_F16_Tuple"
:
"ck::Tuple<F16,F16>"
,
"Row_Row_Tuple"
:
"ck::Tuple<Row,Row>"
,
"Empty_Tuple"
:
"ck::Tuple<>"
,
...
...
@@ -273,9 +270,9 @@ def parse_instances(source):
int8_row_col_instances
=
"
\n
"
.
join
(
int8_instances
[
"row_col"
]),
include_header
=
include_header
))
def
run
():
source
=
"/code/composable_kernel/library/src/tensor_operation_instance/gpu"
parse_instances
(
source
)
def
run
(
args
):
parse_instances
(
args
[
0
],
args
[
1
])
if
__name__
==
'__main__'
:
run
()
\ No newline at end of file
run
(
sys
.
argv
[
1
:])
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment