Commit 206af368 authored by Paul's avatar Paul
Browse files

Format

parent ceee865f
......@@ -88,8 +88,9 @@ struct problem_cache
assert(not solution.is_null());
cache[create_key(name, problem)] = solution;
}
void mark(const std::string& name, const value& problem) {
cache.insert(std::make_pair(create_key(name, problem), value{}));
void mark(const std::string& name, const value& problem)
{
cache.insert(std::make_pair(create_key(name, problem), value{}));
}
optional<value> get(const std::string& name, const value& problem) const
{
......@@ -242,7 +243,7 @@ void compile_ops::apply(module& m) const
cm.compile(m);
// Compile already tuned configs
cm.compile(m);
if (not cm.cps.empty())
if(not cm.cps.empty())
MIGRAPHX_THROW("Untuned configs");
}
......
......@@ -69,16 +69,17 @@ void multinomial(hipStream_t stream,
visit_all(arg0, arg1)([&](auto cdf_host, auto dist_host) {
result.visit([&](auto output_host) {
hip_visit_views(cdf_host, dist_host, output_host)([&](auto cdf, auto dist, auto output) {
gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ {
auto idx = output.get_shape().multi(i);
auto cdf_begin = cdf.begin() + (idx.front() * class_size);
auto cdf_end = cdf_begin + class_size;
auto sample_iter =
upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end)));
output[i] = std::distance(cdf_begin, sample_iter);
hip_visit_views(cdf_host, dist_host, output_host)(
[&](auto cdf, auto dist, auto output) {
gs_launch(stream, batch_size * sample_size)([=](auto i) __device__ {
auto idx = output.get_shape().multi(i);
auto cdf_begin = cdf.begin() + (idx.front() * class_size);
auto cdf_end = cdf_begin + class_size;
auto sample_iter =
upper_bound(cdf_begin, cdf_end, dist[i] * *(std::prev(cdf_end)));
output[i] = std::distance(cdf_begin, sample_iter);
});
});
});
});
});
}
......
......@@ -84,7 +84,7 @@ struct find_ck_gemm_pointwise
auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
if (ins->get_shape().type() != gemm_ins->get_shape().type())
if(ins->get_shape().type() != gemm_ins->get_shape().type())
return;
assert(gemm_it != inputs.end());
if(gemm_idx != 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment