Commit ff878ce6 authored by Paul's avatar Paul
Browse files

Format

parent 93a5de9f
...@@ -113,7 +113,7 @@ auto action_decorate(F f, Action action) ...@@ -113,7 +113,7 @@ auto action_decorate(F f, Action action)
using tuning_entry = std::pair<std::vector<shape>, size_t>; using tuning_entry = std::pair<std::vector<shape>, size_t>;
static std::vector<tuning_entry> read_tuning(const std::string& s) static std::vector<tuning_entry> read_tuning(const std::string& s)
{ {
if (not fs::exists(s)) if(not fs::exists(s))
return {}; return {};
return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s))); return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s)));
} }
...@@ -121,10 +121,9 @@ static std::vector<tuning_entry> read_tuning(const std::string& s) ...@@ -121,10 +121,9 @@ static std::vector<tuning_entry> read_tuning(const std::string& s)
static std::size_t get_tuning_for(const std::vector<shape>& inputs) static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{ {
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, "")); static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
auto it = std::find_if(tuning.begin(), tuning.end(), [&](const auto& p) { auto it = std::find_if(
return p.first == inputs; tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
}); if(it == tuning.end())
if (it == tuning.end())
return 4; return 4;
return it->second; return it->second;
} }
...@@ -159,7 +158,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -159,7 +158,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto sb = b_shape.strides().front(); auto sb = b_shape.strides().front();
auto sc = c_shape.strides().front(); auto sc = c_shape.strides().front();
auto i = v.get("tuning_val", get_tuning_for(inputs)); auto i = v.get("tuning_val", get_tuning_for(inputs));
const auto& instance = get_instance(i, [&](const auto& x) -> bool { const auto& instance = get_instance(i, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[2] and get_type(a_shape) == x[3] and get_layout(c_shape) == x[2] and get_type(a_shape) == x[3] and
......
import os, json, subprocess, tempfile, sys, argparse, contextlib import os, json, subprocess, tempfile, sys, argparse, contextlib
@contextlib.contextmanager @contextlib.contextmanager
def tmp_file(dump=None): def tmp_file(dump=None):
tmp_name = None tmp_name = None
...@@ -12,9 +13,11 @@ def tmp_file(dump=None): ...@@ -12,9 +13,11 @@ def tmp_file(dump=None):
finally: finally:
os.unlink(tmp_name) os.unlink(tmp_name)
def pretty_print(obj): def pretty_print(obj):
print(json.dumps(obj, indent=2)) print(json.dumps(obj, indent=2))
def benchmark_one(config, tuning): def benchmark_one(config, tuning):
b = { b = {
'settings': { 'settings': {
...@@ -29,7 +32,8 @@ def benchmark_one(config, tuning): ...@@ -29,7 +32,8 @@ def benchmark_one(config, tuning):
print(b) print(b)
with tmp_file(lambda tf: json.dump(b, tf)) as tf: with tmp_file(lambda tf: json.dump(b, tf)) as tf:
cp = subprocess.run('./bin/gpu-driver {}'.format(tf), cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
capture_output=True, shell=True) capture_output=True,
shell=True)
for line in cp.stdout.decode().split("\n"): for line in cp.stdout.decode().split("\n"):
s = line.strip() s = line.strip()
if not s: if not s:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment