"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "e6686d255242ddeb808dcd03cf7012fa3a347860"
Commit 1f106ca7 authored by turneram's avatar turneram
Browse files

Add envvars for AB testing

parent f1c8e6c9
......@@ -32,6 +32,8 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_BROADCAST_Q);
void apply_quantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "quantizelinear");
......@@ -61,15 +63,33 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant = qt.max();
min_quant = qt.min();
});
auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
auto min_arg = m.add_literal(literal(s, min_data));
auto max_arg = m.add_literal(literal(s, max_data));
if (enabled(MIGRAPHX_BROADCAST_Q{}))
{
auto s = add_zero_point->get_shape();
auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
auto min_mbcast =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto saturate =
m.insert_instruction(ins, make_op("clip"), add_zero_point, min_mbcast, max_mbcast);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
else
{
auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
auto min_arg = m.add_literal(literal(s, min_data));
auto max_arg = m.add_literal(literal(s, max_data));
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
}
void apply_dequantizelinear(module& m, instruction_ref ins)
......
......@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto qdots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("quant_dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return (dots >= 2 or convs >= 2);
return (dots >= 2 or convs >= 2 or qdots >= 2);
}
struct find_conv_dot_horiz_fusion
......@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
auto pred = [](auto i, auto j) {
if(i->get_operator() != j->get_operator())
return false;
if(not contains({"dot", "convolution"}, i->name()))
if(not contains({"quant_dot", "dot", "convolution"}, i->name()))
return true;
auto x = i->inputs()[1]->get_shape().lens();
auto y = j->inputs()[1]->get_shape().lens();
......@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
return false;
// Check that non-axes match
int axis = 1;
if(i->name() == "dot")
if(i->name() == "dot" or i->name() == "quant_dot")
{
axis = x.size() - 1;
}
......@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if(std::distance(start, last) < 2)
return;
auto&& name = (*start)->name();
if(not contains({"dot", "convolution"}, name))
if(not contains({"quant_dot", "dot", "convolution"}, name))
return;
auto op = (*start)->get_operator();
int group = 1;
......@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); });
int axis = 1;
int concat_axis = 0;
if(name == "dot")
if(name == "dot" or name == "quant_dot")
{
axis = int(args.front()->get_shape().lens().size() - 1);
concat_axis = axis;
......
......@@ -29,6 +29,10 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_USE_LARGE_K);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_CK_FUSION);
struct module;
namespace gpu {
......@@ -72,7 +76,7 @@ namespace {
bool is_ck_supported_type(shape::type_t t)
{
return contains({shape::half_type, shape::int8_type}, t);
return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
}
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
......@@ -89,7 +93,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Integer gemms must be divisible by 4 in ck
if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type()))
{
if(m % 4 != 0)
if(m != 1 and m % 4 != 0)
return false;
if(n % 4 != 0)
return false;
......@@ -99,7 +103,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
return k <= 2048;
return k <= 2048 or enabled(MIGRAPHX_USE_LARGE_K{});
}
struct find_ck_gemm_pointwise
......@@ -130,6 +134,10 @@ struct find_ck_gemm_pointwise
return not is_ck_supported_type(input->get_shape().type());
}))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
}))
return;
assert(gemm_it != inputs.end());
if(gemm_idx != 0)
{
......@@ -152,7 +160,7 @@ struct find_ck_gemm_pointwise
struct find_ck_gemm
{
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
auto matcher() const { return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
......@@ -165,7 +173,8 @@ struct find_ck_gemm
void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm_pointwise{});
if (not enabled(MIGRAPHX_DISABLE_CK_FUSION{}))
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}
......
import subprocess, csv, re
def get_device_name():
out = subprocess.run("rocminfo",
capture_output=True,
check=True,
shell=True)
matches = re.findall("gfx\d*[a-z]*", str(out.stdout))
return matches[0]
def run_perf(model, batch_size, int8=False, use_ck=False, use_large_k=False, disable_fusion=False):
env_vars = ""
if use_ck:
env_vars += "MIGRAPHX_ENABLE_CK=1 "
if use_large_k:
env_vars += "MIGRAPHX_USE_LARGE_K=1 "
if disable_fusion:
env_vars += "MIGRAPHX_DISABLE_CK_FUSION=1 "
int8_str = "--int8" if int8 else ""
cmd = "{env_vars} ../build/bin/driver perf {model} --fill1 input_ids --input-dim @input_ids {batch_size} 384 --batch {batch_size} --fp16 {int8} --exhaustive-tune".format(
env_vars=env_vars,
model=model,
batch_size=str(batch_size),
int8=int8_str
)
out = subprocess.run(cmd,
capture_output=True,
check=True,
shell=True)
summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
total_time = re.findall("Total time: \d+\.\d*", summary)[0]
total_time = total_time.replace("Total time: ", "")
print(summary)
print(total_time)
with open("summaries.txt", "w+") as f:
f.write(cmd + "\n")
f.write(summary + "\n\n")
# run model with:
# RocBlas
# Get gemm info
# CK
# With fusions
# Without fusions
if __name__ == "__main__":
device_id = get_device_name()
model = "/code/bert_base_cased_1_fp16_gpu.onnx"
run_perf(model, 1, True, True, True, True)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment