Commit 076bfefb authored by Paul's avatar Paul
Browse files

Compile format

parent f208c2f3
......@@ -29,6 +29,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/rocblas.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -59,16 +60,33 @@ struct miopen_op
};
MIGRAPHX_REGISTER_OP(miopen_op);
std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool format) const
{
op.from_value({{"int8_x4_format", format}});
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs()));
return v.get("workspace", 0);
}
void compile_miopen::apply(module& m) const
{
assert(ctx);
const bool int8_x4_format = get_int8_x4_format(any_cast<migraphx::gpu::context>(*ctx));
for(auto ins : iterator_for(m))
{
if(ins->name() != "gpu::miopen_op")
continue;
auto op = any_cast<miopen_op>(ins->get_operator()).op;
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs()));
std::size_t ws = v.get("workspace", 0);
std::size_t ws = 0;
try
{
// for the regular convolution and deconvolution, this try would always succeed
ws = compile(op, ins, int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
ws = compile(op, ins, not int8_x4_format);
}
auto inputs = ins->inputs();
auto alloc = m.insert_instruction(
ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}}));
......
......@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <string>
namespace migraphx {
......@@ -32,6 +33,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct context;
struct operation;
namespace gpu {
......@@ -40,6 +42,7 @@ struct compile_miopen
context* ctx = nullptr;
std::string name() const { return "gpu::compile_miopen"; }
void apply(module& m) const;
std::size_t compile(operation& op, instruction_ref ins, bool format) const;
};
} // namespace gpu
......
......@@ -145,12 +145,9 @@ struct miopen_convolution
#endif
}
inline void set_conv_descriptor()
void set_conv_descriptor()
{
if(cd == nullptr)
{
cd = (op.name() == "deconvolution") ? make_deconv(op) : make_conv(op);
}
cd = (op.name() == "deconvolution") ? make_deconv(op) : make_conv(op);
}
value compile(migraphx::context& ctx, const shape& output, const std::vector<shape>& input)
......@@ -240,7 +237,6 @@ struct miopen_convolution
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed");
algo = perf.fwd_algo;
size_t solution_count;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment