"vscode:/vscode.git/clone" did not exist on "69f35439b42494dbdd9d145828161dc41cc327e6"
Commit 174c2d06 authored by Paul's avatar Paul
Browse files

Some update and fixes

parent 4f053f22
...@@ -680,7 +680,7 @@ struct find_contiguous_tranpose_precompile ...@@ -680,7 +680,7 @@ struct find_contiguous_tranpose_precompile
auto matcher() const auto matcher() const
{ {
return match::name("gpu::contiguous")(match::arg(0)( return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")( match::name("transpose")(match::used_once(),
match::arg(0)(match::name("gpu::precompile_op")(match::used_once()).bind("op"))) match::arg(0)(match::name("gpu::precompile_op")(match::used_once()).bind("op")))
.bind("transpose"))); .bind("transpose")));
} }
...@@ -694,11 +694,12 @@ struct find_contiguous_tranpose_precompile ...@@ -694,11 +694,12 @@ struct find_contiguous_tranpose_precompile
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>(); auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm); auto iperm = invert_permutation(perm);
auto s = auto s =
shape::from_permutation(op_ins->get_shape().type(), op_ins->get_shape().lens(), iperm); shape::from_permutation(op_ins->get_shape().type(), op_ins->get_shape().lens(), perm); // perm or iperm?
auto v = op_ins->get_operator().to_value(); auto v = op_ins->get_operator().to_value();
v["output_shape"] = to_value(s); v["output_shape"] = to_value(s);
auto new_op = make_op("gpu::precompile_op", v); auto new_op = make_op("gpu::precompile_op", v);
m.replace_instruction(op_ins, new_op, op_ins->inputs(), op_ins->module_inputs()); m.replace_instruction(op_ins, new_op, op_ins->inputs(), op_ins->module_inputs());
assert(ins->get_shape() == transpose->get_shape());
m.replace_instruction(ins, transpose); m.replace_instruction(ins, transpose);
} }
}; };
......
...@@ -237,7 +237,7 @@ struct index ...@@ -237,7 +237,7 @@ struct index
template <class F, class N> template <class F, class N>
__device__ void group_stride(N n, F f) const __device__ void group_stride(N n, F f) const
{ {
for_stride(group, n, ngroup(), f); for_stride<false>(group, n, ngroup(), f);
} }
}; };
......
...@@ -21,10 +21,14 @@ def pretty_print(obj): ...@@ -21,10 +21,14 @@ def pretty_print(obj):
def run_driver(b): def run_driver(b):
print(b) print(b)
with tmp_file(lambda tf: json.dump(b, tf)) as tf: with tmp_file(lambda tf: json.dump(b, tf)) as tf:
if not os.path.exists('./bin/gpu-driver'):
print("./bin/gpu-driver not found")
os.abort()
cp = subprocess.run('./bin/gpu-driver {}'.format(tf), cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
capture_output=True, capture_output=True,
check=True,
shell=True) shell=True)
print(cp.stderr.decode())
cp.check_returncode()
for line in cp.stdout.decode().split("\n"): for line in cp.stdout.decode().split("\n"):
s = line.strip() s = line.strip()
if not s: if not s:
...@@ -60,6 +64,8 @@ def benchmark_ck(config, tuning): ...@@ -60,6 +64,8 @@ def benchmark_ck(config, tuning):
dtime = get_device_time(line) dtime = get_device_time(line)
print(dtime) print(dtime)
return float(dtime) return float(dtime)
print("Failed")
sys.exit(1)
except: except:
return sys.float_info.max return sys.float_info.max
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment