Commit 2363d06c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into int8_quantize

parents 6893dea9 3540f1b9
......@@ -39,8 +39,6 @@ else()
set(MIGRAPHX_ENABLE_GPU Off CACHE BOOL "")
endif()
set(MIGRAPHX_ENABLE_TF Off CACHE BOOL "")
add_compile_options(-std=c++14)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
......
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
inline std::size_t calculate_padding(std::size_t weight_dim, std::size_t dilation)
{
return (dilation * (weight_dim - 1)) / 2;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -12,12 +12,7 @@ if(MIGRAPHX_ENABLE_PYTHON)
C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden
)
if(MIGRAPHX_ENABLE_TF)
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_cpu)
target_compile_definitions(migraphx_py PRIVATE -DENABLE_TF)
else()
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu)
endif()
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_cpu)
if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(migraphx_py PRIVATE migraphx_gpu)
target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU)
......
......@@ -6,11 +6,8 @@
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/stringutils.hpp>
#ifdef ENABLE_TF
#include <migraphx/tf.hpp>
#else
#include <migraphx/onnx.hpp>
#endif
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
......@@ -162,16 +159,13 @@ PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
#ifdef ENABLE_TF
m.def("parse_tf",
&migraphx::parse_tf,
"Parse tf protobuf (default format is nhwc)",
py::arg("filename"),
py::arg("is_nhwc") = true);
#else
m.def("parse_onnx", &migraphx::parse_onnx);
#endif
m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu")
return migraphx::cpu::target{};
......
......@@ -17,6 +17,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/pad_calc.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -321,7 +322,12 @@ struct tf_parser
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
op.padding_mode = op::padding_mode_t::same;
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
op.padding[0] = calculate_padding(weight_h, op.dilation[0]);
op.padding[1] = calculate_padding(weight_w, op.dilation[1]);
}
else if(pad_mode.find("VALID") != std::string::npos)
{
......@@ -354,14 +360,7 @@ struct tf_parser
op::convolution op;
size_t num_channels = args[0]->get_shape().lens()[1];
op.group = num_channels;
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
}
if(contains(attributes, "strides"))
{
std::vector<size_t> stride;
......@@ -374,6 +373,19 @@ struct tf_parser
op.stride[0] = stride[2];
op.stride[1] = stride[3];
}
if(contains(attributes, "dilations"))
{
std::vector<size_t> dilation;
copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
reorder_data(dilation);
if(dilation.size() != 4)
{
MIGRAPHX_THROW("dilation should have 4 values");
}
op.dilation[0] = dilation[2];
op.dilation[1] = dilation[3];
}
auto weights = args[1];
// check if weights are from a constant
if(weights->name() != "@param")
......@@ -388,6 +400,24 @@ struct tf_parser
}
}
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
std::vector<size_t> weight_dims = weights->get_shape().lens();
size_t weight_h = weight_dims[2];
size_t weight_w = weight_dims[3];
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
op.padding[0] = calculate_padding(weight_h, op.dilation[0]);
op.padding[1] = calculate_padding(weight_w, op.dilation[1]);
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
}
std::vector<int64_t> new_weights_shape;
copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
......@@ -513,18 +543,6 @@ struct tf_parser
{
op::pooling op{starts_with(name, "Max") ? "max" : "average"};
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
}
if(contains(attributes, "strides"))
{
std::vector<size_t> stride;
......@@ -549,6 +567,20 @@ struct tf_parser
op.lengths[0] = ksize[2];
op.lengths[1] = ksize[3];
}
if(contains(attributes, "padding"))
{
const std::string& pad_mode = attributes.at("padding").s();
if(pad_mode.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
op.padding[0] = calculate_padding(op.lengths[0], 1);
op.padding[1] = calculate_padding(op.lengths[1], 1);
}
else if(pad_mode.find("VALID") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::valid;
}
}
return prog.add_instruction(op, args[0]);
}
......
......@@ -109,6 +109,7 @@ TEST_CASE(conv_test)
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
op.padding = {1, 1};
op.stride = {1, 1};
op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
......@@ -131,6 +132,7 @@ TEST_CASE(depthwiseconv_test)
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
op.padding = {1, 1};
op.stride = {1, 1};
op.dilation = {1, 1};
op.group = 3;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment