Unverified Commit 9c91c08d authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into enable_navi_32_ci

parents a56bb11d c1b8c975
...@@ -18,19 +18,8 @@ jobs: ...@@ -18,19 +18,8 @@ jobs:
with: with:
access_token: ${{ github.token }} access_token: ${{ github.token }}
tidy: tidy:
runs-on: ubuntu-20.04 runs-on: ROCM-Ubuntu
steps: steps:
- name: Free space
uses: jlumbroso/free-disk-space@main
with:
tool-cache: true
android: true
dotnet: true
haskell: true
large-packages: true
swap-storage: true
- uses: actions/checkout@v3 - uses: actions/checkout@v3
# In this step, this action saves a list of existing images, # In this step, this action saves a list of existing images,
...@@ -71,7 +60,7 @@ jobs: ...@@ -71,7 +60,7 @@ jobs:
-DCLANG_TIDY_DEPEND_ON_TARGET=Off \ -DCLANG_TIDY_DEPEND_ON_TARGET=Off \
-DCLANG_TIDY_CACHE=/data/tidy-cache \ -DCLANG_TIDY_CACHE=/data/tidy-cache \
.. ..
make -j2 -k onnx-proto tf-proto tidy make -j$(nproc) -k onnx-proto tf-proto tidy
# GH actions can not update existing cache, as a workaround clear cache and then save it # GH actions can not update existing cache, as a workaround clear cache and then save it
- name: Clear tidy cache before saving - name: Clear tidy cache before saving
...@@ -93,20 +82,8 @@ jobs: ...@@ -93,20 +82,8 @@ jobs:
cppcheck: cppcheck:
runs-on: ubuntu-20.04 runs-on: ROCM-Ubuntu
steps: steps:
- name: Free space
uses: jlumbroso/free-disk-space@main
with:
tool-cache: true
android: true
dotnet: true
haskell: true
large-packages: true
swap-storage: true
- uses: actions/checkout@v3 - uses: actions/checkout@v3
# In this step, this action saves a list of existing images, # In this step, this action saves a list of existing images,
...@@ -142,7 +119,7 @@ jobs: ...@@ -142,7 +119,7 @@ jobs:
-DBUILD_DEV=On \ -DBUILD_DEV=On \
-DROCM_ENABLE_GH_ANNOTATIONS=On \ -DROCM_ENABLE_GH_ANNOTATIONS=On \
.. ..
make -j2 cppcheck make -j$(nproc) cppcheck
# GH actions can not update existing cache, as a workaround clear cache and then save it # GH actions can not update existing cache, as a workaround clear cache and then save it
- name: Clear cppcheck cache before saving - name: Clear cppcheck cache before saving
...@@ -164,18 +141,8 @@ jobs: ...@@ -164,18 +141,8 @@ jobs:
format: format:
runs-on: ubuntu-20.04 runs-on: ROCM-Ubuntu
steps: steps:
- name: Free space
uses: jlumbroso/free-disk-space@main
with:
tool-cache: true
android: true
dotnet: true
haskell: true
large-packages: true
swap-storage: true
- uses: actions/checkout@v3 - uses: actions/checkout@v3
# In this step, this action saves a list of existing images, # In this step, this action saves a list of existing images,
......
...@@ -12,7 +12,7 @@ on: ...@@ -12,7 +12,7 @@ on:
rocm_release: rocm_release:
description: ROCm Version description: ROCm Version
required: true required: true
default: '5.4.2' default: '5.5'
performance_reports_repo: performance_reports_repo:
description: Repository where performance reports are stored description: Repository where performance reports are stored
required: true required: true
...@@ -48,7 +48,7 @@ jobs: ...@@ -48,7 +48,7 @@ jobs:
release: release:
uses: ROCmSoftwarePlatform/migraphx-benchmark/.github/workflows/perf-test.yml@main uses: ROCmSoftwarePlatform/migraphx-benchmark/.github/workflows/perf-test.yml@main
with: with:
rocm_release: ${{ github.event.inputs.rocm_release || '5.4.2' }} rocm_release: ${{ github.event.inputs.rocm_release || '5.5' }}
result_number: ${{ github.event.inputs.result_number || '10' }} result_number: ${{ github.event.inputs.result_number || '10' }}
flags: ${{ github.event.inputs.flags || '-r' }} flags: ${{ github.event.inputs.flags || '-r' }}
performance_reports_repo: ${{ github.event.inputs.performance_reports_repo || 'ROCmSoftwarePlatform/migraphx-reports' }} performance_reports_repo: ${{ github.event.inputs.performance_reports_repo || 'ROCmSoftwarePlatform/migraphx-reports' }}
......
...@@ -27,28 +27,27 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_BINARY_DIR}") ...@@ -27,28 +27,27 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_BINARY_DIR}")
message(FATAL_ERROR "The binary and source directroy cannot be the same") message(FATAL_ERROR "The binary and source directroy cannot be the same")
endif() endif()
get_property(_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG)
# This has to be initialized before the project() command appears # This has to be initialized before the project() command appears
# Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE # Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE
if( NOT MSVC_IDE AND NOT CMAKE_BUILD_TYPE ) if(_GENERATOR_IS_MULTI_CONFIG)
set( CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel." ) if (NOT CMAKE_CONFIGURATION_TYPES)
endif() set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo;MinSizeRel" CACHE STRING
"Available build types (configurations) on multi-config generators")
# Setup valid strings for build type endif()
if (NOT CMAKE_CONFIGURATION_TYPES)
set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo;MinSizeRel" CACHE STRING "Configs")
endif()
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS ${CMAKE_CONFIGURATION_TYPES})
# Default installation path
if(WIN32)
set(CMAKE_INSTALL_PREFIX "/opt/rocm/x86_64-w64-mingw32" CACHE PATH "")
else() else()
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "") if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release CACHE STRING
"Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel.")
endif()
endif() endif()
set(CMAKE_BUILD_RPATH "${CMAKE_BINARY_DIR}/lib") set(CMAKE_BUILD_RPATH "${CMAKE_BINARY_DIR}/lib")
project(migraphx) project(migraphx LANGUAGES C CXX)
include(CTest)
find_package(ROCM REQUIRED) find_package(ROCM REQUIRED)
find_path(HALF_INCLUDE_DIR half.hpp PATH_SUFFIXES half) find_path(HALF_INCLUDE_DIR half.hpp PATH_SUFFIXES half)
...@@ -128,6 +127,7 @@ rocm_enable_clang_tidy( ...@@ -128,6 +127,7 @@ rocm_enable_clang_tidy(
-bugprone-implicit-widening-of-multiplication-result -bugprone-implicit-widening-of-multiplication-result
-bugprone-macro-parentheses -bugprone-macro-parentheses
-bugprone-signed-char-misuse -bugprone-signed-char-misuse
-bugprone-unchecked-optional-access
# Disable the aliased reserved identifiers # Disable the aliased reserved identifiers
-cert-dcl37-c -cert-dcl37-c
-cert-dcl51-cpp -cert-dcl51-cpp
...@@ -269,7 +269,9 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) ...@@ -269,7 +269,9 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin)
add_subdirectory(src) add_subdirectory(src)
add_subdirectory(docs) add_subdirectory(docs)
add_subdirectory(test) if(BUILD_TESTING)
add_subdirectory(test)
endif()
add_subdirectory(tools) add_subdirectory(tools)
set(DEST_DIR ${CMAKE_BINARY_DIR}) set(DEST_DIR ${CMAKE_BINARY_DIR})
......
...@@ -12,6 +12,9 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && ...@@ -12,6 +12,9 @@ RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl &&
# Add rocm repository # Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.5/ focal main > /etc/apt/sources.list.d/rocm.list' RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.5/ focal main > /etc/apt/sources.list.d/rocm.list'
# From docs.amd.com for installing rocm. Needed to install properly
RUN sh -c "echo 'Package: *\nPin: release o=repo.radeon.com\nPin-priority: 600' > /etc/apt/preferences.d/rocm-pin-600"
# Install dependencies # Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
apt-utils \ apt-utils \
...@@ -110,7 +113,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -110,7 +113,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@a997d5f51314b45d7a4c04f1599966dcf53f9b4d -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off RUN cget -p /usr/local install ROCmSoftwarePlatform/rocMLIR@8d25af3b3721c159bb41cc6388e9453b1018c126 -DBUILD_MIXR_TARGET=On -DLLVM_ENABLE_ZSTD=Off -DLLVM_ENABLE_THREADS=Off
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
...@@ -89,6 +89,8 @@ def rocmnodename(name) { ...@@ -89,6 +89,8 @@ def rocmnodename(name) {
node_name = "${rocmtest_name} && vega"; node_name = "${rocmtest_name} && vega";
} else if(name == "navi21") { } else if(name == "navi21") {
node_name = "${rocmtest_name} && navi21"; node_name = "${rocmtest_name} && navi21";
} else if(name == "mi100+") {
node_name = "${rocmtest_name} && (gfx908 || gfx90a)";
} else if(name == "anygpu") { } else if(name == "anygpu") {
node_name = "${rocmtest_name} && (gfx908 || gfx90a || vega)"; node_name = "${rocmtest_name} && (gfx908 || gfx90a || vega)";
} else if(name == "navi32") { } else if(name == "navi32") {
...@@ -136,6 +138,12 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build -> ...@@ -136,6 +138,12 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
cmake_build(flags: "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'") cmake_build(flags: "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'")
} }
} }
}, ck_release: rocmnode('mi100+') { cmake_build ->
stage('CK Release') {
withEnv(['MIGRAPHX_ENABLE_CK=1', 'MIGRAPHX_TUNE_CK=1']) {
cmake_build(flags: "-DCMAKE_BUILD_TYPE=release")
}
}
}, clang_asan: rocmnode('nogpu') { cmake_build -> }, clang_asan: rocmnode('nogpu') { cmake_build ->
stage('Clang ASAN') { stage('Clang ASAN') {
def sanitizers = "undefined,address" def sanitizers = "undefined,address"
......
...@@ -24,6 +24,40 @@ ...@@ -24,6 +24,40 @@
find_program(EMBED_LD ld) find_program(EMBED_LD ld)
find_program(EMBED_OBJCOPY objcopy) find_program(EMBED_OBJCOPY objcopy)
if(LINUX)
option(EMBED_USE_LD "Use ld to embed data files" ON)
else()
option(EMBED_USE_LD "Use ld to embed data files" OFF)
endif()
function(wrap_string)
set(options)
set(oneValueArgs VARIABLE AT_COLUMN)
set(multiValueArgs)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN})
string(LENGTH ${${PARSE_VARIABLE}} string_length)
math(EXPR offset "0")
while(string_length GREATER 0)
if(string_length GREATER ${PARSE_AT_COLUMN})
math(EXPR length "${PARSE_AT_COLUMN}")
else()
math(EXPR length "${string_length}")
endif()
string(SUBSTRING ${${PARSE_VARIABLE}} ${offset} ${length} line)
set(lines "${lines}\n${line}")
math(EXPR string_length "${string_length} - ${length}")
math(EXPR offset "${offset} + ${length}")
endwhile()
set(${PARSE_VARIABLE} "${lines}" PARENT_SCOPE)
endfunction()
function(generate_embed_source EMBED_NAME) function(generate_embed_source EMBED_NAME)
set(options) set(options)
set(oneValueArgs SRC HEADER) set(oneValueArgs SRC HEADER)
...@@ -46,14 +80,21 @@ function(generate_embed_source EMBED_NAME) ...@@ -46,14 +80,21 @@ function(generate_embed_source EMBED_NAME)
list(GET PARSE_OBJECTS ${idx} OBJECT) list(GET PARSE_OBJECTS ${idx} OBJECT)
set(START_SYMBOL "_binary_${SYMBOL}_start") set(START_SYMBOL "_binary_${SYMBOL}_start")
set(END_SYMBOL "_binary_${SYMBOL}_end") set(END_SYMBOL "_binary_${SYMBOL}_end")
if(EMBED_USE_LD)
string(APPEND EXTERNS " string(APPEND EXTERNS "
extern const char ${START_SYMBOL}[]; extern const char ${START_SYMBOL}[];
extern const char ${END_SYMBOL}[]; extern const char ${END_SYMBOL}[];
") ")
else()
string(APPEND EXTERNS "
extern const char ${START_SYMBOL}[];
extern const char* ${END_SYMBOL};
")
endif()
# TODO: Should use NAME_WLE # TODO: Should use NAME_WLE
get_filename_component(BASE_NAME "${OBJECT}" NAME) get_filename_component(BASE_NAME "${OBJECT}" NAME)
string(REGEX REPLACE ".[A-Za-z0-9_]$" "" BASE_NAME ${BASE_NAME}) string(REGEX REPLACE ".[A-Za-z0-9_]+$" "" BASE_NAME ${BASE_NAME})
string(APPEND INIT_KERNELS " string(APPEND INIT_KERNELS "
{ \"${BASE_NAME}\", { ${START_SYMBOL}, ${END_SYMBOL}} }, { \"${BASE_NAME}\", { ${START_SYMBOL}, ${END_SYMBOL}} },
...@@ -86,9 +127,14 @@ function(embed_file OUTPUT_FILE OUTPUT_SYMBOL FILE) ...@@ -86,9 +127,14 @@ function(embed_file OUTPUT_FILE OUTPUT_SYMBOL FILE)
string(MAKE_C_IDENTIFIER "${REL_FILE}" SYMBOL) string(MAKE_C_IDENTIFIER "${REL_FILE}" SYMBOL)
get_filename_component(OUTPUT_FILE_DIR "${REL_FILE}" DIRECTORY) get_filename_component(OUTPUT_FILE_DIR "${REL_FILE}" DIRECTORY)
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE_DIR}") file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE_DIR}")
if(EMBED_USE_LD)
set(OUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.o") set(OUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.o")
else()
set(OUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.cpp")
endif()
set(${OUTPUT_SYMBOL} ${SYMBOL} PARENT_SCOPE) set(${OUTPUT_SYMBOL} ${SYMBOL} PARENT_SCOPE)
set(${OUTPUT_FILE} "${OUT_FILE}" PARENT_SCOPE) set(${OUTPUT_FILE} "${OUT_FILE}" PARENT_SCOPE)
if(EMBED_USE_LD)
add_custom_command( add_custom_command(
OUTPUT "${OUT_FILE}" OUTPUT "${OUT_FILE}"
COMMAND ${EMBED_LD} -r -o "${OUT_FILE}" -z noexecstack --format=binary "${REL_FILE}" COMMAND ${EMBED_LD} -r -o "${OUT_FILE}" -z noexecstack --format=binary "${REL_FILE}"
...@@ -97,6 +143,21 @@ function(embed_file OUTPUT_FILE OUTPUT_SYMBOL FILE) ...@@ -97,6 +143,21 @@ function(embed_file OUTPUT_FILE OUTPUT_SYMBOL FILE)
DEPENDS ${FILE} DEPENDS ${FILE}
VERBATIM VERBATIM
) )
else()
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${FILE})
# reads source file contents as hex string
file(READ ${FILE} HEX_STRING HEX)
# wraps the hex string into multiple lines
wrap_string(VARIABLE HEX_STRING AT_COLUMN 80)
# adds '0x' prefix and comma suffix before and after every byte respectively
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1, " ARRAY_VALUES ${HEX_STRING})
# removes trailing comma
string(REGEX REPLACE ", $" "" ARRAY_VALUES ${ARRAY_VALUES})
file(WRITE "${OUT_FILE}" "
extern const char _binary_${SYMBOL}_start[] = { ${ARRAY_VALUES} };
extern const char* _binary_${SYMBOL}_end = _binary_${SYMBOL}_start + sizeof(_binary_${SYMBOL}_start);
\n")
endif()
endforeach() endforeach()
endfunction() endfunction()
...@@ -119,6 +180,6 @@ function(add_embed_library EMBED_NAME) ...@@ -119,6 +180,6 @@ function(add_embed_library EMBED_NAME)
generate_embed_source(${EMBED_NAME} SRC ${SRC_FILE} HEADER ${HEADER_FILE} OBJECTS ${OUTPUT_FILES} SYMBOLS ${SYMBOLS}) generate_embed_source(${EMBED_NAME} SRC ${SRC_FILE} HEADER ${HEADER_FILE} OBJECTS ${OUTPUT_FILES} SYMBOLS ${SYMBOLS})
add_library(${EMBED_NAME} STATIC ${OUTPUT_FILES} "${SRC_FILE}") add_library(${EMBED_NAME} STATIC ${OUTPUT_FILES} "${SRC_FILE}")
target_include_directories(${EMBED_NAME} PUBLIC "${EMBED_DIR}/include") target_include_directories(${EMBED_NAME} PUBLIC "${EMBED_DIR}/include")
target_compile_options(${EMBED_NAME} PRIVATE -Wno-reserved-identifier) target_compile_options(${EMBED_NAME} PRIVATE -Wno-reserved-identifier -Wno-extern-initializer -Wno-missing-variable-declarations)
set_target_properties(${EMBED_NAME} PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(${EMBED_NAME} PROPERTIES POSITION_INDEPENDENT_CODE On)
endfunction() endfunction()
...@@ -28,3 +28,4 @@ ROCmSoftwarePlatform/half@rocm-5.4.2 ...@@ -28,3 +28,4 @@ ROCmSoftwarePlatform/half@rocm-5.4.2
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@5172ec5280f14974beee2acf1af1db3b2670244c -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
...@@ -94,6 +94,7 @@ add_library(migraphx ...@@ -94,6 +94,7 @@ add_library(migraphx
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
split_single_dyn_dim.cpp split_single_dyn_dim.cpp
target.cpp
tmp_dir.cpp tmp_dir.cpp
value.cpp value.cpp
verify_args.cpp verify_args.cpp
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -31,20 +31,6 @@ ...@@ -31,20 +31,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1) std::vector<std::size_t> s1)
{ {
...@@ -77,20 +63,18 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -77,20 +63,18 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
} }
auto offset = s1.ndim() - s0.ndim(); auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims()); std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
std::transform( std::transform(s0.dyn_dims().cbegin(),
s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(), s0.dyn_dims().cend(),
s1.dyn_dims().cbegin() + offset, s1.dyn_dims().cbegin() + offset,
out_dims.begin() + offset, out_dims.begin() + offset,
[&](auto a, auto b) { [&](auto a, auto b) {
if(a == b) if(a == b or b == 1)
{ {
return a; return a;
} }
else if(a == 1 or b == 1) else if(a == 1)
{ {
// setting optimals to empty, may need to be changed return b;
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max)};
} }
else else
{ {
...@@ -102,7 +86,15 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -102,7 +86,15 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
return out_dims; return out_dims;
} }
// Compute the common (broadcasted) dimensions of a list of fixed shapes std::vector<shape::dynamic_dimension> compute_common_dyn_dims(const std::vector<shape>& shapes)
{
auto ret_shape = shapes.at(0);
std::for_each(shapes.cbegin() + 1, shapes.cend(), [&](auto s) {
ret_shape = shape{ret_shape.type(), compute_broadcasted_dyn_dims(ret_shape, s)};
});
return ret_shape.dyn_dims();
}
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes) std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
{ {
assert(not shapes.empty()); assert(not shapes.empty());
...@@ -154,34 +146,29 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> ...@@ -154,34 +146,29 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
if(std::any_of( if(std::any_of(
inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); })) inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
{ {
// currently only handles the binary case auto input_shapes = to_shapes(inputs);
if(inputs.size() != 2) auto c_type = compute_common_types(input_shapes);
{ auto c_dyn_dims = compute_common_dyn_dims(input_shapes);
MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) +
"inputs, only handle two inputs if any are dynamic shape");
}
auto c_type = compute_common_types(to_shapes(inputs));
auto c_dyn_dims =
compute_broadcasted_dyn_dims(inputs[0]->get_shape(), inputs[1]->get_shape());
// following should work for a static or dynamic shape // following should work for a static or dynamic shape
if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims) if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims)
{ {
inputs[0] = m.insert_instruction( inputs[0] = m.insert_instruction(
ins, ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[0],
inputs[1]);
} }
if(inputs[1]->get_shape().dyn_dims() != c_dyn_dims) std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) {
// uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime
if(input->get_shape().dyn_dims() != c_dyn_dims)
{ {
inputs[1] = m.insert_instruction( return m.insert_instruction(
ins, ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[1], input,
inputs[0]); inputs[0]);
} }
return input;
});
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type) if(input->get_shape().type() != c_type)
{ {
......
...@@ -49,8 +49,10 @@ void dead_code_elimination::apply(module& m) const ...@@ -49,8 +49,10 @@ void dead_code_elimination::apply(module& m) const
if(i == last) if(i == last)
break; break;
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined, // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate] // identity, allocate or tuple_type]
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and if((not i->get_shape().dynamic() and
(i->get_shape().elements() == 0 and
i->get_shape().type() != migraphx::shape::tuple_type)) and
not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and
not i->is_undefined()) not i->is_undefined())
continue; continue;
......
...@@ -455,8 +455,29 @@ struct compiler ...@@ -455,8 +455,29 @@ struct compiler
{ {
auto p = l.load(); auto p = l.load();
// Dont compile if its already been compiled // Dont compile if its already been compiled
if(p.is_compiled()) if(p.is_compiled())
{
if(ct.target_name == "gpu")
{
if(is_offload_copy_set(p) and not co.offload_copy)
{
std::cout << "MIGraphX program was likely compiled with offload_copy set, Try "
"passing "
"`--enable-offload-copy` if program run fails.\n";
}
else if(co.offload_copy)
{
std::cout << "MIGraphX program was likely compiled without "
"offload_copy set, Try "
"removing "
"`--enable-offload-copy` flag if passed to driver, if program run "
"fails.\n";
}
}
return p; return p;
}
auto t = ct.get_target(); auto t = ct.get_target();
if(to_fp16) if(to_fp16)
{ {
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#include "perf.hpp" #include "perf.hpp"
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
...@@ -97,6 +99,38 @@ target get_target(bool gpu) ...@@ -97,6 +99,38 @@ target get_target(bool gpu)
return make_target("cpu"); return make_target("cpu");
} }
bool is_offload_copy_set(const program& p)
{
assert(p.is_compiled());
const module* mm = p.get_main_module();
std::vector<std::string> param_names = mm->get_parameter_names();
std::unordered_set<instruction_ref> param_ins;
std::transform(param_names.begin(),
param_names.end(),
std::inserter(param_ins, param_ins.begin()),
[&](const auto& i) { return mm->get_parameter(i); });
for(const auto& i : *mm)
{
if(i.name() == "hip::copy_to_gpu")
{
auto copy_arg = instruction::get_output_alias(i.inputs().front(), true);
param_ins.erase(copy_arg);
}
else if(i.name() == "@return")
{
auto return_args = i.inputs();
for(const auto& j : return_args)
{
auto alias_ins = instruction::get_output_alias(j, true);
if((alias_ins->name() == "@param" && param_ins.erase(alias_ins) == 0) or
(alias_ins->name() != "hip::copy_from_gpu"))
return false;
}
}
}
return param_ins.empty();
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
} // namespace migraphx } // namespace migraphx
...@@ -39,6 +39,15 @@ parameter_map create_param_map(const program& p, const target& t, bool offload = ...@@ -39,6 +39,15 @@ parameter_map create_param_map(const program& p, const target& t, bool offload =
parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu); parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu);
parameter_map create_param_map(const program& p, bool gpu = true); parameter_map create_param_map(const program& p, bool gpu = true);
target get_target(bool gpu); target get_target(bool gpu);
/**
* @brief Checks if MIGraphX program compiled for "GPU" has offload_copy set of not. This is
intended to print a HINT for the users and would not always correctly classify compiled program as
with or without offload_copy in all cases.
* @param p Compiled MIGraphX program for GPU backend
* @return true if program is classified as compiled with "offload_copy" set
*/
bool is_offload_copy_set(const program& p);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
...@@ -34,6 +34,26 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -34,6 +34,26 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
struct operation; struct operation;
/**
* Broadcasting works by comparing the shapes element-wise starting with
* the trailing (right-most) dimensions and working leftwards. This is equivalent
* to what is done in NumPy.
* example 1:
* s0 = (3,2,4,5) and s1 = (2,1,1)
* In this case we need to broadcast (:,1,1) portion of
* s1 plus broadcast the 1st dimension of s0
* giving output_lens = (3,2,4,5)
*
* example 2:
* s0 = (3,2,1,5) and s1 = (2,7,5)
* In this case we need to broadcast the (:,:,1:,:) axis
* of s0 plus the 1st dimension of s1 giving
* output_lens = (3,2,7,5)
*
* example 3:
* s0 = (4, 1, 1) and s1 = (3, 4)
* output_lens = (4, 3, 4)
*/
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0, std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1); std::vector<std::size_t> s1);
...@@ -41,6 +61,28 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -41,6 +61,28 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
shape common_shape(const std::vector<shape>& shapes); shape common_shape(const std::vector<shape>& shapes);
/**
* @brief Compute the common (broadcasted) dimensions of a list of fixed shapes
*/
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes);
/**
* @ brief Compute the common (broadcasted) dynamic dimensions of a list of dynamic shapes
*/
std::vector<shape::dynamic_dimension> compute_common_dyn_dims(const std::vector<shape>& shapes);
/**
* @brief Creates and adds instructions to convert input arguments to common shapes and types
* by adding multi-broadcast and type convert operations. This is a utility function for creating
* operations where the shape and type of inputs need to match. It supports both dynamic and
* static-shaped arguments.
*
* @param m containing module for instruction
* @param ins insertion location in instruction list
* @param inputs instructions to use as argument list; also, the shapes
* attached to each instruction_ref are considered for broadcasting
* @return std::vector<instruction_ref> a modified argument list
*/
std::vector<instruction_ref> std::vector<instruction_ref>
insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> inputs); insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref> inputs);
...@@ -50,6 +92,10 @@ instruction_ref insert_common_op(module& m, ...@@ -50,6 +92,10 @@ instruction_ref insert_common_op(module& m,
instruction_ref ins, instruction_ref ins,
const operation& op, const operation& op,
std::vector<instruction_ref> inputs); std::vector<instruction_ref> inputs);
/**
* @brief Wrapper for insert_common_args() which inserts operation at the end of the module.
*/
instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs); instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#include <migraphx/config.hpp>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
std::size_t hash_value(const T& v)
{
return std::hash<T>{}(v);
}
template <class T>
void hash_combine(std::size_t& seed, const T& v)
{
seed ^= hash_value(v) + 0x9e3779b9 + (seed << 6u) + (seed >> 2u);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
...@@ -137,6 +137,7 @@ struct instruction ...@@ -137,6 +137,7 @@ struct instruction
operation normalized_operator() const; operation normalized_operator() const;
std::size_t get_target_id() const; std::size_t get_target_id() const;
void set_target_id(std::size_t tid); void set_target_id(std::size_t tid);
void debug_print() const; void debug_print() const;
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/optional.hpp> #include <migraphx/optional.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/source_location.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -135,14 +136,14 @@ template <class M> ...@@ -135,14 +136,14 @@ template <class M>
auto bind_match(M m, std::string name) auto bind_match(M m, std::string name)
{ {
return make_function_matcher( return make_function_matcher(
[=, name = std::move(name)](matcher_context& ctx, [=, m_name = std::move(name)](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> { instruction_ref ins) -> optional<instruction_ref> {
auto result = m.match(ctx, ins); auto result = m.match(ctx, ins);
if(result) if(result)
{ {
if(not ctx.has_instruction(ins)) if(not ctx.has_instruction(ins))
return nullopt; return nullopt;
ctx.instructions[name] = ins; ctx.instructions[m_name] = ins;
} }
return result; return result;
}); });
...@@ -370,31 +371,30 @@ match::matcher_result find_match(module& modl, M&& m) ...@@ -370,31 +371,30 @@ match::matcher_result find_match(module& modl, M&& m)
} }
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES_FOR)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES)
/// Find matches for an instruction in the module /// Find matches for an instruction in the module for per section of matchers
template <class Mod, class... Ms> template <class Mod, class... Ms>
void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms) void find_matches_for(source_location location, Mod& mod, instruction_ref ins, Ms&&... ms)
{ {
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 const int trace = value_of(MIGRAPHX_TRACE_MATCHES{});
const const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{});
#endif const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{});
int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); const bool trace_for = not trace_filter.empty() and
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 (contains(std::string{location.file_name()}, trace_filter) or
const contains(std::string{location.function_name()}, trace_filter));
#endif
bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{});
bool match = false; bool match = false;
each_args( each_args(
[&](auto&& m) { [&](auto&& m) {
if(match) if(match)
return; return;
if(trace > 1) if(trace > 1 or trace_for)
std::cout << "Match: " << get_type_name(m) << std::endl; std::cout << "Match: " << get_type_name(m) << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher()); auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end()) if(r.result == get_module(mod).end())
return; return;
if(trace > 0) if(trace > 0 or trace_for)
{ {
std::cout << "Matched by " << get_type_name(m) << std::endl; std::cout << "Matched by " << get_type_name(m) << std::endl;
get_module(mod).debug_print(ins); get_module(mod).debug_print(ins);
...@@ -420,13 +420,19 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms) ...@@ -420,13 +420,19 @@ void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
/// Find matches in a module /// Find matches in a module
template <class Mod, class... Ms> template <class Mod, class... Ms>
void find_matches(Mod& mod, Ms&&... ms) struct find_matches
{ {
find_matches(Mod& mod, Ms&&... ms, source_location location = source_location::current())
{
for(auto ins : iterator_for(get_module(mod))) for(auto ins : iterator_for(get_module(mod)))
{ {
find_matches(mod, ins, ms...); find_matches_for(location, mod, ins, ms...);
} }
} }
};
template <class Mod, class... Ms>
find_matches(Mod& mod, Ms&&... ms) -> find_matches<Mod, Ms...>;
template <class M, class F> template <class M, class F>
struct find_generic_match struct find_generic_match
...@@ -655,9 +661,9 @@ auto skip_output(Ms... ms) ...@@ -655,9 +661,9 @@ auto skip_output(Ms... ms)
inline auto var(std::string s) inline auto var(std::string s)
{ {
return make_basic_fun_matcher( return make_basic_fun_matcher(
[=, s = std::move(s)](const matcher_context& ctx, [=, m_s = std::move(s)](const matcher_context& ctx,
instruction_ref) -> optional<instruction_ref> { instruction_ref) -> optional<instruction_ref> {
auto it = ctx.instructions.find(s); auto it = ctx.instructions.find(m_s);
if(it == ctx.instructions.end()) if(it == ctx.instructions.end())
return nullopt; return nullopt;
return it->second; return it->second;
...@@ -667,7 +673,7 @@ inline auto var(std::string s) ...@@ -667,7 +673,7 @@ inline auto var(std::string s)
inline auto name(std::string s) inline auto name(std::string s)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher(
[=, s = std::move(s)](instruction_ref ins) { return ins->name() == s; }); [=, m_s = std::move(s)](instruction_ref ins) { return ins->name() == m_s; });
} }
inline auto name_contains(const std::string& name) inline auto name_contains(const std::string& name)
...@@ -678,8 +684,8 @@ inline auto name_contains(const std::string& name) ...@@ -678,8 +684,8 @@ inline auto name_contains(const std::string& name)
inline auto name(std::unordered_set<std::string> names) inline auto name(std::unordered_set<std::string> names)
{ {
return make_basic_pred_matcher([=, names = std::move(names)](instruction_ref ins) { return make_basic_pred_matcher([=, m_names = std::move(names)](instruction_ref ins) {
return names.count(ins->name()) > 0; return m_names.count(ins->name()) > 0;
}); });
} }
......
...@@ -189,7 +189,7 @@ struct module ...@@ -189,7 +189,7 @@ struct module
instruction_ref validate() const; instruction_ref validate() const;
instruction_ref find_dangling_reference() const; instruction_ref find_dangling_reference() const;
void finalize(context& ctx); void finalize(std::vector<context>& contexts);
void debug_print() const; void debug_print() const;
void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins) const;
......
...@@ -37,10 +37,13 @@ namespace op { ...@@ -37,10 +37,13 @@ namespace op {
* 1 input version: * 1 input version:
* Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of * Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
* broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension * broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
* that stays the same. ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1 For higher rank * that stays the same.
* input shapes, axis is an offset parameter for the broadcasting. Such that this operator would * ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1.
* work in the opposite direction of NumPy broadcasting. ex: broadcasting shape [2, 2] -> [2, 2, 3] *
* with axis = 0 * For higher rank input shapes, axis is an offset parameter for the broadcasting.
* Such that this operator would work in the opposite direction of NumPy broadcasting
* (left-most to rightwards element-wise comparison)
* ex: broadcasting shape [2, 2] -> [2, 2, 3] with axis = 0
* *
* 2 input version: * 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter. * Broadcast the first input 1D shape into the second input shape based on the axis parameter.
...@@ -68,6 +71,9 @@ struct broadcast ...@@ -68,6 +71,9 @@ struct broadcast
{ {
// the ONNX broadcast op is deprecated now, so not handling the negative // the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore // value of axis anymore
if(s0.dynamic())
MIGRAPHX_THROW(
"BROADCAST: Single dynamic input shape not supported. Use two inputs.");
if(axis >= broadcast_lens.size()) if(axis >= broadcast_lens.size())
{ {
MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) + MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) +
......
...@@ -25,12 +25,13 @@ ...@@ -25,12 +25,13 @@
#define MIGRAPHX_GUARD_OPERATORS_CLIP_HPP #define MIGRAPHX_GUARD_OPERATORS_CLIP_HPP
#include <array> #include <array>
#include <cmath>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <cmath> #include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -48,15 +49,15 @@ struct clip ...@@ -48,15 +49,15 @@ struct clip
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3).same_type().same_dims(); check_shapes{inputs, *this, true}.has(3).same_type().same_dims();
return inputs.front(); return inputs.front();
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[0], args[1], args[2])([&](auto output, auto x, auto min, auto max) { visit_all(result, args[0], args[1], args[2])([&](auto output, auto x, auto min, auto max) {
par_for(output_shape.elements(), par_for(dyn_out.computed_shape.elements(),
[&](auto i) { output[i] = std::min(std::max(min[i], x[i]), max[i]); }); [&](auto i) { output[i] = std::min(std::max(min[i], x[i]), max[i]); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment