Commit 09818ae6 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into fuse-horiz-contiguous

parents 6545452a bd503d89
...@@ -164,6 +164,10 @@ struct module ...@@ -164,6 +164,10 @@ struct module
instruction_ref replace_return(std::vector<instruction_ref> args); instruction_ref replace_return(std::vector<instruction_ref> args);
instruction_ref insert_literal(instruction_ref ins, literal l);
instruction_ref insert_parameter(instruction_ref ins, std::string name, shape s);
std::vector<std::string> get_parameter_names() const; std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const; shape get_parameter_shape(std::string name) const;
......
...@@ -42,11 +42,12 @@ namespace op { ...@@ -42,11 +42,12 @@ namespace op {
struct unsqueeze struct unsqueeze
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
std::vector<int64_t> steps;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axes, "axes")); return pack(f(self.axes, "axes"), f(self.steps, "steps"));
} }
value attributes() const value attributes() const
...@@ -73,6 +74,9 @@ struct unsqueeze ...@@ -73,6 +74,9 @@ struct unsqueeze
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar"); MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
} }
if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
...@@ -80,16 +84,27 @@ struct unsqueeze ...@@ -80,16 +84,27 @@ struct unsqueeze
std::size_t p = 0; std::size_t p = 0;
for(auto i : range(new_size)) for(auto i : range(new_size))
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{ {
new_lens[i] = 1; std::int64_t step = 1;
if(p == 0) // unsqueeze on the first axes if(axis_idx < steps.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{ {
new_strides[i] = old_lens[0] * old_strides[0]; if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
} }
else // unsqueeze on middle or last axes else
{ {
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1; if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
} }
} }
else else
......
...@@ -439,11 +439,7 @@ module::insert_instructions(instruction_ref ins, ...@@ -439,11 +439,7 @@ module::insert_instructions(instruction_ref ins,
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins)); return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins));
} }
instruction_ref module::add_literal(literal l) instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); }
{
impl->emplace_front(std::move(l));
return impl->instructions.begin();
}
instruction_ref module::add_outline(const shape& s) instruction_ref module::add_outline(const shape& s)
{ {
...@@ -453,10 +449,7 @@ instruction_ref module::add_outline(const shape& s) ...@@ -453,10 +449,7 @@ instruction_ref module::add_outline(const shape& s)
instruction_ref module::add_parameter(std::string name, shape s) instruction_ref module::add_parameter(std::string name, shape s)
{ {
assert(get_parameter_shape(name) == shape{}); return insert_parameter(begin(), std::move(name), std::move(s));
impl->push_front({builtin::param{std::move(name), impl->nparams}, std::move(s), {}});
impl->nparams++;
return impl->instructions.begin();
} }
instruction_ref module::add_return(std::vector<instruction_ref> args) instruction_ref module::add_return(std::vector<instruction_ref> args)
...@@ -469,6 +462,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args) ...@@ -469,6 +462,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args)
return result; return result;
} }
instruction_ref module::insert_literal(instruction_ref ins, literal l)
{
impl->emplace(ins, std::move(l));
return std::prev(ins);
}
instruction_ref module::insert_parameter(instruction_ref ins, std::string name, shape s)
{
assert(get_parameter_shape(name) == shape{});
impl->insert(ins, {builtin::param{std::move(name), impl->nparams}, std::move(s), {}});
impl->nparams++;
return std::prev(ins);
}
instruction_ref module::replace_return(std::vector<instruction_ref> args) instruction_ref module::replace_return(std::vector<instruction_ref> args)
{ {
auto last = std::prev(this->end()); auto last = std::prev(this->end());
......
...@@ -504,12 +504,14 @@ static void mod_from_val(module_ref mod, ...@@ -504,12 +504,14 @@ static void mod_from_val(module_ref mod,
if(name == "@param") if(name == "@param")
{ {
output = mod->add_parameter(fields["parameter"].to<std::string>(), output = mod->insert_parameter(mod->end(),
fields["parameter"].to<std::string>(),
migraphx::from_value<shape>(node.at("shape"))); migraphx::from_value<shape>(node.at("shape")));
} }
else if(name == "@literal") else if(name == "@literal")
{ {
output = mod->add_literal(migraphx::from_value<literal>(node.at("literal"))); output =
mod->insert_literal(mod->end(), migraphx::from_value<literal>(node.at("literal")));
} }
else else
{ {
...@@ -544,11 +546,11 @@ static void mod_from_val(module_ref mod, ...@@ -544,11 +546,11 @@ static void mod_from_val(module_ref mod,
} }
else if(module_inputs.empty()) else if(module_inputs.empty())
{ {
output = mod->add_instruction(op, inputs); output = mod->insert_instruction(mod->end(), op, inputs);
} }
else else
{ {
output = mod->add_instruction(op, inputs, module_inputs); output = mod->insert_instruction(mod->end(), op, inputs, module_inputs);
} }
} }
output->set_normalized(normalized); output->set_normalized(normalized);
......
...@@ -36,7 +36,7 @@ void raw_data_to_value(value& v, const RawData& rd) ...@@ -36,7 +36,7 @@ void raw_data_to_value(value& v, const RawData& rd)
result["shape"] = migraphx::to_value(rd.get_shape()); result["shape"] = migraphx::to_value(rd.get_shape());
if(rd.get_shape().type() == shape::tuple_type) if(rd.get_shape().type() == shape::tuple_type)
result["sub"] = migraphx::to_value(rd.get_sub_objects()); result["sub"] = migraphx::to_value(rd.get_sub_objects());
else else if(not rd.empty())
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes()); result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes());
v = result; v = result;
} }
...@@ -56,7 +56,7 @@ void migraphx_from_value(const value& v, argument& a) ...@@ -56,7 +56,7 @@ void migraphx_from_value(const value& v, argument& a)
literal l = migraphx::from_value<literal>(v); literal l = migraphx::from_value<literal>(v);
a = l.get_argument(); a = l.get_argument();
} }
else else if(v.contains("sub"))
{ {
a = migraphx::from_value<std::vector<argument>>(v.at("sub")); a = migraphx::from_value<std::vector<argument>>(v.at("sub"));
} }
......
...@@ -272,7 +272,7 @@ struct find_concat_transpose ...@@ -272,7 +272,7 @@ struct find_concat_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape())); return match::name("concat")(match::all_of[match::inputs()](match::name("transpose")));
} }
void apply(module& m, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -52,6 +53,7 @@ struct cpu_literal ...@@ -52,6 +53,7 @@ struct cpu_literal
return os; return os;
} }
}; };
MIGRAPHX_REGISTER_OP(cpu_literal);
void write_literals::apply(module& m) const void write_literals::apply(module& m) const
{ {
......
...@@ -43,6 +43,9 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs) ...@@ -43,6 +43,9 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs) vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs)
{ {
if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
return {1, axis};
auto sizes = vector_sizes(inputs); auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size; std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(), std::transform(inputs.begin(),
......
...@@ -59,31 +59,30 @@ argument miopen_deconvolution::compute(context& ctx, ...@@ -59,31 +59,30 @@ argument miopen_deconvolution::compute(context& ctx,
auto w_desc = make_tensor(reshape_if_1d(args[1].get_shape())); auto w_desc = make_tensor(reshape_if_1d(args[1].get_shape()));
auto y_desc = make_tensor(reshape_if_1d(output_shape)); auto y_desc = make_tensor(reshape_if_1d(output_shape));
float alpha = 1; if(solution_id == 0)
float beta = 0; MIGRAPHX_THROW("MIOpen Deconvolution: invalid solution ID");
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha, auto status = miopenConvolutionForwardImmediate(ctx.get_stream().get_miopen(),
x_desc.get(),
args[0].implicit(),
w_desc.get(), w_desc.get(),
args[1].implicit(), args[1].implicit(),
x_desc.get(),
args[0].implicit(),
cd.get(), cd.get(),
algo,
&beta,
y_desc.get(), y_desc.get(),
args[3].implicit(), args[3].implicit(),
args[2].implicit(), args[2].implicit(),
args[2].get_shape().bytes()); args[2].get_shape().bytes(),
solution_id);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("Running deconvolution failed"); MIGRAPHX_THROW("MIOpen Deconvolution: running convolution failed");
return args[3]; return args[3];
} }
shape miopen_deconvolution::compile(context& ctx, shape miopen_deconvolution::find(context& ctx, const shape& output_shape, std::vector<shape> inputs)
const shape& output_shape,
std::vector<shape> inputs)
{ {
shape workspace_shape{}; shape workspace_shape{};
auto x_desc = make_tensor(reshape_if_1d(inputs[0])); auto x_desc = make_tensor(reshape_if_1d(inputs[0]));
auto w_desc = make_tensor(reshape_if_1d(inputs[1])); auto w_desc = make_tensor(reshape_if_1d(inputs[1]));
auto y_desc = make_tensor(reshape_if_1d(output_shape)); auto y_desc = make_tensor(reshape_if_1d(output_shape));
...@@ -119,9 +118,35 @@ shape miopen_deconvolution::compile(context& ctx, ...@@ -119,9 +118,35 @@ shape miopen_deconvolution::compile(context& ctx,
workspace_size, workspace_size,
false); false);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("Find deconvolution failed"); MIGRAPHX_THROW("MIOpen Deconvolution: find convolution failed");
handle = ctx.get_stream().get_miopen();
algo = perf.fwd_algo; algo = perf.fwd_algo;
size_t solution_count;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&solution_count);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Deconvolution: get solution count failed");
std::vector<miopenConvSolution_t> solutions(solution_count);
status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
solution_count,
&solution_count,
solutions.data());
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Deconvolution: get solution failed");
solution_id = solutions.front().solution_id;
return shape{shape::int8_type, {perf.memory}}; return shape{shape::int8_type, {perf.memory}};
} }
...@@ -129,13 +154,29 @@ void miopen_deconvolution::finalize(context& ctx, ...@@ -129,13 +154,29 @@ void miopen_deconvolution::finalize(context& ctx,
const shape& output_shape, const shape& output_shape,
std::vector<shape> inputs) std::vector<shape> inputs)
{ {
if(handle == ctx.get_stream().get_miopen()) if(cd == nullptr)
return; cd = make_deconv(op);
if(solution_id == 0)
{
// Check that workspace hasn't changed // Check that workspace hasn't changed
auto size = inputs.at(2).bytes(); auto size = inputs.at(2).bytes();
auto ws = compile(ctx, output_shape, std::move(inputs)); auto ws = find(ctx, output_shape, inputs);
if(ws.bytes() > size) if(ws.bytes() > size)
MIGRAPHX_THROW("Workspace has changed during finalization."); MIGRAPHX_THROW("MIOpen Deconvolution: workspace has changed during finalization.");
}
auto x_desc = make_tensor(reshape_if_1d(inputs[0]));
auto w_desc = make_tensor(reshape_if_1d(inputs[1]));
auto y_desc = make_tensor(reshape_if_1d(output_shape));
auto status = miopenConvolutionForwardCompileSolution(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
solution_id);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Deconvolution: compile solution failed");
} }
} // namespace gpu } // namespace gpu
......
...@@ -39,20 +39,20 @@ struct miopen_deconvolution ...@@ -39,20 +39,20 @@ struct miopen_deconvolution
op::deconvolution op; op::deconvolution op;
shared<convolution_descriptor> cd; shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{}; miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr; uint64_t solution_id = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
// TODO: Add algo return pack_join(op::deconvolution::reflect(self.op, f),
return op::convolution::reflect(self.op, f); pack(f(self.solution_id, "solution_id")));
} }
std::string name() const { return "gpu::deconv"; } std::string name() const { return "gpu::deconv"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
shape compile(context& ctx, const shape& output_shape, std::vector<shape> inputs); shape find(context& ctx, const shape& output_shape, std::vector<shape> inputs);
void finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs); void finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs);
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
......
...@@ -41,7 +41,7 @@ struct miopen_quant_convolution ...@@ -41,7 +41,7 @@ struct miopen_quant_convolution
bool int8_x4_format = false; bool int8_x4_format = false;
shared<convolution_descriptor> cd; shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{}; miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr; uint64_t solution_id = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -55,7 +55,7 @@ struct miopen_quant_convolution ...@@ -55,7 +55,7 @@ struct miopen_quant_convolution
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
shape compile(context& ctx, const shape& output_shape, std::vector<shape> inputs); shape find(context& ctx, const shape& output_shape, std::vector<shape> inputs);
void finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs); void finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs);
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
static const char* const softmax_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/softmax.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void softmax_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
softmax<${axis}>(input, output);
});
}
}
} // namespace migraphx
)__migraphx__";
struct softmax_compiler : compiler<softmax_compiler>
{
std::vector<std::string> names() const { return {"softmax"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
auto axis = v.at("axis").to<int64_t>();
auto faxis = find_fast_axis({inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(faxis == axis)
{
vec = vectorize::elements(faxis, inputs);
}
auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
auto block_size = compute_block_size(relements, 256);
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "softmax_kernel";
auto src = interpolate_string(
softmax_kernel,
{{"transformers", make_transformer_args(vec)}, {"axis", to_string(axis)}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
namespace migraphx { namespace migraphx {
...@@ -213,6 +214,13 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f) ...@@ -213,6 +214,13 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
return integral_const_array<T, f(Xs)...>{}; return integral_const_array<T, f(Xs)...>{};
} }
template <class T, T... Xs, class F>
constexpr auto transform_i(integral_const_array<T, Xs...>, F f)
{
return sequence_c<sizeof...(Xs)>(
[=](auto... is) { return integral_const_array<T, f(Xs, is)...>{}; });
}
template <class T, T... Xs, class U, U... Ys, class F> template <class T, T... Xs, class U, U... Ys, class F>
constexpr auto transform(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>, F f) constexpr auto transform(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>, F f)
{ {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#include <migraphx/kernels/array.hpp> #include <migraphx/kernels/integral_constant.hpp>
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \ #define MIGRAPHX_RETURNS(...) \
......
...@@ -175,6 +175,21 @@ constexpr auto sliced(Slicer slicer, F f) ...@@ -175,6 +175,21 @@ constexpr auto sliced(Slicer slicer, F f)
}; };
} }
template <class Input, index_int Axis>
constexpr auto compute_reduce_axis()
{
constexpr auto lens =
transform_i(get_shape_c<Input>{}.lens, [](index_int x, index_int i) -> index_int {
if(i == Axis)
return 1;
return x;
});
return make_shape(lens, get_shape_c<Input>{}.strides);
}
template <class Input, index_int Axis>
using with_axis = decltype(compute_reduce_axis<Input, Axis>());
struct block struct block
{ {
template <class Slicer> template <class Slicer>
...@@ -201,6 +216,14 @@ struct block ...@@ -201,6 +216,14 @@ struct block
if(idx.local == 0) if(idx.local == 0)
f(); f();
} }
template <class F>
__device__ auto inner(F f) const
{
return sliced(slicer, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
});
}
}; };
template <class Slicer> template <class Slicer>
...@@ -247,6 +270,17 @@ struct lane ...@@ -247,6 +270,17 @@ struct lane
{ {
f(); f();
} }
template <class F>
__device__ auto inner(F f) const
{
return sliced(slicer, [=](auto x, auto... xs) {
for(index_int j = 0; j < x.get_shape().elements(); j++)
{
f(x[j], xs[j]...);
}
});
}
}; };
template <class Slicer> template <class Slicer>
......
...@@ -32,6 +32,7 @@ namespace migraphx { ...@@ -32,6 +32,7 @@ namespace migraphx {
template <class Lens, class Strides> template <class Lens, class Strides>
struct shape struct shape
{ {
using shape_type = shape;
using index_array = typename Lens::base_array; using index_array = typename Lens::base_array;
Lens lens = {}; Lens lens = {};
Strides strides = {}; Strides strides = {};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
namespace migraphx {
template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input, Output output)
{
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input);
auto batch_sum =
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output,
input);
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -186,7 +186,6 @@ struct miopen_apply ...@@ -186,7 +186,6 @@ struct miopen_apply
add_extend_op("rnn_var_sl_shift_output"); add_extend_op("rnn_var_sl_shift_output");
add_extend_op("rnn_var_sl_shift_sequence"); add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op("scatter_none"); add_extend_op("scatter_none");
add_extend_op("softmax");
add_extend_op("topk"); add_extend_op("topk");
add_batch_norm_inference_op(); add_batch_norm_inference_op();
...@@ -301,7 +300,7 @@ struct miopen_apply ...@@ -301,7 +300,7 @@ struct miopen_apply
auto&& op = any_cast<op::deconvolution>(ins->get_operator()); auto&& op = any_cast<op::deconvolution>(ins->get_operator());
auto conv = miopen_deconvolution{op, make_deconv(op)}; auto conv = miopen_deconvolution{op, make_deconv(op)};
auto ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs())); auto ws = conv.find(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
auto workspace = insert_allocation(ins, ws); auto workspace = insert_allocation(ins, ws);
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
...@@ -332,7 +331,7 @@ struct miopen_apply ...@@ -332,7 +331,7 @@ struct miopen_apply
miopen_quant_convolution conv; miopen_quant_convolution conv;
auto compile_quant_conv_with_format = [&](bool format) { auto compile_quant_conv_with_format = [&](bool format) {
conv = miopen_quant_convolution{op, format, make_conv(op)}; conv = miopen_quant_convolution{op, format, make_conv(op)};
ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs())); ws = conv.find(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
}; };
try try
......
...@@ -67,7 +67,7 @@ argument miopen_quant_convolution::compute(context& ctx, ...@@ -67,7 +67,7 @@ argument miopen_quant_convolution::compute(context& ctx,
return args[3]; return args[3];
} }
shape miopen_quant_convolution::compile(context& ctx, shape miopen_quant_convolution::find(context& ctx,
const shape& output_shape, const shape& output_shape,
std::vector<shape> inputs) std::vector<shape> inputs)
{ {
...@@ -92,8 +92,8 @@ shape miopen_quant_convolution::compile(context& ctx, ...@@ -92,8 +92,8 @@ shape miopen_quant_convolution::compile(context& ctx,
x_shape = pack_int8_shape(x_shape); x_shape = pack_int8_shape(x_shape);
w_shape = pack_int8_shape(w_shape); w_shape = pack_int8_shape(w_shape);
} }
auto arg_vec4_x = to_gpu(generate_argument(x_shape)); auto x = to_gpu(generate_argument(x_shape));
auto arg_vec4_w = to_gpu(generate_argument(w_shape)); auto w = to_gpu(generate_argument(w_shape));
auto y = allocate_gpu(output_shape); auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape); auto workspace = allocate_gpu(workspace_shape);
...@@ -101,9 +101,9 @@ shape miopen_quant_convolution::compile(context& ctx, ...@@ -101,9 +101,9 @@ shape miopen_quant_convolution::compile(context& ctx,
miopenConvAlgoPerf_t perf; miopenConvAlgoPerf_t perf;
auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(), auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(), x_desc.get(),
arg_vec4_x.implicit(), x.implicit(),
w_desc.get(), w_desc.get(),
arg_vec4_w.implicit(), w.implicit(),
cd.get(), cd.get(),
y_desc.get(), y_desc.get(),
y.implicit(), y.implicit(),
...@@ -114,11 +114,35 @@ shape miopen_quant_convolution::compile(context& ctx, ...@@ -114,11 +114,35 @@ shape miopen_quant_convolution::compile(context& ctx,
workspace_size, workspace_size,
false); false);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
{ MIGRAPHX_THROW("MIOpen Quant Convolution: find convolution failed");
MIGRAPHX_THROW("QUANT_CONVOLUTION: find convolution failed");
}
handle = ctx.get_stream().get_miopen();
algo = perf.fwd_algo; algo = perf.fwd_algo;
size_t solution_count;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&solution_count);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: get solution count failed");
std::vector<miopenConvSolution_t> solutions(solution_count);
status = miopenConvolutionForwardGetSolution(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
solution_count,
&solution_count,
solutions.data());
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: get solution failed");
solution_id = solutions.front().solution_id;
return shape{shape::int8_type, {perf.memory}}; return shape{shape::int8_type, {perf.memory}};
} }
...@@ -126,13 +150,29 @@ void miopen_quant_convolution::finalize(context& ctx, ...@@ -126,13 +150,29 @@ void miopen_quant_convolution::finalize(context& ctx,
const shape& output_shape, const shape& output_shape,
std::vector<shape> inputs) std::vector<shape> inputs)
{ {
if(handle == ctx.get_stream().get_miopen()) if(cd == nullptr)
return; cd = make_conv(op);
if(solution_id == 0)
{
// Check that workspace hasn't changed // Check that workspace hasn't changed
auto size = inputs.at(2).bytes(); auto size = inputs.at(2).bytes();
auto ws = compile(ctx, output_shape, std::move(inputs)); auto ws = find(ctx, output_shape, inputs);
if(ws.bytes() > size) if(ws.bytes() > size)
MIGRAPHX_THROW("Workspace has changed during finalization."); MIGRAPHX_THROW("MIOpen Quant Convolution: workspace has changed during finalization.");
}
auto x_desc = make_tensor(inputs[0], int8_x4_format);
auto w_desc = make_tensor(inputs[1], int8_x4_format);
auto y_desc = make_tensor(output_shape);
auto status = miopenConvolutionForwardCompileSolution(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
solution_id);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: compile solution failed");
} }
shape miopen_quant_convolution::pack_int8_shape(const shape& s) const shape miopen_quant_convolution::pack_int8_shape(const shape& s) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment