Commit ceee865f authored by Paul's avatar Paul
Browse files

Use is_null

parent 19ebf40d
...@@ -85,9 +85,12 @@ struct problem_cache ...@@ -85,9 +85,12 @@ struct problem_cache
} }
void insert(const std::string& name, const value& problem, const value& solution) void insert(const std::string& name, const value& problem, const value& solution)
{ {
assert(not solution.is_null());
cache[create_key(name, problem)] = solution; cache[create_key(name, problem)] = solution;
} }
void mark(const std::string& name, const value& problem) { insert(name, problem, value{}); } void mark(const std::string& name, const value& problem) {
cache.insert(std::make_pair(create_key(name, problem), value{}));
}
optional<value> get(const std::string& name, const value& problem) const optional<value> get(const std::string& name, const value& problem) const
{ {
auto it = cache.find(create_key(name, problem)); auto it = cache.find(create_key(name, problem));
...@@ -120,7 +123,7 @@ struct compile_plan ...@@ -120,7 +123,7 @@ struct compile_plan
{ {
auto solution = sol.value(); auto solution = sol.value();
// No solution yet until benchmarked so skip for now // No solution yet until benchmarked so skip for now
if(solution.empty()) if(solution.is_null())
return; return;
compiles.emplace_back([=] { compiles.emplace_back([=] {
results[0] = compiled_result{compile(*ctx, ins, preop, solution), ins}; results[0] = compiled_result{compile(*ctx, ins, preop, solution), ins};
...@@ -163,7 +166,7 @@ struct compile_plan ...@@ -163,7 +166,7 @@ struct compile_plan
time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first); time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first);
} }
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end())); auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
pc.insert(preop.name(), config.value().problem, config.value().solutions[i]); pc.insert(preop.name(), config.value().problem, config.value().solutions.at(i));
return results[i]; return results[i];
} }
void replace(module& m, problem_cache& pc) const void replace(module& m, problem_cache& pc) const
...@@ -239,6 +242,8 @@ void compile_ops::apply(module& m) const ...@@ -239,6 +242,8 @@ void compile_ops::apply(module& m) const
cm.compile(m); cm.compile(m);
// Compile already tuned configs // Compile already tuned configs
cm.compile(m); cm.compile(m);
if (not cm.cps.empty())
MIGRAPHX_THROW("Untuned configs");
} }
} // namespace gpu } // namespace gpu
......
...@@ -67,9 +67,9 @@ void multinomial(hipStream_t stream, ...@@ -67,9 +67,9 @@ void multinomial(hipStream_t stream,
size_t class_size = arg0.get_shape().lens().back(); size_t class_size = arg0.get_shape().lens().back();
size_t sample_size = result.get_shape().lens().back(); size_t sample_size = result.get_shape().lens().back();
hip_visit_all(arg0, arg1)([&](auto cdf, auto dist) { visit_all(arg0, arg1)([&](auto cdf_host, auto dist_host) {
result.visit([&](auto out) { result.visit([&](auto output_host) {
hip_visit_views(out)([&](auto output) { hip_visit_views(cdf_host, dist_host, output_host)([&](auto cdf, auto dist, auto output) {
gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ { gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ {
auto idx = output.get_shape().multi(i); auto idx = output.get_shape().multi(i);
auto cdf_begin = cdf.begin() + (idx.front() * class_size); auto cdf_begin = cdf.begin() + (idx.front() * class_size);
......
...@@ -84,6 +84,8 @@ struct find_ck_gemm_pointwise ...@@ -84,6 +84,8 @@ struct find_ck_gemm_pointwise
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins); auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin(); auto gemm_idx = gemm_it - inputs.begin();
if (ins->get_shape().type() != gemm_ins->get_shape().type())
return;
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
if(gemm_idx != 0) if(gemm_idx != 0)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment