Commit ceee865f authored by Paul's avatar Paul
Browse files

Use is_null

parent 19ebf40d
......@@ -85,9 +85,12 @@ struct problem_cache
}
void insert(const std::string& name, const value& problem, const value& solution)
{
assert(not solution.is_null());
cache[create_key(name, problem)] = solution;
}
void mark(const std::string& name, const value& problem) { insert(name, problem, value{}); }
void mark(const std::string& name, const value& problem) {
cache.insert(std::make_pair(create_key(name, problem), value{}));
}
optional<value> get(const std::string& name, const value& problem) const
{
auto it = cache.find(create_key(name, problem));
......@@ -120,7 +123,7 @@ struct compile_plan
{
auto solution = sol.value();
// No solution yet until benchmarked so skip for now
if(solution.empty())
if(solution.is_null())
return;
compiles.emplace_back([=] {
results[0] = compiled_result{compile(*ctx, ins, preop, solution), ins};
......@@ -163,7 +166,7 @@ struct compile_plan
time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first);
}
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
pc.insert(preop.name(), config.value().problem, config.value().solutions[i]);
pc.insert(preop.name(), config.value().problem, config.value().solutions.at(i));
return results[i];
}
void replace(module& m, problem_cache& pc) const
......@@ -239,6 +242,8 @@ void compile_ops::apply(module& m) const
cm.compile(m);
// Compile already tuned configs
cm.compile(m);
if (not cm.cps.empty())
MIGRAPHX_THROW("Untuned configs");
}
} // namespace gpu
......
......@@ -67,9 +67,9 @@ void multinomial(hipStream_t stream,
size_t class_size = arg0.get_shape().lens().back();
size_t sample_size = result.get_shape().lens().back();
hip_visit_all(arg0, arg1)([&](auto cdf, auto dist) {
result.visit([&](auto out) {
hip_visit_views(out)([&](auto output) {
visit_all(arg0, arg1)([&](auto cdf_host, auto dist_host) {
result.visit([&](auto output_host) {
hip_visit_views(cdf_host, dist_host, output_host)([&](auto cdf, auto dist, auto output) {
gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ {
auto idx = output.get_shape().multi(i);
auto cdf_begin = cdf.begin() + (idx.front() * class_size);
......
......@@ -84,6 +84,8 @@ struct find_ck_gemm_pointwise
auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
if (ins->get_shape().type() != gemm_ins->get_shape().type())
return;
assert(gemm_it != inputs.end());
if(gemm_idx != 0)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment