Commit 076bfefb authored by Paul's avatar Paul
Browse files

Compile format

parent f208c2f3
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/gpu/rocblas.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -59,16 +60,33 @@ struct miopen_op ...@@ -59,16 +60,33 @@ struct miopen_op
}; };
MIGRAPHX_REGISTER_OP(miopen_op); MIGRAPHX_REGISTER_OP(miopen_op);
std::size_t compile_miopen::compile(operation& op, instruction_ref ins, bool format) const
{
op.from_value({{"int8_x4_format", format}});
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs()));
return v.get("workspace", 0);
}
void compile_miopen::apply(module& m) const void compile_miopen::apply(module& m) const
{ {
assert(ctx); assert(ctx);
const bool int8_x4_format = get_int8_x4_format(any_cast<migraphx::gpu::context>(*ctx));
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "gpu::miopen_op") if(ins->name() != "gpu::miopen_op")
continue; continue;
auto op = any_cast<miopen_op>(ins->get_operator()).op; auto op = any_cast<miopen_op>(ins->get_operator()).op;
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(ins->inputs())); std::size_t ws = 0;
std::size_t ws = v.get("workspace", 0); try
{
// for the regular convolution and deconvolution, this try would always succeed
ws = compile(op, ins, int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
ws = compile(op, ins, not int8_x4_format);
}
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto alloc = m.insert_instruction( auto alloc = m.insert_instruction(
ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}})); ins, make_op("allocate", {{"shape", to_value(shape{shape::int8_type, {ws}})}}));
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP #define MIGRAPHX_GUARD_GPU_COMPILE_MIOPEN_HPP
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <string> #include <string>
namespace migraphx { namespace migraphx {
...@@ -32,6 +33,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -32,6 +33,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
struct context; struct context;
struct operation;
namespace gpu { namespace gpu {
...@@ -40,6 +42,7 @@ struct compile_miopen ...@@ -40,6 +42,7 @@ struct compile_miopen
context* ctx = nullptr; context* ctx = nullptr;
std::string name() const { return "gpu::compile_miopen"; } std::string name() const { return "gpu::compile_miopen"; }
void apply(module& m) const; void apply(module& m) const;
std::size_t compile(operation& op, instruction_ref ins, bool format) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -145,12 +145,9 @@ struct miopen_convolution ...@@ -145,12 +145,9 @@ struct miopen_convolution
#endif #endif
} }
inline void set_conv_descriptor() void set_conv_descriptor()
{ {
if(cd == nullptr) cd = (op.name() == "deconvolution") ? make_deconv(op) : make_conv(op);
{
cd = (op.name() == "deconvolution") ? make_deconv(op) : make_conv(op);
}
} }
value compile(migraphx::context& ctx, const shape& output, const std::vector<shape>& input) value compile(migraphx::context& ctx, const shape& output, const std::vector<shape>& input)
...@@ -240,7 +237,6 @@ struct miopen_convolution ...@@ -240,7 +237,6 @@ struct miopen_convolution
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed"); MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed");
algo = perf.fwd_algo; algo = perf.fwd_algo;
size_t solution_count; size_t solution_count;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(), status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment