Unverified Commit a5065265 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Add build flag for fast math (#639)



* add flag

* formatting

* remove env variable

* fix api expression

* add api test

* add api test

* add op test

* formatting

* fix function name

* fix syntax

* formatting

* modify test

* remove test and update doc

* move test to new file

* formatting

* revert test files

* rewrite check

* New
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 4789b387
...@@ -151,12 +151,13 @@ program ...@@ -151,12 +151,13 @@ program
:rtype: shape :rtype: shape
.. py:method:: compile(t, offload_copy=True) .. py:method:: compile(t, offload_copy=True, fast_math=True)
Compiles the program for the target and optimizes it. Compiles the program for the target and optimizes it.
:param target t: This is the target to compile the program for. :param target t: This is the target to compile the program for.
:param bool offload_copy: For targets with offloaded memory(such as the gpu), this will insert instructions during compilation to copy the input parameters to the offloaded memory and to copy the final result from the offloaded memory back to main memory. :param bool offload_copy: For targets with offloaded memory(such as the gpu), this will insert instructions during compilation to copy the input parameters to the offloaded memory and to copy the final result from the offloaded memory back to main memory.
:param bool fast_math: Optimize math functions to use faster approximate versions. There may be slight accuracy degredation when enabled.
.. py:method:: run(params) .. py:method:: run(params)
......
...@@ -70,6 +70,7 @@ migraphx::compile_options to_compile_options(const migraphx_compile_options& opt ...@@ -70,6 +70,7 @@ migraphx::compile_options to_compile_options(const migraphx_compile_options& opt
{ {
migraphx::compile_options result{}; migraphx::compile_options result{};
result.offload_copy = options.offload_copy; result.offload_copy = options.offload_copy;
result.fast_math = options.fast_math;
return result; return result;
} }
......
...@@ -42,6 +42,7 @@ typedef enum { ...@@ -42,6 +42,7 @@ typedef enum {
typedef struct typedef struct
{ {
bool offload_copy; bool offload_copy;
bool fast_math;
} migraphx_compile_options; } migraphx_compile_options;
typedef struct typedef struct
......
...@@ -221,6 +221,7 @@ struct compiler ...@@ -221,6 +221,7 @@ struct compiler
program_params parameters; program_params parameters;
bool gpu = true; bool gpu = true;
bool offload_copy = false; bool offload_copy = false;
bool fast_math = true;
int quantize = 0; int quantize = 0;
std::vector<std::string> fill0; std::vector<std::string> fill0;
...@@ -235,6 +236,10 @@ struct compiler ...@@ -235,6 +236,10 @@ struct compiler
{"--enable-offload-copy"}, {"--enable-offload-copy"},
ap.help("Enable implicit offload copying"), ap.help("Enable implicit offload copying"),
ap.set_value(true)); ap.set_value(true));
ap(fast_math,
{"--disable-fast-math"},
ap.help("Disable fast math optimization"),
ap.set_value(false));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(q_fp16)); ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(q_fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(q_int8)); ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(q_int8));
} }
...@@ -261,6 +266,7 @@ struct compiler ...@@ -261,6 +266,7 @@ struct compiler
} }
compile_options options; compile_options options;
options.offload_copy = offload_copy; options.offload_copy = offload_copy;
options.fast_math = fast_math;
p.compile(t, options); p.compile(t, options);
l.save(p); l.save(p);
return p; return p;
...@@ -300,6 +306,7 @@ struct verify : command<verify> ...@@ -300,6 +306,7 @@ struct verify : command<verify>
bool per_instruction = false; bool per_instruction = false;
bool reduce = false; bool reduce = false;
bool offload_copy = false; bool offload_copy = false;
bool fast_math = true;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
l.parse(ap); l.parse(ap);
...@@ -308,6 +315,10 @@ struct verify : command<verify> ...@@ -308,6 +315,10 @@ struct verify : command<verify>
{"--enable-offload-copy"}, {"--enable-offload-copy"},
ap.help("Enable implicit offload copying"), ap.help("Enable implicit offload copying"),
ap.set_value(true)); ap.set_value(true));
ap(fast_math,
{"--disable-fast-math"},
ap.help("Disable fast math optimization"),
ap.set_value(false));
ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors")); ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors"));
ap(per_instruction, ap(per_instruction,
{"-i", "--per-instruction"}, {"-i", "--per-instruction"},
...@@ -324,7 +335,7 @@ struct verify : command<verify> ...@@ -324,7 +335,7 @@ struct verify : command<verify>
compile_options options; compile_options options;
options.offload_copy = offload_copy; options.offload_copy = offload_copy;
options.fast_math = fast_math;
auto m = parameters.generate(p, false); auto m = parameters.generate(p, false);
if(per_instruction) if(per_instruction)
......
...@@ -10,6 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,6 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct compile_options struct compile_options
{ {
bool offload_copy = false; bool offload_copy = false;
bool fast_math = true;
tracer trace{}; tracer trace{};
}; };
......
...@@ -172,14 +172,17 @@ PYBIND11_MODULE(migraphx, m) ...@@ -172,14 +172,17 @@ PYBIND11_MODULE(migraphx, m)
.def("get_parameter_names", &migraphx::program::get_parameter_names) .def("get_parameter_names", &migraphx::program::get_parameter_names)
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_output_shapes", &migraphx::program::get_output_shapes) .def("get_output_shapes", &migraphx::program::get_output_shapes)
.def("compile", .def(
[](migraphx::program& p, const migraphx::target& t, bool offload_copy) { "compile",
[](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) {
migraphx::compile_options options; migraphx::compile_options options;
options.offload_copy = offload_copy; options.offload_copy = offload_copy;
options.fast_math = fast_math;
p.compile(t, options); p.compile(t, options);
}, },
py::arg("t"), py::arg("t"),
py::arg("offload_copy") = true) py::arg("offload_copy") = true,
py::arg("fast_math") = true)
.def("run", .def("run",
[](migraphx::program& p, py::dict params) { [](migraphx::program& p, py::dict params) {
migraphx::program::parameter_map pm; migraphx::program::parameter_map pm;
......
...@@ -29,7 +29,6 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -29,7 +29,6 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MIOPEN_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MIOPEN_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FAST_GELU)
struct fusion struct fusion
{ {
...@@ -396,6 +395,7 @@ struct find_add_gelu ...@@ -396,6 +395,7 @@ struct find_add_gelu
struct find_gelu_new struct find_gelu_new
{ {
bool fast_math = true;
static auto pow_fn() static auto pow_fn()
{ {
...@@ -430,7 +430,7 @@ struct find_gelu_new ...@@ -430,7 +430,7 @@ struct find_gelu_new
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto args = ins->inputs(); auto args = ins->inputs();
if(enabled(MIGRAPHX_DISABLE_FAST_GELU{})) if(not fast_math)
p.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back()); p.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back());
else else
p.replace_instruction(ins, hip_gelu{}, x_ins, args.back()); p.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
...@@ -807,7 +807,7 @@ struct find_commutative_broadcast ...@@ -807,7 +807,7 @@ struct find_commutative_broadcast
void fuse_ops::apply(program& p) const void fuse_ops::apply(program& p) const
{ {
match::find_matches(p, find_gelu{}, find_gelu_new{}); match::find_matches(p, find_gelu{}, find_gelu_new{fast_math});
run_passes(p, {dead_code_elimination{}}); run_passes(p, {dead_code_elimination{}});
match::find_matches(p, find_triadd{}); match::find_matches(p, find_triadd{});
match::find_matches(p, match::find_matches(p,
......
...@@ -13,6 +13,7 @@ namespace gpu { ...@@ -13,6 +13,7 @@ namespace gpu {
struct fuse_ops struct fuse_ops
{ {
context* ctx = nullptr; context* ctx = nullptr;
bool fast_math = true;
std::string name() const { return "gpu::fuse_ops"; } std::string name() const { return "gpu::fuse_ops"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
...@@ -73,7 +73,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -73,7 +73,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{}, pack_int8_args{},
dead_code_elimination{}, dead_code_elimination{},
fuse_ops{&ctx}, fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})}, schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})},
......
#include <test.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/iterator_for.hpp>
#include "test_utils.hpp"
#include "test.hpp"
migraphx::program create_gelu()
{
migraphx::program p;
std::vector<float> data0 = {0.044715};
std::vector<float> data1 = {0.797885};
std::vector<float> data2 = {3};
std::vector<float> data3 = {0.5};
migraphx::shape s0{migraphx::shape::float_type, {1}};
std::vector<size_t> x_dims{1, 1, 5};
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, x_dims});
auto const_val = p.add_literal(migraphx::literal{s0, data0});
auto sqrt_2_pi = p.add_literal(migraphx::literal{s0, data1});
auto three_val = p.add_literal(migraphx::literal{s0, data2});
auto half_val = p.add_literal(migraphx::literal{s0, data3});
auto mbcast_3 = p.add_instruction(migraphx::op::multibroadcast{x_dims}, three_val);
auto pow_op = p.add_instruction(migraphx::op::pow{}, x, mbcast_3);
auto mbcast_const = p.add_instruction(migraphx::op::multibroadcast{x_dims}, const_val);
auto mul_const = p.add_instruction(migraphx::op::mul{}, mbcast_const, pow_op);
auto add_x = p.add_instruction(migraphx::op::add{}, x, mul_const);
auto mbcast_sqrt_2_pi = p.add_instruction(migraphx::op::multibroadcast{x_dims}, sqrt_2_pi);
auto mul_add_x = p.add_instruction(migraphx::op::mul{}, mbcast_sqrt_2_pi, add_x);
auto tanh_op = p.add_instruction(migraphx::op::tanh{}, mul_add_x);
auto mbcast_half = p.add_instruction(migraphx::op::multibroadcast{x_dims}, half_val);
auto mul_half = p.add_instruction(migraphx::op::mul{}, mbcast_half, tanh_op);
auto add_mul_half = p.add_instruction(migraphx::op::add{}, mul_half, mbcast_half);
auto mul_x = p.add_instruction(migraphx::op::mul{}, x, add_mul_half);
p.add_return({mul_x});
return p;
}
TEST_CASE(enable_fast_gelu)
{
migraphx::program p = create_gelu();
p.compile(migraphx::gpu::target{});
CHECK(any_of(p, [&](auto&& i) { return i.name() == "gpu::gelu"; }));
}
TEST_CASE(disable_fast_gelu)
{
migraphx::program p = create_gelu();
migraphx::compile_options options;
options.fast_math = false;
p.compile(migraphx::gpu::target{}, options);
CHECK(any_of(p, [&](auto&& i) { return i.name() == "gpu::gelu_new"; }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -70,6 +70,7 @@ migraphx::compile_options to_compile_options(const migraphx_compile_options& opt ...@@ -70,6 +70,7 @@ migraphx::compile_options to_compile_options(const migraphx_compile_options& opt
{ {
migraphx::compile_options result{}; migraphx::compile_options result{};
result.offload_copy = options.offload_copy; result.offload_copy = options.offload_copy;
result.fast_math = options.fast_math;
return result; return result;
} }
......
...@@ -42,6 +42,7 @@ typedef enum { ...@@ -42,6 +42,7 @@ typedef enum {
typedef struct typedef struct
{ {
bool offload_copy; bool offload_copy;
bool fast_math;
} migraphx_compile_options; } migraphx_compile_options;
typedef struct typedef struct
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment