"examples/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "6de15707c1f9ec7409432aea231e149c91b79626"
Commit 206af368 authored by Paul's avatar Paul
Browse files

Format

parent ceee865f
...@@ -88,8 +88,9 @@ struct problem_cache ...@@ -88,8 +88,9 @@ struct problem_cache
assert(not solution.is_null()); assert(not solution.is_null());
cache[create_key(name, problem)] = solution; cache[create_key(name, problem)] = solution;
} }
void mark(const std::string& name, const value& problem) { void mark(const std::string& name, const value& problem)
cache.insert(std::make_pair(create_key(name, problem), value{})); {
cache.insert(std::make_pair(create_key(name, problem), value{}));
} }
optional<value> get(const std::string& name, const value& problem) const optional<value> get(const std::string& name, const value& problem) const
{ {
...@@ -242,7 +243,7 @@ void compile_ops::apply(module& m) const ...@@ -242,7 +243,7 @@ void compile_ops::apply(module& m) const
cm.compile(m); cm.compile(m);
// Compile already tuned configs // Compile already tuned configs
cm.compile(m); cm.compile(m);
if (not cm.cps.empty()) if(not cm.cps.empty())
MIGRAPHX_THROW("Untuned configs"); MIGRAPHX_THROW("Untuned configs");
} }
......
...@@ -69,16 +69,17 @@ void multinomial(hipStream_t stream, ...@@ -69,16 +69,17 @@ void multinomial(hipStream_t stream,
visit_all(arg0, arg1)([&](auto cdf_host, auto dist_host) { visit_all(arg0, arg1)([&](auto cdf_host, auto dist_host) {
result.visit([&](auto output_host) { result.visit([&](auto output_host) {
hip_visit_views(cdf_host, dist_host, output_host)([&](auto cdf, auto dist, auto output) { hip_visit_views(cdf_host, dist_host, output_host)(
gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ { [&](auto cdf, auto dist, auto output) {
auto idx = output.get_shape().multi(i); gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ {
auto cdf_begin = cdf.begin() + (idx.front() * class_size); auto idx = output.get_shape().multi(i);
auto cdf_end = cdf_begin + class_size; auto cdf_begin = cdf.begin() + (idx.front() * class_size);
auto sample_iter = auto cdf_end = cdf_begin + class_size;
upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end))); auto sample_iter =
output[i] = std::distance(cdf_begin, sample_iter); upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end)));
output[i] = std::distance(cdf_begin, sample_iter);
});
}); });
});
}); });
}); });
} }
......
...@@ -84,7 +84,7 @@ struct find_ck_gemm_pointwise ...@@ -84,7 +84,7 @@ struct find_ck_gemm_pointwise
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins); auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin(); auto gemm_idx = gemm_it - inputs.begin();
if (ins->get_shape().type() != gemm_ins->get_shape().type()) if(ins->get_shape().type() != gemm_ins->get_shape().type())
return; return;
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
if(gemm_idx != 0) if(gemm_idx != 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment