Commit 9db8a28d authored by Paul's avatar Paul
Browse files

Merge

parents 1f8aa24f 4b1c1c41
...@@ -11,7 +11,7 @@ jobs: ...@@ -11,7 +11,7 @@ jobs:
with: with:
access_token: ${{ github.token }} access_token: ${{ github.token }}
tidy: tidy:
runs-on: ubuntu-18.04 runs-on: ubuntu-20.04
steps: steps:
- name: Free space - name: Free space
...@@ -61,7 +61,7 @@ jobs: ...@@ -61,7 +61,7 @@ jobs:
make -j2 -k onnx-proto tf-proto tidy make -j2 -k onnx-proto tf-proto tidy
cppcheck: cppcheck:
runs-on: ubuntu-18.04 runs-on: ubuntu-20.04
steps: steps:
- name: Free space - name: Free space
...@@ -106,7 +106,7 @@ jobs: ...@@ -106,7 +106,7 @@ jobs:
make -j2 cppcheck make -j2 cppcheck
format: format:
runs-on: ubuntu-18.04 runs-on: ubuntu-20.04
steps: steps:
- name: Free space - name: Free space
...@@ -142,7 +142,7 @@ jobs: ...@@ -142,7 +142,7 @@ jobs:
| xargs -n 1 -P 1 -I{} -t sh -c 'yapf {} | diff - {}' | xargs -n 1 -P 1 -I{} -t sh -c 'yapf {} | diff - {}'
pyflakes: pyflakes:
runs-on: ubuntu-18.04 runs-on: ubuntu-20.04
steps: steps:
- name: Free space - name: Free space
...@@ -163,7 +163,7 @@ jobs: ...@@ -163,7 +163,7 @@ jobs:
mypy tools/api.py mypy tools/api.py
licensing: licensing:
runs-on: ubuntu-18.04 runs-on: ubuntu-20.04
steps: steps:
- name: Free space - name: Free space
...@@ -190,7 +190,6 @@ jobs: ...@@ -190,7 +190,6 @@ jobs:
strategy: strategy:
matrix: matrix:
os: os:
- ubuntu-18.04
- ubuntu-20.04 - ubuntu-20.04
configuration: configuration:
- debug - debug
...@@ -204,7 +203,7 @@ jobs: ...@@ -204,7 +203,7 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
python-version: 3.6 python-version: 3.7
- name: Cache dependencies - name: Cache dependencies
# Ignore the failure of a step and avoid terminating the job. # Ignore the failure of a step and avoid terminating the job.
continue-on-error: true continue-on-error: true
...@@ -287,7 +286,6 @@ jobs: ...@@ -287,7 +286,6 @@ jobs:
strategy: strategy:
matrix: matrix:
os: os:
- ubuntu-18.04
- ubuntu-20.04 - ubuntu-20.04
configuration: configuration:
- debug - debug
...@@ -301,7 +299,7 @@ jobs: ...@@ -301,7 +299,7 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
python-version: 3.6 python-version: 3.7
- name: Cache dependencies - name: Cache dependencies
# Ignore the failure of a step and avoid terminating the job. # Ignore the failure of a step and avoid terminating the job.
continue-on-error: true continue-on-error: true
......
...@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES) ...@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 2.4) rocm_setup_version(VERSION 2.5)
set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR})
option( BUILD_SHARED_LIBS "Build as a shared library" ON ) option( BUILD_SHARED_LIBS "Build as a shared library" ON )
...@@ -114,6 +114,7 @@ rocm_enable_clang_tidy( ...@@ -114,6 +114,7 @@ rocm_enable_clang_tidy(
hicpp-signed-bitwise hicpp-signed-bitwise
llvm-namespace-comment llvm-namespace-comment
misc-* misc-*
-misc-confusable-identifiers
modernize-* modernize-*
performance-* performance-*
readability-* readability-*
......
...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local ...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN dpkg --add-architecture i386 RUN dpkg --add-architecture i386
# Add rocm repository # Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.0.2/ ubuntu main > /etc/apt/sources.list.d/rocm.list' RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.3/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
...@@ -71,7 +71,7 @@ RUN /download_models.sh && rm /download_models.sh ...@@ -71,7 +71,7 @@ RUN /download_models.sh && rm /download_models.sh
# Install latest ccache version # Install latest ccache version
RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake
RUN cget -p $PREFIX install ccache@v4.1 RUN cget -p $PREFIX install ccache@v4.1 -DENABLE_TESTING=OFF
# Install newer cmake for onnx runtime # Install newer cmake for onnx runtime
RUN cget -p /opt/cmake install kitware/cmake@v3.13.4 RUN cget -p /opt/cmake install kitware/cmake@v3.13.4
...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@e8e77eb16be413d301ea8509726d47f265d9011f -DBUILD_MIXR_TARGET=On RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@c0723a7e50043d973cb73ae51dc30d36679ee7e5 -DBUILD_MIXR_TARGET=On
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
pfultz2/rocm-recipes ROCmSoftwarePlatform/rocm-recipes
facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cmake
ccache@v4.1 ccache@v4.1 -DENABLE_TESTING=OFF
pcre,pfultz2/pcre@8.45 -H sha256:d6f7182602a775a7d500a0cedca6449af0400c6493951513046d17615ed0bf11 pcre,pfultz2/pcre@8.45 -H sha256:d6f7182602a775a7d500a0cedca6449af0400c6493951513046d17615ed0bf11
danmar/cppcheck@2.9 -DHAVE_RULES=1 danmar/cppcheck@2.9 -DHAVE_RULES=1
RadeonOpenCompute/rocm-cmake@1ebf7e7bc61bb5e949c171562b421264065230a7 --build RadeonOpenCompute/rocm-cmake@1ebf7e7bc61bb5e949c171562b421264065230a7 --build
......
...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local ...@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN dpkg --add-architecture i386 RUN dpkg --add-architecture i386
# Add rocm repository # Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.0.2/ ubuntu main > /etc/apt/sources.list.d/rocm.list' RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.3/ ubuntu main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
cxx = ${rocm_path}/llvm/bin/clang++ cxx = ${rocm_path}/llvm/bin/clang++
cc = ${rocm_path}/llvm/bin/clang cc = ${rocm_path}/llvm/bin/clang
deps = deps =
pfultz2/rocm-recipes ROCmSoftwarePlatform/rocm-recipes
-f requirements.txt -f requirements.txt
[gh] [gh]
...@@ -24,4 +24,4 @@ deps = ...@@ -24,4 +24,4 @@ deps =
define = define =
CMAKE_C_COMPILER_LAUNCHER=${deps_dir}/bin/ccache CMAKE_C_COMPILER_LAUNCHER=${deps_dir}/bin/ccache
CMAKE_CXX_COMPILER_LAUNCHER=${deps_dir}/bin/ccache CMAKE_CXX_COMPILER_LAUNCHER=${deps_dir}/bin/ccache
MIGRAPHX_ENABLE_CPU=On MIGRAPHX_ENABLE_CPU=On
\ No newline at end of file
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
google/protobuf@v3.11.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off google/protobuf@v3.11.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
nlohmann/json@v3.8.0 nlohmann/json@v3.8.0
live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212 live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212
half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969 half,https://github.com/ROCmSoftwarePlatform/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
...@@ -82,7 +82,6 @@ add_library(migraphx ...@@ -82,7 +82,6 @@ add_library(migraphx
replace_allocate.cpp replace_allocate.cpp
simplify_qdq.cpp simplify_qdq.cpp
sqlite.cpp sqlite.cpp
rewrite_batchnorm.cpp
rewrite_gelu.cpp rewrite_gelu.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
rewrite_quantization.cpp rewrite_quantization.cpp
...@@ -116,7 +115,6 @@ register_migraphx_ops( ...@@ -116,7 +115,6 @@ register_migraphx_ops(
as_shape as_shape
atanh atanh
atan atan
batch_norm_inference
broadcast broadcast
capture capture
ceil ceil
......
...@@ -74,9 +74,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -74,9 +74,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
auto x_main_module_19 = mmain->add_literal(migraphx::generate_literal( auto x_main_module_19 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 18)); migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 18));
auto x_main_module_20 = mmain->add_instruction( auto x_main_module_20 = mmain->add_instruction(
migraphx::make_json_op("convolution", migraphx::make_json_op(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4," "convolution",
"4],use_dynamic_same_auto_pad:0}"), "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}"),
x_0, x_0,
x_main_module_19); x_main_module_19);
auto x_main_module_21 = mmain->add_instruction( auto x_main_module_21 = mmain->add_instruction(
...@@ -90,9 +90,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -90,9 +90,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"), "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_23); x_main_module_23);
auto x_main_module_25 = mmain->add_instruction( auto x_main_module_25 = mmain->add_instruction(
migraphx::make_json_op("convolution", migraphx::make_json_op(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1," "convolution",
"1],use_dynamic_same_auto_pad:0}"), "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}"),
x_main_module_24, x_main_module_24,
x_main_module_17); x_main_module_17);
auto x_main_module_26 = mmain->add_instruction( auto x_main_module_26 = mmain->add_instruction(
...@@ -106,9 +106,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -106,9 +106,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"), "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_28); x_main_module_28);
auto x_main_module_30 = mmain->add_instruction( auto x_main_module_30 = mmain->add_instruction(
migraphx::make_json_op("convolution", migraphx::make_json_op(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1," "convolution",
"1],use_dynamic_same_auto_pad:0}"), "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"),
x_main_module_29, x_main_module_29,
x_main_module_15); x_main_module_15);
auto x_main_module_31 = mmain->add_instruction( auto x_main_module_31 = mmain->add_instruction(
...@@ -117,9 +117,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -117,9 +117,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31); mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31);
auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32); auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32);
auto x_main_module_34 = mmain->add_instruction( auto x_main_module_34 = mmain->add_instruction(
migraphx::make_json_op("convolution", migraphx::make_json_op(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1," "convolution",
"1],use_dynamic_same_auto_pad:0}"), "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"),
x_main_module_33, x_main_module_33,
x_main_module_13); x_main_module_13);
auto x_main_module_35 = mmain->add_instruction( auto x_main_module_35 = mmain->add_instruction(
...@@ -128,9 +128,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -128,9 +128,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35); mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35);
auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36); auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36);
auto x_main_module_38 = mmain->add_instruction( auto x_main_module_38 = mmain->add_instruction(
migraphx::make_json_op("convolution", migraphx::make_json_op(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1," "convolution",
"1],use_dynamic_same_auto_pad:0}"), "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"),
x_main_module_37, x_main_module_37,
x_main_module_11); x_main_module_11);
auto x_main_module_39 = mmain->add_instruction( auto x_main_module_39 = mmain->add_instruction(
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -44,7 +44,6 @@ ...@@ -44,7 +44,6 @@
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
...@@ -221,7 +220,6 @@ struct loader ...@@ -221,7 +220,6 @@ struct loader
{ {
migraphx::run_passes(*p.get_main_module(), migraphx::run_passes(*p.get_main_module(),
{ {
migraphx::rewrite_batchnorm{},
migraphx::eliminate_identity{}, migraphx::eliminate_identity{},
migraphx::dead_code_elimination{}, migraphx::dead_code_elimination{},
migraphx::simplify_algebra{}, migraphx::simplify_algebra{},
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -145,7 +145,7 @@ void verify_reduced(program p, ...@@ -145,7 +145,7 @@ void verify_reduced(program p,
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n + 1); auto last = std::prev(mm->end(), n + 1);
mm->remove_instructions(last, mm->end()); mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << std::endl; std::cout << "Verify: " << n << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance); verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance);
} }
...@@ -159,6 +159,7 @@ void verify_reduced_program(const program& p, ...@@ -159,6 +159,7 @@ void verify_reduced_program(const program& p,
{ {
const auto* mm = p.get_main_module(); const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end()); auto n = std::distance(mm->begin(), mm->end());
std::cout << "Verify steps: " << n << std::endl;
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
verify_reduced(p, i, t, options, quantize, inputs, tolerance); verify_reduced(p, i, t, options, quantize, inputs, tolerance);
......
...@@ -39,7 +39,7 @@ static literal get_scalar(instruction_ref ins) ...@@ -39,7 +39,7 @@ static literal get_scalar(instruction_ref ins)
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
return get_scalar(ins->inputs().front()); return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape(); const auto& s = ins->get_shape();
if(not(s.elements() == 1 or s.scalar())) if(s.elements() != 1 && not(s.scalar()))
return {}; return {};
if(not ins->can_eval()) if(not ins->can_eval())
return {}; return {};
......
...@@ -107,6 +107,7 @@ struct argument : raw_data<argument> ...@@ -107,6 +107,7 @@ struct argument : raw_data<argument>
data_t m_data{}; data_t m_data{};
}; };
std::vector<shape> to_shapes(const std::vector<argument>& args);
void migraphx_to_value(value& v, const argument& a); void migraphx_to_value(value& v, const argument& a);
void migraphx_from_value(const value& v, argument& a); void migraphx_from_value(const value& v, argument& a);
......
...@@ -21,41 +21,55 @@ ...@@ -21,41 +21,55 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#define MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/reflect.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context; struct dyn_output
{
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
struct miopen_batch_norm_inference /**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{ {
op::batch_norm_inference op; F ins_inputs;
template <class Self, class F> operator dyn_output() const
static auto reflect(Self& self, F f)
{ {
return migraphx::reflect(self.op, f); return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
if(ins_shape.dynamic())
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
return dyn_output{ins_shape, ins_shape};
});
} }
std::string name() const { return "gpu::batch_norm_inference"; } operator shape() const
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
} }
}; };
} // namespace gpu template <class F>
compute_output_shape<F> make_compute_output_shape(F f)
{
return {f};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP
#define MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <cmath>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct batch_norm_inference
{
float epsilon = 1.0e-6f;
float momentum = 0.9f;
std::string name() const { return "batch_norm_inference"; }
enum bn_infer_mode_t
{
per_activation,
spatial,
};
bn_infer_mode_t bn_mode = spatial;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.epsilon, "epsilon"), f(self.momentum, "momentum"), f(self.bn_mode, "bn_mode"));
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(5);
check_shapes{inputs.data(), inputs.data() + 1, *this}.same_ndims();
check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape();
return inputs.front();
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -33,11 +33,11 @@ namespace migraphx { ...@@ -33,11 +33,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
// Padding mode is default_ for fixed shape padding.
// same_lower and same_upper used for dynamic padding.
enum padding_mode_t enum padding_mode_t
{ {
default_, // NOLINT default_, // NOLINT
same,
valid,
same_lower, same_lower,
same_upper same_upper
}; };
......
...@@ -44,7 +44,7 @@ struct convert : unary<convert> ...@@ -44,7 +44,7 @@ struct convert : unary<convert>
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
if(input.dynamic()) if(input.dynamic())
{ {
......
...@@ -41,9 +41,8 @@ struct convolution ...@@ -41,9 +41,8 @@ struct convolution
std::vector<std::size_t> stride = {1, 1}; std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1}; std::vector<std::size_t> dilation = {1, 1};
int group = 1; int group = 1;
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
bool use_dynamic_same_auto_pad = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -52,16 +51,15 @@ struct convolution ...@@ -52,16 +51,15 @@ struct convolution
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.group, "group"), f(self.group, "group"),
f(self.padding_mode, "padding_mode"), f(self.padding_mode, "padding_mode"));
f(self.use_dynamic_same_auto_pad, "use_dynamic_same_auto_pad"));
} }
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
void check_attribute_size() const void check_attribute_size() const
{ {
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() == dilation.size())) stride.size() != dilation.size())
{ {
MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes"); MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes");
} }
...@@ -76,7 +74,8 @@ struct convolution ...@@ -76,7 +74,8 @@ struct convolution
// num of dims of input and attribute should match // num of dims of input and attribute should match
const auto input_size = inputs[0].max_lens().size(); const auto input_size = inputs[0].max_lens().size();
const auto padding_size = padding.size(); const auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
if(input_size != padding_size / 2 + 2 && input_size != padding_size + 2)
{ {
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!"); MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
} }
...@@ -93,13 +92,6 @@ struct convolution ...@@ -93,13 +92,6 @@ struct convolution
x_shape.lens().at(1) != (w_shape.lens().at(1) * group)) x_shape.lens().at(1) != (w_shape.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers"); MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
std::vector<op::padding_mode_t> dyn_pad_modes = {op::padding_mode_t::same_upper,
op::padding_mode_t::same_lower};
if(use_dynamic_same_auto_pad and not contains(dyn_pad_modes, padding_mode))
{
MIGRAPHX_THROW("CONVOLUTION: use_dynamic_same_auto_pad set with invalid padding mode");
}
if(x_shape.dynamic() or w_shape.dynamic()) if(x_shape.dynamic() or w_shape.dynamic())
{ {
return dynamic_compute_shape(x_shape, w_shape); return dynamic_compute_shape(x_shape, w_shape);
...@@ -161,7 +153,7 @@ struct convolution ...@@ -161,7 +153,7 @@ struct convolution
dynamic_shape_push_back(w_shape); dynamic_shape_push_back(w_shape);
const size_t num_spatial_dims = x_shape.max_lens().size() - 2; const size_t num_spatial_dims = x_shape.max_lens().size() - 2;
if(use_dynamic_same_auto_pad) if(padding_mode != default_)
{ {
for(std::size_t i = 0; i < num_spatial_dims; ++i) for(std::size_t i = 0; i < num_spatial_dims; ++i)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment