Commit f7838bc8 authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into ck-elementwise

parents fea58a7b d78bcdfb
......@@ -47,7 +47,7 @@ constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value)
it = first;
step = count / 2;
std::advance(it, step);
if(!(value < *it))
if(not(value < *it))
{
first = ++it;
count -= step + 1;
......
......@@ -38,8 +38,11 @@ struct compile_op : action<compile_op>
context ctx;
auto inputs = p.parse_shapes(v.at("inputs"));
auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v);
double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl;
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms";
if(device_time > 0)
std::cout << ", " << device_time << "ms";
std::cout << std::endl;
}
};
......
......@@ -33,7 +33,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace driver {
double time_op(context& ctx, operation op, const std::vector<shape>& inputs, int n = 100);
std::pair<double, double>
time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n = 100);
} // namespace driver
} // namespace gpu
......
......@@ -42,22 +42,31 @@ std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsig
}
using milliseconds = std::chrono::duration<double, std::milli>;
double time_op(context& ctx, operation op, const std::vector<shape>& inputs, int n)
std::pair<double, double>
time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
{
// TODO: Use std::ref
migraphx::context gctx = ctx;
auto output = op.compute_shape(inputs);
op.finalize(gctx, output, inputs);
migraphx::context ctx = ictx;
auto& gctx = any_cast<migraphx::gpu::context>(ctx);
auto output = op.compute_shape(inputs);
op.finalize(ctx, output, inputs);
auto args = generate_arguments(inputs);
auto run = [&] {
op.compute(gctx, output, args);
gctx.finish();
op.compute(ctx, output, args);
ctx.finish();
};
gctx.enable_perf_measurement();
run();
auto r = range(n);
double t = std::accumulate(
r.begin(), r.end(), double{0.0}, [&](auto x, auto) { return x + time<milliseconds>(run); });
return t / n;
double host_time = 0.0;
double device_time = 0.0;
for(auto i : range(n))
{
(void)i;
host_time += time<milliseconds>(run);
device_time += gctx.get_elapsed_ms();
}
return std::make_pair(host_time / n, device_time / n);
}
} // namespace driver
......
......@@ -43,8 +43,8 @@ struct run_op : action<run_op>
auto op = make_op(name);
if(v.contains("fields"))
op.from_value(v.at("fields"));
double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl;
auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << host_time << "ms" << std::endl;
}
};
......
......@@ -26,7 +26,6 @@
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/oper.hpp>
......@@ -48,8 +47,8 @@
#include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/array.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <cmath>
#include <set>
......@@ -279,6 +278,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm>
{
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return inputs[0];
}
// Empty finalize to skip dimension reduction
void finalize(context&, const shape&, const std::vector<shape>&) {}
};
......@@ -943,28 +947,70 @@ struct find_gemm_add
}
};
auto pointwise_name(const std::string& s)
{
return precompile_name("pointwise")(match::make_basic_pred_matcher([=](auto ins) {
module_ref pm = ins->module_inputs().front();
auto n = std::count_if(pm->begin(), pm->end(), [&](auto& i) { return i.name() == s; });
if(n != 1)
return false;
return std::all_of(pm->begin(), pm->end(), [&](auto& i) {
return starts_with(i.name(), "@") or i.name() == s;
});
}));
}
struct find_gemm_pointwise
{
auto matcher() const
{
return pointwise_name("add")(
return precompile_name("pointwise")(
match::nargs(3),
match::all_of[match::inputs()](match::standard_shape()),
match::either_arg(0, 1)(match::used_once().bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
match::either_arg(0, 1)(
match::any_of(match::standard_shape(), match::is_constant()).bind("c"),
match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm")));
}
// TODO: Move to matcher.hpp
static auto match_param(const std::string& name)
{
return match::make_basic_pred_matcher([=](auto ins) {
if(ins->name() != "@param")
return false;
auto p = any_cast<builtin::param>(ins->get_operator());
return p.parameter == name;
});
}
template <class M>
static auto match_mul_const(M m, const std::string& var)
{
return match::name("mul")(match::either_arg(0, 1)(match::name("@literal").bind(var), m))
.bind(var + "_mul");
}
static auto match_add(const std::string& input, const std::string& output)
{
auto param = match::name("@param");
auto add = match::name("add")(match::args(param, param));
auto inner_mul = match::any_of(match_mul_const(match_param(input), "alpha"),
match_mul_const(match_param(output), "beta"));
auto mul_add = match::name("add")(match::either_arg(0, 1)(inner_mul, param));
auto add_mul = match_mul_const(add, "gamma");
return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul)));
}
static float get_float(instruction_ref ins) { return ins->get_literal().at<float>(); }
template <class Gemm>
static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input)
{
auto names = pm->get_parameter_names();
if(names.size() != 2)
return false;
std::sort(names.begin(), names.end());
unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction(
*pm, std::prev(pm->end()), match_add(names[input], names[output]));
if(mr.result == pm->end())
return false;
if(contains(mr.instructions, "alpha_mul"))
gemm.alpha *= get_float(mr.instructions["alpha"]);
else if(contains(mr.instructions, "beta_mul"))
gemm.beta *= get_float(mr.instructions["beta"]);
else if(contains(mr.instructions, "gamma_mul"))
{
gemm.alpha *= get_float(mr.instructions["gamma"]);
gemm.beta *= get_float(mr.instructions["gamma"]);
}
return true;
}
void apply(module& m, const match::matcher_result& r) const
......@@ -978,6 +1024,19 @@ struct find_gemm_pointwise
// Already fused gemm
if(not float_equal(gemm.beta, 0))
return;
gemm.beta = 1;
if(not update_gemm(
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return;
// const-fold input if not standard shape since rocblas can't handle it
if(not c_ins->get_shape().standard())
{
auto c = make_op("contiguous");
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
c_ins = m.add_literal(l.get_shape(), l.data());
}
auto inputs = gemm_ins->inputs();
inputs.pop_back();
......@@ -985,11 +1044,68 @@ struct find_gemm_pointwise
inputs.push_back(c_ins);
inputs.push_back(ins->inputs().back());
gemm.beta = 1;
m.replace_instruction(ins, gemm, inputs);
}
};
struct find_contiguous_tranpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm")))
.bind("transpose")));
}
template <class Vector>
static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j)
{
if(i >= perm.size() or j >= perm.size())
return false;
auto perm2 = perm;
std::iota(perm2.begin(), perm2.end(), 0);
std::swap(perm2[i], perm2[j]);
return perm2 == perm;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm = r.instructions["gemm"];
auto alloc = gemm->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);
if(perm.size() < 3)
return;
if(not is_swapped(perm, perm.size() - 3, perm.size() - 2))
return;
auto lens = gemm->get_shape().lens();
if(lens.size() > 3 and
not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; }))
return;
auto gemmv = gemm->get_operator().to_value();
gemmv["trans_batch"] = 1;
auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)};
auto new_alloc = m.insert_instruction(gemm, make_op("allocate", {{"shape", to_value(s)}}));
auto alloc_transpose =
m.insert_instruction(gemm, make_op("transpose", {{"permutation", perm}}), new_alloc);
auto inputs = gemm->inputs();
inputs.back() = alloc_transpose;
auto new_gemm = m.insert_instruction(gemm, make_op("gpu::gemm", gemmv), inputs);
auto gemm_transpoe = m.insert_instruction(gemm, transpose->get_operator(), new_gemm);
m.replace_instruction(ins, gemm_transpoe);
}
};
struct find_commutative_broadcast
{
auto matcher() const
......@@ -1091,6 +1207,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add{},
find_layernorm_pointwise{},
find_gemm_pointwise{},
find_contiguous_tranpose_gemm{},
find_commutative_broadcast{});
match::find_matches(m, find_contiguous{});
}
......
......@@ -24,6 +24,7 @@
#include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
}
shape transpose_batch(const shape& s, unsigned trans_batch)
{
if(trans_batch == 0)
return s;
if(s.lens().size() < 3)
return s;
auto batch = s.lens().size() - 3;
std::vector<int64_t> perm(s.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[batch], perm[batch + trans_batch]);
return shape::from_permutation(s.type(), s.lens(), perm);
}
template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs)
{
......@@ -97,6 +111,12 @@ void gemm_impl(context& ctx,
bool int8_x4_format,
bool compute_fp32)
{
const bool is_3inputs = (args.size() == 4);
if(not is_3inputs)
{
beta = 0;
}
bool transa = is_transposed(args[0].get_shape());
bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size();
......@@ -105,12 +125,8 @@ void gemm_impl(context& ctx,
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldd = is_3inputs ? args[3].get_shape().strides()[dim_0] : ldc;
bool is_3inputs = (args.size() == 4);
if(!is_3inputs)
{
beta = 0;
}
rocblas_datatype arg_type = get_type(args[0].get_shape().type());
auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
......@@ -186,7 +202,7 @@ void gemm_impl(context& ctx,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
ldd,
compute_type,
rocblas_gemm_algo_standard,
0,
......@@ -197,6 +213,7 @@ void gemm_impl(context& ctx,
auto a_stride = get_batch_stride(args[0]);
auto b_stride = get_batch_stride(args[1]);
auto c_stride = get_batch_stride(args[2]);
auto d_stride = is_3inputs ? get_batch_stride(args[3]) : c_stride;
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
......@@ -220,8 +237,8 @@ void gemm_impl(context& ctx,
c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
c_stride,
ldd,
d_stride,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
......
......@@ -244,6 +244,15 @@ struct context
return hip_event_ptr{event};
}
static hip_event_ptr create_event_for_timing()
{
hipEvent_t event;
auto status = hipEventCreate(&event);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to create event");
return hip_event_ptr{event};
}
value to_value() const
{
value result;
......@@ -267,10 +276,49 @@ struct context
any_ptr get_queue() { return get_stream().get(); }
void enable_perf_measurement(bool b = true)
{
if(b)
{
start_event = create_event_for_timing();
stop_event = create_event_for_timing();
get_stream().record(start_event.get());
get_stream().record(stop_event.get());
}
else
{
start_event = nullptr;
stop_event = nullptr;
}
measure_perf = b;
}
std::pair<hipEvent_t, hipEvent_t> get_perf_events() const
{
if(measure_perf)
return std::make_pair(start_event.get(), stop_event.get());
return std::make_pair(nullptr, nullptr);
}
float get_elapsed_ms() const
{
float result = 0;
if(start_event != nullptr and stop_event != nullptr)
{
auto status = hipEventElapsedTime(&result, start_event.get(), stop_event.get());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed hipEventElapsedTime: " + hip_error(status));
}
return result;
}
private:
// TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device;
std::vector<shared<hip_event_ptr>> events;
bool measure_perf = false;
shared<hip_event_ptr> start_event = nullptr;
shared<hip_event_ptr> stop_event = nullptr;
};
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
......
......@@ -27,7 +27,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -42,15 +42,17 @@ namespace gpu {
struct context;
void blas_shape(const shape& s);
shape transpose_batch(const shape& s, unsigned trans_batch);
template <class Op>
struct rocblas_gemm
{
Op op;
float alpha = 1;
float beta = 0;
bool int8_x4_format = true;
bool compute_fp32 = false;
float alpha = 1;
float beta = 0;
bool int8_x4_format = true;
bool compute_fp32 = false;
unsigned trans_batch = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -58,7 +60,9 @@ struct rocblas_gemm
return pack_join(migraphx::reflect(self.op, f),
pack(f(self.alpha, "alpha"),
f(self.beta, "beta"),
f(self.int8_x4_format, "int8_x4_format")));
f(self.int8_x4_format, "int8_x4_format"),
f(self.compute_fp32, "compute_fp32"),
f(self.trans_batch, "trans_batch")));
}
std::string name() const
......@@ -74,13 +78,14 @@ struct rocblas_gemm
{
std::vector<shape> in_shapes(inputs);
in_shapes.pop_back();
check_shapes{in_shapes, *this}.not_broadcasted();
check_shapes{in_shapes, *this}.has(2, 3);
blas_shape(inputs[0]);
blas_shape(inputs[1]);
// if gemm and add are fused
if(in_shapes.size() > 2)
{
auto cmat_shape = in_shapes.back();
check_shapes{{cmat_shape}, *this}.not_transposed().not_broadcasted();
in_shapes.pop_back();
blas_shape(cmat_shape);
auto op_out_shape = op.compute_shape(in_shapes);
......@@ -97,10 +102,10 @@ struct rocblas_gemm
to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type()));
}
return op_out_shape;
return transpose_batch(op_out_shape, trans_batch);
}
return op.compute_shape(in_shapes);
return transpose_batch(op.compute_shape(in_shapes), trans_batch);
}
argument
......
......@@ -37,6 +37,8 @@ namespace gpu {
struct context;
std::string hip_error(int error);
argument allocate_gpu(const shape& s, bool host = false);
argument register_on_gpu(const argument& arg);
......
......@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <utility>
......
......@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <utility>
......
......@@ -50,17 +50,22 @@ struct kernel
void launch(hipStream_t stream,
std::size_t global,
std::size_t local,
const std::vector<kernel_argument>& args) const;
const std::vector<kernel_argument>& args,
hipEvent_t start = nullptr,
hipEvent_t stop = nullptr) const;
void launch(hipStream_t stream,
std::size_t global,
std::size_t local,
std::vector<void*> args) const;
std::vector<void*> args,
hipEvent_t start = nullptr,
hipEvent_t stop = nullptr) const;
auto launch(hipStream_t stream, std::size_t global, std::size_t local) const
template <class... Ts>
auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const
{
return [=](auto&&... xs) {
launch(stream, global, local, std::vector<kernel_argument>{xs...});
launch(stream, global, local, std::vector<kernel_argument>{xs...}, zs...);
};
}
......
......@@ -24,22 +24,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -26,7 +26,7 @@
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -25,7 +25,7 @@
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -27,7 +27,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/reverse.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,22 +24,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -25,8 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_GPU_SYNC_DEVICE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment