Commit 6289e36f authored by Alan Turner's avatar Alan Turner
Browse files

Add int8 instances

parent 6dd246a6
......@@ -95,17 +95,18 @@ private:
auto GetInstances(const std::string& arch) const
{
std::vector<std::string> instances;
const bool quantize = ADataType == "int8_t" and BDataType == "int8_t";
if (get_xdlop_archs().find(arch) != get_xdlop_archs().end())
{
instance::gemm_add_add_fastgelu_instances all_instances{};
if(TransA and TransB)
instances = all_instances.get_col_col_instances();
instances = all_instances.get_col_col_instances(quantize);
else if(TransA and not TransB)
instances = all_instances.get_col_row_instances();
instances = all_instances.get_col_row_instances(quantize);
else if(not TransA and not TransB)
instances = all_instances.get_row_row_instances();
instances = all_instances.get_row_row_instances(quantize);
else
instances = all_instances.get_row_col_instances();
instances = all_instances.get_row_col_instances(quantize);
}
return instances;
}
......
......@@ -36,24 +36,48 @@ struct {op_name}_instances
{row_col_instances}
}};
static auto get_col_row_instances()
static inline std::vector<std::string> {int8_col_row_name} =
{{
return {col_row_name};
{int8_col_row_instances}
}};
static inline std::vector<std::string> {int8_col_col_name} =
{{
{int8_col_col_instances}
}};
static inline std::vector<std::string> {int8_row_row_name} =
{{
{int8_row_row_instances}
}};
static inline std::vector<std::string> {int8_row_col_name} =
{{
{int8_row_col_instances}
}};
static auto get_col_row_instances(const bool quantize)
{{
return quantize ? {int8_col_row_name} :
{col_row_name};
}}
static auto get_col_col_instances()
static auto get_col_col_instances(const bool quantize)
{{
return {col_col_name};
return quantize ? {int8_col_col_name} :
{col_col_name};
}}
static auto get_row_row_instances()
static auto get_row_row_instances(const bool quantize)
{{
return {row_row_name};
return quantize ? {int8_row_row_name} :
{row_row_name};
}}
static auto get_row_col_instances()
static auto get_row_col_instances(const bool quantize)
{{
return {row_col_name};
return quantize ? {int8_row_col_name} :
{row_col_name};
}}
static auto get_include_header()
......@@ -87,19 +111,83 @@ def remove_commas_and_brackets(string):
return string
def get_int8_instances(src, file, template_name):
aliases = {"Empty_Tuple": "ck::Tuple<>",
"Row": "ck::tensor_layout::gemm::RowMajor",
"Col": "ck::tensor_layout::gemm::ColumnMajor",
"OutElementOp": "PassThrough"}
instances = {"row_row": [],
"row_col": [],
"col_col": [],
"col_row": [],
"row_row_name": [],
"row_col_name": [],
"col_col_name": [],
"col_row_name": []}
path = src + file
with open(path) as f:
for line in f:
if "impl" in line:
include_header = line.replace("#include \"", "").replace("\"", "").replace("\n", "")
elif "using" in line:
if bool(re.search(".*mk.*kn.*", line)):
instances["row_row_name"] = re.search("device_gemm.*instance", line).group()
elif bool(re.search(".*mk.*nk.*", line)):
instances["row_col_name"] = re.search("device_gemm.*instance", line).group()
elif bool(re.search(".*km.*nk.*", line)):
instances["col_col_name"] = re.search("device_gemm.*instance", line).group()
elif bool(re.search(".*km.*kn.*", line)):
instances["col_row_name"] = re.search("device_gemm.*instance", line).group()
elif template_name in line:
# Turn all whitespace into single spaces
new_line = " ".join(line.split())
# Remove whitespace from S<*>
new_line = strip_sequences(new_line)
new_line = remove_commas_and_brackets(new_line)
last_char = "\n"
if new_line[-1] == ",":
last_char = ",\n"
new_line = new_line[:-1]
new_line = ' "ck::tensor_operation::device::' + new_line + '",'
versions = []
for key in aliases:
new_line = new_line.replace(key, aliases[key])
versions.append(new_line.replace("GemmPipeline", "ck:PipelineVersion::v1").replace("GemmLoopScheduler", "ck::LoopScheduler::Default"))
versions.append(new_line.replace("GemmPipeline", "ck:PipelineVersion::v1").replace("GemmLoopScheduler", "ck::LoopScheduler::Interwave"))
versions.append(new_line.replace("GemmPipeline", "ck:PipelineVersion::v2").replace("GemmLoopScheduler", "ck::LoopScheduler::Default"))
if "ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::RowMajor" in new_line:
instances["row_row"].extend(versions)
elif "ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::ColumnMajor" in new_line:
instances["row_col"].extend(versions)
elif "ck::tensor_layout::gemm::ColumnMajor ck::tensor_layout::gemm::ColumnMajor" in new_line:
instances["col_col"].extend(versions)
elif "ck::tensor_layout::gemm::ColumnMajor ck::tensor_layout::gemm::RowMajor" in new_line:
instances["col_row"].extend(versions)
instances["row_row"][-1] = instances["row_row"][-1][:-1]
instances["row_col"][-1] = instances["row_col"][-1][:-1]
instances["col_col"][-1] = instances["col_col"][-1][:-1]
instances["col_row"][-1] = instances["col_row"][-1][:-1]
return instances
def parse_instances(source):
out_dir = os.path.join(source, "../../../src/jit_library/solution_instances")
aliases = {"F16_F16_Tuple": "ck::Tuple<F16,F16>",
"Row_Row_Tuple": "ck::Tuple<Row,Row>",
"Empty_Tuple": "ck::Tuple<>",
"LoopScheduler": "ck::LoopScheduler",
"PipelineVersion": "ck::PipelineVersion",
"Row": "ck::tensor_layout::gemm::RowMajor",
"Col": "ck::tensor_layout::gemm::ColumnMajor",
"F16": "ck::half_t",
"F32": "float"}
"F32": "float",
"OutElementOp": "PassThrough"}
device_ops = {"gemm_add_add_fastgelu": "DeviceGemmMultipleD_Xdl_CShuffle",
#"batched_gemm_softmax_gemm": "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
for root_, dirs_, files_ in os.walk(source):
for dir in dirs_:
op_name = os.path.split(dir)[-1]
......@@ -163,6 +251,8 @@ def parse_instances(source):
out_file_name = op_name + "_instances.hpp"
if not os.path.exists(out_dir):
os.mkdir(out_dir)
int8_file = "/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
int8_instances = get_int8_instances(source, int8_file, "DeviceGemmMultipleD_Xdl_CShuffle")
with open(os.path.join(out_dir, out_file_name), "w+") as f:
f.write(out_file.format(op_name=op_name,
col_row_name=col_row_name,
......@@ -173,6 +263,14 @@ def parse_instances(source):
row_row_instances="\n".join(row_row_instances),
row_col_name=row_col_name,
row_col_instances="\n".join(row_col_instances),
int8_col_row_name=int8_instances["col_row_name"],
int8_col_row_instances="\n".join(int8_instances["col_row"]),
int8_col_col_name=int8_instances["col_col_name"],
int8_col_col_instances="\n".join(int8_instances["col_col"]),
int8_row_row_name=int8_instances["row_row_name"],
int8_row_row_instances="\n".join(int8_instances["row_row"]),
int8_row_col_name=int8_instances["row_col_name"],
int8_row_col_instances="\n".join(int8_instances["row_col"]),
include_header=include_header))
def run():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment