Commit e4ecf265 authored by Paul's avatar Paul
Browse files

Merge

parents acc58cfe f3eb708b
......@@ -22,6 +22,7 @@
* THE SOFTWARE.
*/
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
......@@ -48,12 +49,13 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
return {4, 2};
}
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs)
vectorize vectorize::elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes)
{
if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
return {1, axis};
auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(),
inputs.end(),
......@@ -81,6 +83,33 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis};
}
vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{
if(inputs.empty())
return {1, axis};
std::size_t n = std::max_element(inputs.begin(),
inputs.end(),
by(std::less<>{}, [](const auto& s) { return s.elements(); }))
->elements();
std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu();
std::size_t over = n / max_global;
bool broadcasted =
std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); });
std::vector<std::size_t> sizes;
if(broadcasted and over > 8)
sizes.push_back(8);
if(over > 4)
sizes.push_back(4);
sizes.push_back(2);
return elements(axis, inputs, sizes);
}
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs)
{
return elements(axis, inputs, vector_sizes(inputs));
}
std::string vectorize::str() const
{
return "vectorize<" + to_string(size) + ", " + to_string(axis) + ">()";
......
......@@ -36,6 +36,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct shape;
namespace gpu {
struct context;
namespace gen {
struct vectorize
......@@ -43,6 +46,10 @@ struct vectorize
std::size_t size = 1;
std::size_t axis = 0;
static vectorize elements(std::size_t axis, const std::vector<shape>& inputs);
static vectorize elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs);
static vectorize elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes);
std::string str() const;
};
struct preload
......
......@@ -78,7 +78,8 @@ struct concat_compiler : compiler<concat_compiler>
options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(axis, options.inputs);
auto vec = vectorize::elements(ctx, axis, options.inputs);
options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(
......
......@@ -50,7 +50,6 @@ ${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, ${eps}, xs...);
});
......@@ -78,9 +77,8 @@ struct layernorm_compiler : compiler<layernorm_compiler>
// Vectorize if the axis is a reduction axis
if(axis == faxis)
{
vec = vectorize::elements(faxis, inputs);
vec = vectorize::elements(ctx, faxis, inputs);
}
auto preloads = preload::broadcasts(axis, inputs);
auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
auto block_size = compute_block_size(relements, 256);
......@@ -96,7 +94,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(preloads, vec)},
{"transformers", make_transformer_args(vec)},
{"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})},
{"layernorm", v.get("layernorm", std::string{"layernorm"})},
......
......@@ -75,20 +75,16 @@ struct pointwise_compiler : compiler<pointwise_compiler>
options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(axis, options.virtual_inputs);
auto preloads = preload::broadcasts(axis, options.virtual_inputs);
auto vec = vectorize::elements(ctx, axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "kernel");
options.set_launch_params(
v,
compute_global_for(ctx,
options.output.elements() / vec.size,
oversubscribe_if(not preloads.is_preloading())));
v, compute_global_for(ctx, options.output.elements() / vec.size, 256));
auto src = interpolate_string(pointwise_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(preloads, vec)},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options);
}
......
......@@ -121,7 +121,7 @@ struct reduce_compiler : compiler<reduce_compiler>
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(faxis, options.virtual_inputs);
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements();
......
......@@ -71,7 +71,7 @@ struct softmax_compiler : compiler<softmax_compiler>
// Vectorize if the axis is a reduction axis
if(faxis == axis)
{
vec = vectorize::elements(faxis, inputs);
vec = vectorize::elements(ctx, faxis, inputs);
}
auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment