Commit 538dbd75 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Merge branch 'develop' into resize_op

parents c7161d99 e3e00547
......@@ -58,10 +58,10 @@ struct hiprtc_src_file
MIGRAPHX_GPU_EXPORT bool hip_has_flags(const std::vector<std::string>& flags);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc(
std::vector<hiprtc_src_file> srcs, std::string params, const std::string& arch);
std::vector<hiprtc_src_file> srcs, const std::string& params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src(
const std::vector<src_file>& srcs, const std::string& params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::string enum_params(std::size_t count, std::string param);
......
......@@ -34,6 +34,7 @@ struct module_pass_manager;
namespace gpu {
MIGRAPHX_GPU_EXPORT bool mlir_enabled();
MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled();
struct MIGRAPHX_GPU_EXPORT fuse_mlir
{
......
......@@ -66,6 +66,10 @@ struct gemm_softmax_gemm
}
static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
static bool is_mlir_supported_type(shape::type_t t)
{
return contains({shape::type_t::float_type, shape::half_type}, t);
}
};
} // namespace gpu
......
......@@ -211,6 +211,12 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
ss << op.mode;
MIGRAPHX_THROW(ss.str());
}
if(not std::all_of(
op.dilations.cbegin(), op.dilations.cend(), [](std::size_t d) { return d == 1; }))
{
MIGRAPHX_THROW("Unsupported dilations for pooling: [" + to_string_range(op.dilations) +
"]");
}
auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor);
int kdims = op.kdims();
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx {
......@@ -53,9 +54,9 @@ __device__ void pad(const index& idx,
if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) {
return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j];
}))
output[multi] = pad_val;
output[multi] = implicit_conversion(pad_val);
else
output[multi] = input[input_idx];
output[multi] = implicit_conversion(input[input_idx]);
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment