"vscode:/vscode.git/clone" did not exist on "9310bff0ab5f4636601a00e49997f21ee6e67029"
Commit 8a1619d8 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

testing changes

parent 3f566882
......@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#include <migraphx/pass_manager.hpp>
#include <string>
#include <vector>
#include <migraphx/config.hpp>
......@@ -41,7 +42,7 @@ struct quantize_fp16_pass
{
std::vector<std::string> ins_names = {"all"};
std::string name() const { return "quantize_fp16"; }
void apply(module& m) const;
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -187,6 +187,7 @@ struct raw_data : raw_data_base
std::string to_string() const
{
std::stringstream ss;
ss.precision(std::numeric_limits<double>::max_digits10);
ss << static_cast<const Derived&>(*this);
return ss.str();
}
......
......@@ -21,6 +21,9 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantize_fp16.hpp>
......@@ -31,6 +34,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -74,9 +78,41 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
// Replace original instruction
m.replace_instruction(ins, converted_ins);
}
// m.debug_print();
// std::cout << "HERE" << std::endl;
}
void quantize_fp16_pass::apply(module& m) const { quantize_module(m, ins_names); }
static void quantize_params(module& m)
{
std::vector<std::string> param_names = m.get_parameter_names();
std::unordered_set<std::string> processed_params;
for( auto param_name : param_names)
{
auto param = m.get_parameter(param_name);
// m.debug_print(param);
if( not contains(processed_params, param_name) and param->get_shape().type() == shape::float_type)
{
auto new_param = m.add_parameter(param_name, migraphx::shape{shape::half_type, param->get_shape().lens()});
// m.debug_print(new_param);
// m.debug_print();
// m.debug_print();
m.replace_instruction(param, new_param);
// std::cout << "HERE" << std::endl;
}
processed_params.insert(param_name);
}
}
void quantize_fp16_pass::apply(module_pass_manager& mpm) const {
module m = mpm.get_module();
quantize_module(m, ins_names);
// mpm.run_pass(dead_code_elimination{});
// m.debug_print();
// quantize_params(m);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -22,6 +22,7 @@
* THE SOFTWARE.
*/
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/ref/lowering.hpp>
#include <migraphx/register_target.hpp>
......@@ -45,6 +46,8 @@ std::vector<pass> target::get_passes(migraphx::context&, const compile_options&)
return {normalize_ops{},
eliminate_pad{},
dead_code_elimination{},
simplify_qdq{},
dead_code_elimination{},
insert_pad{},
dead_code_elimination{},
rewrite_rnn{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment