Commit a22ec139 authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mlir-attention

parents 9898823d 650ba45f
##################################################################################### #####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -23,8 +23,14 @@ ...@@ -23,8 +23,14 @@
##################################################################################### #####################################################################################
include(PythonModules) include(PythonModules)
set(VENV ${CMAKE_BINARY_DIR}/test/py/venv)
set(VENV_ONNX ${CMAKE_BINARY_DIR}/test/py/venv-onnx)
set(REQUIREMENTS ${CMAKE_CURRENT_SOURCE_DIR}/requirements.txt)
set(REQUIREMENTS_ONNX ${CMAKE_CURRENT_SOURCE_DIR}/requirements-onnx.txt)
set(PYTHON_VERSION_TO_DISABLE_ONNX 3.6)
function(add_py_test NAME SCRIPT)
function(add_py_venv_fixture FIXTURE_NAME VIRTUAL_ENV_DIR REQUIREMENTS_FILE)
foreach(PYTHON_VERSION ${PYTHON_VERSIONS}) foreach(PYTHON_VERSION ${PYTHON_VERSIONS})
set (ENV_COMMAND ${CMAKE_COMMAND} -E env set (ENV_COMMAND ${CMAKE_COMMAND} -E env
"PYTHONPATH=$<TARGET_FILE_DIR:migraphx_pybind_${PYTHON_VERSION}>" "PYTHONPATH=$<TARGET_FILE_DIR:migraphx_pybind_${PYTHON_VERSION}>"
...@@ -32,28 +38,57 @@ function(add_py_test NAME SCRIPT) ...@@ -32,28 +38,57 @@ function(add_py_test NAME SCRIPT)
"MALLOC_CHECK_=3" "MALLOC_CHECK_=3"
) )
set(PYTHON_EXECUTABLE ${PYTHON_${PYTHON_VERSION}_EXECUTABLE}) set(PYTHON_EXECUTABLE ${PYTHON_${PYTHON_VERSION}_EXECUTABLE})
if(NOT TEST py_${PYTHON_VERSION}_${FIXTURE_NAME}_initialize_env)
if (NOT (${FIXTURE_NAME} STREQUAL "onnx" AND ${PYTHON_VERSION} STREQUAL ${PYTHON_VERSION_TO_DISABLE_ONNX}))
add_test(NAME py_${PYTHON_VERSION}_${FIXTURE_NAME}_initialize_env COMMAND ${PYTHON_EXECUTABLE} -m venv ${VIRTUAL_ENV_DIR}/${PYTHON_VERSION} --clear)
set_tests_properties(py_${PYTHON_VERSION}_${FIXTURE_NAME}_initialize_env PROPERTIES FIXTURES_SETUP ${FIXTURE_NAME}_${PYTHON_VERSION}_INIT_VENV)
set(PYTHON_EXECUTABLE ${VIRTUAL_ENV_DIR}/${PYTHON_VERSION}/bin/python)
add_test(
NAME py_${PYTHON_VERSION}_${FIXTURE_NAME}_setup_env
COMMAND ${PYTHON_EXECUTABLE} -m pip install -r ${REQUIREMENTS_FILE})
set_tests_properties(py_${PYTHON_VERSION}_${FIXTURE_NAME}_setup_env PROPERTIES FIXTURES_REQUIRED ${FIXTURE_NAME}_${PYTHON_VERSION}_INIT_VENV)
set_tests_properties(py_${PYTHON_VERSION}_${FIXTURE_NAME}_setup_env PROPERTIES FIXTURES_SETUP ${FIXTURE_NAME}_${PYTHON_VERSION}_VENV)
endif()
endif()
endforeach()
endfunction()
function(add_py_test NAME SCRIPT FIXTURE_NAME VENV_DIR)
foreach(PYTHON_VERSION ${PYTHON_VERSIONS})
set (ENV_COMMAND ${CMAKE_COMMAND} -E env
"PYTHONPATH=$<TARGET_FILE_DIR:migraphx_pybind_${PYTHON_VERSION}>"
"PYTHONMALLOC=debug"
"MALLOC_CHECK_=3"
)
set(PYTHON_EXECUTABLE ${VENV_DIR}/${PYTHON_VERSION}/bin/python)
if(NOT (${FIXTURE_NAME} STREQUAL "onnx" AND ${PYTHON_VERSION} STREQUAL ${PYTHON_VERSION_TO_DISABLE_ONNX}))
add_test( add_test(
NAME test_py_${PYTHON_VERSION}_${NAME} NAME test_py_${PYTHON_VERSION}_${NAME}
COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN}) COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN})
set_tests_properties(test_py_${PYTHON_VERSION}_${NAME} PROPERTIES FIXTURES_REQUIRED ${FIXTURE_NAME}_${PYTHON_VERSION}_VENV)
add_custom_target(test_py_${PYTHON_VERSION}_${NAME} add_custom_target(test_py_${PYTHON_VERSION}_${NAME}
COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN} COMMAND ${ENV_COMMAND} ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${SCRIPT} ${ARGN}
COMMENT "${PYTHON_EXECUTABLE} ${SCRIPT}") COMMENT "${PYTHON_EXECUTABLE} ${SCRIPT}")
endif()
endforeach() endforeach()
endfunction() endfunction()
add_dependencies(tests migraphx_py) add_dependencies(tests migraphx_py)
add_dependencies(check migraphx_py) add_dependencies(check migraphx_py)
add_py_test(ref test_cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_venv_fixture(common ${VENV} ${REQUIREMENTS})
add_py_test(save_load test_save_load.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_venv_fixture(onnx ${VENV_ONNX} ${REQUIREMENTS_ONNX})
add_py_test(op test_op.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(shape test_shape.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(ref test_cpu.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(module_construct test_module_construct.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(save_load test_save_load.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(literal test_literal.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(op test_op.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(shape test_shape.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(module_construct test_module_construct.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(literal test_literal.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_py_test(gpu_offload test_gpu_offload.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(gpu_offload test_gpu_offload.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(gpu test_gpu.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(array test_array.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(array test_array.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(backend onnx_backend_test.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(backend onnx_backend_test.py onnx ${VENV_ONNX} WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(gpu_async test_gpu_async.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(gpu_async test_gpu_async.py common ${VENV} WORKING_DIRECTORY ${TEST_ONNX_DIR})
endif() endif()
##################################################################################### #####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -21,23 +21,9 @@ ...@@ -21,23 +21,9 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
CLANG_FORMAT=/opt/rocm/llvm/bin/clang-format
SRC_DIR=$DIR/../src
PYTHON=python3
if type -p python3.6 > /dev/null ; then
PYTHON=python3.6
fi
if type -p python3.8 > /dev/null ; then
PYTHON=python3.8
fi
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "$PYTHON $DIR/te.py $DIR/include/{} | $CLANG_FORMAT -style=file > $SRC_DIR/include/migraphx/{}"
function api { onnx==1.14.1
$PYTHON $DIR/api.py $SRC_DIR/api/migraphx.py $1 | $CLANG_FORMAT -style=file > $2 protobuf==3.20.2
} numpy==1.21.6
packaging==23.0
api $DIR/api/migraphx.h $SRC_DIR/api/include/migraphx/migraphx.h pytest==6.0.1
echo "Finished generating header migraphx.h" \ No newline at end of file
api $DIR/api/api.cpp $SRC_DIR/api/api.cpp
echo "Finished generating source api.cpp "
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
numpy==1.21.6
\ No newline at end of file
##################################################################################### #####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -21,12 +21,8 @@ ...@@ -21,12 +21,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
import sys
import migraphx import migraphx
try: import numpy as np
import numpy as np
except:
sys.exit()
def test_conv_relu(): def test_conv_relu():
......
##################################################################################### #####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -21,11 +21,8 @@ ...@@ -21,11 +21,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
import migraphx, sys import migraphx
try: import numpy as np
import numpy as np
except:
sys.exit()
def test_add_op(): def test_add_op():
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -669,6 +669,23 @@ TEST_CASE(simplify_inner_broadcast_different_broadcasts) ...@@ -669,6 +669,23 @@ TEST_CASE(simplify_inner_broadcast_different_broadcasts)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_inner_broadcast_no_common_axis)
{
auto b = migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {5, 10}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 5, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(simplify_add_conv1) TEST_CASE(simplify_add_conv1)
{ {
migraphx::module m; migraphx::module m;
......
...@@ -67,6 +67,57 @@ migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens, ...@@ -67,6 +67,57 @@ migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
return m; return m;
} }
TEST_CASE(broadcast_transpose)
{
migraphx::module m1;
{
auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}});
auto mb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l);
auto t1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), mb);
m1.add_return({t1});
}
run_pass(m1);
migraphx::module m2;
{
auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}});
auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l);
auto t1 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), u1);
auto mb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3}}}), t1);
m2.add_return({mb});
}
EXPECT(m1 == m2);
}
TEST_CASE(broadcast_transpose_opt)
{
// extra transpose from transformation will be optimized out
migraphx::module m1;
{
auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}});
auto mb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l);
auto t1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), mb);
m1.add_return({t1});
}
run_pass(m1);
migraphx::module m2;
{
auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}});
auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l);
auto mb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 5}}}), u1);
m2.add_return({mb});
}
EXPECT(m1 == m2);
}
TEST_CASE(broadcast_transpose_scalar) TEST_CASE(broadcast_transpose_scalar)
{ {
migraphx::module m1; migraphx::module m1;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
auto m2_elements = m2_shape.elements();
auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
std::vector<float> zeros(m2_elements, 0);
auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct gemm_softmax_gemm_relu : verify_program<gemm_softmax_gemm_relu>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
auto m2_elements = m2_shape.elements();
auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
std::vector<float> eights(m2_elements, 0.125);
auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
std::vector<float> zeros(m2_elements, 0);
auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), bias);
auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
mm->add_instruction(migraphx::make_op("relu"), gemm2);
return p;
}
};
...@@ -44,8 +44,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DUMP_TEST) ...@@ -44,8 +44,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DUMP_TEST)
// An improved async, that doesn't block // An improved async, that doesn't block
template <class Function> template <class Function>
std::future<typename std::result_of<Function()>::type> detach_async(Function&& f, std::future<std::invoke_result_t<Function>> detach_async(Function&& f, bool parallel = true)
bool parallel = true)
{ {
if(parallel) if(parallel)
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_flatten_dot_relu : verify_program<test_flatten_dot_relu>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto a =
mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 3, 5}});
a = mm->add_instruction(migraphx::make_op("flatten", {{"axis", 3}}), a);
auto b =
mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 3, 1}});
b = mm->add_instruction(migraphx::make_op("flatten", {{"axis", 3}}), b);
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
mm->add_instruction(migraphx::make_op("relu"), dot);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_scatter_nonstandard_shape : verify_program<test_scatter_nonstandard_shape>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {3, 1, 3}, {1, 3, 2}};
migraphx::shape si{migraphx::shape::int32_type, {2, 1, 3}, {1, 3, 2}};
std::vector<int> vi = {1, 0, 2, 0, 2, 1};
migraphx::shape su{migraphx::shape::float_type, {2, 1, 3}, {1, 2, 3}};
auto pd = mm->add_parameter("data", sd);
auto li = mm->add_literal(migraphx::literal{si, vi});
auto pu = mm->add_parameter("update", su);
auto r = mm->add_instruction(migraphx::make_op("scatter_none", {{"axis", -1}}), pd, li, pu);
mm->add_return({r});
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_squeeze_conv_relu : verify_program<test_squeeze_conv_relu>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 1, 3, 3}});
input = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), input);
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_instruction(migraphx::make_op("relu"), conv);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_unsqueeze_conv_relu : verify_program<test_unsqueeze_conv_relu>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {3, 3, 3}});
weights = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), weights);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_instruction(migraphx::make_op("relu"), conv);
return p;
}
};
...@@ -22,4 +22,16 @@ ...@@ -22,4 +22,16 @@
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
add_custom_target(generate bash ${CMAKE_CURRENT_SOURCE_DIR}/generate.sh) find_package(Python 3 COMPONENTS Interpreter)
if(NOT Python_EXECUTABLE)
message(WARNING "Python 3 interpreter not found - skipping 'generate' target!")
return()
endif()
find_program(CLANG_FORMAT clang-format PATHS /opt/rocm/llvm ENV HIP_PATH PATH_SUFFIXES bin)
if(NOT CLANG_FORMAT)
message(WARNING "clang-format not found - skipping 'generate' target!")
return()
endif()
add_custom_target(generate ${Python_EXECUTABLE} generate.py -f ${CLANG_FORMAT} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
...@@ -27,6 +27,7 @@ import re ...@@ -27,6 +27,7 @@ import re
import runpy import runpy
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pathlib import Path
type_map: Dict[str, Callable[['Parameter'], None]] = {} type_map: Dict[str, Callable[['Parameter'], None]] = {}
cpp_type_map: Dict[str, str] = {} cpp_type_map: Dict[str, str] = {}
...@@ -1281,18 +1282,17 @@ def template_eval(template, **kwargs): ...@@ -1281,18 +1282,17 @@ def template_eval(template, **kwargs):
return template return template
def run(args: List[str]) -> None: def run(path: Union[Path, str]) -> str:
runpy.run_path(args[0]) return template_eval(open(path).read())
if len(args) > 1:
f = open(args[1]).read()
r = template_eval(f) if __name__ == "__main__":
sys.modules['api'] = sys.modules['__main__']
runpy.run_path(sys.argv[1])
if len(sys.argv) > 2:
r = run(sys.argv[2])
sys.stdout.write(r) sys.stdout.write(r)
else: else:
sys.stdout.write(generate_c_header()) sys.stdout.write(generate_c_header())
sys.stdout.write(generate_c_api_body()) sys.stdout.write(generate_c_api_body())
# sys.stdout.write(generate_cpp_header()) # sys.stdout.write(generate_cpp_header())
if __name__ == "__main__":
sys.modules['api'] = sys.modules['__main__']
run(sys.argv[1:])
...@@ -38,26 +38,32 @@ ...@@ -38,26 +38,32 @@
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <array>
#include <algorithm> #include <algorithm>
#include <cstdarg> #include <cstdarg>
namespace migraphx { namespace migraphx {
#ifdef MIGRAPHX_BUILD_TESTING
static thread_local bool disable_exception_catch = false; // NOLINT static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool b) extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool b)
{ {
disable_exception_catch = b; disable_exception_catch = b;
} }
#endif
template <class F> template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT migraphx_status try_(F f, bool output = true) // NOLINT
{ {
#ifdef MIGRAPHX_BUILD_TESTING
if(disable_exception_catch) if(disable_exception_catch)
{ {
f(); f();
} }
else else
{ {
#endif
try try
{ {
f(); f();
...@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT ...@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{ {
return migraphx_status_unknown_error; return migraphx_status_unknown_error;
} }
#ifdef MIGRAPHX_BUILD_TESTING
} }
#endif
return migraphx_status_success; return migraphx_status_success;
} }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h> #include <stdbool.h>
#include <stdint.h>
#include <migraphx/api/export.h> #include <migraphx/api/export.h>
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
##################################################################################### #####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -27,11 +27,11 @@ import sys ...@@ -27,11 +27,11 @@ import sys
debug = False debug = False
# The filetypes we want to check for that are stamped # The filetypes we want to check for that are stamped
# LICENSE is included here as it SHOULD have a liscence in it otherwise flag it as unstamped # LICENSE is included here as it SHOULD have a license in it otherwise flag it as unstamped
supported_file_types = (".cpp", ".hpp", ".h", ".ipynb", ".py", ".txt", ".sh", supported_file_types = (".cpp", ".hpp", ".h", ".ipynb", ".py", ".txt", ".sh",
".bsh", "LICENSE", ".cmake") ".bsh", "LICENSE", ".cmake")
#add general stuff we shouldn't stamp and any exceptions here # add general stuff we shouldn't stamp and any exceptions here
unsupported_file_types = [ unsupported_file_types = [
".onnx", ".pb", ".rst", ".jpg", ".jpeg", ".proto", ".md", ".clang", ".onnx", ".pb", ".rst", ".jpg", ".jpeg", ".proto", ".md", ".clang",
".weight", ".ini", ".json", ".docker", ".git", ".rules", ".yml" ".weight", ".ini", ".json", ".docker", ".git", ".rules", ".yml"
...@@ -40,105 +40,89 @@ unsupported_file_types = [ ...@@ -40,105 +40,89 @@ unsupported_file_types = [
specificIgnores = ("digits.txt", "Dockerfile", "Jenkinsfile", "") specificIgnores = ("digits.txt", "Dockerfile", "Jenkinsfile", "")
def hasKeySequence(inputfile, key_message): def hasKeySequence(inputfile: str, key_message: str) -> bool:
result = False
if key_message in inputfile: if key_message in inputfile:
result = True return True
return result return False
#Simple just open and write stuff to each file with the license stamp # Simple just open and write stuff to each file with the license stamp
def openAndCheckFile(filename): def needStampCheck(filename: str) -> bool:
result = False # open save old contents and append things here
#open save old contents and append things here if debug: print("Open", filename, end=' ')
if debug is True:
print("Open", filename, end='')
try: try:
file = open(filename, 'r') file = open(filename, 'r')
except OSError as e: except OSError as e:
if debug is True: if debug: print(str(e) + "....Open Error: Skipping file ")
print(str(e) + "....Open Error: Skipping file ")
file.close() file.close()
return return False
else: else:
with file as contents: with file as contents:
try: try:
save = contents.read() save = contents.read()
hasAmdLic = hasKeySequence(
save, "Advanced Micro Devices, Inc. All rights reserved")
#Check if we have a licence stamp already # Check if we have a license stamp already
if hasAmdLic is True: if hasKeySequence(
if debug is True: save,
print("....Already Stamped: Skipping file ") "Advanced Micro Devices, Inc. All rights reserved"):
if debug: print("....Already Stamped: Skipping file ")
contents.close() contents.close()
result = True return False
except UnicodeDecodeError as eu: except UnicodeDecodeError as eu:
if debug is True: if debug: print(f"{str(eu)}...Skipping binary file ")
print(str(eu) + "...Skipping binary file ")
contents.close() contents.close()
result = True return False
return result return True
# Deterine if filename is desired in the fileTuple past in # Check if any element in fileTuple is in filename
def check_filename(filename, fileTuple): def check_filename(filename: str, fileTuple: tuple or list) -> bool:
supported = False if any([x in filename for x in fileTuple]):
for key in fileTuple: return True
if key in filename: return False
supported = True
break
return supported
def main(): def main() -> None:
unsupported_file_types.extend(specificIgnores) unsupported_file_types.extend(specificIgnores)
#Get a list of all the tracked files in our git repo # Get a list of all the tracked files in our git repo
proc = subprocess.run("git ls-files --exclude-standard", proc = subprocess.run("git ls-files --exclude-standard",
shell=True, shell=True,
stdout=subprocess.PIPE) stdout=subprocess.PIPE)
fileList = proc.stdout.decode().split('\n') fileList = proc.stdout.decode().split('\n')
if debug is True: if debug: print("Target file list:\n" + str(fileList))
print("Target file list:\n" + str(fileList))
unsupportedFiles = [] unsupportedFiles = []
unstampedFiles = [] unstampedFiles = []
unknownFiles = [] unknownFiles = []
for file in fileList: for file in fileList:
supported = check_filename(file, supported_file_types) if check_filename(file, supported_file_types):
if supported is True: if needStampCheck(file):
isStamped = openAndCheckFile(file)
if isStamped is False:
unstampedFiles.append(file) unstampedFiles.append(file)
else: elif check_filename(file, unsupported_file_types):
unsupported = check_filename(file, unsupported_file_types)
if unsupported is True:
unsupportedFiles.append(file) unsupportedFiles.append(file)
else: else:
unknownFiles.append(file) unknownFiles.append(file)
#Do a bunch of checks based on our file lists # Do a bunch of checks based on our file lists
if len(unstampedFiles) > 0: if len(unstampedFiles) > 0:
print("Error: The following " + str(len(unstampedFiles)) + print("\nError: The following " + str(len(unstampedFiles)) +
" files are currently without a license:") " files are currently without a license:")
print(str(unstampedFiles)) print(str(unstampedFiles))
sys.exit(1) sys.exit(1)
if len(unknownFiles) > 0: if len(unknownFiles) > 0:
print("Error: The following " + str(len(unknownFiles)) + print("\nError: The following " + str(len(unknownFiles)) +
" files not handled:") " files not handled:")
print(str(unknownFiles)) print(str(unknownFiles))
sys.exit(2) sys.exit(2)
sys.exit(0)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import api, argparse, os, runpy, subprocess, sys, te
from pathlib import Path
clang_format_path = Path('clang-format.exe' if os.name ==
'nt' else '/opt/rocm/llvm/bin/clang-format')
work_dir = Path().cwd()
src_dir = (work_dir / '../src').absolute()
migraphx_py_path = src_dir / 'api/migraphx.py'
def clang_format(buffer, **kwargs):
return subprocess.run(f'{clang_format_path} -style=file',
capture_output=True,
shell=True,
check=True,
input=buffer.encode('utf-8'),
cwd=work_dir,
**kwargs).stdout.decode('utf-8')
def api_generate(input_path: Path, output_path: Path):
with open(output_path, 'w') as f:
f.write(clang_format(api.run(input_path)))
def te_generate(input_path: Path, output_path: Path):
with open(output_path, 'w') as f:
f.write(clang_format(te.run(input_path)))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--clang-format', type=Path)
args = parser.parse_args()
global clang_format_path
if args.clang_format:
clang_format_path = args.clang_format
if not clang_format_path.is_file():
print(f"{clang_format_path}: invalid path or not installed",
file=sys.stderr)
return
try:
files = Path('include').absolute().iterdir()
for f in [f for f in files if f.is_file()]:
te_generate(f, src_dir / f'include/migraphx/{f.name}')
runpy.run_path(str(migraphx_py_path))
api_generate(work_dir / 'api/migraphx.h',
src_dir / 'api/include/migraphx/migraphx.h')
print('Finished generating header migraphx.h')
api_generate(work_dir / 'api/api.cpp', src_dir / 'api/api.cpp')
print('Finished generating source api.cpp')
except subprocess.CalledProcessError as ex:
if ex.stdout:
print(ex.stdout.decode('utf-8'))
if ex.stderr:
print(ex.stdout.decode('utf-8'))
print(f"Command '{ex.cmd}' returned {ex.returncode}")
raise
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment