Commit e4ecf265 authored by Paul's avatar Paul
Browse files

Merge

parents acc58cfe f3eb708b
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/compile_gen.hpp> #include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -48,12 +49,13 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs) ...@@ -48,12 +49,13 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
return {4, 2}; return {4, 2};
} }
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs) vectorize vectorize::elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes)
{ {
if(std::all_of( if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; })) inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
return {1, axis}; return {1, axis};
auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size; std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(), std::transform(inputs.begin(),
inputs.end(), inputs.end(),
...@@ -81,6 +83,33 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs ...@@ -81,6 +83,33 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis}; return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis};
} }
vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{
if(inputs.empty())
return {1, axis};
std::size_t n = std::max_element(inputs.begin(),
inputs.end(),
by(std::less<>{}, [](const auto& s) { return s.elements(); }))
->elements();
std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu();
std::size_t over = n / max_global;
bool broadcasted =
std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); });
std::vector<std::size_t> sizes;
if(broadcasted and over > 8)
sizes.push_back(8);
if(over > 4)
sizes.push_back(4);
sizes.push_back(2);
return elements(axis, inputs, sizes);
}
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs)
{
return elements(axis, inputs, vector_sizes(inputs));
}
std::string vectorize::str() const std::string vectorize::str() const
{ {
return "vectorize<" + to_string(size) + ", " + to_string(axis) + ">()"; return "vectorize<" + to_string(size) + ", " + to_string(axis) + ">()";
......
...@@ -36,6 +36,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -36,6 +36,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct shape; struct shape;
namespace gpu { namespace gpu {
struct context;
namespace gen { namespace gen {
struct vectorize struct vectorize
...@@ -43,6 +46,10 @@ struct vectorize ...@@ -43,6 +46,10 @@ struct vectorize
std::size_t size = 1; std::size_t size = 1;
std::size_t axis = 0; std::size_t axis = 0;
static vectorize elements(std::size_t axis, const std::vector<shape>& inputs); static vectorize elements(std::size_t axis, const std::vector<shape>& inputs);
static vectorize elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs);
static vectorize elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes);
std::string str() const; std::string str() const;
}; };
struct preload struct preload
......
...@@ -78,7 +78,8 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -78,7 +78,8 @@ struct concat_compiler : compiler<concat_compiler>
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel"); options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs); auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(axis, options.inputs); auto vec = vectorize::elements(ctx, axis, options.inputs);
options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256)); v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string( auto src = interpolate_string(
......
...@@ -50,7 +50,6 @@ ${preamble} ...@@ -50,7 +50,6 @@ ${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) __global__ void ${kernel}(${params})
{ {
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, ${eps}, xs...); ${layernorm}<${axis}>(${post}, ${eps}, xs...);
}); });
...@@ -78,9 +77,8 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -78,9 +77,8 @@ struct layernorm_compiler : compiler<layernorm_compiler>
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(axis == faxis) if(axis == faxis)
{ {
vec = vectorize::elements(faxis, inputs); vec = vectorize::elements(ctx, faxis, inputs);
} }
auto preloads = preload::broadcasts(axis, inputs);
auto relements = inputs[0].lens()[axis] / vec.size; auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]); auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
...@@ -96,7 +94,7 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -96,7 +94,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(preloads, vec)}, {"transformers", make_transformer_args(vec)},
{"post", v.get("post", std::string{"op::id{}"})}, {"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})}, {"preamble", v.get("preamble", std::string{})},
{"layernorm", v.get("layernorm", std::string{"layernorm"})}, {"layernorm", v.get("layernorm", std::string{"layernorm"})},
......
...@@ -75,20 +75,16 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -75,20 +75,16 @@ struct pointwise_compiler : compiler<pointwise_compiler>
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs); auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(axis, options.virtual_inputs); auto vec = vectorize::elements(ctx, axis, options.virtual_inputs);
auto preloads = preload::broadcasts(axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "kernel"); options.kernel_name = v.get("kernel", "kernel");
options.set_launch_params( options.set_launch_params(
v, v, compute_global_for(ctx, options.output.elements() / vec.size, 256));
compute_global_for(ctx,
options.output.elements() / vec.size,
oversubscribe_if(not preloads.is_preloading())));
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()}, {"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(preloads, vec)}, {"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -121,7 +121,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -121,7 +121,7 @@ struct reduce_compiler : compiler<reduce_compiler>
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1) if(options.virtual_inputs.back().lens()[faxis] == 1)
{ {
vec = vectorize::elements(faxis, options.virtual_inputs); vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
} }
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size; auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
......
...@@ -71,7 +71,7 @@ struct softmax_compiler : compiler<softmax_compiler> ...@@ -71,7 +71,7 @@ struct softmax_compiler : compiler<softmax_compiler>
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(faxis == axis) if(faxis == axis)
{ {
vec = vectorize::elements(faxis, inputs); vec = vectorize::elements(ctx, faxis, inputs);
} }
auto relements = inputs[0].lens()[axis] / vec.size; auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]); auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment