Commit f7838bc8 authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into ck-elementwise

parents fea58a7b d78bcdfb
...@@ -47,7 +47,7 @@ constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value) ...@@ -47,7 +47,7 @@ constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value)
it = first; it = first;
step = count / 2; step = count / 2;
std::advance(it, step); std::advance(it, step);
if(!(value < *it)) if(not(value < *it))
{ {
first = ++it; first = ++it;
count -= step + 1; count -= step + 1;
......
...@@ -38,8 +38,11 @@ struct compile_op : action<compile_op> ...@@ -38,8 +38,11 @@ struct compile_op : action<compile_op>
context ctx; context ctx;
auto inputs = p.parse_shapes(v.at("inputs")); auto inputs = p.parse_shapes(v.at("inputs"));
auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v); auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v);
double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl; std::cout << op << ": " << host_time << "ms";
if(device_time > 0)
std::cout << ", " << device_time << "ms";
std::cout << std::endl;
} }
}; };
......
...@@ -33,7 +33,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -33,7 +33,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace driver { namespace driver {
double time_op(context& ctx, operation op, const std::vector<shape>& inputs, int n = 100); std::pair<double, double>
time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n = 100);
} // namespace driver } // namespace driver
} // namespace gpu } // namespace gpu
......
...@@ -42,22 +42,31 @@ std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsig ...@@ -42,22 +42,31 @@ std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsig
} }
using milliseconds = std::chrono::duration<double, std::milli>; using milliseconds = std::chrono::duration<double, std::milli>;
double time_op(context& ctx, operation op, const std::vector<shape>& inputs, int n) std::pair<double, double>
time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
{ {
// TODO: Use std::ref // TODO: Use std::ref
migraphx::context gctx = ctx; migraphx::context ctx = ictx;
auto output = op.compute_shape(inputs); auto& gctx = any_cast<migraphx::gpu::context>(ctx);
op.finalize(gctx, output, inputs); auto output = op.compute_shape(inputs);
op.finalize(ctx, output, inputs);
auto args = generate_arguments(inputs); auto args = generate_arguments(inputs);
auto run = [&] { auto run = [&] {
op.compute(gctx, output, args); op.compute(ctx, output, args);
gctx.finish(); ctx.finish();
}; };
gctx.enable_perf_measurement();
run(); run();
auto r = range(n); double host_time = 0.0;
double t = std::accumulate( double device_time = 0.0;
r.begin(), r.end(), double{0.0}, [&](auto x, auto) { return x + time<milliseconds>(run); }); for(auto i : range(n))
return t / n; {
(void)i;
host_time += time<milliseconds>(run);
device_time += gctx.get_elapsed_ms();
}
return std::make_pair(host_time / n, device_time / n);
} }
} // namespace driver } // namespace driver
......
...@@ -43,8 +43,8 @@ struct run_op : action<run_op> ...@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto op = make_op(name); auto op = make_op(name);
if(v.contains("fields")) if(v.contains("fields"))
op.from_value(v.at("fields")); op.from_value(v.at("fields"));
double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl; std::cout << op << ": " << host_time << "ms" << std::endl;
} }
}; };
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/oper.hpp> #include <migraphx/gpu/oper.hpp>
...@@ -48,8 +47,8 @@ ...@@ -48,8 +47,8 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <cmath> #include <cmath>
#include <set> #include <set>
...@@ -279,6 +278,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm) ...@@ -279,6 +278,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm> struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm>
{ {
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return inputs[0];
}
// Empty finalize to skip dimension reduction // Empty finalize to skip dimension reduction
void finalize(context&, const shape&, const std::vector<shape>&) {} void finalize(context&, const shape&, const std::vector<shape>&) {}
}; };
...@@ -943,28 +947,70 @@ struct find_gemm_add ...@@ -943,28 +947,70 @@ struct find_gemm_add
} }
}; };
auto pointwise_name(const std::string& s)
{
return precompile_name("pointwise")(match::make_basic_pred_matcher([=](auto ins) {
module_ref pm = ins->module_inputs().front();
auto n = std::count_if(pm->begin(), pm->end(), [&](auto& i) { return i.name() == s; });
if(n != 1)
return false;
return std::all_of(pm->begin(), pm->end(), [&](auto& i) {
return starts_with(i.name(), "@") or i.name() == s;
});
}));
}
struct find_gemm_pointwise struct find_gemm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return pointwise_name("add")( return precompile_name("pointwise")(
match::nargs(3), match::nargs(3),
match::all_of[match::inputs()](match::standard_shape()), match::either_arg(0, 1)(
match::either_arg(0, 1)(match::used_once().bind("c"), match::any_of(match::standard_shape(), match::is_constant()).bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm"))); match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm")));
}
// TODO: Move to matcher.hpp
static auto match_param(const std::string& name)
{
return match::make_basic_pred_matcher([=](auto ins) {
if(ins->name() != "@param")
return false;
auto p = any_cast<builtin::param>(ins->get_operator());
return p.parameter == name;
});
}
template <class M>
static auto match_mul_const(M m, const std::string& var)
{
return match::name("mul")(match::either_arg(0, 1)(match::name("@literal").bind(var), m))
.bind(var + "_mul");
}
static auto match_add(const std::string& input, const std::string& output)
{
auto param = match::name("@param");
auto add = match::name("add")(match::args(param, param));
auto inner_mul = match::any_of(match_mul_const(match_param(input), "alpha"),
match_mul_const(match_param(output), "beta"));
auto mul_add = match::name("add")(match::either_arg(0, 1)(inner_mul, param));
auto add_mul = match_mul_const(add, "gamma");
return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul)));
}
static float get_float(instruction_ref ins) { return ins->get_literal().at<float>(); }
template <class Gemm>
static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input)
{
auto names = pm->get_parameter_names();
if(names.size() != 2)
return false;
std::sort(names.begin(), names.end());
unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction(
*pm, std::prev(pm->end()), match_add(names[input], names[output]));
if(mr.result == pm->end())
return false;
if(contains(mr.instructions, "alpha_mul"))
gemm.alpha *= get_float(mr.instructions["alpha"]);
else if(contains(mr.instructions, "beta_mul"))
gemm.beta *= get_float(mr.instructions["beta"]);
else if(contains(mr.instructions, "gamma_mul"))
{
gemm.alpha *= get_float(mr.instructions["gamma"]);
gemm.beta *= get_float(mr.instructions["gamma"]);
}
return true;
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -978,6 +1024,19 @@ struct find_gemm_pointwise ...@@ -978,6 +1024,19 @@ struct find_gemm_pointwise
// Already fused gemm // Already fused gemm
if(not float_equal(gemm.beta, 0)) if(not float_equal(gemm.beta, 0))
return; return;
gemm.beta = 1;
if(not update_gemm(
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return;
// const-fold input if not standard shape since rocblas can't handle it
if(not c_ins->get_shape().standard())
{
auto c = make_op("contiguous");
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
c_ins = m.add_literal(l.get_shape(), l.data());
}
auto inputs = gemm_ins->inputs(); auto inputs = gemm_ins->inputs();
inputs.pop_back(); inputs.pop_back();
...@@ -985,11 +1044,68 @@ struct find_gemm_pointwise ...@@ -985,11 +1044,68 @@ struct find_gemm_pointwise
inputs.push_back(c_ins); inputs.push_back(c_ins);
inputs.push_back(ins->inputs().back()); inputs.push_back(ins->inputs().back());
gemm.beta = 1;
m.replace_instruction(ins, gemm, inputs); m.replace_instruction(ins, gemm, inputs);
} }
}; };
struct find_contiguous_tranpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm")))
.bind("transpose")));
}
template <class Vector>
static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j)
{
if(i >= perm.size() or j >= perm.size())
return false;
auto perm2 = perm;
std::iota(perm2.begin(), perm2.end(), 0);
std::swap(perm2[i], perm2[j]);
return perm2 == perm;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm = r.instructions["gemm"];
auto alloc = gemm->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);
if(perm.size() < 3)
return;
if(not is_swapped(perm, perm.size() - 3, perm.size() - 2))
return;
auto lens = gemm->get_shape().lens();
if(lens.size() > 3 and
not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; }))
return;
auto gemmv = gemm->get_operator().to_value();
gemmv["trans_batch"] = 1;
auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)};
auto new_alloc = m.insert_instruction(gemm, make_op("allocate", {{"shape", to_value(s)}}));
auto alloc_transpose =
m.insert_instruction(gemm, make_op("transpose", {{"permutation", perm}}), new_alloc);
auto inputs = gemm->inputs();
inputs.back() = alloc_transpose;
auto new_gemm = m.insert_instruction(gemm, make_op("gpu::gemm", gemmv), inputs);
auto gemm_transpoe = m.insert_instruction(gemm, transpose->get_operator(), new_gemm);
m.replace_instruction(ins, gemm_transpoe);
}
};
struct find_commutative_broadcast struct find_commutative_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -1091,6 +1207,7 @@ void fuse_ops::apply(module& m) const ...@@ -1091,6 +1207,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add{}, find_gemm_add{},
find_layernorm_pointwise{}, find_layernorm_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <rocblas.h> #include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -67,6 +68,19 @@ void blas_shape(const shape& s) ...@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
} }
shape transpose_batch(const shape& s, unsigned trans_batch)
{
if(trans_batch == 0)
return s;
if(s.lens().size() < 3)
return s;
auto batch = s.lens().size() - 3;
std::vector<int64_t> perm(s.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[batch], perm[batch + trans_batch]);
return shape::from_permutation(s.type(), s.lens(), perm);
}
template <class R, class... Ts, class... Us> template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs) R rocblas_invoke(R (*f)(Ts...), Us... xs)
{ {
...@@ -97,6 +111,12 @@ void gemm_impl(context& ctx, ...@@ -97,6 +111,12 @@ void gemm_impl(context& ctx,
bool int8_x4_format, bool int8_x4_format,
bool compute_fp32) bool compute_fp32)
{ {
const bool is_3inputs = (args.size() == 4);
if(not is_3inputs)
{
beta = 0;
}
bool transa = is_transposed(args[0].get_shape()); bool transa = is_transposed(args[0].get_shape());
bool transb = is_transposed(args[1].get_shape()); bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
...@@ -105,12 +125,8 @@ void gemm_impl(context& ctx, ...@@ -105,12 +125,8 @@ void gemm_impl(context& ctx,
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldd = is_3inputs ? args[3].get_shape().strides()[dim_0] : ldc;
bool is_3inputs = (args.size() == 4);
if(!is_3inputs)
{
beta = 0;
}
rocblas_datatype arg_type = get_type(args[0].get_shape().type()); rocblas_datatype arg_type = get_type(args[0].get_shape().type());
auto output_type = arg_type; auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r) if(output_type == rocblas_datatype_i8_r)
...@@ -186,7 +202,7 @@ void gemm_impl(context& ctx, ...@@ -186,7 +202,7 @@ void gemm_impl(context& ctx,
ldc, ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldd,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
...@@ -197,6 +213,7 @@ void gemm_impl(context& ctx, ...@@ -197,6 +213,7 @@ void gemm_impl(context& ctx,
auto a_stride = get_batch_stride(args[0]); auto a_stride = get_batch_stride(args[0]);
auto b_stride = get_batch_stride(args[1]); auto b_stride = get_batch_stride(args[1]);
auto c_stride = get_batch_stride(args[2]); auto c_stride = get_batch_stride(args[2]);
auto d_stride = is_3inputs ? get_batch_stride(args[3]) : c_stride;
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -220,8 +237,8 @@ void gemm_impl(context& ctx, ...@@ -220,8 +237,8 @@ void gemm_impl(context& ctx,
c_stride, c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldd,
c_stride, d_stride,
num_matrices, num_matrices,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
......
...@@ -244,6 +244,15 @@ struct context ...@@ -244,6 +244,15 @@ struct context
return hip_event_ptr{event}; return hip_event_ptr{event};
} }
static hip_event_ptr create_event_for_timing()
{
hipEvent_t event;
auto status = hipEventCreate(&event);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to create event");
return hip_event_ptr{event};
}
value to_value() const value to_value() const
{ {
value result; value result;
...@@ -267,10 +276,49 @@ struct context ...@@ -267,10 +276,49 @@ struct context
any_ptr get_queue() { return get_stream().get(); } any_ptr get_queue() { return get_stream().get(); }
void enable_perf_measurement(bool b = true)
{
if(b)
{
start_event = create_event_for_timing();
stop_event = create_event_for_timing();
get_stream().record(start_event.get());
get_stream().record(stop_event.get());
}
else
{
start_event = nullptr;
stop_event = nullptr;
}
measure_perf = b;
}
std::pair<hipEvent_t, hipEvent_t> get_perf_events() const
{
if(measure_perf)
return std::make_pair(start_event.get(), stop_event.get());
return std::make_pair(nullptr, nullptr);
}
float get_elapsed_ms() const
{
float result = 0;
if(start_event != nullptr and stop_event != nullptr)
{
auto status = hipEventElapsedTime(&result, start_event.get(), stop_event.get());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed hipEventElapsedTime: " + hip_error(status));
}
return result;
}
private: private:
// TODO: Make this a vector to support multiple devices // TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device; std::shared_ptr<hip_device> current_device;
std::vector<shared<hip_event_ptr>> events; std::vector<shared<hip_event_ptr>> events;
bool measure_perf = false;
shared<hip_event_ptr> start_event = nullptr;
shared<hip_event_ptr> stop_event = nullptr;
}; };
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); } inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/op/gather.hpp> #include <migraphx/op/gather.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -42,15 +42,17 @@ namespace gpu { ...@@ -42,15 +42,17 @@ namespace gpu {
struct context; struct context;
void blas_shape(const shape& s); void blas_shape(const shape& s);
shape transpose_batch(const shape& s, unsigned trans_batch);
template <class Op> template <class Op>
struct rocblas_gemm struct rocblas_gemm
{ {
Op op; Op op;
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false; bool compute_fp32 = false;
unsigned trans_batch = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -58,7 +60,9 @@ struct rocblas_gemm ...@@ -58,7 +60,9 @@ struct rocblas_gemm
return pack_join(migraphx::reflect(self.op, f), return pack_join(migraphx::reflect(self.op, f),
pack(f(self.alpha, "alpha"), pack(f(self.alpha, "alpha"),
f(self.beta, "beta"), f(self.beta, "beta"),
f(self.int8_x4_format, "int8_x4_format"))); f(self.int8_x4_format, "int8_x4_format"),
f(self.compute_fp32, "compute_fp32"),
f(self.trans_batch, "trans_batch")));
} }
std::string name() const std::string name() const
...@@ -74,13 +78,14 @@ struct rocblas_gemm ...@@ -74,13 +78,14 @@ struct rocblas_gemm
{ {
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
in_shapes.pop_back(); in_shapes.pop_back();
check_shapes{in_shapes, *this}.not_broadcasted(); check_shapes{in_shapes, *this}.has(2, 3);
blas_shape(inputs[0]); blas_shape(inputs[0]);
blas_shape(inputs[1]); blas_shape(inputs[1]);
// if gemm and add are fused // if gemm and add are fused
if(in_shapes.size() > 2) if(in_shapes.size() > 2)
{ {
auto cmat_shape = in_shapes.back(); auto cmat_shape = in_shapes.back();
check_shapes{{cmat_shape}, *this}.not_transposed().not_broadcasted();
in_shapes.pop_back(); in_shapes.pop_back();
blas_shape(cmat_shape); blas_shape(cmat_shape);
auto op_out_shape = op.compute_shape(in_shapes); auto op_out_shape = op.compute_shape(in_shapes);
...@@ -97,10 +102,10 @@ struct rocblas_gemm ...@@ -97,10 +102,10 @@ struct rocblas_gemm
to_string(cmat_shape.type()) + to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type())); ", it must be: " + to_string(op_out_shape.type()));
} }
return op_out_shape; return transpose_batch(op_out_shape, trans_batch);
} }
return op.compute_shape(in_shapes); return transpose_batch(op.compute_shape(in_shapes), trans_batch);
} }
argument argument
......
...@@ -37,6 +37,8 @@ namespace gpu { ...@@ -37,6 +37,8 @@ namespace gpu {
struct context; struct context;
std::string hip_error(int error);
argument allocate_gpu(const shape& s, bool host = false); argument allocate_gpu(const shape& s, bool host = false);
argument register_on_gpu(const argument& arg); argument register_on_gpu(const argument& arg);
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP #define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <utility> #include <utility>
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP #define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <utility> #include <utility>
......
...@@ -50,17 +50,22 @@ struct kernel ...@@ -50,17 +50,22 @@ struct kernel
void launch(hipStream_t stream, void launch(hipStream_t stream,
std::size_t global, std::size_t global,
std::size_t local, std::size_t local,
const std::vector<kernel_argument>& args) const; const std::vector<kernel_argument>& args,
hipEvent_t start = nullptr,
hipEvent_t stop = nullptr) const;
void launch(hipStream_t stream, void launch(hipStream_t stream,
std::size_t global, std::size_t global,
std::size_t local, std::size_t local,
std::vector<void*> args) const; std::vector<void*> args,
hipEvent_t start = nullptr,
hipEvent_t stop = nullptr) const;
auto launch(hipStream_t stream, std::size_t global, std::size_t local) const template <class... Ts>
auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const
{ {
return [=](auto&&... xs) { return [=](auto&&... xs) {
launch(stream, global, local, std::vector<kernel_argument>{xs...}); launch(stream, global, local, std::vector<kernel_argument>{xs...}, zs...);
}; };
} }
......
...@@ -24,22 +24,10 @@ ...@@ -24,22 +24,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP #define MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/logsoftmax.hpp> #include <migraphx/op/logsoftmax.hpp>
#include <migraphx/generate.hpp> #include <migraphx/shape.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP #define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp> #include <string>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/op/reverse.hpp> #include <migraphx/op/reverse.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -24,22 +24,10 @@ ...@@ -24,22 +24,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP #define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/softmax.hpp> #include <migraphx/op/softmax.hpp>
#include <migraphx/generate.hpp> #include <migraphx/shape.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -25,8 +25,6 @@ ...@@ -25,8 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_GPU_SYNC_DEVICE_HPP #define MIGRAPHX_GUARD_RTGLIB_GPU_SYNC_DEVICE_HPP
#include <string> #include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment