Commit 264a7647 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Merge branch 'develop' into multinomial_parse_merge

parents d99729f8 8e18544f
...@@ -89,38 +89,23 @@ struct find_reshaper ...@@ -89,38 +89,23 @@ struct find_reshaper
{ {
auto matcher() const auto matcher() const
{ {
return match::name(reshaper_names())( auto reshaper = match::name(reshaper_names());
match::any_of[match::outputs()](match::name(reshaper_names()))); auto contiguous = match::name("contiguous");
auto no_output_reshape = match::none_of[match::outputs()](reshaper);
auto input_reshape = match::arg(0)(match::skip(contiguous)(reshaper));
auto input = match::skip(reshaper, contiguous)(match::any().bind("x"));
return reshaper(no_output_reshape, input_reshape, input);
} }
void apply(module& m, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins}; auto input = mr.instructions["x"];
while(is_reshaper(reshapes.back())) auto dims = ins->get_shape().lens();
{
assert(not reshapes.back()->inputs().empty());
assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
std::pair<instruction_ref, instruction_ref> r{m.end(), m.end()}; if(not input->get_shape().standard())
for(auto start : iterator_for(reshapes)) input = m.insert_instruction(ins, make_op("contiguous"), input);
{ m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{
m.replace_instruction(r.first, r.second);
}
} }
}; };
...@@ -804,9 +789,9 @@ void simplify_reshapes::apply(module& m) const ...@@ -804,9 +789,9 @@ void simplify_reshapes::apply(module& m) const
match::find_matches(m, match::find_matches(m,
find_where_op{}, find_where_op{},
find_resize{}, find_resize{},
find_reshape_cont{},
find_nop_reshapes{}, find_nop_reshapes{},
find_reshaper{}, find_reshaper{},
find_reshape_cont{},
find_transpose{}, find_transpose{},
find_concat_transpose{}, find_concat_transpose{},
find_concat_multibroadcasts{}, find_concat_multibroadcasts{},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/target.hpp>
#include <migraphx/register_target.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void migraphx_to_value(value& v, const target& t) { v["name"] = t.name(); }
void migraphx_from_value(const value& v, target& t)
{
t = make_target(v.at("name").to<std::string>());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -78,6 +78,8 @@ else() ...@@ -78,6 +78,8 @@ else()
endif() endif()
target_link_libraries(migraphx_cpu PRIVATE migraphx) target_link_libraries(migraphx_cpu PRIVATE migraphx)
migraphx_generate_export_header(migraphx_cpu)
find_package(OpenMP) find_package(OpenMP)
target_link_libraries(migraphx_cpu PUBLIC OpenMP::OpenMP_CXX) target_link_libraries(migraphx_cpu PUBLIC OpenMP::OpenMP_CXX)
# Add library path to rpath to workaround issues with our broken packages # Add library path to rpath to workaround issues with our broken packages
......
...@@ -23,14 +23,14 @@ ...@@ -23,14 +23,14 @@
*/ */
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/cpu/dnnl.hpp> #include <migraphx/cpu/dnnl.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/convolution_backwards.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
struct dnnl_deconvolution struct dnnl_deconvolution
: dnnl_extend_op<dnnl_deconvolution, dnnl::deconvolution_forward, op::deconvolution> : dnnl_extend_op<dnnl_deconvolution, dnnl::deconvolution_forward, op::convolution_backwards>
{ {
std::vector<int> arg_map(int) const std::vector<int> arg_map(int) const
{ {
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/cpu/dnnl.hpp> #include <migraphx/cpu/dnnl.hpp>
#include <migraphx/cpu/parallel.hpp> #include <migraphx/cpu/parallel.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/cpu/export.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -24,8 +24,7 @@ ...@@ -24,8 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#define MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP #define MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#include <string> #include <migraphx/cpu/context.hpp>
#include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -34,7 +33,7 @@ struct module; ...@@ -34,7 +33,7 @@ struct module;
namespace cpu { namespace cpu {
struct lowering struct MIGRAPHX_CPU_EXPORT lowering
{ {
std::string name() const { return "cpu::lowering"; } std::string name() const { return "cpu::lowering"; }
void apply(module& m) const; void apply(module& m) const;
......
...@@ -28,14 +28,13 @@ ...@@ -28,14 +28,13 @@
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/cpu/context.hpp> #include <migraphx/cpu/context.hpp>
#include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct pass; struct pass;
namespace cpu { namespace cpu {
struct target struct MIGRAPHX_CPU_EXPORT target
{ {
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraphx::context& gctx, const compile_options&) const; std::vector<pass> get_passes(migraphx::context& gctx, const compile_options&) const;
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/convolution_backwards.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
...@@ -345,7 +345,7 @@ struct cpu_apply ...@@ -345,7 +345,7 @@ struct cpu_apply
extend_op("contiguous", "dnnl::reorder"); extend_op("contiguous", "dnnl::reorder");
extend_op("convolution", "dnnl::convolution"); extend_op("convolution", "dnnl::convolution");
#ifndef MIGRAPHX_ENABLE_ZENDNN #ifndef MIGRAPHX_ENABLE_ZENDNN
extend_op("deconvolution", "dnnl::deconvolution"); extend_op("convolution_backwards", "dnnl::convolution_backwards");
extend_op("dot", "dnnl::dot"); extend_op("dot", "dnnl::dot");
#endif #endif
extend_op("erf", "cpu::erf"); extend_op("erf", "cpu::erf");
......
...@@ -33,7 +33,10 @@ if(NOT TARGET MIOpen) ...@@ -33,7 +33,10 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
find_package(composable_kernel 1.0.0 COMPONENTS jit_library REQUIRED) if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
endif()
if(BUILD_DEV) if(BUILD_DEV)
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs") set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
...@@ -42,12 +45,12 @@ else() ...@@ -42,12 +45,12 @@ else()
endif() endif()
include(Embed) include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS} file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES}) add_embed_library(migraphx_kernels ${KERNEL_FILES})
file(GLOB DEVICE_GPU_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp) file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
add_library(migraphx_device ${DEVICE_GPU_SRCS}) add_library(migraphx_device ${DEVICE_GPU_SRCS})
add_library(compile_for_gpu INTERFACE) add_library(compile_for_gpu INTERFACE)
...@@ -67,6 +70,8 @@ target_link_libraries(migraphx_device PUBLIC migraphx) ...@@ -67,6 +70,8 @@ target_link_libraries(migraphx_device PUBLIC migraphx)
target_link_libraries(migraphx_device PRIVATE compile_for_gpu) target_link_libraries(migraphx_device PRIVATE compile_for_gpu)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
target_compile_options(migraphx_device PRIVATE -Wno-ignored-attributes)
migraphx_generate_export_header(migraphx_device DIRECTORY migraphx/gpu/device)
add_library(kernel_file_check EXCLUDE_FROM_ALL) add_library(kernel_file_check EXCLUDE_FROM_ALL)
...@@ -82,7 +87,13 @@ target_link_libraries(kernel_file_check compile_for_gpu) ...@@ -82,7 +87,13 @@ target_link_libraries(kernel_file_check compile_for_gpu)
rocm_clang_tidy_check(kernel_file_check) rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp) file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp)
endif()
add_library(migraphx_gpu add_library(migraphx_gpu
abs.cpp abs.cpp
analyze_streams.cpp analyze_streams.cpp
...@@ -131,7 +142,9 @@ add_library(migraphx_gpu ...@@ -131,7 +142,9 @@ add_library(migraphx_gpu
write_literals.cpp write_literals.cpp
${JIT_GPU_SRCS} ${JIT_GPU_SRCS}
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
migraphx_generate_export_header(migraphx_gpu)
function(register_migraphx_gpu_ops PREFIX) function(register_migraphx_gpu_ops PREFIX)
foreach(OP ${ARGN}) foreach(OP ${ARGN})
...@@ -173,7 +186,7 @@ register_op(migraphx_gpu ...@@ -173,7 +186,7 @@ register_op(migraphx_gpu
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot> OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp) INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::deconvolution> gpu::miopen_convolution<op::quant_convolution> OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::convolution_backwards> gpu::miopen_convolution<op::quant_convolution>
INCLUDES migraphx/gpu/context.hpp) INCLUDES migraphx/gpu/context.hpp)
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
...@@ -185,7 +198,9 @@ if(MIGRAPHX_ENABLE_MLIR) ...@@ -185,7 +198,9 @@ if(MIGRAPHX_ENABLE_MLIR)
find_package(rocMLIR 1.0.0 CONFIG REQUIRED) find_package(rocMLIR 1.0.0 CONFIG REQUIRED)
message(STATUS "Build with rocMLIR::rockCompiler ${rocMLIR_VERSION}") message(STATUS "Build with rocMLIR::rockCompiler ${rocMLIR_VERSION}")
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR") target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR")
target_link_libraries(migraphx_gpu PUBLIC rocMLIR::rockCompiler) # Make this private to avoid multiple inclusions of LLVM symbols.
# TODO: Fix rocMLIR's library to hide LLVM internals.
target_link_libraries(migraphx_gpu PRIVATE rocMLIR::rockCompiler)
endif() endif()
if(MIGRAPHX_USE_HIPRTC) if(MIGRAPHX_USE_HIPRTC)
...@@ -231,7 +246,12 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_ ...@@ -231,7 +246,12 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "") set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
if(MIGRAPHX_USE_FIND_2_API) if(MIGRAPHX_USE_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API) check_library_exists(MIOpen "miopenSetFindOptionPreallocatedTensor" "${MIOPEN_LOCATION}" HAS_PREALLOCATION_API)
if(HAS_PREALLOCATION_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API -DMIGRAPHX_PREALLOCATE_MIOPEN_BUFFERS)
else()
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
endif()
message(STATUS "MIGraphx is using Find-2.0 API of MIOpen") message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
else() else()
message(STATUS "MIGraphx is using legacy Find API in MIOpen") message(STATUS "MIGraphx is using legacy Find API in MIOpen")
...@@ -245,7 +265,11 @@ else() ...@@ -245,7 +265,11 @@ else()
endif() endif()
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels composable_kernel::jit_library) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(NOT WIN32)
# TODO: re-enable when CK is ported to Windows
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
endif()
add_subdirectory(driver) add_subdirectory(driver)
add_subdirectory(hiprtc) add_subdirectory(hiprtc)
......
...@@ -135,14 +135,13 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) ...@@ -135,14 +135,13 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t max_global = ctx.get_current_device().get_cu_count() * std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu(); ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) { return [n, over, max_global](std::size_t local) {
std::size_t num_elements = n; // hip require global workitems multiple of local workitems. It may degrade performance.
// [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available.
// https://reviews.llvm.org/D155213
std::size_t num_elements = ((n + local - 1) / local) * local;
std::size_t groups = (num_elements + local - 1) / local; std::size_t groups = (num_elements + local - 1) / local;
std::size_t max_blocks = max_global / local; std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local; std::size_t nglobal = std::min(max_blocks * over, groups) * local;
#ifdef MIGRAPHX_USE_HIPRTC
if(enabled(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{}))
num_elements = ((num_elements + local - 1) / local) * local;
#endif
return std::min(nglobal, num_elements); return std::min(nglobal, num_elements);
}; };
} }
......
...@@ -79,7 +79,7 @@ void compile_miopen::apply(module& m) const ...@@ -79,7 +79,7 @@ void compile_miopen::apply(module& m) const
std::size_t ws = 0; std::size_t ws = 0;
try try
{ {
// for the regular convolution and deconvolution, this try would always succeed // for the regular convolution and convolution_backwards, this try would always succeed
ws = compile(op, ins, int8_x4_format); ws = compile(op, ins, int8_x4_format);
} }
catch(migraphx::exception&) catch(migraphx::exception&)
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) file(GLOB GPU_DRIVER_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
add_executable(gpu-driver add_executable(gpu-driver
${GPU_DRIVER_SRCS} ${GPU_DRIVER_SRCS}
) )
......
...@@ -216,6 +216,7 @@ struct find_mlir_op ...@@ -216,6 +216,7 @@ struct find_mlir_op
"quant_dot", "quant_dot",
"add", "add",
"clip", "clip",
"relu",
"sub", "sub",
"mul", "mul",
"div", "div",
......
...@@ -140,8 +140,11 @@ void gemm_impl(context& ctx, ...@@ -140,8 +140,11 @@ void gemm_impl(context& ctx,
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
rocblas_gemm_flags flag = rocblas_gemm_flags flag = rocblas_gemm_flags_none;
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none; #if ROCBLAS_VERSION_MAJOR < 3
if(int8_x4_format)
flag = rocblas_gemm_flags_pack_int8x4;
#endif
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_GPU_ALLOCATION_MODEL_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_GPU_ALLOCATION_MODEL_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_GPU_ALLOCATION_MODEL_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_GPU_ALLOCATION_MODEL_HPP
#include <migraphx/config.hpp> #include <migraphx/gpu/config.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <string> #include <string>
...@@ -33,7 +33,7 @@ namespace migraphx { ...@@ -33,7 +33,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct gpu_allocation_model struct MIGRAPHX_GPU_EXPORT gpu_allocation_model
{ {
std::string name() const; std::string name() const;
std::string copy() const; std::string copy() const;
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_ANALYZE_STREAMS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_GPU_ANALYZE_STREAMS_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_ANALYZE_STREAMS_HPP #define MIGRAPHX_GUARD_RTGLIB_GPU_ANALYZE_STREAMS_HPP
#include <migraphx/config.hpp> #include <migraphx/gpu/config.hpp>
#include <migraphx/analyze_streams.hpp> #include <migraphx/analyze_streams.hpp>
namespace migraphx { namespace migraphx {
...@@ -34,7 +34,7 @@ struct module; ...@@ -34,7 +34,7 @@ struct module;
namespace gpu { namespace gpu {
std::vector<stream_race> analyze_streams(const module& m); MIGRAPHX_GPU_EXPORT std::vector<stream_race> analyze_streams(const module& m);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_COMPILE_HIP_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_COMPILE_HIP_HPP
#define MIGRAPHX_GUARD_RTGLIB_COMPILE_HIP_HPP #define MIGRAPHX_GUARD_RTGLIB_COMPILE_HIP_HPP
#include <migraphx/config.hpp> #include <migraphx/gpu/config.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/compile_src.hpp> #include <migraphx/compile_src.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
...@@ -58,14 +58,13 @@ struct hiprtc_src_file ...@@ -58,14 +58,13 @@ struct hiprtc_src_file
} }
}; };
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_src_file> srcs, MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc(
std::string params, std::vector<hiprtc_src_file> srcs, std::string params, const std::string& arch);
const std::string& arch);
std::vector<std::vector<char>> MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch); compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch);
std::string enum_params(std::size_t count, std::string param); MIGRAPHX_GPU_EXPORT std::string enum_params(std::size_t count, std::string param);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_HIP_CODE_OBJECT_HPP #ifndef MIGRAPHX_GUARD_GPU_COMPILE_HIP_CODE_OBJECT_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_HIP_CODE_OBJECT_HPP #define MIGRAPHX_GUARD_GPU_COMPILE_HIP_CODE_OBJECT_HPP
#include <migraphx/config.hpp> #include <migraphx/gpu/config.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/compile_src.hpp> #include <migraphx/compile_src.hpp>
...@@ -66,14 +66,16 @@ struct hip_compile_options ...@@ -66,14 +66,16 @@ struct hip_compile_options
}; };
/// Compute global for n elements, but max out on target-specific upper limit /// Compute global for n elements, but max out on target-specific upper limit
std::function<std::size_t(std::size_t local)> MIGRAPHX_GPU_EXPORT std::function<std::size_t(std::size_t local)>
compute_global_for(context& ctx, std::size_t n, std::size_t over = 1); compute_global_for(context& ctx, std::size_t n, std::size_t over = 1);
operation compile_hip_code_object(const std::string& content, hip_compile_options options); MIGRAPHX_GPU_EXPORT operation compile_hip_code_object(const std::string& content,
hip_compile_options options);
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size = 1024); MIGRAPHX_GPU_EXPORT std::size_t compute_block_size(std::size_t n,
std::size_t max_block_size = 1024);
std::string generate_make_shape(const shape& s); MIGRAPHX_GPU_EXPORT std::string generate_make_shape(const shape& s);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_OPS_HPP #ifndef MIGRAPHX_GUARD_GPU_COMPILE_OPS_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_OPS_HPP #define MIGRAPHX_GUARD_GPU_COMPILE_OPS_HPP
#include <migraphx/config.hpp> #include <migraphx/gpu/config.hpp>
#include <string> #include <string>
namespace migraphx { namespace migraphx {
...@@ -36,7 +36,7 @@ namespace gpu { ...@@ -36,7 +36,7 @@ namespace gpu {
struct context; struct context;
struct compile_ops struct MIGRAPHX_GPU_EXPORT compile_ops
{ {
context* ctx = nullptr; context* ctx = nullptr;
bool exhaustive_tune = false; bool exhaustive_tune = false;
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_GPU_COMPILER_HPP #ifndef MIGRAPHX_GUARD_GPU_COMPILER_HPP
#define MIGRAPHX_GUARD_GPU_COMPILER_HPP #define MIGRAPHX_GUARD_GPU_COMPILER_HPP
#include <migraphx/config.hpp> #include <migraphx/gpu/config.hpp>
#include <migraphx/auto_register.hpp> #include <migraphx/auto_register.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
...@@ -81,17 +81,21 @@ using compiler_compile_op = ...@@ -81,17 +81,21 @@ using compiler_compile_op =
using compiler_tuning_config = using compiler_tuning_config =
std::function<optional<tuning_config>(context&, instruction_ref, const operation&, bool)>; std::function<optional<tuning_config>(context&, instruction_ref, const operation&, bool)>;
void register_compiler(const std::string& name, MIGRAPHX_GPU_EXPORT void register_compiler(const std::string& name,
compiler_compile c, compiler_compile c,
compiler_compile_op cop, compiler_compile_op cop,
compiler_tuning_config ctg); compiler_tuning_config ctg);
bool has_compiler_for(const std::string& name); MIGRAPHX_GPU_EXPORT bool has_compiler_for(const std::string& name);
compiler_replace MIGRAPHX_GPU_EXPORT compiler_replace compile(context& ctx,
compile(context& ctx, instruction_ref ins, const operation& op, const value& solution); instruction_ref ins,
operation const operation& op,
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v); const value& solution);
optional<tuning_config> MIGRAPHX_GPU_EXPORT operation compile_op(const std::string& name,
context& ctx,
const std::vector<shape>& inputs,
const value& v);
MIGRAPHX_GPU_EXPORT optional<tuning_config>
get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive); get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive);
template <class T> template <class T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment