Commit 38218edc authored by Umang Yadav's avatar Umang Yadav
Browse files

few changes

parent f155b0e6
...@@ -146,7 +146,6 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler> ...@@ -146,7 +146,6 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
vectorize vec{}; vectorize vec{};
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block") if(algo == "block")
{ {
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
...@@ -170,13 +169,13 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler> ...@@ -170,13 +169,13 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
options.kernel_name = "reduce_kernel"; options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }"; std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel, auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()}, {{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})}, {"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)}, {"read", v.get("read", identity)},
{"write", v.get("write", identity)}, {"write", v.get("write", identity)},
{"algo", algo}, {"algo", algo},
{"transformers", make_transformer_args(vec)}, {"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal"; options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
...@@ -267,13 +266,13 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler> ...@@ -267,13 +266,13 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto src = interpolate_string( auto src = interpolate_string(
fused_reduce_kernel, fused_reduce_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"algo", algo}, {"algo", algo},
{"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"}, {"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"},
{"lambda", v.at("lambda").to<std::string>()}, {"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(vec)}, {"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal"; options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -501,9 +501,7 @@ class numeric_limits<fp8e5m2fnuz> ...@@ -501,9 +501,7 @@ class numeric_limits<fp8e5m2fnuz>
{ {
return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits());
} }
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static constexpr __device__ fp8e5m2fnuz min() static constexpr __device__ fp8e5m2fnuz min()
{ {
return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits());
...@@ -528,9 +526,7 @@ class numeric_limits<fp8e5m2> ...@@ -528,9 +526,7 @@ class numeric_limits<fp8e5m2>
} }
static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); }
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); }
static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); }
......
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/shape.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
......
...@@ -22,9 +22,9 @@ ...@@ -22,9 +22,9 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/shape.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
......
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/shape.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/shape.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment