"sgl-kernel/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "af6535e7aaf5c1e9352149f0edfde37d977cd473"
Unverified Commit 40fbef9b authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into threaded_nms

parents d164b151 aeb9f78c
...@@ -28,10 +28,6 @@ ...@@ -28,10 +28,6 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/math_functions.h> #include <hip/math_functions.h>
#include <hip/hip_math_constants.h>
#elif defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS)
#include <hip/hip_common.h>
#include <hip/hip_math_constants.h>
#endif #endif
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP #endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
...@@ -130,6 +130,8 @@ struct index ...@@ -130,6 +130,8 @@ struct index
return blockDim.x; return blockDim.x;
} }
#endif #endif
constexpr auto ngroup() const { return nglobal() / max_nlocal(); }
template <class N, class Stride> template <class N, class Stride>
static constexpr auto max_stride_iterations(N n, Stride stride) static constexpr auto max_stride_iterations(N n, Stride stride)
{ {
...@@ -231,6 +233,12 @@ struct index ...@@ -231,6 +233,12 @@ struct index
{ {
for_stride<true>(local, n, nlocal(), f); for_stride<true>(local, n, nlocal(), f);
} }
template <class F, class N>
__device__ void group_stride(N n, F f) const
{
for_stride<false>(group, n, ngroup(), f);
}
}; };
#ifdef MIGRAPHX_NLOCAL #ifdef MIGRAPHX_NLOCAL
......
...@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor) ...@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt)
// Use float to compute half overload // Use float to compute half overload
...@@ -161,8 +161,7 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod) ...@@ -161,8 +161,7 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// Map math functions to hip half2 functions // Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats // The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names // packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
// Most but not all of these math ops have operators of the same names. Ones not yet implemented // Most but not all of these math ops have operators of the same names.
// at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2) MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2)
MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil) MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil)
MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos) MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos)
...@@ -176,7 +175,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log) ...@@ -176,7 +175,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10) MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10)
MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2) MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2)
MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt) MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt)
// MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin) MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt) MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt)
template <class T, class U> template <class T, class U>
...@@ -189,9 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max) ...@@ -189,9 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min) MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max) MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min) MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
// Add overloads for half that calls the float version MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::fmaxf) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::fminf)
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())> template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto max(const T& a, const T& b) constexpr auto max(const T& a, const T& b)
...@@ -217,14 +215,6 @@ constexpr auto min(const T& a, const U& b) ...@@ -217,14 +215,6 @@ constexpr auto min(const T& a, const U& b)
return min<common_type_t<T, U>>(a, b); return min<common_type_t<T, U>>(a, b);
} }
// Sin for half is broken on hip, so use cos instead
template <class T, MIGRAPHX_REQUIRES(is_same<vec_type<T>, half>{})>
constexpr T sin(T x)
{
constexpr const T shift = HIP_PIO2_F;
return migraphx::cos(shift - x);
}
MIGRAPHX_DEVICE_MATH_VEC(abs) MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos) MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh) MIGRAPHX_DEVICE_MATH_VEC(acosh)
......
...@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs) ...@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs)
template <class... Ts> template <class... Ts>
__device__ void println(Ts... xs) __device__ void println(Ts... xs)
{ {
print_each(&coutln, xs...); print_each(&cout, xs..., '\n');
} }
template <class... Ts> template <class... Ts>
__device__ void println_once(Ts... xs) __device__ void println_once(Ts... xs)
{ {
print_each_once(&coutln, xs...); print_each_once(&cout, xs..., '\n');
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -79,20 +79,21 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -79,20 +79,21 @@ __device__ void dpp_reduce(T& in, Op op)
#endif #endif
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix) \ #define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \ __device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \ __device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \ __device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(int32_t& x, op) \ __device__ inline void dpp_reduce(int32_t& x, op) \
{ \ { \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); \ MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32); \
} \ } \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); } __device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
MIGRAPHX_DPP_REDUCE(op::sum, v_add) // Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE(op::max, v_max) MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u)
MIGRAPHX_DPP_REDUCE(op::min, v_min) MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
MIGRAPHX_DPP_REDUCE(op::product, v_mul) MIGRAPHX_DPP_REDUCE(op::max, v_max, _i)
MIGRAPHX_DPP_REDUCE(op::min, v_min, _i)
template <class Op, class T, class Index, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
...@@ -570,7 +571,7 @@ template <class Algo, class Reduced, class Output, class F> ...@@ -570,7 +571,7 @@ template <class Algo, class Reduced, class Output, class F>
__device__ void fused_reduce(Output output, F f) __device__ void fused_reduce(Output output, F f)
{ {
Algo::template run<Reduced>([&](auto out_idx, auto r) { Algo::template run<Reduced>([&](auto out_idx, auto r) {
auto result = f(r); auto result = f(r, out_idx);
if constexpr(reduce::is_inner_storage<decltype(result)>{}) if constexpr(reduce::is_inner_storage<decltype(result)>{})
{ {
r.inner([&](auto& y, auto x) { y = x; })(output, result); r.inner([&](auto& y, auto x) { y = x; })(output, result);
......
...@@ -218,7 +218,15 @@ using common_type_t = typename common_type<Ts...>::type; ...@@ -218,7 +218,15 @@ using common_type_t = typename common_type<Ts...>::type;
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__> #define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
constexpr unsigned long int_max(unsigned long n) { return (1u << (n * 8)) - 1; } constexpr unsigned long int_max(unsigned long n)
{
// Note, left shift cannot be used to get the maximum value of int64_type or
// uint64_type because it is undefined behavior to left shift 64 bits for
// these types
if(n == sizeof(int64_t))
return -1;
return (1ul << (n * 8)) - 1;
}
template <class T, template <class T,
MIGRAPHX_REQUIRES(is_integral<T>{} or is_floating_point<T>{} or MIGRAPHX_REQUIRES(is_integral<T>{} or is_floating_point<T>{} or
...@@ -228,9 +236,9 @@ constexpr T numeric_max() ...@@ -228,9 +236,9 @@ constexpr T numeric_max()
if constexpr(is_integral<T>{}) if constexpr(is_integral<T>{})
{ {
if constexpr(is_unsigned<T>{}) if constexpr(is_unsigned<T>{})
return int_max(sizeof(T)) * 2;
else
return int_max(sizeof(T)); return int_max(sizeof(T));
else
return int_max(sizeof(T)) / 2;
} }
else if constexpr(is_same<T, double>{}) else if constexpr(is_same<T, double>{})
return __DBL_MAX__; return __DBL_MAX__;
......
...@@ -135,7 +135,7 @@ constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i) ...@@ -135,7 +135,7 @@ constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i)
return vec<T, N>{x}; return vec<T, N>{x};
else else
{ {
MIGRAPHX_ASSERT((i + N) < vec_size<T>()); MIGRAPHX_ASSERT((i + N) <= vec_size<T>());
vec<vec_type<T>, N> result = {0}; vec<vec_type<T>, N> result = {0};
for(int j = 0; j < N; j++) for(int j = 0; j < N; j++)
{ {
......
...@@ -22,12 +22,19 @@ ...@@ -22,12 +22,19 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <iterator> #include <iterator>
#include <migraphx/gpu/lowering.hpp> #include <utility>
#include <functional>
#include <algorithm>
#include <map>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
...@@ -35,17 +42,12 @@ ...@@ -35,17 +42,12 @@
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <utility>
#include <functional>
#include <algorithm>
#include <map>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -53,8 +55,9 @@ namespace gpu { ...@@ -53,8 +55,9 @@ namespace gpu {
struct miopen_apply struct miopen_apply
{ {
module* mod = nullptr; module* mod = nullptr;
const lowering* pass = nullptr; module_pass_manager* mpm = nullptr;
const lowering* pass = nullptr;
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{}; std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
instruction_ref last{}; instruction_ref last{};
bool offload_copy = false; bool offload_copy = false;
...@@ -83,7 +86,7 @@ struct miopen_apply ...@@ -83,7 +86,7 @@ struct miopen_apply
auto& ctx = get_context(); auto& ctx = get_context();
int8_x4_format = get_int8_x4_format(ctx); int8_x4_format = get_int8_x4_format(ctx);
compute_fp32 = get_compute_fp32_flag(); compute_fp32 = get_compute_fp32_flag();
offload_copy = (mod->name() == "main") ? pass->offload_copy : false; offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;
add_generic_op("contiguous"); add_generic_op("contiguous");
...@@ -103,7 +106,7 @@ struct miopen_apply ...@@ -103,7 +106,7 @@ struct miopen_apply
add_extend_op("topk"); add_extend_op("topk");
add_convolution_op("convolution"); add_convolution_op("convolution");
add_convolution_op("deconvolution"); add_convolution_op("convolution_backwards");
add_convolution_op("quant_convolution"); add_convolution_op("quant_convolution");
add_gemm_op<op::dot>("dot"); add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot"); add_gemm_op<op::quant_dot>("quant_dot");
...@@ -375,7 +378,10 @@ struct miopen_apply ...@@ -375,7 +378,10 @@ struct miopen_apply
} }
}; };
void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); } void lowering::apply(module_pass_manager& mpm) const
{
miopen_apply{&mpm.get_module(), &mpm, this}.apply();
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/perfdb.hpp> #include <migraphx/gpu/perfdb.hpp>
#include <migraphx/gpu/tuning_config.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <deque> #include <deque>
...@@ -121,7 +122,10 @@ struct mlir_handle ...@@ -121,7 +122,10 @@ struct mlir_handle
#define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT #define MIGRAPHX_MANAGE_MLIR_HANDLE(T, F) migraphx::gpu::mlir_handle<T, decltype(&F), &F> // NOLINT
using mlir_context = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy); using mlir_context = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirContext, mlirContextDestroy);
using mlir_thread_pool = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirLlvmThreadPool, mlirLlvmThreadPoolDestroy);
using mlir_dialect_registry = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirDialectRegistry,
mlirDialectRegistryDestroy);
using mlir_module = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirModule, mlirModuleDestroy); using mlir_module = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirModule, mlirModuleDestroy);
using mlir_operation = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOperation, mlirOperationDestroy); using mlir_operation = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOperation, mlirOperationDestroy);
using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags, using mlir_op_printing_flags = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirOpPrintingFlags,
...@@ -131,6 +135,10 @@ using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockD ...@@ -131,6 +135,10 @@ using mlir_block = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirBlock, mlirBlockD
using mlir_pass_manager = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirPassManager, mlirPassManagerDestroy); using mlir_pass_manager = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirPassManager, mlirPassManagerDestroy);
using mlir_tuning_table = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningTable, using mlir_tuning_table = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningTable,
mlirRockTuningTableDestroy); mlirRockTuningTableDestroy);
using mlir_tuning_space = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningSpace,
mlirRockTuningSpaceDestroy);
using mlir_tuning_param = MIGRAPHX_MANAGE_MLIR_HANDLE(MlirRockTuningParam,
mlirRockTuningParamDestroy);
std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; } std::string_view to_string_view(MlirStringRef s) { return {s.data, s.length}; }
...@@ -164,25 +172,47 @@ std::string mlir_print(F f, T x) ...@@ -164,25 +172,47 @@ std::string mlir_print(F f, T x)
return ss.str(); return ss.str();
} }
const std::unordered_set<std::string>& get_xdlops_archs() bool has_xdlops(const std::string& target_arch)
{ {
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"}; const auto device_name = trim(split_string(target_arch, ':').front());
return supported_archs; return (starts_with(device_name, "gfx9") and device_name >= "gfx908");
} }
struct mlir_program struct mlir_program
{ {
mlir_program() mlir_program()
: ctx(mlirContextCreate()), : ctx(mlirContextCreateWithRegistry(get_dialect_registry().get(),
/*threadingEnable=*/false)),
location(mlirLocationUnknownGet(ctx.get())), location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location)) mmodule(mlirModuleCreateEmpty(location))
{ {
MlirDialectRegistry registry = mlirDialectRegistryCreate(); mlirContextSetThreadPool(ctx.get(), get_thread_pool().get());
mlirRegisterRocMLIRDialects(registry);
mlirContextAppendDialectRegistry(ctx.get(), registry);
mlirContextLoadAllAvailableDialects(ctx.get()); mlirContextLoadAllAvailableDialects(ctx.get());
mlirDialectRegistryDestroy(registry); }
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
static mlir_dialect_registry& get_dialect_registry()
{
static std::once_flag init_guard;
static mlir_dialect_registry the_registry;
// The MLIR registration functions (for dialects and passes) are not
// necessarily thread-safe and need to be executed exactly once
// (especially since they eventually call non-thread-safe LLVM
// initilizations).
std::call_once(init_guard, [&]() {
the_registry = mlirDialectRegistryCreate();
mlirRegisterRocMLIRDialects(the_registry.get());
mlirRegisterRocMLIRPasses();
});
return the_registry;
}
static mlir_thread_pool& get_thread_pool()
{
// To save on overhead, we create one LLVM thread pool and reuse it
// across all MLIR contexts as recommended by MLIR upstream.
// Note that this is thread-safe as of C++11.
static mlir_thread_pool the_pool = mlirLlvmThreadPoolCreate();
return the_pool;
} }
MlirType make_type(shape::type_t t) const MlirType make_type(shape::type_t t) const
...@@ -244,8 +274,6 @@ struct mlir_program ...@@ -244,8 +274,6 @@ struct mlir_program
MlirAttribute attribute(std::int64_t i) const MlirAttribute attribute(std::int64_t i) const
{ {
if(i < 0)
MIGRAPHX_THROW("MLIR cant handle negative values since they are ambiguous");
return mlirIntegerAttrGet(mlirIntegerTypeGet(ctx.get(), 64), i); return mlirIntegerAttrGet(mlirIntegerTypeGet(ctx.get(), 64), i);
} }
MlirAttribute attribute(std::uint64_t i) const MlirAttribute attribute(std::uint64_t i) const
...@@ -324,7 +352,8 @@ struct mlir_program ...@@ -324,7 +352,8 @@ struct mlir_program
std::string, std::string,
value, value,
std::vector<value>, std::vector<value>,
MlirType>; MlirType,
MlirAttribute>;
using named_attribute_t = std::pair<std::string_view, attribute_t>; using named_attribute_t = std::pair<std::string_view, attribute_t>;
MlirNamedAttribute name_attribute(const named_attribute_t& na) const MlirNamedAttribute name_attribute(const named_attribute_t& na) const
...@@ -365,14 +394,20 @@ struct mlir_program ...@@ -365,14 +394,20 @@ struct mlir_program
mlir_operation_state& add_attributes(const std::vector<named_attribute_t>& named_attrs) mlir_operation_state& add_attributes(const std::vector<named_attribute_t>& named_attrs)
{ {
auto attributes = prog->name_attributes(named_attrs); auto attributes = prog->name_attributes(named_attrs);
mlirOperationStateAddAttributes(&op_state, attributes.size(), attributes.data()); if(not attributes.empty())
{
mlirOperationStateAddAttributes(&op_state, attributes.size(), attributes.data());
}
return *this; return *this;
} }
mlir_operation_state& add_attribute_value(const value& v) mlir_operation_state& add_attribute_value(const value& v)
{ {
auto attributes = prog->name_attributes(v); auto attributes = prog->name_attributes(v);
mlirOperationStateAddAttributes(&op_state, attributes.size(), attributes.data()); if(not attributes.empty())
{
mlirOperationStateAddAttributes(&op_state, attributes.size(), attributes.data());
}
return *this; return *this;
} }
...@@ -395,13 +430,19 @@ struct mlir_program ...@@ -395,13 +430,19 @@ struct mlir_program
return shape{r.type(), r.lens()}; return shape{r.type(), r.lens()};
}); });
auto x = prog->make_tensors(reshaped); auto x = prog->make_tensors(reshaped);
mlirOperationStateAddResults(&op_state, x.size(), x.data()); if(not x.empty())
{
mlirOperationStateAddResults(&op_state, x.size(), x.data());
}
return *this; return *this;
} }
mlir_operation_state& add_operands(const std::vector<MlirValue>& inputs) mlir_operation_state& add_operands(const std::vector<MlirValue>& inputs)
{ {
mlirOperationStateAddOperands(&op_state, inputs.size(), inputs.data()); if(not inputs.empty())
{
mlirOperationStateAddOperands(&op_state, inputs.size(), inputs.data());
}
return *this; return *this;
} }
...@@ -411,7 +452,10 @@ struct mlir_program ...@@ -411,7 +452,10 @@ struct mlir_program
std::transform(regions.begin(), regions.end(), mregions.begin(), [](const auto& r) { std::transform(regions.begin(), regions.end(), mregions.begin(), [](const auto& r) {
return r.get(); return r.get();
}); });
mlirOperationStateAddOwnedRegions(&op_state, mregions.size(), mregions.data()); if(not mregions.empty())
{
mlirOperationStateAddOwnedRegions(&op_state, mregions.size(), mregions.data());
}
mlir_operation op(mlirOperationCreate(&op_state)); mlir_operation op(mlirOperationCreate(&op_state));
// Release memory since mlir_operation owns it // Release memory since mlir_operation owns it
for(auto& r : regions) for(auto& r : regions)
...@@ -481,6 +525,10 @@ struct mlir_program ...@@ -481,6 +525,10 @@ struct mlir_program
{ {
if(ins->name() == "@return") if(ins->name() == "@return")
return "func.return"; return "func.return";
if(ins->name() == "@literal")
{
return "tosa.const";
}
return "migraphx." + ins->name(); return "migraphx." + ins->name();
} }
...@@ -532,19 +580,30 @@ struct mlir_program ...@@ -532,19 +580,30 @@ struct mlir_program
{ {
if(ins->name() == "@param") if(ins->name() == "@param")
continue; continue;
if(ins->name() == "contiguous")
{
ins_map[ins] = ins_map[ins->inputs().at(0)];
continue;
}
auto name = get_name(ins); auto name = get_name(ins);
auto ops = create_operation_state(name); auto ops = create_operation_state(name);
ops.add_attribute_value(get_operator_value(ins->get_operator())); ops.add_attribute_value(get_operator_value(ins->get_operator()));
if(ins->name() != "@return") if(ins->name() != "@return")
ops.add_results({get_shape(ins)}); ops.add_results({get_shape(ins)});
if(ins->name() == "@literal")
{
literal r = ins->get_literal();
MlirType tensor_type = make_tensor(ins->get_shape());
MlirAttribute mlir_value_attr =
mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data());
ops.add_attributes({{"value", mlir_value_attr}});
}
if(ins->name() == "convolution" or ins->name() == "dot") if(ins->name() == "convolution" or ins->name() == "dot")
{ {
pp = pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()}; problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops // check if HW supports xdlops
auto target_chip = trim(split_string(target_arch, ':').front()); if(has_xdlops(target_arch))
bool xdlops = contains(get_xdlops_archs(), target_chip);
if(xdlops)
ops.add_attributes({{"xdlopsV2", true}}); ops.add_attributes({{"xdlopsV2", true}});
} }
...@@ -562,18 +621,30 @@ struct mlir_program ...@@ -562,18 +621,30 @@ struct mlir_program
} }
} }
code_object_op compile() MIGRAPHX_TIDY_CONST void run_high_level_pipeline() MIGRAPHX_TIDY_CONST
{ {
mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())}; mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())};
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
// 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm_front.get()); mlirMIGraphXAddHighLevelPipeline(pm_front.get());
mlirPassManagerRun(pm_front.get(), mmodule.get()); mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get()));
}
// 2nd pipeline to call void run_backend_pipeline() MIGRAPHX_TIDY_CONST
get_module_tuned(); {
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str()); mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str());
mlirPassManagerRun(pm_back.get(), mmodule.get()); mlirPassManagerRunOnOp(pm_back.get(), mlirModuleGetOperation(mmodule.get()));
}
code_object_op compile(const value& solution) MIGRAPHX_TIDY_CONST
{
// 1st pipeline to call
run_high_level_pipeline();
if(solution.is_null())
get_module_tuned();
else
set_tuning(solution);
// 2nd pipeline to call
run_backend_pipeline();
code_object_op op{}; code_object_op op{};
op.symbol_name = sym_name; op.symbol_name = sym_name;
...@@ -604,6 +675,33 @@ struct mlir_program ...@@ -604,6 +675,33 @@ struct mlir_program
MIGRAPHX_THROW("Failed to compile mlir program"); MIGRAPHX_THROW("Failed to compile mlir program");
} }
void set_tuning(const value& v)
{
auto str = v.to<std::string>();
// We need to make a copy of the buffer since mlirRockTuningSetFromStr may modify the string
std::vector<char> buffer(str.begin(), str.end());
buffer.push_back(0);
if(not mlirRockTuningSetFromStr(mmodule.get(), buffer.data()))
MIGRAPHX_THROW("Failed setting tuning key: " + str);
}
tuning_config get_tuning_config() MIGRAPHX_TIDY_CONST
{
tuning_config tc;
run_high_level_pipeline();
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get())};
for(auto i : range(mlirRockTuningGetNumParamsFull(params.get())))
{
mlir_tuning_param param{mlirRockTuningParamCreate()};
if(not mlirRockTuningParamGet(params.get(), i, param.get()))
MIGRAPHX_THROW("Incorrect mlir tuning parameter: " + std::to_string(i));
tc.solutions.push_back(std::string{mlirRockTuningGetParamStr(param.get())});
}
mlir_tuning_table tuning_table{mlirRockTuningTableCreate()};
tc.problem = std::string{mlirRockTuningGetKey(tuning_table.get(), mmodule.get())};
return tc;
}
std::string get_tune_params(bool xdlops) const { return get_mlir_perf_for_conv(pp, xdlops); } std::string get_tune_params(bool xdlops) const { return get_mlir_perf_for_conv(pp, xdlops); }
// This function appends to tuning cfg file that could be // This function appends to tuning cfg file that could be
...@@ -662,6 +760,11 @@ struct mlir_program ...@@ -662,6 +760,11 @@ struct mlir_program
bool get_module_tuned() const bool get_module_tuned() const
{ {
static mlir_tuning_table tuning_table = create_tuning_table(); static mlir_tuning_table tuning_table = create_tuning_table();
// The tuning table as currently implemented is currently not
// thread safe. This will be fixed in the future. For now,
// stick a mutex around all tuning table interaction.
static std::mutex lock;
std::lock_guard<std::mutex> guard(lock);
if(!mlirRockTuningSetFromTable(tuning_table.get(), mmodule.get())) if(!mlirRockTuningSetFromTable(tuning_table.get(), mmodule.get()))
{ {
const char* prob_config = mlirRockTuningGetKey(tuning_table.get(), mmodule.get()); const char* prob_config = mlirRockTuningGetKey(tuning_table.get(), mmodule.get());
...@@ -690,14 +793,14 @@ std::string dump_mlir(const module& m) ...@@ -690,14 +793,14 @@ std::string dump_mlir(const module& m)
return mlir_print(&mlirOperationPrint, mod_op); return mlir_print(&mlirOperationPrint, mod_op);
} }
void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs) void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
{ {
auto names = m.get_parameter_names(); auto names = m.get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
for(auto i : range(names.size())) for(auto i : range(names.size()))
{ {
const auto& name = names[i]; const auto& name = names[i];
const auto& input = inputs[i]->get_shape(); const auto& input = inputs[i];
auto param = m.get_parameter(name); auto param = m.get_parameter(name);
if(input.standard()) if(input.standard())
continue; continue;
...@@ -735,24 +838,26 @@ void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs) ...@@ -735,24 +838,26 @@ void adjust_param_shapes(module& m, const std::vector<instruction_ref>& inputs)
} }
} }
code_object_op compile_mlir(const context&, module m, const std::vector<instruction_ref>& inputs) code_object_op compile_mlir(const context&,
module m,
const std::vector<instruction_ref>& inputs,
const value& solution)
{ {
adjust_param_shapes(m, inputs); adjust_param_shapes(m, to_shapes(inputs));
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace) if(trace)
std::cout << m << std::endl; std::cout << m << std::endl;
// set mutex while llvm thread support is disabled.
static std::mutex g_mlirc_mutex; // NOLINT
const std::lock_guard<std::mutex> lock(g_mlirc_mutex);
mlir_program mp; mlir_program mp;
mp.find_target(); mp.find_target();
mp.parse(m); mp.parse(m);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace) if(trace)
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
auto co = mp.compile(); auto co = mp.compile(solution);
co.output = m.get_output_shapes().front(); co.expected_inputs = to_shapes(inputs);
co.output = m.get_output_shapes().front();
return co; return co;
} }
...@@ -772,6 +877,16 @@ instruction_ref insert_mlir(module& m, ...@@ -772,6 +877,16 @@ instruction_ref insert_mlir(module& m,
return m.insert_instruction(ins, co, refs); return m.insert_instruction(ins, co, refs);
} }
tuning_config get_tuning_config_mlir(module m, const std::vector<shape>& inputs)
{
adjust_param_shapes(m, inputs);
mlir_program mp;
mp.find_target();
mp.parse(m);
return mp.get_tuning_config();
}
#else #else
std::string dump_mlir(const module&) { return {}; } std::string dump_mlir(const module&) { return {}; }
...@@ -783,11 +898,11 @@ void use(T&) ...@@ -783,11 +898,11 @@ void use(T&)
// Disabling clang-tidy warning on non-real useage. // Disabling clang-tidy warning on non-real useage.
// NOLINTBEGIN(performance-unnecessary-value-param) // NOLINTBEGIN(performance-unnecessary-value-param)
code_object_op compile_mlir(const context&, module, const std::vector<instruction_ref>&) code_object_op
compile_mlir(const context&, module, const std::vector<instruction_ref>&, const value&)
{ {
return {}; return {};
} }
// NOLINTEND(performance-unnecessary-value-param)
instruction_ref instruction_ref
// cppcheck-suppress funcArgNamesDifferent // cppcheck-suppress funcArgNamesDifferent
...@@ -797,6 +912,9 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins ...@@ -797,6 +912,9 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
return m.end(); return m.end();
} }
tuning_config get_tuning_config_mlir(module, const std::vector<shape>&) { return {}; }
// NOLINTEND(performance-unnecessary-value-param)
#endif #endif
} // namespace gpu } // namespace gpu
......
...@@ -47,32 +47,24 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s) ...@@ -47,32 +47,24 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s)
return rb; return rb;
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
bool get_compute_fp32_flag() bool get_compute_fp32_flag()
{ {
bool compute_fp32 = false;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
const auto device_name = trim(split_string(get_device_name(), ':').front()); const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name)) return (starts_with(device_name, "gfx9") and device_name >= "gfx908");
compute_fp32 = true;
#endif
return compute_fp32;
} }
bool get_int8_x4_format(context& ctx) bool get_int8_x4_format(context& ctx)
{ {
bool int8_x4_format = true; #if ROCBLAS_VERSION_MAJOR >= 3
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 (void)(ctx);
return false;
#else
// int8x4 packed format is only available starting from rocblas-v2.38 and it is deprecated in
// v3.0 and will be removed in v4.0
rocblas_gemm_flags flag; rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag); rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4); return flag == rocblas_gemm_flags_pack_int8x4;
#endif #endif
return int8_x4_format;
} }
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -57,6 +57,7 @@ ...@@ -57,6 +57,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_mlir.hpp> #include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
...@@ -72,9 +73,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -72,9 +73,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
#endif
struct id_pass struct id_pass
{ {
std::string name() const { return "id"; } std::string name() const { return "id"; }
...@@ -98,16 +102,17 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -98,16 +102,17 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::bool_type); unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type); unsupported_types.erase(shape::type_t::int8_type);
unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type); unsupported_types.erase(shape::type_t::tuple_type);
// clang-format off // clang-format off
return return
{ {
enable_pass(options.split_single_dyn_dim, split_single_dyn_dim{}), split_single_dyn_dim{},
enable_pass(options.split_single_dyn_dim, dead_code_elimination{}), dead_code_elimination{},
normalize_ops{}, normalize_ops{},
dead_code_elimination{}, dead_code_elimination{},
simplify_qdq{}, simplify_qdq{},
rewrite_quantization{}, enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{}, dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type}, eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{}, simplify_reshapes{},
...@@ -121,7 +126,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -121,7 +126,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module{}, inline_module{},
rewrite_pooling{}, rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_gelu{}, enable_pass(options.fast_math, rewrite_gelu{}),
optimize_module{}, optimize_module{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}), enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
dead_code_elimination{}, dead_code_elimination{},
...@@ -129,11 +134,15 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -129,11 +134,15 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
optimize_module{}, optimize_module{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}), fuse_pointwise{},
dead_code_elimination{}, dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}), enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{}, dead_code_elimination{},
fuse_mlir{&ctx}, #ifndef _WIN32
enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}),
#endif
dead_code_elimination{},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
dead_code_elimination{}, dead_code_elimination{},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
...@@ -150,7 +159,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -150,7 +159,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}}, adjust_allocation{gpu_allocation_model{}},
dead_code_elimination{}, dead_code_elimination{},
compile_ops{&ctx}, compile_ops{&ctx, options.exhaustive_tune},
dead_code_elimination{}, dead_code_elimination{},
promote_literals{}, promote_literals{},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/driver/perf.hpp> #include <migraphx/gpu/time_op.hpp>
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/time.hpp> #include <migraphx/time.hpp>
...@@ -30,7 +30,6 @@ ...@@ -30,7 +30,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace driver {
std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsigned long seed = 0) std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsigned long seed = 0)
{ {
...@@ -69,7 +68,6 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n) ...@@ -69,7 +68,6 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
return std::make_pair(host_time / n, device_time / n); return std::make_pair(host_time / n, device_time / n);
} }
} // namespace driver
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -37,6 +37,8 @@ target_link_libraries(migraphx_ref PUBLIC migraphx) ...@@ -37,6 +37,8 @@ target_link_libraries(migraphx_ref PUBLIC migraphx)
target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE}) target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS) target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)
migraphx_generate_export_header(migraphx_ref)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_ref TARGETS migraphx_ref
INCLUDE INCLUDE
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP #define MIGRAPHX_GUARD_RTGLIB_CONTEXT_HPP
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/ref/export.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -24,14 +24,14 @@ ...@@ -24,14 +24,14 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#define MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP #define MIGRAPHX_GUARD_RTGLIB_CPU_LOWERING_HPP
#include <migraphx/ref/context.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace ref { namespace ref {
struct lowering struct MIGRAPHX_REF_EXPORT lowering
{ {
std::string name() const { return "ref::lowering"; } std::string name() const { return "ref::lowering"; }
void apply(module& m) const; void apply(module& m) const;
......
...@@ -35,7 +35,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -35,7 +35,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct pass; struct pass;
namespace ref { namespace ref {
struct target struct MIGRAPHX_REF_EXPORT target
{ {
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const; std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const;
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/convolution_backwards.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
......
...@@ -42,8 +42,9 @@ target_compile_options(tf-proto PRIVATE -w) ...@@ -42,8 +42,9 @@ target_compile_options(tf-proto PRIVATE -w)
target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
file(GLOB TF_SRCS ${CONFIGURE_DEPENDS} *.cpp) file(GLOB TF_SRCS CONFIGURE_DEPENDS *.cpp)
add_library(migraphx_tf ${TF_SRCS}) add_library(migraphx_tf ${TF_SRCS})
migraphx_generate_export_header(migraphx_tf)
target_include_directories(migraphx_tf PRIVATE include) target_include_directories(migraphx_tf PRIVATE include)
set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf) set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf)
rocm_set_soversion(migraphx_tf ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_tf ${MIGRAPHX_SO_VERSION})
......
...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers() ...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
op_parser_map().end(), op_parser_map().end(),
std::back_inserter(result), std::back_inserter(result),
[&](auto&& p) { return p.first; }); [&](auto&& p) { return p.first; });
std::sort(result.begin(), result.end());
return result; return result;
} }
......
...@@ -52,7 +52,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -52,7 +52,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto x_type = args[0]->get_shape().type(); auto x_type = args[0]->get_shape().type();
// unsqueeze tensors of shape (C) to broadcast correctly // unsqueeze tensors of shape (C) to broadcast correctly
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze = auto scale_unsqueeze =
...@@ -64,11 +63,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -64,11 +63,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto var_unsqueeze = auto var_unsqueeze =
info.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), args[4]); info.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps); auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto mul0 = info.add_broadcastable_binary_op("mul", scale_unsqueeze, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze); return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment