Commit b3813f86 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

update pass to remove old parms

parent f0f9120a
......@@ -78,41 +78,31 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
// Replace original instruction
m.replace_instruction(ins, converted_ins);
}
// m.debug_print();
// std::cout << "HERE" << std::endl;
}
static void quantize_params(module& m)
{
std::vector<std::string> param_names = m.get_parameter_names();
std::unordered_set<std::string> processed_params;
for(auto param_name : param_names)
{
auto param = m.get_parameter(param_name);
// m.debug_print(param);
if(not contains(processed_params, param_name) and
param->get_shape().type() == shape::float_type)
auto param_shape = param->get_shape();
if(param_shape.type() == shape::float_type)
{
auto new_param = m.add_parameter(
param_name, migraphx::shape{shape::half_type, param->get_shape().lens()});
// m.debug_print(new_param);
// m.debug_print();
// m.debug_print();
param_name, migraphx::shape{shape::half_type, param_shape.lens()});
m.replace_instruction(param, new_param);
// std::cout << "HERE" << std::endl;
m.remove_instruction(param);
}
processed_params.insert(param_name);
}
}
void quantize_fp16_pass::apply(module_pass_manager& mpm) const
{
module m = mpm.get_module();
module& m = mpm.get_module();
quantize_module(m, ins_names);
// mpm.run_pass(dead_code_elimination{});
// m.debug_print();
// quantize_params(m);
mpm.run_pass(dead_code_elimination{});
quantize_params(m);
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -188,8 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::fmaxf)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::fminf)
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto max(const T& a, const T& b)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment