Commit 3905f4a2 authored by Paul's avatar Paul
Browse files

Fix compile errors

parent b155a0ac
......@@ -6,10 +6,11 @@ add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${PROJECT_SOURCE_DIR}/buil
execute_process(
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/util/make_instance_strings.py
${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
${CMAKE_CURRENT_BINARY_DIR}/solution_instances
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../tensor_operation_instance/gpu/
)
add_library(jit_library STATIC
src/device_gemm_multiple_d.cpp
src/common.cpp
......@@ -21,6 +22,7 @@ set_target_properties(jit_library PROPERTIES LINKER_LANGUAGE CXX)
target_include_directories(jit_library PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/library/src/jit_library/solution_instances>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/solution_instances>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/embed/ck_headers/include>
)
......
......@@ -9,6 +9,7 @@
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/common.hpp"
namespace ck {
......
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/common.hpp"
#include "ck/solution_instances/gemm_add_add_fastgelu_instances.hpp"
#include "gemm_add_add_fastgelu_instances.hpp"
#include <algorithm>
#include <unordered_set>
namespace ck {
......@@ -120,11 +121,11 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
auto m_per_block_str = params[m_per_block_idx];
auto n_per_block_str = params[n_per_block_idx];
auto k_per_block_str = params[k_per_block_idx];
const auto block_size = std::stoi(block_size_str);
const auto m_per_block = std::stoi(m_per_block_str);
const auto n_per_block = std::stoi(n_per_block_str);
const auto k_per_block = std::stoi(k_per_block_str);
const auto grid_size = GetGridSize(M, N, m_per_block, n_per_block);
const std::size_t block_size = std::stoi(block_size_str);
const std::size_t m_per_block = std::stoi(m_per_block_str);
const std::size_t n_per_block = std::stoi(n_per_block_str);
const std::size_t k_per_block = std::stoi(k_per_block_str);
const std::size_t grid_size = GetGridSize(M, N, m_per_block, n_per_block);
params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block);
std::string str = std::accumulate(params.begin() + 1, params.end(), std::string{},
......@@ -144,8 +145,8 @@ std::string Problem::GetIncludeHeader() const
std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
{
std::vector<Solution> solutions;
const auto num_instances = GetInstances(arch).size();
for (auto i = 0; i < num_instances; ++i)
const std::size_t num_instances = GetInstances(arch).size();
for (std::size_t i = 0; i < num_instances; ++i)
{
solutions.push_back(MakeSolution(i, arch));
}
......
import argparse, re, json, os
import argparse, re, json, os, sys
out_file = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
......@@ -10,8 +10,7 @@ out_file = """// SPDX-License-Identifier: MIT
#include <memory>
namespace ck {{
namespace tensor_operation {{
namespace device {{
namespace host {{
namespace instance {{
struct {op_name}_instances
......@@ -87,8 +86,7 @@ struct {op_name}_instances
}};
}} // namespace instance
}} // namespace device
}} // namespace tensor_operation
}} // namespace host
}} // namespace ck
"""
......@@ -172,8 +170,7 @@ def get_int8_instances(src, file, template_name):
instances["col_row"][-1] = instances["col_row"][-1][:-1]
return instances
def parse_instances(source):
out_dir = os.path.join(source, "../../../src/jit_library/solution_instances")
def parse_instances(source, out_dir):
aliases = {"F16_F16_Tuple": "ck::Tuple<F16,F16>",
"Row_Row_Tuple": "ck::Tuple<Row,Row>",
"Empty_Tuple": "ck::Tuple<>",
......@@ -273,9 +270,9 @@ def parse_instances(source):
int8_row_col_instances="\n".join(int8_instances["row_col"]),
include_header=include_header))
def run():
source = "/code/composable_kernel/library/src/tensor_operation_instance/gpu"
parse_instances(source)
def run(args):
parse_instances(args[0], args[1])
if __name__ == '__main__':
run()
\ No newline at end of file
run(sys.argv[1:])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment