Commit 538dbd75 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Merge branch 'develop' into resize_op

parents c7161d99 e3e00547
...@@ -58,10 +58,10 @@ struct hiprtc_src_file ...@@ -58,10 +58,10 @@ struct hiprtc_src_file
MIGRAPHX_GPU_EXPORT bool hip_has_flags(const std::vector<std::string>& flags); MIGRAPHX_GPU_EXPORT bool hip_has_flags(const std::vector<std::string>& flags);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc( MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc(
std::vector<hiprtc_src_file> srcs, std::string params, const std::string& arch); std::vector<hiprtc_src_file> srcs, const std::string& params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src(
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch); const std::vector<src_file>& srcs, const std::string& params, const std::string& arch);
MIGRAPHX_GPU_EXPORT std::string enum_params(std::size_t count, std::string param); MIGRAPHX_GPU_EXPORT std::string enum_params(std::size_t count, std::string param);
......
...@@ -34,10 +34,11 @@ struct module_pass_manager; ...@@ -34,10 +34,11 @@ struct module_pass_manager;
namespace gpu { namespace gpu {
MIGRAPHX_GPU_EXPORT bool mlir_enabled(); MIGRAPHX_GPU_EXPORT bool mlir_enabled();
MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled();
struct MIGRAPHX_GPU_EXPORT fuse_mlir struct MIGRAPHX_GPU_EXPORT fuse_mlir
{ {
context* ctx = nullptr; context* ctx = nullptr;
bool enable_extra = false; bool enable_extra = false;
std::string name() const { return "gpu::fuse_mlir"; } std::string name() const { return "gpu::fuse_mlir"; }
void apply(module_pass_manager& mpm) const; void apply(module_pass_manager& mpm) const;
......
...@@ -66,6 +66,10 @@ struct gemm_softmax_gemm ...@@ -66,6 +66,10 @@ struct gemm_softmax_gemm
} }
static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); } static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
static bool is_mlir_supported_type(shape::type_t t)
{
return contains({shape::type_t::float_type, shape::half_type}, t);
}
}; };
} // namespace gpu } // namespace gpu
......
...@@ -211,6 +211,12 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op) ...@@ -211,6 +211,12 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
ss << op.mode; ss << op.mode;
MIGRAPHX_THROW(ss.str()); MIGRAPHX_THROW(ss.str());
} }
if(not std::all_of(
op.dilations.cbegin(), op.dilations.cend(), [](std::size_t d) { return d == 1; }))
{
MIGRAPHX_THROW("Unsupported dilations for pooling: [" + to_string_range(op.dilations) +
"]");
}
auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor); auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor);
int kdims = op.kdims(); int kdims = op.kdims();
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -43,7 +43,7 @@ __device__ void softmax(Input input1, Output output) ...@@ -43,7 +43,7 @@ __device__ void softmax(Input input1, Output output)
auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input); auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input);
auto batch_sum = auto batch_sum =
r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in); r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert<float>(x); })(exp_in);
r.inner([&](auto& y, auto x) { y = x / batch_sum; })(output, exp_in); r.inner([&](auto& y, auto x) { y = implicit_conversion(x / batch_sum); })(output, exp_in);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment