Commit dac5563d authored by umangyadav's avatar umangyadav
Browse files

Merge remote-tracking branch 'upstream/develop' into resnet50_partition

parents 95f2cdb9 7e2a550c
......@@ -70,5 +70,13 @@ docs/_images
docs/_static
docs/_templates
docs/.doxygen/docBin
docs/.doxygen/DoxygenWarningLog.txt
docs/_doxygen
docs/html
/_readthedocs
# JetBrains config directories (ignoring symlinks)
.idea/
cmake-build*/
build*/
......@@ -90,7 +90,7 @@ def rocmnodename(name) {
} else if(name == "mi100+") {
node_name = "${rocmtest_name} && (gfx908 || gfx90a) && !vm";
} else if(name == "cdna") {
node_name = "${rocmtest_name} && (gfx908 || gfx90a || vega) && !vm";
node_name = "${rocmtest_name} && (gfx908 || gfx90a || vega20) && !vm";
} else if(name == "nogpu") {
node_name = "${rocmtest_name} && nogpu";
}
......
......@@ -24,7 +24,7 @@
google/protobuf@v3.11.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
nlohmann/json@v3.8.0
live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212
ROCmSoftwarePlatform/half@rocm-5.4.2
ROCmSoftwarePlatform/half@rocm-5.6.0
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
......
......@@ -32,15 +32,17 @@ add_executable(driver
marker_roctx.cpp
)
set_target_properties(driver PROPERTIES OUTPUT_NAME migraphx-driver)
# Copy driver for backwards compatibility
add_custom_command(
if(NOT WIN32)
# Copy driver for backwards compatibility (Linux only)
add_custom_command(
TARGET driver
POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy
$<TARGET_FILE:driver>
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
BYPRODUCTS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver
)
set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver)
)
set_directory_properties(PROPERTIES ADDITIONAL_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver)
endif()
rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf)
......
......@@ -96,7 +96,7 @@ struct allocation_model
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
......@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique())
if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -88,7 +88,7 @@ struct concat_optimization
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
......@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique())
if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -118,7 +118,7 @@ struct context
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
......@@ -373,7 +373,7 @@ struct context
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique())
if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -80,7 +80,7 @@ struct marker
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
......@@ -233,7 +233,7 @@ struct marker
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique())
if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP
......@@ -37,6 +38,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* Parent struct for prefix scan operations. A prefix scan is equivalent to the C++
* std::exclusive_scan or std::inclusive_scan. Given a list of numbers, a prefix scan
* sum op returns an equal size list of running totals of the values. Other operations
* besides addition can be supported by their own child ops.
*/
template <class Derived>
struct prefix_scan_op : op_name<Derived>
{
......@@ -60,9 +67,13 @@ struct prefix_scan_op : op_name<Derived>
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this, true}.has(1);
auto s = inputs.front();
if(s.broadcasted())
if(s.dynamic())
{
return s;
}
else if(s.broadcasted())
{
return {s.type(), s.lens()};
}
......@@ -72,8 +83,9 @@ struct prefix_scan_op : op_name<Derived>
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
shape output_shape(dyn_out.computed_shape);
argument result{output_shape};
auto s = args[0].get_shape();
if(s == output_shape)
......
......@@ -575,7 +575,7 @@ struct operation
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
......@@ -1265,7 +1265,7 @@ struct operation
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique())
if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -116,7 +116,7 @@ struct pass
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
......@@ -292,7 +292,7 @@ struct pass
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique())
if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -99,7 +99,7 @@ struct schedule_model
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
......@@ -274,7 +274,7 @@ struct schedule_model
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique())
if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -100,7 +100,7 @@ struct stream_model
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
......@@ -288,7 +288,7 @@ struct stream_model
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique())
if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -167,7 +167,7 @@ struct target
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
......@@ -428,7 +428,7 @@ struct target
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique())
if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
......@@ -33,7 +33,10 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen")
endif()
find_package(composable_kernel 1.0.0 COMPONENTS jit_library REQUIRED)
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
endif()
if(BUILD_DEV)
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
......@@ -85,6 +88,12 @@ target_link_libraries(kernel_file_check compile_for_gpu)
rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp)
endif()
add_library(migraphx_gpu
abs.cpp
analyze_streams.cpp
......@@ -133,6 +142,7 @@ add_library(migraphx_gpu
write_literals.cpp
${JIT_GPU_SRCS}
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
migraphx_generate_export_header(migraphx_gpu)
......@@ -255,7 +265,11 @@ else()
endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels composable_kernel::jit_library)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
endif()
add_subdirectory(driver)
add_subdirectory(hiprtc)
......
......@@ -140,8 +140,11 @@ void gemm_impl(context& ctx,
compute_type = rocblas_datatype_f32_r;
}
rocblas_gemm_flags flag =
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
rocblas_gemm_flags flag = rocblas_gemm_flags_none;
#if ROCBLAS_VERSION_MAJOR < 3
if(int8_x4_format)
flag = rocblas_gemm_flags_pack_int8x4;
#endif
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
......
......@@ -75,7 +75,9 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
#ifdef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
#endif
struct id_pass
{
......@@ -136,7 +138,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{},
#ifdef _WIN32
enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}),
#endif
dead_code_elimination{},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
dead_code_elimination{},
......
......@@ -36,7 +36,7 @@ endfunction()
function(add_c_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
target_link_libraries(${NAME} migraphx_c migraphx)
target_link_libraries(${NAME} migraphx_c)
target_include_directories(${NAME} PUBLIC ../include)
add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR})
add_dependencies(tests ${NAME})
......
......@@ -384,7 +384,7 @@ bool throws(F f, const std::string& msg = "")
}
template <class T, class U>
auto near(T px, U py, double ptol = 1e-6f)
auto within_abs(T px, U py, double ptol = 1e-6f)
{
return make_function("near", [](auto x, auto y, auto tol) { return std::abs(x - y) < tol; })(
px, py, ptol);
......
......@@ -82,9 +82,9 @@ TEST_CASE(generate_module)
auto f = compile_module<float(float, float)>(m);
EXPECT(test::near(f(2, 2), 2));
EXPECT(test::near(f(10, 6), 4));
EXPECT(test::near(f(1, 2), std::sqrt(3)));
EXPECT(test::within_abs(f(2, 2), 2));
EXPECT(test::within_abs(f(10, 6), 4));
EXPECT(test::within_abs(f(1, 2), std::sqrt(3)));
}
TEST_CASE(generate_module_with_literals)
......@@ -99,9 +99,9 @@ TEST_CASE(generate_module_with_literals)
auto f = compile_module<float(float, float)>(m);
EXPECT(test::near(f(1, 2), 2));
EXPECT(test::near(f(9, 6), 4));
EXPECT(test::near(f(0, 2), std::sqrt(3)));
EXPECT(test::within_abs(f(1, 2), 2));
EXPECT(test::within_abs(f(9, 6), 4));
EXPECT(test::within_abs(f(0, 2), std::sqrt(3)));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment