Commit 1f106ca7 authored by turneram's avatar turneram
Browse files

Add envvars for AB testing

parent f1c8e6c9
...@@ -32,6 +32,8 @@ ...@@ -32,6 +32,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_BROADCAST_Q);
void apply_quantizelinear(module& m, instruction_ref ins) void apply_quantizelinear(module& m, instruction_ref ins)
{ {
assert(ins->name() == "quantizelinear"); assert(ins->name() == "quantizelinear");
...@@ -61,15 +63,33 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -61,15 +63,33 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant = qt.max(); max_quant = qt.max();
min_quant = qt.min(); min_quant = qt.min();
}); });
auto s = add_zero_point->get_shape(); if (enabled(MIGRAPHX_BROADCAST_Q{}))
std::vector<int> min_data(s.elements(), min_quant); {
std::vector<int> max_data(s.elements(), max_quant); auto s = add_zero_point->get_shape();
auto min_arg = m.add_literal(literal(s, min_data)); auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
auto max_arg = m.add_literal(literal(s, max_data)); auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
auto min_mbcast =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto saturate =
m.insert_instruction(ins, make_op("clip"), add_zero_point, min_mbcast, max_mbcast);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
else
{
auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
auto min_arg = m.add_literal(literal(s, min_data));
auto max_arg = m.add_literal(literal(s, max_data));
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg); auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
m.replace_instruction( m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate); ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
} }
void apply_dequantizelinear(module& m, instruction_ref ins) void apply_dequantizelinear(module& m, instruction_ref ins)
......
...@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) ...@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
}; };
}; };
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")); auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto qdots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("quant_dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution")); auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return (dots >= 2 or convs >= 2); return (dots >= 2 or convs >= 2 or qdots >= 2);
} }
struct find_conv_dot_horiz_fusion struct find_conv_dot_horiz_fusion
...@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
auto pred = [](auto i, auto j) { auto pred = [](auto i, auto j) {
if(i->get_operator() != j->get_operator()) if(i->get_operator() != j->get_operator())
return false; return false;
if(not contains({"dot", "convolution"}, i->name())) if(not contains({"quant_dot", "dot", "convolution"}, i->name()))
return true; return true;
auto x = i->inputs()[1]->get_shape().lens(); auto x = i->inputs()[1]->get_shape().lens();
auto y = j->inputs()[1]->get_shape().lens(); auto y = j->inputs()[1]->get_shape().lens();
...@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
return false; return false;
// Check that non-axes match // Check that non-axes match
int axis = 1; int axis = 1;
if(i->name() == "dot") if(i->name() == "dot" or i->name() == "quant_dot")
{ {
axis = x.size() - 1; axis = x.size() - 1;
} }
...@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if(std::distance(start, last) < 2) if(std::distance(start, last) < 2)
return; return;
auto&& name = (*start)->name(); auto&& name = (*start)->name();
if(not contains({"dot", "convolution"}, name)) if(not contains({"quant_dot", "dot", "convolution"}, name))
return; return;
auto op = (*start)->get_operator(); auto op = (*start)->get_operator();
int group = 1; int group = 1;
...@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); }); start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); });
int axis = 1; int axis = 1;
int concat_axis = 0; int concat_axis = 0;
if(name == "dot") if(name == "dot" or name == "quant_dot")
{ {
axis = int(args.front()->get_shape().lens().size() - 1); axis = int(args.front()->get_shape().lens().size() - 1);
concat_axis = axis; concat_axis = axis;
......
...@@ -29,6 +29,10 @@ ...@@ -29,6 +29,10 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_USE_LARGE_K);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_CK_FUSION);
struct module; struct module;
namespace gpu { namespace gpu {
...@@ -72,7 +76,7 @@ namespace { ...@@ -72,7 +76,7 @@ namespace {
bool is_ck_supported_type(shape::type_t t) bool is_ck_supported_type(shape::type_t t)
{ {
return contains({shape::half_type, shape::int8_type}, t); return contains({shape::half_type, shape::int8_type, shape::int32_type}, t);
} }
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
...@@ -89,7 +93,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -89,7 +93,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Integer gemms must be divisible by 4 in ck // Integer gemms must be divisible by 4 in ck
if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type())) if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type()))
{ {
if(m % 4 != 0) if(m != 1 and m % 4 != 0)
return false; return false;
if(n % 4 != 0) if(n % 4 != 0)
return false; return false;
...@@ -99,7 +103,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -99,7 +103,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy // Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK // to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy // To-do: Investigate a more precise strategy
return k <= 2048; return k <= 2048 or enabled(MIGRAPHX_USE_LARGE_K{});
} }
struct find_ck_gemm_pointwise struct find_ck_gemm_pointwise
...@@ -130,6 +134,10 @@ struct find_ck_gemm_pointwise ...@@ -130,6 +134,10 @@ struct find_ck_gemm_pointwise
return not is_ck_supported_type(input->get_shape().type()); return not is_ck_supported_type(input->get_shape().type());
})) }))
return; return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
}))
return;
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
if(gemm_idx != 0) if(gemm_idx != 0)
{ {
...@@ -152,7 +160,7 @@ struct find_ck_gemm_pointwise ...@@ -152,7 +160,7 @@ struct find_ck_gemm_pointwise
struct find_ck_gemm struct find_ck_gemm
{ {
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); } auto matcher() const { return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
...@@ -165,7 +173,8 @@ struct find_ck_gemm ...@@ -165,7 +173,8 @@ struct find_ck_gemm
void fuse_ck::apply(module_pass_manager& mpm) const void fuse_ck::apply(module_pass_manager& mpm) const
{ {
match::find_matches(mpm, find_ck_gemm_pointwise{}); if (not enabled(MIGRAPHX_DISABLE_CK_FUSION{}))
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{}); match::find_matches(mpm, find_ck_gemm{});
} }
......
import subprocess, csv, re
def get_device_name():
out = subprocess.run("rocminfo",
capture_output=True,
check=True,
shell=True)
matches = re.findall("gfx\d*[a-z]*", str(out.stdout))
return matches[0]
def run_perf(model, batch_size, int8=False, use_ck=False, use_large_k=False, disable_fusion=False):
env_vars = ""
if use_ck:
env_vars += "MIGRAPHX_ENABLE_CK=1 "
if use_large_k:
env_vars += "MIGRAPHX_USE_LARGE_K=1 "
if disable_fusion:
env_vars += "MIGRAPHX_DISABLE_CK_FUSION=1 "
int8_str = "--int8" if int8 else ""
cmd = "{env_vars} ../build/bin/driver perf {model} --fill1 input_ids --input-dim @input_ids {batch_size} 384 --batch {batch_size} --fp16 {int8} --exhaustive-tune".format(
env_vars=env_vars,
model=model,
batch_size=str(batch_size),
int8=int8_str
)
out = subprocess.run(cmd,
capture_output=True,
check=True,
shell=True)
summary = re.findall("Summary.*", str(out.stdout))[0].replace("\\n", "\n")
total_time = re.findall("Total time: \d+\.\d*", summary)[0]
total_time = total_time.replace("Total time: ", "")
print(summary)
print(total_time)
with open("summaries.txt", "w+") as f:
f.write(cmd + "\n")
f.write(summary + "\n\n")
# run model with:
# RocBlas
# Get gemm info
# CK
# With fusions
# Without fusions
if __name__ == "__main__":
device_id = get_device_name()
model = "/code/bert_base_cased_1_fp16_gpu.onnx"
run_perf(model, 1, True, True, True, True)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment