Commit 72011beb authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into jit-concat-pointwise

parents d48d9bf7 d37a4df9
...@@ -43,8 +43,8 @@ struct run_op : action<run_op> ...@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto op = make_op(name); auto op = make_op(name);
if(v.contains("fields")) if(v.contains("fields"))
op.from_value(v.at("fields")); op.from_value(v.at("fields"));
double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl; std::cout << op << ": " << host_time << "ms" << std::endl;
} }
}; };
......
...@@ -48,8 +48,10 @@ ...@@ -48,8 +48,10 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp>
#include <cmath> #include <cmath>
#include <set> #include <set>
...@@ -279,6 +281,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm) ...@@ -279,6 +281,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm> struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm>
{ {
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return inputs[0];
}
// Empty finalize to skip dimension reduction // Empty finalize to skip dimension reduction
void finalize(context&, const shape&, const std::vector<shape>&) {} void finalize(context&, const shape&, const std::vector<shape>&) {}
}; };
...@@ -943,28 +950,70 @@ struct find_gemm_add ...@@ -943,28 +950,70 @@ struct find_gemm_add
} }
}; };
auto pointwise_name(const std::string& s)
{
return precompile_name("pointwise")(match::make_basic_pred_matcher([=](auto ins) {
module_ref pm = ins->module_inputs().front();
auto n = std::count_if(pm->begin(), pm->end(), [&](auto& i) { return i.name() == s; });
if(n != 1)
return false;
return std::all_of(pm->begin(), pm->end(), [&](auto& i) {
return starts_with(i.name(), "@") or i.name() == s;
});
}));
}
struct find_gemm_pointwise struct find_gemm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return pointwise_name("add")( return precompile_name("pointwise")(
match::nargs(3), match::nargs(3),
match::all_of[match::inputs()](match::standard_shape()), match::either_arg(0, 1)(
match::either_arg(0, 1)(match::used_once().bind("c"), match::any_of(match::standard_shape(), match::is_constant()).bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm"))); match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm")));
}
// TODO: Move to matcher.hpp
static auto match_param(const std::string& name)
{
return match::make_basic_pred_matcher([=](auto ins) {
if(ins->name() != "@param")
return false;
auto p = any_cast<builtin::param>(ins->get_operator());
return p.parameter == name;
});
}
template <class M>
static auto match_mul_const(M m, const std::string& var)
{
return match::name("mul")(match::either_arg(0, 1)(match::name("@literal").bind(var), m))
.bind(var + "_mul");
}
static auto match_add(const std::string& input, const std::string& output)
{
auto param = match::name("@param");
auto add = match::name("add")(match::args(param, param));
auto inner_mul = match::any_of(match_mul_const(match_param(input), "alpha"),
match_mul_const(match_param(output), "beta"));
auto mul_add = match::name("add")(match::either_arg(0, 1)(inner_mul, param));
auto add_mul = match_mul_const(add, "gamma");
return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul)));
}
static float get_float(instruction_ref ins) { return ins->get_literal().at<float>(); }
template <class Gemm>
static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input)
{
auto names = pm->get_parameter_names();
if(names.size() != 2)
return false;
std::sort(names.begin(), names.end());
unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction(
*pm, std::prev(pm->end()), match_add(names[input], names[output]));
if(mr.result == pm->end())
return false;
if(contains(mr.instructions, "alpha_mul"))
gemm.alpha *= get_float(mr.instructions["alpha"]);
else if(contains(mr.instructions, "beta_mul"))
gemm.beta *= get_float(mr.instructions["beta"]);
else if(contains(mr.instructions, "gamma_mul"))
{
gemm.alpha *= get_float(mr.instructions["gamma"]);
gemm.beta *= get_float(mr.instructions["gamma"]);
}
return true;
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -978,6 +1027,19 @@ struct find_gemm_pointwise ...@@ -978,6 +1027,19 @@ struct find_gemm_pointwise
// Already fused gemm // Already fused gemm
if(not float_equal(gemm.beta, 0)) if(not float_equal(gemm.beta, 0))
return; return;
gemm.beta = 1;
if(not update_gemm(
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return;
// const-fold input if not standard shape since rocblas can't handle it
if(not c_ins->get_shape().standard())
{
auto c = op::contiguous{};
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
c_ins = m.add_literal(l.get_shape(), l.data());
}
auto inputs = gemm_ins->inputs(); auto inputs = gemm_ins->inputs();
inputs.pop_back(); inputs.pop_back();
...@@ -985,11 +1047,68 @@ struct find_gemm_pointwise ...@@ -985,11 +1047,68 @@ struct find_gemm_pointwise
inputs.push_back(c_ins); inputs.push_back(c_ins);
inputs.push_back(ins->inputs().back()); inputs.push_back(ins->inputs().back());
gemm.beta = 1;
m.replace_instruction(ins, gemm, inputs); m.replace_instruction(ins, gemm, inputs);
} }
}; };
struct find_contiguous_tranpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm")))
.bind("transpose")));
}
template <class Vector>
static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j)
{
if(i >= perm.size() or j >= perm.size())
return false;
auto perm2 = perm;
std::iota(perm2.begin(), perm2.end(), 0);
std::swap(perm2[i], perm2[j]);
return perm2 == perm;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm = r.instructions["gemm"];
auto alloc = gemm->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);
if(perm.size() < 3)
return;
if(not is_swapped(perm, perm.size() - 3, perm.size() - 2))
return;
auto lens = gemm->get_shape().lens();
if(lens.size() > 3 and
not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; }))
return;
auto gemmv = gemm->get_operator().to_value();
gemmv["trans_batch"] = 1;
auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)};
auto new_alloc = m.insert_instruction(gemm, make_op("allocate", {{"shape", to_value(s)}}));
auto alloc_transpose =
m.insert_instruction(gemm, make_op("transpose", {{"permutation", perm}}), new_alloc);
auto inputs = gemm->inputs();
inputs.back() = alloc_transpose;
auto new_gemm = m.insert_instruction(gemm, make_op("gpu::gemm", gemmv), inputs);
auto gemm_transpoe = m.insert_instruction(gemm, transpose->get_operator(), new_gemm);
m.replace_instruction(ins, gemm_transpoe);
}
};
struct find_commutative_broadcast struct find_commutative_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -1091,6 +1210,7 @@ void fuse_ops::apply(module& m) const ...@@ -1091,6 +1210,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add{}, find_gemm_add{},
find_layernorm_pointwise{}, find_layernorm_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <rocblas.h> #include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -67,6 +68,19 @@ void blas_shape(const shape& s) ...@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
} }
shape transpose_batch(const shape& s, unsigned trans_batch)
{
if(trans_batch == 0)
return s;
if(s.lens().size() < 3)
return s;
auto batch = s.lens().size() - 3;
std::vector<int64_t> perm(s.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[batch], perm[batch + trans_batch]);
return shape::from_permutation(s.type(), s.lens(), perm);
}
template <class R, class... Ts, class... Us> template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs) R rocblas_invoke(R (*f)(Ts...), Us... xs)
{ {
...@@ -97,6 +111,12 @@ void gemm_impl(context& ctx, ...@@ -97,6 +111,12 @@ void gemm_impl(context& ctx,
bool int8_x4_format, bool int8_x4_format,
bool compute_fp32) bool compute_fp32)
{ {
const bool is_3inputs = (args.size() == 4);
if(not is_3inputs)
{
beta = 0;
}
bool transa = is_transposed(args[0].get_shape()); bool transa = is_transposed(args[0].get_shape());
bool transb = is_transposed(args[1].get_shape()); bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
...@@ -105,12 +125,8 @@ void gemm_impl(context& ctx, ...@@ -105,12 +125,8 @@ void gemm_impl(context& ctx,
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldd = is_3inputs ? args[3].get_shape().strides()[dim_0] : ldc;
bool is_3inputs = (args.size() == 4);
if(!is_3inputs)
{
beta = 0;
}
rocblas_datatype arg_type = get_type(args[0].get_shape().type()); rocblas_datatype arg_type = get_type(args[0].get_shape().type());
auto output_type = arg_type; auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r) if(output_type == rocblas_datatype_i8_r)
...@@ -186,7 +202,7 @@ void gemm_impl(context& ctx, ...@@ -186,7 +202,7 @@ void gemm_impl(context& ctx,
ldc, ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldd,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
...@@ -197,6 +213,7 @@ void gemm_impl(context& ctx, ...@@ -197,6 +213,7 @@ void gemm_impl(context& ctx,
auto a_stride = get_batch_stride(args[0]); auto a_stride = get_batch_stride(args[0]);
auto b_stride = get_batch_stride(args[1]); auto b_stride = get_batch_stride(args[1]);
auto c_stride = get_batch_stride(args[2]); auto c_stride = get_batch_stride(args[2]);
auto d_stride = is_3inputs ? get_batch_stride(args[3]) : c_stride;
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -220,8 +237,8 @@ void gemm_impl(context& ctx, ...@@ -220,8 +237,8 @@ void gemm_impl(context& ctx,
c_stride, c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldd,
c_stride, d_stride,
num_matrices, num_matrices,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
......
...@@ -244,6 +244,15 @@ struct context ...@@ -244,6 +244,15 @@ struct context
return hip_event_ptr{event}; return hip_event_ptr{event};
} }
static hip_event_ptr create_event_for_timing()
{
hipEvent_t event;
auto status = hipEventCreate(&event);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to create event");
return hip_event_ptr{event};
}
value to_value() const value to_value() const
{ {
value result; value result;
...@@ -267,10 +276,49 @@ struct context ...@@ -267,10 +276,49 @@ struct context
any_ptr get_queue() { return get_stream().get(); } any_ptr get_queue() { return get_stream().get(); }
void enable_perf_measurement(bool b = true)
{
if(b)
{
start_event = create_event_for_timing();
stop_event = create_event_for_timing();
get_stream().record(start_event.get());
get_stream().record(stop_event.get());
}
else
{
start_event = nullptr;
stop_event = nullptr;
}
measure_perf = b;
}
std::pair<hipEvent_t, hipEvent_t> get_perf_events() const
{
if(measure_perf)
return std::make_pair(start_event.get(), stop_event.get());
return std::make_pair(nullptr, nullptr);
}
float get_elapsed_ms() const
{
float result = 0;
if(start_event != nullptr and stop_event != nullptr)
{
auto status = hipEventElapsedTime(&result, start_event.get(), stop_event.get());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed hipEventElapsedTime: " + hip_error(status));
}
return result;
}
private: private:
// TODO: Make this a vector to support multiple devices // TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device; std::shared_ptr<hip_device> current_device;
std::vector<shared<hip_event_ptr>> events; std::vector<shared<hip_event_ptr>> events;
bool measure_perf = false;
shared<hip_event_ptr> start_event = nullptr;
shared<hip_event_ptr> stop_event = nullptr;
}; };
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); } inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
......
...@@ -42,15 +42,17 @@ namespace gpu { ...@@ -42,15 +42,17 @@ namespace gpu {
struct context; struct context;
void blas_shape(const shape& s); void blas_shape(const shape& s);
shape transpose_batch(const shape& s, unsigned trans_batch);
template <class Op> template <class Op>
struct rocblas_gemm struct rocblas_gemm
{ {
Op op; Op op;
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false; bool compute_fp32 = false;
unsigned trans_batch = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -58,7 +60,9 @@ struct rocblas_gemm ...@@ -58,7 +60,9 @@ struct rocblas_gemm
return pack_join(migraphx::reflect(self.op, f), return pack_join(migraphx::reflect(self.op, f),
pack(f(self.alpha, "alpha"), pack(f(self.alpha, "alpha"),
f(self.beta, "beta"), f(self.beta, "beta"),
f(self.int8_x4_format, "int8_x4_format"))); f(self.int8_x4_format, "int8_x4_format"),
f(self.compute_fp32, "compute_fp32"),
f(self.trans_batch, "trans_batch")));
} }
std::string name() const std::string name() const
...@@ -74,13 +78,14 @@ struct rocblas_gemm ...@@ -74,13 +78,14 @@ struct rocblas_gemm
{ {
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
in_shapes.pop_back(); in_shapes.pop_back();
check_shapes{in_shapes, *this}.not_broadcasted(); check_shapes{in_shapes, *this}.has(2, 3);
blas_shape(inputs[0]); blas_shape(inputs[0]);
blas_shape(inputs[1]); blas_shape(inputs[1]);
// if gemm and add are fused // if gemm and add are fused
if(in_shapes.size() > 2) if(in_shapes.size() > 2)
{ {
auto cmat_shape = in_shapes.back(); auto cmat_shape = in_shapes.back();
check_shapes{{cmat_shape}, *this}.not_transposed().not_broadcasted();
in_shapes.pop_back(); in_shapes.pop_back();
blas_shape(cmat_shape); blas_shape(cmat_shape);
auto op_out_shape = op.compute_shape(in_shapes); auto op_out_shape = op.compute_shape(in_shapes);
...@@ -97,10 +102,10 @@ struct rocblas_gemm ...@@ -97,10 +102,10 @@ struct rocblas_gemm
to_string(cmat_shape.type()) + to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type())); ", it must be: " + to_string(op_out_shape.type()));
} }
return op_out_shape; return transpose_batch(op_out_shape, trans_batch);
} }
return op.compute_shape(in_shapes); return transpose_batch(op.compute_shape(in_shapes), trans_batch);
} }
argument argument
......
...@@ -37,6 +37,8 @@ namespace gpu { ...@@ -37,6 +37,8 @@ namespace gpu {
struct context; struct context;
std::string hip_error(int error);
argument allocate_gpu(const shape& s, bool host = false); argument allocate_gpu(const shape& s, bool host = false);
argument register_on_gpu(const argument& arg); argument register_on_gpu(const argument& arg);
......
...@@ -50,17 +50,22 @@ struct kernel ...@@ -50,17 +50,22 @@ struct kernel
void launch(hipStream_t stream, void launch(hipStream_t stream,
std::size_t global, std::size_t global,
std::size_t local, std::size_t local,
const std::vector<kernel_argument>& args) const; const std::vector<kernel_argument>& args,
hipEvent_t start = nullptr,
hipEvent_t stop = nullptr) const;
void launch(hipStream_t stream, void launch(hipStream_t stream,
std::size_t global, std::size_t global,
std::size_t local, std::size_t local,
std::vector<void*> args) const; std::vector<void*> args,
hipEvent_t start = nullptr,
hipEvent_t stop = nullptr) const;
auto launch(hipStream_t stream, std::size_t global, std::size_t local) const template <class... Ts>
auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const
{ {
return [=](auto&&... xs) { return [=](auto&&... xs) {
launch(stream, global, local, std::vector<kernel_argument>{xs...}); launch(stream, global, local, std::vector<kernel_argument>{xs...}, zs...);
}; };
} }
......
...@@ -80,7 +80,9 @@ void launch_kernel(hipFunction_t fun, ...@@ -80,7 +80,9 @@ void launch_kernel(hipFunction_t fun,
std::size_t global, std::size_t global,
std::size_t local, std::size_t local,
void* kernargs, void* kernargs,
std::size_t size) std::size_t size,
hipEvent_t start,
hipEvent_t stop)
{ {
assert(global > 0); assert(global > 0);
assert(local > 0); assert(local > 0);
...@@ -97,34 +99,55 @@ void launch_kernel(hipFunction_t fun, ...@@ -97,34 +99,55 @@ void launch_kernel(hipFunction_t fun,
#endif #endif
}; };
auto status = hipExtModuleLaunchKernel( auto status = hipExtModuleLaunchKernel(fun,
fun, global, 1, 1, local, 1, 1, 0, stream, nullptr, reinterpret_cast<void**>(&config)); global,
1,
1,
local,
1,
1,
0,
stream,
nullptr,
reinterpret_cast<void**>(&config),
start,
stop);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to launch kernel: " + hip_error(status)); MIGRAPHX_THROW("Failed to launch kernel: " + hip_error(status));
if(stop != nullptr)
{
status = hipEventSynchronize(stop);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to sync event: " + hip_error(status));
}
} }
void kernel::launch(hipStream_t stream, void kernel::launch(hipStream_t stream,
std::size_t global, std::size_t global,
std::size_t local, std::size_t local,
std::vector<void*> args) const std::vector<void*> args,
hipEvent_t start,
hipEvent_t stop) const
{ {
assert(impl != nullptr); assert(impl != nullptr);
void* kernargs = args.data(); void* kernargs = args.data();
std::size_t size = args.size() * sizeof(void*); std::size_t size = args.size() * sizeof(void*);
launch_kernel(impl->fun, stream, global, local, kernargs, size); launch_kernel(impl->fun, stream, global, local, kernargs, size, start, stop);
} }
void kernel::launch(hipStream_t stream, void kernel::launch(hipStream_t stream,
std::size_t global, std::size_t global,
std::size_t local, std::size_t local,
const std::vector<kernel_argument>& args) const const std::vector<kernel_argument>& args,
hipEvent_t start,
hipEvent_t stop) const
{ {
assert(impl != nullptr); assert(impl != nullptr);
std::vector<char> kernargs = pack_args(args); std::vector<char> kernargs = pack_args(args);
std::size_t size = kernargs.size(); std::size_t size = kernargs.size();
launch_kernel(impl->fun, stream, global, local, kernargs.data(), size); launch_kernel(impl->fun, stream, global, local, kernargs.data(), size, start, stop);
} }
} // namespace gpu } // namespace gpu
......
...@@ -163,7 +163,7 @@ constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, I ...@@ -163,7 +163,7 @@ constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, I
{ {
return last; return last;
} }
if(!(*it == *s_it)) if(not(*it == *s_it))
{ {
break; break;
} }
......
...@@ -153,7 +153,7 @@ struct array ...@@ -153,7 +153,7 @@ struct array
return true; return true;
} }
friend constexpr bool operator!=(const array& x, const array& y) { return !(x == y); } friend constexpr bool operator!=(const array& x, const array& y) { return not(x == y); }
// This uses the product order rather than lexical order // This uses the product order rather than lexical order
friend constexpr bool operator<(const array& x, const array& y) friend constexpr bool operator<(const array& x, const array& y)
{ {
......
...@@ -73,10 +73,10 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=) ...@@ -73,10 +73,10 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(^) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(^)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(|) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(|)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&&) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(and)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(||) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(or)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(!) MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(not )
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(~) MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(~)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(+) MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(+)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(-) MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(-)
......
...@@ -340,7 +340,7 @@ struct miopen_apply ...@@ -340,7 +340,7 @@ struct miopen_apply
catch(migraphx::exception&) catch(migraphx::exception&)
{ {
// In case no solver supports the default format, retry using the other format. // In case no solver supports the default format, retry using the other format.
compile_quant_conv_with_format(!int8_x4_format); compile_quant_conv_with_format(not int8_x4_format);
} }
auto args = ins->inputs(); auto args = ins->inputs();
......
...@@ -78,7 +78,7 @@ struct mlir_handle ...@@ -78,7 +78,7 @@ struct mlir_handle
friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); } friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); }
friend bool operator!=(ptr x, ptr y) { return !(x == y); } friend bool operator!=(ptr x, ptr y) { return not(x == y); }
T obj{}; T obj{};
}; };
...@@ -503,7 +503,7 @@ struct mlir_program ...@@ -503,7 +503,7 @@ struct mlir_program
pp = pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()}; problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
std::string tuned = get_tune_params(); std::string tuned = get_tune_params();
if(!tuned.empty()) if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}}); ops.add_attributes({{"perf_config", tuned}});
// check if HW supports xdlops // check if HW supports xdlops
if(contains(get_xdlops_archs(), target_name)) if(contains(get_xdlops_archs(), target_name))
......
...@@ -154,7 +154,7 @@ void pack_int8_args::apply(module& m) const ...@@ -154,7 +154,7 @@ void pack_int8_args::apply(module& m) const
bool transa = inputs[0]->get_shape().transposed(); bool transa = inputs[0]->get_shape().transposed();
bool transb = inputs[1]->get_shape().transposed(); bool transb = inputs[1]->get_shape().transposed();
if(!transb) if(not transb)
{ {
auto packed_b = m.insert_instruction( auto packed_b = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(inputs[1]->get_shape())}})); ins, make_op("hip::allocate", {{"shape", to_value(inputs[1]->get_shape())}}));
......
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp> #include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp> #include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
...@@ -116,6 +117,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -116,6 +117,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module{}, inline_module{},
rewrite_pooling{}, rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_gelu{},
dead_code_elimination{},
eliminate_common_subexpression{}, eliminate_common_subexpression{},
dead_code_elimination{}, dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
...@@ -134,8 +137,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -134,8 +137,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{}, pack_int8_args{},
...@@ -144,6 +145,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -144,6 +145,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
fuse_ops{&ctx, options.fast_math}, fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
dead_code_elimination{},
compile_ops{&ctx}, compile_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
......
...@@ -100,7 +100,7 @@ struct parse_conv : op_parser<parse_conv> ...@@ -100,7 +100,7 @@ struct parse_conv : op_parser<parse_conv>
{ {
MIGRAPHX_THROW("padding should have 4 values"); MIGRAPHX_THROW("padding should have 4 values");
} }
if(padding[0] != padding[2] || padding[1] != padding[3]) if(padding[0] != padding[2] or padding[1] != padding[3])
{ {
MIGRAPHX_THROW("migraphx does not support asymetric padding"); MIGRAPHX_THROW("migraphx does not support asymetric padding");
} }
......
...@@ -90,7 +90,7 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv> ...@@ -90,7 +90,7 @@ struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h); calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w); calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
if(pads[0] != pads[2] || pads[1] != pads[3]) if(pads[0] != pads[2] or pads[1] != pads[3])
{ {
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]}; std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = info.add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0); l0 = info.add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
......
...@@ -42,7 +42,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -42,7 +42,7 @@ struct parse_pooling : op_parser<parse_pooling>
tf_parser::node_info info, tf_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
if(!starts_with(opd.tf_name, "Max") && !starts_with(opd.tf_name, "Av")) if(not starts_with(opd.tf_name, "Max") and not starts_with(opd.tf_name, "Av"))
{ {
MIGRAPHX_THROW("tf pooling mode must be Max or Average"); MIGRAPHX_THROW("tf pooling mode must be Max or Average");
} }
......
...@@ -371,7 +371,7 @@ void tf_parser::parse_node(const std::string& name) ...@@ -371,7 +371,7 @@ void tf_parser::parse_node(const std::string& name)
{ {
result = ops[node.op()](*this, {get_attributes(node), node.op(), mm}, args); result = ops[node.op()](*this, {get_attributes(node), node.op(), mm}, args);
} }
assert(!result.empty()); assert(not result.empty());
// First output has no ":" delimiter // First output has no ":" delimiter
instructions[name] = result.front(); instructions[name] = result.front();
for(size_t i = 1; i < result.size(); i++) for(size_t i = 1; i < result.size(); i++)
...@@ -458,7 +458,7 @@ literal tf_parser::parse_tensor(const tensorflow::TensorProto& t) const ...@@ -458,7 +458,7 @@ literal tf_parser::parse_tensor(const tensorflow::TensorProto& t) const
{ {
std::vector<size_t> dims = parse_dims(t.tensor_shape()); std::vector<size_t> dims = parse_dims(t.tensor_shape());
size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()); size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
if(!t.tensor_content().empty()) // has raw data if(not t.tensor_content().empty()) // has raw data
{ {
const std::string& s = t.tensor_content(); const std::string& s = t.tensor_content();
switch(t.dtype()) switch(t.dtype())
......
...@@ -78,7 +78,7 @@ void tmp_dir::execute(const std::string& exe, const std::string& args) const ...@@ -78,7 +78,7 @@ void tmp_dir::execute(const std::string& exe, const std::string& args) const
tmp_dir::~tmp_dir() tmp_dir::~tmp_dir()
{ {
if(!enabled(MIGRAPHX_DEBUG_SAVE_TEMP_DIR{})) if(not enabled(MIGRAPHX_DEBUG_SAVE_TEMP_DIR{}))
{ {
fs::remove_all(this->path); fs::remove_all(this->path);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment