Commit e88d46a6 authored by Umang Yadav's avatar Umang Yadav
Browse files

disable DPP For FP8

parent 32033d85
...@@ -146,6 +146,7 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler> ...@@ -146,6 +146,7 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
vectorize vec{}; vectorize vec{};
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block") if(algo == "block")
{ {
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
...@@ -169,13 +170,20 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler> ...@@ -169,13 +170,20 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
options.kernel_name = "reduce_kernel"; options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }"; std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel, auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()}, {{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})}, {"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)}, {"read", v.get("read", identity)},
{"write", v.get("write", identity)}, {"write", v.get("write", identity)},
{"algo", algo}, {"algo", algo},
{"transformers", make_transformer_args(vec)}, {"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
// disable DPP for FP8 for now,, TODO: need to disable for Any FP8 types
if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) {
return s.type() == migraphx::shape::fp8e4m3fnuz_type;
}))
{
options.params += "-DMIGRAPHX_HAS_DPP=0 ";
}
options.params += "-Wno-float-equal"; options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
...@@ -266,13 +274,13 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler> ...@@ -266,13 +274,13 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto src = interpolate_string( auto src = interpolate_string(
fused_reduce_kernel, fused_reduce_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"algo", algo}, {"algo", algo},
{"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"}, {"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"},
{"lambda", v.at("lambda").to<std::string>()}, {"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(vec)}, {"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal"; options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment