Unverified Commit 59e2dc29 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Updates to ck host library API (#731)

* Move functions to cpp file

* Move another function to cpp file

* Fix semicolon

* Move solution to common.hpp

* Fix compile errors

* Use enum for data types

* Remove -Werror

* Fix header install

* Fix relative path

* Fix header path

* Install all headers
parent 61386bf9
...@@ -27,7 +27,7 @@ find_program(EMBED_OBJCOPY objcopy) ...@@ -27,7 +27,7 @@ find_program(EMBED_OBJCOPY objcopy)
function(generate_embed_source EMBED_NAME) function(generate_embed_source EMBED_NAME)
set(options) set(options)
set(oneValueArgs SRC HEADER RELATIVE) set(oneValueArgs SRC HEADER RELATIVE)
set(multiValueArgs OBJECTS SYMBOLS) set(multiValueArgs OBJECTS SYMBOLS FILES)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
...@@ -44,6 +44,7 @@ function(generate_embed_source EMBED_NAME) ...@@ -44,6 +44,7 @@ function(generate_embed_source EMBED_NAME)
foreach(idx RANGE ${LEN}) foreach(idx RANGE ${LEN})
list(GET PARSE_SYMBOLS ${idx} SYMBOL) list(GET PARSE_SYMBOLS ${idx} SYMBOL)
list(GET PARSE_OBJECTS ${idx} OBJECT) list(GET PARSE_OBJECTS ${idx} OBJECT)
list(GET PARSE_FILES ${idx} FILE)
set(START_SYMBOL "_binary_${SYMBOL}_start") set(START_SYMBOL "_binary_${SYMBOL}_start")
set(END_SYMBOL "_binary_${SYMBOL}_end") set(END_SYMBOL "_binary_${SYMBOL}_end")
string(APPEND EXTERNS " string(APPEND EXTERNS "
...@@ -52,8 +53,7 @@ function(generate_embed_source EMBED_NAME) ...@@ -52,8 +53,7 @@ function(generate_embed_source EMBED_NAME)
") ")
file(RELATIVE_PATH BASE_NAME ${PARSE_RELATIVE} "${OBJECT}") file(RELATIVE_PATH BASE_NAME ${PARSE_RELATIVE} "${FILE}")
string(REGEX REPLACE ".[A-Za-z0-9_]$" "" BASE_NAME ${BASE_NAME})
string(APPEND INIT_KERNELS " string(APPEND INIT_KERNELS "
{ \"${BASE_NAME}\", { ${START_SYMBOL}, ${END_SYMBOL}} }, { \"${BASE_NAME}\", { ${START_SYMBOL}, ${END_SYMBOL}} },
...@@ -121,7 +121,7 @@ function(add_embed_library EMBED_NAME) ...@@ -121,7 +121,7 @@ function(add_embed_library EMBED_NAME)
list(APPEND SYMBOLS ${OUTPUT_SYMBOL}) list(APPEND SYMBOLS ${OUTPUT_SYMBOL})
endforeach() endforeach()
message(STATUS "Generating embedding library ${EMBED_NAME}") message(STATUS "Generating embedding library ${EMBED_NAME}")
generate_embed_source(${EMBED_NAME} SRC ${SRC_FILE} HEADER ${HEADER_FILE} OBJECTS ${OUTPUT_FILES} SYMBOLS ${SYMBOLS} RELATIVE ${PARSE_RELATIVE}) generate_embed_source(${EMBED_NAME} SRC ${SRC_FILE} HEADER ${HEADER_FILE} OBJECTS ${OUTPUT_FILES} SYMBOLS ${SYMBOLS} RELATIVE ${PARSE_RELATIVE} FILES ${PARSE_UNPARSED_ARGUMENTS})
add_library(${EMBED_NAME} STATIC ${OUTPUT_FILES} "${SRC_FILE}") add_library(${EMBED_NAME} STATIC ${OUTPUT_FILES} "${SRC_FILE}")
target_include_directories(${EMBED_NAME} PUBLIC "$<BUILD_INTERFACE:${EMBED_DIR}/include>") target_include_directories(${EMBED_NAME} PUBLIC "$<BUILD_INTERFACE:${EMBED_DIR}/include>")
target_compile_options(${EMBED_NAME} PRIVATE -Wno-reserved-identifier) target_compile_options(${EMBED_NAME} PRIVATE -Wno-reserved-identifier)
......
...@@ -66,7 +66,6 @@ else() ...@@ -66,7 +66,6 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-Werror
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
) )
......
include(Embed) include(Embed)
file(GLOB_RECURSE KERNEL_FILES ${CONFIGURE_DEPENDS} file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
${PROJECT_SOURCE_DIR}/include/ck/*.hpp) ${PROJECT_SOURCE_DIR}/include/ck/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${PROJECT_SOURCE_DIR}/build/include) message(STATUS "RELATIVE: ${PROJECT_SOURCE_DIR}/include")
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${PROJECT_SOURCE_DIR}/include)
execute_process( execute_process(
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/util/make_instance_strings.py COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/util/make_instance_strings.py
${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
${CMAKE_CURRENT_BINARY_DIR}/solution_instances
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../tensor_operation_instance/gpu/ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../tensor_operation_instance/gpu/
) )
add_library(jit_library STATIC
set(JIT_LIB_SOURCE src/device_gemm_multiple_d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/device_gemm_multiple_d.hpp src/common.cpp
) )
add_library(jit_library STATIC ${JIT_LIB_SOURCE})
add_library(composable_kernel::jit_library ALIAS jit_library) add_library(composable_kernel::jit_library ALIAS jit_library)
set_target_properties(jit_library PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(jit_library PROPERTIES LINKER_LANGUAGE CXX)
target_include_directories(jit_library PUBLIC target_include_directories(jit_library PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include> $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/library/src/jit_library/solution_instances>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/solution_instances>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/embed/ck_headers/include>
) )
target_link_libraries(jit_library PRIVATE ck_headers) target_link_libraries(jit_library PRIVATE ck_headers)
...@@ -30,14 +34,8 @@ rocm_install( ...@@ -30,14 +34,8 @@ rocm_install(
EXPORT jit_libraryTargets EXPORT jit_libraryTargets
) )
set(INCLUDE_DIRS rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
${PROJECT_SOURCE_DIR}/include/ck/ rocm_install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
${PROJECT_SOURCE_DIR}/library/src/jit_library/include
${PROJECT_SOURCE_DIR}/library/src/jit_library/solution_instances
${CMAKE_CURRENT_BINARY_DIR}/embed/ck_headers/include
)
rocm_install(DIRECTORY ${INCLUDE_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck)
rocm_install( rocm_install(
EXPORT jit_libraryTargets EXPORT jit_libraryTargets
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <utility>
#include <unordered_map>
namespace ck {
namespace host {
struct Solution
{
std::string template_str;
std::size_t block_size;
std::size_t grid_size;
};
enum class DataType {
Half,
Float,
Int8,
Int32
};
std::string ToString(DataType dt);
std::unordered_map<std::string, std::pair<const char*,const char*>> GetHeaders();
std::size_t integer_divide_ceil(std::size_t x, std::size_t y);
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/common.hpp"
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
bool TransA = false;
bool TransB = false;
bool TransE = false;
std::vector<bool> DsTrans = {};
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType EDataType = DataType::Half;
std::vector<DataType> DsDataType = {};
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CDEElementOp = "ck::Tuple<>";
static const std::size_t ds_layout_idx = 3;
static const std::size_t ds_data_type_idx = 9;
static const std::size_t e_data_type_idx = 10;
static const std::size_t a_elementwise_op_idx = 11;
static const std::size_t b_elementwise_op_idx = 12;
static const std::size_t ds_elementwise_op_idx = 13;
static const std::size_t gemm_spec_idx = 14;
static const std::size_t block_size_idx = 16;
static const std::size_t m_per_block_idx = 17;
static const std::size_t n_per_block_idx = 18;
static const std::size_t k_per_block_idx = 19;
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
private:
std::vector<std::string> GetInstances(const std::string& arch) const;
Solution MakeSolution(std::size_t idx, const std::string& arch) const;
};
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/solution_instances/gemm_add_add_fastgelu_instances.hpp"
#include "ck/ck.hpp"
#include "ck/utility/math.hpp"
#include "ck_headers.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_multiple_d {
struct Solution
{
std::string template_str;
index_t block_size;
index_t grid_size;
};
std::string GetGemmSpec(const index_t m,
const index_t n,
const index_t k,
const index_t m_per_block,
const index_t n_per_block,
const index_t k_per_block)
{
std::string spec = "";
if(math::integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
spec += "M";
if(math::integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
spec += "N";
if(math::integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
spec += "K";
if(spec == "")
return "ck::tensor_operation::device::GemmSpecialization::Default";
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
}
index_t GetGridSize(const index_t m,
const index_t n,
const index_t m_per_block,
const index_t n_per_block)
{
return math::integer_divide_ceil(m, m_per_block) *
math::integer_divide_ceil(n, n_per_block);
}
const std::unordered_set<std::string>& get_xdlop_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx90a"};
return supported_archs;
}
struct Problem
{
index_t M = 0;
index_t N = 0;
index_t K = 0;
bool TransA = false;
bool TransB = false;
bool TransE = false;
std::vector<bool> DsLayout = {};
std::string ADataType = "ck::half_t";
std::string BDataType = "ck::half_t";
std::string EDataType = "ck::half_t";
std::vector<std::string> DsDataType = {};
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CDEElementOp = "ck::Tuple<>";
static const index_t ds_layout_idx = 3;
static const index_t ds_data_type_idx = 9;
static const index_t e_data_type_idx = 10;
static const index_t a_elementwise_op_idx = 11;
static const index_t b_elementwise_op_idx = 12;
static const index_t ds_elementwise_op_idx = 13;
static const index_t gemm_spec_idx = 14;
static const index_t block_size_idx = 16;
static const index_t m_per_block_idx = 17;
static const index_t n_per_block_idx = 18;
static const index_t k_per_block_idx = 19;
private:
auto GetInstances(const std::string& arch) const
{
std::vector<std::string> instances;
const bool quantize = ADataType == "int8_t" and BDataType == "int8_t";
if (get_xdlop_archs().find(arch) != get_xdlop_archs().end())
{
instance::gemm_add_add_fastgelu_instances all_instances{};
if(TransA and TransB)
instances = all_instances.get_col_col_instances(quantize);
else if(TransA and not TransB)
instances = all_instances.get_col_row_instances(quantize);
else if(not TransA and not TransB)
instances = all_instances.get_row_row_instances(quantize);
else
instances = all_instances.get_row_col_instances(quantize);
}
return instances;
}
auto MakeLayoutTuple(const std::vector<bool>& layouts) const
{
std::string layout_tuple = "ck::Tuple<";
auto it = layouts.begin();
while(it != layouts.end())
{
layout_tuple += *it ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor";
it = std::next(it);
if (it != layouts.end())
layout_tuple += ", ";
}
return layout_tuple + ">";
}
auto MakeTypeTuple(const std::vector<std::string>& types) const
{
std::string type_tuple = "ck::Tuple<";
auto it = types.begin();
while(it != types.end())
{
type_tuple += *it;
it = std::next(it);
if (it != types.end())
type_tuple += ", ";
}
return type_tuple + ">";
}
auto MakeSolution(index_t idx, const std::string& arch) const
{
auto template_str = GetInstances(arch).at(idx);
std::istringstream iss(template_str);
std::vector<std::string> params(std::istream_iterator<std::string>{iss},
std::istream_iterator<std::string>());
if (ADataType == "int8_t" and BDataType == "int8_t")
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == "ck::half_t"; }))
{
params[params.size() - 3] = "8";
}
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == "float"; }))
{
params[params.size() - 3] = "4";
}
}
params[a_elementwise_op_idx] = AElementOp;
params[b_elementwise_op_idx] = BElementOp;
params[ds_layout_idx] = MakeLayoutTuple(DsLayout);
params[ds_data_type_idx] = MakeTypeTuple(DsDataType);
params[ds_elementwise_op_idx] = CDEElementOp;
params[e_data_type_idx] = EDataType;
auto block_size_str = params[block_size_idx];
auto m_per_block_str = params[m_per_block_idx];
auto n_per_block_str = params[n_per_block_idx];
auto k_per_block_str = params[k_per_block_idx];
const auto block_size = std::stoi(block_size_str);
const auto m_per_block = std::stoi(m_per_block_str);
const auto n_per_block = std::stoi(n_per_block_str);
const auto k_per_block = std::stoi(k_per_block_str);
const auto grid_size = GetGridSize(M, N, m_per_block, n_per_block);
params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block);
std::string str = std::accumulate(params.begin() + 1, params.end(), std::string{},
[](const std::string& a, const std::string& b) {
return a.empty() ? b : a + ", " + b;
});
str = params.front() + "< " + str + ">";
return Solution{str, block_size, grid_size};
}
public:
auto GetHeaders() const
{
return ck_headers();
}
auto GetIncludeHeader() const
{
return instance::gemm_add_add_fastgelu_instances{}.get_include_header();
}
auto GetSolutions(const std::string& arch) const
{
std::vector<Solution> solutions;
const auto num_instances = GetInstances(arch).size();
for (auto i = 0; i < num_instances; ++i)
{
solutions.push_back(MakeSolution(i, arch));
}
return solutions;
}
};
} // namespace device_gemm_multiple_d
} // namespace device
} // namespace tensor_operation
} // namespace ck
#include "ck/host/common.hpp"
#include "ck_headers.hpp"
namespace ck {
namespace host {
std::string ToString(DataType dt)
{
switch (dt) {
case DataType::Float: return "float";
case DataType::Half: return "ck::half_t";
case DataType::Int8: return "int8_t";
case DataType::Int32: return "int32_t";
}
throw std::runtime_error("Incorrect data type");
}
std::unordered_map<std::string, std::pair<const char*,const char*>> GetHeaders()
{
return ck_headers();
}
std::size_t integer_divide_ceil(std::size_t x, std::size_t y)
{
return (x + y - std::size_t{1}) / y;
}
} // namespace host
} // namespace ck
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/common.hpp"
#include "gemm_add_add_fastgelu_instances.hpp"
#include <algorithm>
#include <unordered_set>
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
std::string GetGemmSpec(const std::size_t m,
const std::size_t n,
const std::size_t k,
const std::size_t m_per_block,
const std::size_t n_per_block,
const std::size_t k_per_block)
{
std::string spec = "";
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
spec += "M";
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
spec += "N";
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
spec += "K";
if(spec == "")
return "ck::tensor_operation::device::GemmSpecialization::Default";
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
}
std::size_t GetGridSize(const std::size_t m,
const std::size_t n,
const std::size_t m_per_block,
const std::size_t n_per_block)
{
return integer_divide_ceil(m, m_per_block) *
integer_divide_ceil(n, n_per_block);
}
const std::unordered_set<std::string>& get_xdlop_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx90a"};
return supported_archs;
}
std::vector<std::string> Problem::GetInstances(const std::string& arch) const
{
std::vector<std::string> instances;
const bool quantize = ADataType == DataType::Int8 and BDataType == DataType::Int8;
if (get_xdlop_archs().find(arch) != get_xdlop_archs().end())
{
instance::gemm_add_add_fastgelu_instances all_instances{};
if(TransA and TransB)
instances = all_instances.get_col_col_instances(quantize);
else if(TransA and not TransB)
instances = all_instances.get_col_row_instances(quantize);
else if(not TransA and not TransB)
instances = all_instances.get_row_row_instances(quantize);
else
instances = all_instances.get_row_col_instances(quantize);
}
return instances;
}
std::string MakeLayoutTuple(const std::vector<bool>& layouts)
{
std::string layout_tuple = "ck::Tuple<";
auto it = layouts.begin();
while(it != layouts.end())
{
layout_tuple += *it ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor";
it = std::next(it);
if (it != layouts.end())
layout_tuple += ", ";
}
return layout_tuple + ">";
}
std::string MakeTypeTuple(const std::vector<DataType>& types)
{
std::string type_tuple = "ck::Tuple<";
auto it = types.begin();
while(it != types.end())
{
type_tuple += ToString(*it);
it = std::next(it);
if (it != types.end())
type_tuple += ", ";
}
return type_tuple + ">";
}
Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
{
auto template_str = GetInstances(arch).at(idx);
std::istringstream iss(template_str);
std::vector<std::string> params(std::istream_iterator<std::string>{iss},
std::istream_iterator<std::string>());
if (ADataType == DataType::Int8 and BDataType == DataType::Int8)
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Half; }))
{
params[params.size() - 3] = "8";
}
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Float; }))
{
params[params.size() - 3] = "4";
}
}
params[a_elementwise_op_idx] = AElementOp;
params[b_elementwise_op_idx] = BElementOp;
params[ds_layout_idx] = MakeLayoutTuple(DsTrans);
params[ds_data_type_idx] = MakeTypeTuple(DsDataType);
params[ds_elementwise_op_idx] = CDEElementOp;
params[e_data_type_idx] = ToString(EDataType);
auto block_size_str = params[block_size_idx];
auto m_per_block_str = params[m_per_block_idx];
auto n_per_block_str = params[n_per_block_idx];
auto k_per_block_str = params[k_per_block_idx];
const std::size_t block_size = std::stoi(block_size_str);
const std::size_t m_per_block = std::stoi(m_per_block_str);
const std::size_t n_per_block = std::stoi(n_per_block_str);
const std::size_t k_per_block = std::stoi(k_per_block_str);
const std::size_t grid_size = GetGridSize(M, N, m_per_block, n_per_block);
params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block);
std::string str = std::accumulate(params.begin() + 1, params.end(), std::string{},
[](const std::string& a, const std::string& b) {
return a.empty() ? b : a + ", " + b;
});
str = params.front() + "< " + str + ">";
return Solution{str, block_size, grid_size};
}
std::string Problem::GetIncludeHeader() const
{
return instance::gemm_add_add_fastgelu_instances{}.get_include_header();
}
std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
{
std::vector<Solution> solutions;
const std::size_t num_instances = GetInstances(arch).size();
for (std::size_t i = 0; i < num_instances; ++i)
{
solutions.push_back(MakeSolution(i, arch));
}
return solutions;
}
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck
import argparse, re, json, os import argparse, re, json, os, sys
out_file = """// SPDX-License-Identifier: MIT out_file = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
...@@ -10,8 +10,7 @@ out_file = """// SPDX-License-Identifier: MIT ...@@ -10,8 +10,7 @@ out_file = """// SPDX-License-Identifier: MIT
#include <memory> #include <memory>
namespace ck {{ namespace ck {{
namespace tensor_operation {{ namespace host {{
namespace device {{
namespace instance {{ namespace instance {{
struct {op_name}_instances struct {op_name}_instances
...@@ -87,8 +86,7 @@ struct {op_name}_instances ...@@ -87,8 +86,7 @@ struct {op_name}_instances
}}; }};
}} // namespace instance }} // namespace instance
}} // namespace device }} // namespace host
}} // namespace tensor_operation
}} // namespace ck }} // namespace ck
""" """
...@@ -172,8 +170,7 @@ def get_int8_instances(src, file, template_name): ...@@ -172,8 +170,7 @@ def get_int8_instances(src, file, template_name):
instances["col_row"][-1] = instances["col_row"][-1][:-1] instances["col_row"][-1] = instances["col_row"][-1][:-1]
return instances return instances
def parse_instances(source): def parse_instances(source, out_dir):
out_dir = os.path.join(source, "../../../src/jit_library/solution_instances")
aliases = {"F16_F16_Tuple": "ck::Tuple<F16,F16>", aliases = {"F16_F16_Tuple": "ck::Tuple<F16,F16>",
"Row_Row_Tuple": "ck::Tuple<Row,Row>", "Row_Row_Tuple": "ck::Tuple<Row,Row>",
"Empty_Tuple": "ck::Tuple<>", "Empty_Tuple": "ck::Tuple<>",
...@@ -273,9 +270,9 @@ def parse_instances(source): ...@@ -273,9 +270,9 @@ def parse_instances(source):
int8_row_col_instances="\n".join(int8_instances["row_col"]), int8_row_col_instances="\n".join(int8_instances["row_col"]),
include_header=include_header)) include_header=include_header))
def run(): def run(args):
source = "/code/composable_kernel/library/src/tensor_operation_instance/gpu" parse_instances(args[0], args[1])
parse_instances(source)
if __name__ == '__main__': if __name__ == '__main__':
run() run(sys.argv[1:])
\ No newline at end of file
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment