Commit 9466f4c0 authored by Paul's avatar Paul
Browse files

Use larger vec sizes when possible

parent 7662d9c0
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/compile_gen.hpp> #include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -48,12 +49,11 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs) ...@@ -48,12 +49,11 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
return {4, 2}; return {4, 2};
} }
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs) vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs, const std::vector<std::size_t>& sizes)
{ {
if(std::all_of( if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; })) inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
return {1, axis}; return {1, axis};
auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size; std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(), std::transform(inputs.begin(),
inputs.end(), inputs.end(),
...@@ -81,6 +81,28 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs ...@@ -81,6 +81,28 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis}; return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis};
} }
vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{
if (inputs.empty())
return {1, axis};
std::size_t n = std::max_element(inputs.begin(), inputs.end(), by(std::less<>{}, [](const auto& s) { return s.elements(); }))->elements();
std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu();
std::size_t over = n / max_global;
std::vector<std::size_t> sizes;
if (over > 8)
sizes.push_back(8);
if (over > 4)
sizes.push_back(4);
sizes.push_back(2);
return elements(axis, inputs, sizes);
}
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs)
{
return elements(axis, inputs, vector_sizes(inputs));
}
std::string vectorize::str() const std::string vectorize::str() const
{ {
return "vectorize<" + to_string(size) + ", " + to_string(axis) + ">()"; return "vectorize<" + to_string(size) + ", " + to_string(axis) + ">()";
......
...@@ -36,6 +36,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -36,6 +36,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct shape; struct shape;
namespace gpu { namespace gpu {
struct context;
namespace gen { namespace gen {
struct vectorize struct vectorize
...@@ -43,6 +46,8 @@ struct vectorize ...@@ -43,6 +46,8 @@ struct vectorize
std::size_t size = 1; std::size_t size = 1;
std::size_t axis = 0; std::size_t axis = 0;
static vectorize elements(std::size_t axis, const std::vector<shape>& inputs); static vectorize elements(std::size_t axis, const std::vector<shape>& inputs);
static vectorize elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs);
static vectorize elements(std::size_t axis, const std::vector<shape>& inputs, const std::vector<std::size_t>& sizes);
std::string str() const; std::string str() const;
}; };
struct preload struct preload
......
...@@ -75,20 +75,19 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -75,20 +75,19 @@ struct pointwise_compiler : compiler<pointwise_compiler>
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs); auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(axis, options.virtual_inputs); auto vec = vectorize::elements(ctx, axis, options.virtual_inputs);
auto preloads = preload::broadcasts(axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "kernel"); options.kernel_name = v.get("kernel", "kernel");
options.set_launch_params( options.set_launch_params(
v, v,
compute_global_for(ctx, compute_global_for(ctx,
options.output.elements() / vec.size, options.output.elements() / vec.size,
oversubscribe_if(not preloads.is_preloading()))); 256));
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()}, {"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(preloads, vec)}, {"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment