Commit 546a764e authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'migraphx' into uif2-migraphx

parents 8da3dfff 57cdd70b
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
#include "ck/host/common.hpp"
#include "batched_gemm_softmax_gemm_instances.hpp"
#include <algorithm>
#include <unordered_set>
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
std::string GetGemmSpec(const std::size_t m,
const std::size_t n,
const std::size_t k,
const std::size_t n1,
const std::size_t m_per_block,
const std::size_t n_per_block,
const std::size_t k_per_block,
const std::size_t n1_per_block)
{
std::string spec = "";
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
spec += "M";
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
spec += "N";
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
spec += "K";
if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0)
spec += "O";
if(spec == "")
return "ck::tensor_operation::device::GemmSpecialization::Default";
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
}
std::size_t GetGridSize(const std::size_t m,
const std::size_t n,
const std::size_t m_per_block,
const std::size_t n_per_block)
{
return integer_divide_ceil(m, m_per_block) * integer_divide_ceil(n, n_per_block);
}
const std::unordered_set<std::string>& get_xdlop_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx90a", "gfx908", "gfx940", "gfx942"};
return supported_archs;
}
std::vector<std::string> Problem::GetInstances(const std::string& arch) const
{
std::vector<std::string> instances;
if(get_xdlop_archs().find(arch) != get_xdlop_archs().end())
{
ck::host::instance::batched_gemm_softmax_gemm_instances all_instances{};
instances = all_instances.get_instances();
}
return instances;
}
Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
{
auto template_str = GetInstances(arch).at(idx);
std::istringstream iss(template_str);
std::vector<std::string> params(std::istream_iterator<std::string>{iss},
std::istream_iterator<std::string>());
params[AElementwiseOperation_idx] = AElementOp;
params[B0ElementwiseOperation_idx] = BElementOp;
params[B1ElementwiseOperation_idx] = BElementOp;
params[CElementwiseOperation_idx] = CElementOp;
params[Acc0ElementwiseOperation_idx] = AccElementOp;
auto block_size_str = params[BlockSize_idx];
auto m_per_block_str = params[Gemm01MPerBlock_idx];
auto n_per_block_str = params[Gemm0NPerBlock_idx];
auto k_per_block_str = params[Gemm0KPerBlock_idx];
auto n1_per_block_str = params[Gemm1NPerBlock_idx];
const std::size_t block_size = std::stoi(block_size_str);
const std::size_t m_per_block = std::stoi(m_per_block_str);
const std::size_t n_per_block = std::stoi(n_per_block_str);
const std::size_t k_per_block = std::stoi(k_per_block_str);
const std::size_t n1_per_block = std::stoi(n1_per_block_str);
const std::size_t grid_size = GetGridSize(M, O, m_per_block, n1_per_block);
params[GEMMSpecialization_idx] =
GetGemmSpec(M, N, K, O, m_per_block, n_per_block, k_per_block, n1_per_block);
std::string str = std::accumulate(
params.begin() + 1,
params.end(),
std::string{},
[](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; });
str = params.front() + "< " + str + ">";
return Solution{str, block_size, grid_size};
}
std::string Problem::GetIncludeHeader() const
{
return ck::host::instance::batched_gemm_softmax_gemm_instances{}.get_include_header();
}
std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
{
std::vector<Solution> solutions;
const std::size_t num_instances = GetInstances(arch).size();
for(std::size_t i = 0; i < num_instances; ++i)
{
solutions.push_back(MakeSolution(i, arch));
}
return solutions;
}
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/common.hpp"
#include "gemm_add_add_fastgelu_instances.hpp"
#include <algorithm>
#include <unordered_set>
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
std::string GetGemmSpec(const std::size_t m,
const std::size_t n,
const std::size_t k,
const std::size_t m_per_block,
const std::size_t n_per_block,
const std::size_t k_per_block)
{
std::string spec = "";
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
spec += "M";
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
spec += "N";
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
spec += "K";
if(spec == "")
return "ck::tensor_operation::device::GemmSpecialization::Default";
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
}
std::size_t GetGridSize(const std::size_t m,
const std::size_t n,
const std::size_t m_per_block,
const std::size_t n_per_block)
{
return integer_divide_ceil(m, m_per_block) * integer_divide_ceil(n, n_per_block);
}
const std::unordered_set<std::string>& get_xdlop_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx90a", "gfx908", "gfx940", "gfx942"};
return supported_archs;
}
std::vector<std::string> Problem::GetInstances(const std::string& arch) const
{
std::vector<std::string> instances;
const bool quantize = ADataType == DataType::Int8 and BDataType == DataType::Int8;
if(get_xdlop_archs().find(arch) != get_xdlop_archs().end())
{
ck::host::instance::gemm_add_add_fastgelu_instances all_instances{};
if(TransA and TransB)
instances = all_instances.get_col_col_instances(quantize);
else if(TransA and not TransB)
instances = all_instances.get_col_row_instances(quantize);
else if(not TransA and not TransB)
instances = all_instances.get_row_row_instances(quantize);
else
instances = all_instances.get_row_col_instances(quantize);
}
return instances;
}
std::string MakeLayoutTuple(const std::vector<bool>& layouts)
{
std::string layout_tuple = "ck::Tuple<";
auto it = layouts.begin();
while(it != layouts.end())
{
layout_tuple +=
*it ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor";
it = std::next(it);
if(it != layouts.end())
layout_tuple += ", ";
}
return layout_tuple + ">";
}
std::string MakeTypeTuple(const std::vector<DataType>& types)
{
std::string type_tuple = "ck::Tuple<";
auto it = types.begin();
while(it != types.end())
{
type_tuple += ToString(*it);
it = std::next(it);
if(it != types.end())
type_tuple += ", ";
}
return type_tuple + ">";
}
Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
{
auto template_str = GetInstances(arch).at(idx);
std::istringstream iss(template_str);
std::vector<std::string> params(std::istream_iterator<std::string>{iss},
std::istream_iterator<std::string>());
if(ADataType == DataType::Int8 and BDataType == DataType::Int8)
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if(EDataType == DataType::Half or std::any_of(DsDataType.begin(),
DsDataType.end(),
[](auto t) { return t == DataType::Half; }))
{
params[params.size() - 3] = "8";
}
if(EDataType == DataType::Float or std::any_of(DsDataType.begin(),
DsDataType.end(),
[](auto t) { return t == DataType::Float; }))
{
params[params.size() - 3] = "4";
}
if(EDataType == DataType::Int32 or std::any_of(DsDataType.begin(),
DsDataType.end(),
[](auto t) { return t == DataType::Int32; }))
{
params[params.size() - 3] = "4";
}
}
params[a_elementwise_op_idx] = AElementOp;
params[b_elementwise_op_idx] = BElementOp;
params[ds_layout_idx] = MakeLayoutTuple(DsTrans);
params[ds_data_type_idx] = MakeTypeTuple(DsDataType);
params[ds_elementwise_op_idx] = CDEElementOp;
params[e_data_type_idx] = ToString(EDataType);
auto block_size_str = params[block_size_idx];
auto m_per_block_str = params[m_per_block_idx];
auto n_per_block_str = params[n_per_block_idx];
auto k_per_block_str = params[k_per_block_idx];
const std::size_t block_size = std::stoi(block_size_str);
const std::size_t m_per_block = std::stoi(m_per_block_str);
const std::size_t n_per_block = std::stoi(n_per_block_str);
const std::size_t k_per_block = std::stoi(k_per_block_str);
const std::size_t grid_size = GetGridSize(M, N, m_per_block, n_per_block);
params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block);
std::string str = std::accumulate(
params.begin() + 1,
params.end(),
std::string{},
[](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; });
str = params.front() + "< " + str + ">";
if(params.back().find("v2") != std::string::npos and K % k_per_block != 0)
str = "";
return Solution{str, block_size, grid_size};
}
std::string Problem::GetIncludeHeader() const
{
return ck::host::instance::gemm_add_add_fastgelu_instances{}.get_include_header();
}
std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
{
std::vector<Solution> solutions;
const std::size_t num_instances = GetInstances(arch).size();
for(std::size_t i = 0; i < num_instances; ++i)
{
auto solution = MakeSolution(i, arch);
if(solution.template_str != "")
solutions.push_back(solution);
}
return solutions;
}
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck
out_file_with_quant = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
namespace ck {{
namespace host {{
namespace instance {{
struct {op_name}_instances
{{
static inline std::vector<std::string> {col_row_name} =
{{
{col_row_instances}
}};
static inline std::vector<std::string> {col_col_name} =
{{
{col_col_instances}
}};
static inline std::vector<std::string> {row_row_name} =
{{
{row_row_instances}
}};
static inline std::vector<std::string> {row_col_name} =
{{
{row_col_instances}
}};
static inline std::vector<std::string> {int8_col_row_name} =
{{
{int8_col_row_instances}
}};
static inline std::vector<std::string> {int8_col_col_name} =
{{
{int8_col_col_instances}
}};
static inline std::vector<std::string> {int8_row_row_name} =
{{
{int8_row_row_instances}
}};
static inline std::vector<std::string> {int8_row_col_name} =
{{
{int8_row_col_instances}
}};
static auto get_col_row_instances(const bool quantize)
{{
return quantize ? {int8_col_row_name} :
{col_row_name};
}}
static auto get_col_col_instances(const bool quantize)
{{
return quantize ? {int8_col_col_name} :
{col_col_name};
}}
static auto get_row_row_instances(const bool quantize)
{{
return quantize ? {int8_row_row_name} :
{row_row_name};
}}
static auto get_row_col_instances(const bool quantize)
{{
return quantize ? {int8_row_col_name} :
{row_col_name};
}}
static auto get_include_header()
{{
return "{include_header}";
}}
}};
}} // namespace instance
}} // namespace host
}} // namespace ck
"""
out_file_no_quant = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
namespace ck {{
namespace host {{
namespace instance {{
struct {op_name}_instances
{{
static inline std::vector<std::string> {instances_name} =
{{
{instances}
}};
static auto get_instances()
{{
return {instances_name};
}}
static auto get_include_header()
{{
return "{include_header}";
}}
}};
}} // namespace instance
}} // namespace host
}} // namespace ck
"""
def get_device_gemm_multiple_d_file(op_name,
col_row_name,
col_row_instances,
col_col_name,
col_col_instances,
row_row_name,
row_row_instances,
row_col_name,
row_col_instances,
int8_col_row_name,
int8_col_row_instances,
int8_col_col_name,
int8_col_col_instances,
int8_row_row_name,
int8_row_row_instances,
int8_row_col_name,
int8_row_col_instances,
include_header):
return out_file_with_quant.format(
op_name=op_name,
col_row_name=col_row_name,
col_row_instances=col_row_instances,
col_col_name=col_col_name,
col_col_instances=col_col_instances,
row_row_name=row_row_name,
row_row_instances=row_row_instances,
row_col_name=row_col_name,
row_col_instances=row_col_instances,
int8_col_row_name=int8_col_row_name,
int8_col_row_instances=int8_col_row_instances,
int8_col_col_name=int8_col_col_name,
int8_col_col_instances=int8_col_col_instances,
int8_row_row_name=int8_row_row_name,
int8_row_row_instances=int8_row_row_instances,
int8_row_col_name=int8_row_col_name,
int8_row_col_instances=int8_row_col_instances,
include_header=include_header)
def get_device_gemm_softmax_gemm_file(op_name,
instances_name,
instances,
include_header):
return out_file_no_quant.format(
op_name=op_name,
instances_name=instances_name,
instances=instances,
include_header=include_header)
import argparse, re, json, os, sys, file_templates
def strip_sequences(str):
matches = re.findall(r'S<\s*\d+(?:,\s*\d+)*>', str)
for match in matches:
str = str.replace(match, match.replace(' ', ''))
str = str.replace('S<', "ck::Sequence<")
return str
def remove_commas_and_brackets(string):
regex_matches = re.findall(r'ck::Sequence<.*?>', string)
for match in regex_matches:
string = string.replace(match, match.replace(',', '|').replace('<', '%').replace('>', '$'))
string = string.replace(',', '').replace('<', '').replace('>', '')
for match in regex_matches:
string = string.replace(match.replace(',', '|').replace('<', '%').replace('>', '$'), match)
return string
def get_int8_instances(src, file, template_name):
aliases = {"Empty_Tuple": "ck::Tuple<>",
"Row": "ck::tensor_layout::gemm::RowMajor",
"Col": "ck::tensor_layout::gemm::ColumnMajor",
"OutElementOp": "PassThrough"}
instances = {"row_row": [],
"row_col": [],
"col_col": [],
"col_row": [],
"row_row_name": [],
"row_col_name": [],
"col_col_name": [],
"col_row_name": []}
path = src + file
with open(path) as f:
for line in f:
if "impl" in line:
include_header = line.replace("#include \"", "").replace("\"", "").replace("\n", "")
elif "using" in line:
if bool(re.search(".*mk.*kn.*", line)):
instances["row_row_name"] = re.search("device_gemm.*instance", line).group()
elif bool(re.search(".*mk.*nk.*", line)):
instances["row_col_name"] = re.search("device_gemm.*instance", line).group()
elif bool(re.search(".*km.*nk.*", line)):
instances["col_col_name"] = re.search("device_gemm.*instance", line).group()
elif bool(re.search(".*km.*kn.*", line)):
instances["col_row_name"] = re.search("device_gemm.*instance", line).group()
elif template_name in line:
# Turn all whitespace into single spaces
new_line = " ".join(line.split())
# Remove whitespace from S<*>
new_line = strip_sequences(new_line)
new_line = remove_commas_and_brackets(new_line)
last_char = "\n"
if new_line[-1] == ",":
last_char = ",\n"
new_line = new_line[:-1]
new_line = ' "ck::tensor_operation::device::' + new_line + '",'
versions = []
for key in aliases:
new_line = new_line.replace(key, aliases[key])
versions.append(new_line.replace("GemmPipeline", "ck::PipelineVersion::v1").replace("GemmLoopScheduler", "ck::LoopScheduler::Default"))
versions.append(new_line.replace("GemmPipeline", "ck::PipelineVersion::v1").replace("GemmLoopScheduler", "ck::LoopScheduler::Interwave"))
versions.append(new_line.replace("GemmPipeline", "ck::PipelineVersion::v2").replace("GemmLoopScheduler", "ck::LoopScheduler::Default"))
if "ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::RowMajor" in new_line:
instances["row_row"].extend(versions)
elif "ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::ColumnMajor" in new_line:
instances["row_col"].extend(versions)
elif "ck::tensor_layout::gemm::ColumnMajor ck::tensor_layout::gemm::ColumnMajor" in new_line:
instances["col_col"].extend(versions)
elif "ck::tensor_layout::gemm::ColumnMajor ck::tensor_layout::gemm::RowMajor" in new_line:
instances["col_row"].extend(versions)
instances["row_row"][-1] = instances["row_row"][-1][:-1]
instances["row_col"][-1] = instances["row_col"][-1][:-1]
instances["col_col"][-1] = instances["col_col"][-1][:-1]
instances["col_row"][-1] = instances["col_row"][-1][:-1]
return instances
def parse_instances(source, out_dir):
aliases = {"F16_F16_Tuple": "ck::Tuple<F16,F16>",
"Row_Row_Tuple": "ck::Tuple<Row,Row>",
"Empty_Tuple": "ck::Tuple<>",
"LoopScheduler": "ck::LoopScheduler",
"PipelineVersion": "ck::PipelineVersion",
"Row": "ck::tensor_layout::gemm::RowMajor",
"Col": "ck::tensor_layout::gemm::ColumnMajor",
"F16": "ck::half_t",
"F32": "float",
"OutElementOp": "PassThrough"}
device_ops = {"gemm_add_add_fastgelu": "DeviceGemmMultipleD_Xdl_CShuffle",
#"batched_gemm_softmax_gemm": "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
for root_, dirs_, files_ in os.walk(source):
for dir in dirs_:
op_name = os.path.split(dir)[-1]
if op_name not in device_ops:
continue
col_row_name = ""
col_col_name = ""
row_row_name = ""
row_col_name = ""
row_row_instances = []
col_row_instances = []
row_col_instances = []
col_col_instances = []
for root, dirs, files in os.walk(os.path.join(root_, dir)):
for file in files:
if not file.endswith(".cpp"):
continue;
file_name = os.path.split(file)[-1]
is_row_row = bool(re.search(".*mk.*kn.*", file_name))
is_col_row = bool(re.search(".*km.*kn.*", file_name))
is_row_col = bool(re.search(".*mk.*nk.*", file_name))
is_col_col = bool(re.search(".*km.*nk.*", file_name))
if is_row_row:
row_row_name = file_name[:-4]
if is_col_row:
col_row_name = file_name[:-4]
if is_row_col:
row_col_name = file_name[:-4]
if is_col_col:
col_col_name = file_name[:-4]
instances_list = []
template_name = device_ops[op_name]
include_header = ""
with open(os.path.join(root, file)) as f:
for line in f:
if "impl" in line:
include_header = line.replace("#include \"", "").replace("\"", "").replace("\n", "")
elif template_name in line:
# Turn all whitespace into single spaces
new_line = " ".join(line.split())
# Remove whitespace from S<*>
new_line = strip_sequences(new_line)
new_line = remove_commas_and_brackets(new_line)
last_char = "\n"
if new_line[-1] == ",":
last_char = ",\n"
new_line = new_line[:-1]
new_line = ' "ck::tensor_operation::device::' + new_line + '",'
for key in aliases:
new_line = new_line.replace(key, aliases[key])
instances_list.append(new_line)
instances_list[-1] = instances_list[-1][:-1]
if is_row_row:
row_row_instances = instances_list
if is_col_row:
col_row_instances = instances_list
if is_row_col:
row_col_instances = instances_list
if is_col_col:
col_col_instances = instances_list
out_file_name = op_name + "_instances.hpp"
if not os.path.exists(out_dir):
os.mkdir(out_dir)
int8_file = "/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
int8_instances = get_int8_instances(source, int8_file, "DeviceGemmMultipleD_Xdl_CShuffle")
with open(os.path.join(out_dir, out_file_name), "w+") as f:
f.write(file_templates.get_device_gemm_multiple_d_file(
op_name,
col_row_name,
"\n".join(col_row_instances),
col_col_name,
"\n".join(col_col_instances),
row_row_name,
"\n".join(row_row_instances),
row_col_name,
"\n".join(row_col_instances),
int8_instances["col_row_name"],
"\n".join(int8_instances["col_row"]),
int8_instances["col_col_name"],
"\n".join(int8_instances["col_col"]),
int8_instances["row_row_name"],
"\n".join(int8_instances["row_row"]),
int8_instances["row_col_name"],
"\n".join(int8_instances["row_col"]),
include_header))
def parse_device_gemm_multiple_d_instances(source, out_dir):
aliases = {"F16_F16_Tuple": "ck::Tuple<F16,F16>",
"Row_Row_Tuple": "ck::Tuple<Row,Row>",
"Empty_Tuple": "ck::Tuple<>",
"LoopScheduler": "ck::LoopScheduler",
"PipelineVersion": "ck::PipelineVersion",
"Row": "ck::tensor_layout::gemm::RowMajor",
"Col": "ck::tensor_layout::gemm::ColumnMajor",
"F16": "ck::half_t",
"F32": "float",
"OutElementOp": "PassThrough"}
device_ops = {"gemm_add_add_fastgelu": "DeviceGemmMultipleD_Xdl_CShuffle",
#"batched_gemm_softmax_gemm": "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
for root_, dirs_, files_ in os.walk(source):
for dir in dirs_:
op_name = os.path.split(dir)[-1]
if op_name not in device_ops:
continue
col_row_name = ""
col_col_name = ""
row_row_name = ""
row_col_name = ""
row_row_instances = []
col_row_instances = []
row_col_instances = []
col_col_instances = []
for root, dirs, files in os.walk(os.path.join(root_, dir)):
for file in files:
if not file.endswith(".cpp"):
continue;
file_name = os.path.split(file)[-1]
is_row_row = bool(re.search(".*mk.*kn.*", file_name))
is_col_row = bool(re.search(".*km.*kn.*", file_name))
is_row_col = bool(re.search(".*mk.*nk.*", file_name))
is_col_col = bool(re.search(".*km.*nk.*", file_name))
if is_row_row:
row_row_name = file_name[:-4]
if is_col_row:
col_row_name = file_name[:-4]
if is_row_col:
row_col_name = file_name[:-4]
if is_col_col:
col_col_name = file_name[:-4]
instances_list = []
template_name = device_ops[op_name]
include_header = ""
with open(os.path.join(root, file)) as f:
for line in f:
if "impl" in line:
include_header = line.replace("#include \"", "").replace("\"", "").replace("\n", "")
elif template_name in line:
# Turn all whitespace into single spaces
new_line = " ".join(line.split())
# Remove whitespace from S<*>
new_line = strip_sequences(new_line)
new_line = remove_commas_and_brackets(new_line)
last_char = "\n"
if new_line[-1] == ",":
last_char = ",\n"
new_line = new_line[:-1]
new_line = ' "ck::tensor_operation::device::' + new_line + '",'
for key in aliases:
new_line = new_line.replace(key, aliases[key])
instances_list.append(new_line)
instances_list[-1] = instances_list[-1][:-1]
if is_row_row:
row_row_instances = instances_list
if is_col_row:
col_row_instances = instances_list
if is_row_col:
row_col_instances = instances_list
if is_col_col:
col_col_instances = instances_list
out_file_name = op_name + "_instances.hpp"
if not os.path.exists(out_dir):
os.mkdir(out_dir)
int8_file = "/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
int8_instances = get_int8_instances(source, int8_file, "DeviceGemmMultipleD_Xdl_CShuffle")
with open(os.path.join(out_dir, out_file_name), "w+") as f:
f.write(file_templates.get_device_gemm_multiple_d_file(
op_name,
col_row_name,
"\n".join(col_row_instances),
col_col_name,
"\n".join(col_col_instances),
row_row_name,
"\n".join(row_row_instances),
row_col_name,
"\n".join(row_col_instances),
int8_instances["col_row_name"],
"\n".join(int8_instances["col_row"]),
int8_instances["col_col_name"],
"\n".join(int8_instances["col_col"]),
int8_instances["row_row_name"],
"\n".join(int8_instances["row_row"]),
int8_instances["row_col_name"],
"\n".join(int8_instances["row_col"]),
include_header))
def parse_param_names(file):
param_names = []
for line in file:
if bool(re.search(r"\s*//#+", line)):
names = line.split('|')
names = [n.strip() for n in names]
if not param_names:
param_names = [""] * len(names)
param_names = [a + b for a, b in zip(param_names, names)]
elif param_names:
param_names[0] = line.split('<')[0].strip()
file.seek(0)
return param_names[:-1]
file.seek(0)
return param_names[:-1]
def parse_device_batched_gemm_softmax_gemm_instances(source, out_dir):
aliases = {"Row": "ck::tensor_layout::gemm::RowMajor",
"Col": "ck::tensor_layout::gemm::ColumnMajor",
"F16": "ck::half_t",
"F32": "float",
"PassThrough": "ck::tensor_operation::element_wise::PassThrough",
"Scale": "ck::tensor_operation::element_wise::Scale",
"GemmPadded": "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
"GemmDefault": "ck::tensor_operation::device::GemmSpecialization::Default"}
device_ops = {"batched_gemm_softmax_gemm": "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
for root_, dirs_, files_ in os.walk(source):
for dir in dirs_:
op_name = os.path.split(dir)[-1]
if "permute" in op_name or op_name not in device_ops:
continue
for root, dirs, files in os.walk(os.path.join(root_, dir)):
for file in files:
if not file.endswith(".cpp"):
continue;
file_name = os.path.split(file)[-1]
instances_name = file_name[:-4]
instances_list = []
template_name = device_ops[op_name]
include_header = ""
with open(os.path.join(root, file)) as f:
param_names = parse_param_names(f)
# for i in range(len(param_names)):
# print(f"{i}: {param_names[i]}")
for line in f:
if "impl" in line:
include_header = line.replace("#include \"", "").replace("\"", "").replace("\n", "")
elif template_name in line:
# Turn all whitespace into single spaces
new_line = " ".join(line.split())
# Remove whitespace from S<*>
new_line = strip_sequences(new_line)
new_line = remove_commas_and_brackets(new_line)
last_char = "\n"
if new_line[-1] == ",":
last_char = ",\n"
new_line = new_line[:-1]
new_line = ' "ck::tensor_operation::device::' + new_line + '",'
for key in aliases:
new_line = new_line.replace(key, aliases[key])
masking = new_line.replace("Masking", "true")
no_masking = new_line.replace("Masking", "false")
instances_list.append(masking)
instances_list.append(no_masking)
out_file_name = op_name + "_instances.hpp"
if not os.path.exists(out_dir):
os.mkdir(out_dir)
with open(os.path.join(out_dir, out_file_name), "w+") as f:
f.write(file_templates.get_device_gemm_softmax_gemm_file(
op_name,
instances_name,
"\n".join(instances_list),
include_header))
def run(args):
parse_device_gemm_multiple_d_instances(args[0], args[1])
parse_device_batched_gemm_softmax_gemm_instances(args[0], args[1])
if __name__ == '__main__':
run(sys.argv[1:])
\ No newline at end of file
......@@ -120,36 +120,40 @@ function(add_gtest_executable TEST_NAME)
set(result ${result} PARENT_SCOPE)
endfunction()
add_subdirectory(magic_number_division)
add_subdirectory(space_filling_curve)
add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
add_subdirectory(gemm_layernorm)
add_subdirectory(gemm_split_k)
add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(batched_gemm_softmax_gemm_permute)
add_subdirectory(grouped_gemm)
add_subdirectory(reduce)
add_subdirectory(convnd_fwd)
add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_convnd_fwd)
add_subdirectory(grouped_convnd_bwd_weight)
add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax)
add_subdirectory(normalization)
add_subdirectory(data_type)
add_subdirectory(elementwise_normalization)
add_subdirectory(batchnorm)
add_subdirectory(contraction)
add_subdirectory(pool)
add_subdirectory(batched_gemm_multi_d)
add_subdirectory(grouped_convnd_bwd_data)
add_subdirectory(conv_tensor_rearrange)
if(GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op)
if(CK_BUILD_JIT_LIB)
add_subdirectory(jit_library)
else()
add_subdirectory(magic_number_division)
add_subdirectory(space_filling_curve)
add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
add_subdirectory(gemm_layernorm)
add_subdirectory(gemm_split_k)
add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(batched_gemm_softmax_gemm_permute)
add_subdirectory(grouped_gemm)
add_subdirectory(reduce)
add_subdirectory(convnd_fwd)
add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_convnd_fwd)
add_subdirectory(grouped_convnd_bwd_weight)
add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax)
add_subdirectory(normalization)
add_subdirectory(data_type)
add_subdirectory(elementwise_normalization)
add_subdirectory(batchnorm)
add_subdirectory(contraction)
add_subdirectory(pool)
add_subdirectory(batched_gemm_multi_d)
add_subdirectory(grouped_convnd_bwd_data)
add_subdirectory(conv_tensor_rearrange)
if(GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op)
endif()
endif()
add_test_executable(test_jit_library jit_library.cpp)
add_dependencies(test_jit_library jit_library)
target_include_directories(test_jit_library PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../../library/src/jit_library/include>)
target_link_libraries(test_jit_library PRIVATE jit_library ck_headers)
#include "ck/host/device_gemm_multiple_d.hpp"
#include <iostream>
bool test_Problem()
{
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
false,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
const auto grid_size = solution.grid_size;
const auto block_size = solution.block_size;
bool pass = true;
pass &= include_header ==
"ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp";
pass &= solutions.size() == 42;
pass &= template_str ==
"ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< "
"ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, "
"ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, "
"ck::half_t, ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, "
"8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, "
"8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, "
"1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>";
pass &= grid_size == 2;
pass &= block_size == 256;
return pass;
}
bool test_GetGemmSpec()
{
bool pass = true;
{
// PadMNK
auto problem = ck::host::device_gemm_multiple_d::Problem{
255,
255,
255,
false,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("GemmSpecialization::MNKPadding") != std::string::npos;
}
{
// Default
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
false,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("GemmSpecialization::Default") != std::string::npos;
}
return pass;
}
bool test_GetInstances()
{
bool pass = true;
{
// Col Col Fp16
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
true,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 51;
}
{
// Col Row Fp16
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
true,
false,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 51;
}
{
// Row Col Fp16
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
false,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 42;
}
{
// Row Row Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
false,
false,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 48;
}
{
// Col Col Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
true,
true,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 48;
}
{
// Col Row Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
true,
false,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 48;
}
{
// Row Col Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
false,
true,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 39;
}
{
// Row Row Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
false,
false,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 48;
}
return pass;
}
bool test_MakeLayoutsTuple()
{
bool pass = true;
{
// Empty Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
false,
false,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{ck::host::DataType::Half},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("ck::Tuple<>") != std::string::npos;
}
{
// RowColRow Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
false,
false,
false,
{false, true, false},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{ck::host::DataType::Half},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find(
"ck::Tuple<ck::tensor_layout::gemm::RowMajor, "
"ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor>") !=
std::string::npos;
}
return pass;
}
bool test_MakeTypeTuple()
{
bool pass = true;
{
// Empty Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
false,
false,
false,
{true},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("ck::Tuple<>") != std::string::npos;
}
{
// Half Int8 Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256,
256,
false,
false,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{ck::host::DataType::Half, ck::host::DataType::Int8},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("ck::Tuple<ck::half_t, int8_t>") != std::string::npos;
}
return pass;
}
int main()
{
bool pass = true;
pass &= test_Problem();
pass &= test_GetGemmSpec();
pass &= test_GetInstances();
pass &= test_MakeLayoutsTuple();
pass &= test_MakeTypeTuple();
if(pass)
{
std::cout << "Test jit library: Pass" << std::endl;
return 0;
}
else
{
std::cout << "Test jit library: Fail" << std::endl;
return -1;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment