Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
#include <migraphx/gpu/compiler.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
auto& compiler_map()
{
static std::unordered_map<std::string, compiler_compile> m; // NOLINT
return m;
}
auto& compiler_op_map()
{
static std::unordered_map<std::string, compiler_compile_op> m; // NOLINT
return m;
}
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop)
{
compiler_map()[name] = std::move(c);
compiler_op_map()[name] = std::move(cop);
}
bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op)
{
return compiler_map().at(op.name())(ctx, ins, op);
}
operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v)
{
return compiler_op_map().at(name)(ctx, inputs, v);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -57,9 +57,10 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal) ...@@ -57,9 +57,10 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal)
{ {
assert(s.standard); assert(s.standard);
assert(s.elements() > 0); assert(s.elements() > 0);
index_int n = s.elements(); index_int n = s.elements();
index_int groups = (n + nlocal - 1) / nlocal; index_int groups = (n + nlocal - 1) / nlocal;
index_int nglobal = std::min<index_int>(128, groups) * nlocal; // max possible number of blocks is set to 1B (1,073,741,824)
index_int nglobal = std::min<index_int>(1073741824, groups) * nlocal;
assert(groups > 0); assert(groups > 0);
assert(nglobal > 0); assert(nglobal > 0);
......
...@@ -44,12 +44,19 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, ...@@ -44,12 +44,19 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input,
template <index_int N, class Op, class T, class Input, class Output> template <index_int N, class Op, class T, class Input, class Output>
__device__ void block_scan(index idx, Op op, T init, index_int n, Input input, Output output) __device__ void block_scan(index idx, Op op, T init, index_int n, Input input, Output output)
{ {
block_scan<N>(idx, block_scan<N>(
op, idx,
init, op,
[&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); }, init,
input, [&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); },
output); input,
output);
}
template <class F>
constexpr auto reverse_scan(index_int n, F f)
{
return [=](auto i, auto&&... xs) { return f(n - i - 1, xs...); };
} }
} // namespace device } // namespace device
......
...@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f) ...@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f)
{ {
switch(n) switch(n)
{ {
case 1: case 1: {
{
f(std::integral_constant<index_int, 1>{}); f(std::integral_constant<index_int, 1>{});
break; break;
} }
case 2: case 2: {
{
f(std::integral_constant<index_int, 2>{}); f(std::integral_constant<index_int, 2>{});
break; break;
} }
case 3: case 3: {
{
f(std::integral_constant<index_int, 3>{}); f(std::integral_constant<index_int, 3>{});
break; break;
} }
case 4: case 4: {
{
f(std::integral_constant<index_int, 4>{}); f(std::integral_constant<index_int, 4>{});
break; break;
} }
case 5: case 5: {
{
f(std::integral_constant<index_int, 5>{}); f(std::integral_constant<index_int, 5>{});
break; break;
} }
...@@ -181,7 +176,13 @@ template <index_int N, class T, class... Ts> ...@@ -181,7 +176,13 @@ template <index_int N, class T, class... Ts>
auto hip_vec_visit_all(T&& x, Ts&&... xs) auto hip_vec_visit_all(T&& x, Ts&&... xs)
{ {
return [&](auto f) { return [&](auto f) {
hip_visit_all_impl(get_shape(x), auto sx = get_shape(x);
auto lens = sx.lens();
assert(lens.back() % N == 0);
assert(sx.strides().back() == 1);
lens.back() /= N;
shape vec_sx{sx.type(), lens};
hip_visit_all_impl(vec_sx,
make_hip_convert([](auto* p) { return as_vec<N>(device_cast(p)); }), make_hip_convert([](auto* p) { return as_vec<N>(device_cast(p)); }),
f, f,
x, x,
......
...@@ -25,22 +25,23 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg ...@@ -25,22 +25,23 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg
// fill all output to 0 first // fill all output to 0 first
idx.local_stride(out_elem_num, [&](auto j) { ptr[j] = 0; }); idx.local_stride(out_elem_num, [&](auto j) { ptr[j] = 0; });
block_scan<block_size>(idx, block_scan<block_size>(
sum{}, idx,
0, sum{},
elem_num, 0,
[&](auto j) { return (float_equal(in_ptr[j], 0)) ? 0 : 1; }, elem_num,
[&](auto j, auto x) { [&](auto j) { return (float_equal(in_ptr[j], 0)) ? 0 : 1; },
auto out_loc = x - 1; [&](auto j, auto x) {
if(float_equal(in_ptr[j], 0)) auto out_loc = x - 1;
return; if(float_equal(in_ptr[j], 0))
return;
auto index = si.multi(j); auto index = si.multi(j);
for(size_t k = 0; k < index.size(); ++k) for(size_t k = 0; k < index.size(); ++k)
{ {
ptr[k * elem_num + out_loc] = index[k]; ptr[k * elem_num + out_loc] = index[k];
} }
}); });
}); });
}); });
......
#include <migraphx/gpu/device/prefix_scan_sum.hpp> #include <migraphx/gpu/device/prefix_scan_sum.hpp>
#include <migraphx/gpu/device/scan.hpp> #include <migraphx/gpu/device/scan.hpp>
#include <migraphx/gpu/device/reduce_ops.hpp> #include <migraphx/gpu/device/reduce_ops.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
namespace migraphx { namespace migraphx {
...@@ -8,29 +9,108 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -8,29 +9,108 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& arg, int32_t axis) void prefix_scan_sum(hipStream_t stream,
const argument& result,
const argument& arg,
int32_t axis,
bool exclusive,
bool reverse)
{ {
const index_int block_size = 256; const index_int max_block_size = 256;
const index_int n = arg.get_shape().lens()[axis]; const index_int n = arg.get_shape().lens()[axis];
auto rlens = result.get_shape().lens(); auto rlens = result.get_shape().lens();
rlens[axis] = 1; rlens[axis] = 1;
hip_visit_all(result, arg, result.get_shape().with_lens(rlens))( hip_visit_all(result, arg, result.get_shape().with_lens(rlens))(
[=](auto output, auto input, auto rshape) { [=](auto output, auto input, auto rshape) {
gs_launch(stream, rshape.elements() * block_size, block_size)( const index_int block_size = compute_block_size(rshape.elements(), max_block_size);
[=](auto i, auto idx) __device__ { if(reverse and exclusive)
const auto ridx = rshape.multi(i / block_size); {
auto compute_idx = [&](auto j) { gs_launch(stream, rshape.elements() * block_size, block_size)(
auto k = ridx; [=](auto i, auto idx) __device__ {
k[axis] = j; const auto ridx = rshape.multi(i / block_size);
return k; auto compute_idx = [&](auto j) {
}; auto k = ridx;
block_scan<block_size>(idx, k[axis] = j;
sum{}, return k;
0, };
n, block_scan<max_block_size>(
[&](auto j) { return input[compute_idx(j)]; }, idx,
[&](auto j, auto x) { output[compute_idx(j)] = x; }); sum{},
}); 0,
n,
reverse_scan(n, [&](auto j) { return input[compute_idx(j)]; }),
reverse_scan(n, [&](auto j, auto x) {
if(j == n - 1)
output[compute_idx(j)] = 0;
if(j > 0)
output[compute_idx(j - 1)] = x;
}));
});
}
else if(reverse)
{
gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size);
auto compute_idx = [&](auto j) {
auto k = ridx;
k[axis] = j;
return k;
};
block_scan<max_block_size>(
idx,
sum{},
0,
n,
reverse_scan(n, [&](auto j) { return input[compute_idx(j)]; }),
reverse_scan(n, [&](auto j, auto x) { output[compute_idx(j)] = x; }));
});
}
else if(exclusive)
{
gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size);
auto compute_idx = [&](auto j) {
auto k = ridx;
k[axis] = j;
return k;
};
block_scan<max_block_size>(
idx,
sum{},
0,
n,
[&](auto j) { return input[compute_idx(j)]; },
[&](auto j, auto x) {
auto k = j + 1;
if(j == 0)
output[compute_idx(0)] = 0;
if(k < n)
output[compute_idx(k)] = x;
});
});
}
else
{
gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size);
auto compute_idx = [&](auto j) {
auto k = ridx;
k[axis] = j;
return k;
};
block_scan<max_block_size>(
idx,
sum{},
0,
n,
[&](auto j) { return input[compute_idx(j)]; },
[&](auto j, auto x) { output[compute_idx(j)] = x; });
});
}
}); });
} }
......
file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
add_executable(gpu-driver add_executable(gpu-driver
action.cpp ${GPU_DRIVER_SRCS}
compile_pointwise.cpp
main.cpp
parser.cpp
perf.cpp
run_op.cpp
) )
target_include_directories(gpu-driver PRIVATE include) target_include_directories(gpu-driver PRIVATE include)
target_link_libraries(gpu-driver PRIVATE migraphx_gpu) target_link_libraries(gpu-driver PRIVATE migraphx_gpu)
#include <migraphx/gpu/driver/action.hpp> #include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/perf.hpp> #include <migraphx/gpu/driver/perf.hpp>
#include <migraphx/gpu/compile_pointwise.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
...@@ -8,13 +8,13 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -8,13 +8,13 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace driver { namespace driver {
struct compile_pointwise : action<compile_pointwise> struct compile_op : action<compile_op>
{ {
static void apply(const parser& p, const value& v) static void apply(const parser& p, const value& v)
{ {
context ctx; context ctx;
auto inputs = p.parse_shapes(v.at("inputs")); auto inputs = p.parse_shapes(v.at("inputs"));
auto op = gpu::compile_pointwise(ctx, inputs, v.at("lambda").to<std::string>()); auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v);
double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl; std::cout << op << ": " << t << "ms" << std::endl;
} }
......
...@@ -17,8 +17,10 @@ struct run_op : action<run_op> ...@@ -17,8 +17,10 @@ struct run_op : action<run_op>
auto name = v.at("name").to<std::string>(); auto name = v.at("name").to<std::string>();
if(not contains(name, "::")) if(not contains(name, "::"))
name = "gpu::" + name; name = "gpu::" + name;
auto op = make_op(name); auto op = make_op(name);
double t = time_op(ctx, op, inputs); if(v.contains("fields"))
op.from_value(v.at("fields"));
double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl; std::cout << op << ": " << t << "ms" << std::endl;
} }
}; };
......
...@@ -11,11 +11,11 @@ namespace migraphx { ...@@ -11,11 +11,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
void eliminate_workspace::apply(module& p) const void eliminate_workspace::apply(module& m) const
{ {
std::size_t n = 0; std::size_t n = 0;
std::vector<instruction_ref> allocs; std::vector<instruction_ref> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->outputs().size() != 1) if(ins->outputs().size() != 1)
continue; continue;
...@@ -30,11 +30,11 @@ void eliminate_workspace::apply(module& p) const ...@@ -30,11 +30,11 @@ void eliminate_workspace::apply(module& p) const
} }
if(n > 0) if(n > 0)
{ {
auto ws = p.add_parameter("workspace", shape{shape::int8_type, {n}}); auto ws = m.add_parameter("workspace", shape{shape::int8_type, {n}});
for(auto&& a : allocs) for(auto&& a : allocs)
{ {
p.replace_instruction(a, ws); m.replace_instruction(a, ws);
p.remove_instruction(a); m.remove_instruction(a);
} }
} }
} }
......
...@@ -316,7 +316,7 @@ struct find_layernorm ...@@ -316,7 +316,7 @@ struct find_layernorm
{ {
auto matcher() const { return match::layernorm(&gpu_name); } auto matcher() const { return match::layernorm(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -331,7 +331,7 @@ struct find_layernorm ...@@ -331,7 +331,7 @@ struct find_layernorm
if(relements > 1024 or (relements % 4 != 0 and relements > 256)) if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return; return;
p.replace_instruction(ins, hip_layernorm{}, x_ins, args.back()); m.replace_instruction(ins, hip_layernorm{}, x_ins, args.back());
} }
}; };
...@@ -343,11 +343,11 @@ struct find_triadd_layernorm ...@@ -343,11 +343,11 @@ struct find_triadd_layernorm
match::used_once(), match::all_of[match::inputs()](match::standard_shape())))); match::used_once(), match::all_of[match::inputs()](match::standard_shape()))));
} }
void apply(module& p, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto triadd = ins->inputs().front(); auto triadd = ins->inputs().front();
p.replace_instruction(ins, hip_triadd_layernorm{}, triadd->inputs()); m.replace_instruction(ins, hip_triadd_layernorm{}, triadd->inputs());
} }
}; };
...@@ -355,13 +355,13 @@ struct find_gelu ...@@ -355,13 +355,13 @@ struct find_gelu
{ {
auto matcher() const { return match::gelu_erf(&gpu_name); } auto matcher() const { return match::gelu_erf(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto args = ins->inputs(); auto args = ins->inputs();
p.replace_instruction(ins, hip_gelu{}, x_ins, args.back()); m.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
} }
}; };
...@@ -372,7 +372,7 @@ struct find_add_gelu ...@@ -372,7 +372,7 @@ struct find_add_gelu
return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add"))); return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -381,7 +381,7 @@ struct find_add_gelu ...@@ -381,7 +381,7 @@ struct find_add_gelu
move_broadcasted_back(args); move_broadcasted_back(args);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_gelu{}, args); m.replace_instruction(ins, hip_add_gelu{}, args);
} }
}; };
...@@ -391,16 +391,16 @@ struct find_gelu_new ...@@ -391,16 +391,16 @@ struct find_gelu_new
auto matcher() const { return match::gelu_tanh(&gpu_name); } auto matcher() const { return match::gelu_tanh(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto args = ins->inputs(); auto args = ins->inputs();
if(fast_math) if(fast_math)
p.replace_instruction(ins, hip_gelu{}, x_ins, args.back()); m.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
else else
p.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back()); m.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back());
} }
}; };
...@@ -411,7 +411,7 @@ struct find_add_gelu_new ...@@ -411,7 +411,7 @@ struct find_add_gelu_new
return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add"))); return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -420,7 +420,7 @@ struct find_add_gelu_new ...@@ -420,7 +420,7 @@ struct find_add_gelu_new
move_broadcasted_back(args); move_broadcasted_back(args);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_gelu_new{}, args); m.replace_instruction(ins, hip_add_gelu_new{}, args);
} }
}; };
...@@ -435,7 +435,7 @@ struct find_add_clip ...@@ -435,7 +435,7 @@ struct find_add_clip
.bind("add"))); .bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -448,9 +448,9 @@ struct find_add_clip ...@@ -448,9 +448,9 @@ struct find_add_clip
add_args.pop_back(); add_args.pop_back();
add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end()); add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end());
if(add_ins->name() == "gpu::add") if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, hip_add_clip{}, add_args); m.replace_instruction(ins, hip_add_clip{}, add_args);
else if(add_ins->name() == "gpu::triadd") else if(add_ins->name() == "gpu::triadd")
p.replace_instruction(ins, hip_triadd_clip{}, add_args); m.replace_instruction(ins, hip_triadd_clip{}, add_args);
} }
}; };
...@@ -470,7 +470,7 @@ struct find_add_unary ...@@ -470,7 +470,7 @@ struct find_add_unary
.bind("add"))); .bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -481,9 +481,9 @@ struct find_add_unary ...@@ -481,9 +481,9 @@ struct find_add_unary
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
if(add_ins->name() == "gpu::add") if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, binary_add_op, args); m.replace_instruction(ins, binary_add_op, args);
else if(add_ins->name() == "gpu::triadd") else if(add_ins->name() == "gpu::triadd")
p.replace_instruction(ins, ternary_add_op, args); m.replace_instruction(ins, ternary_add_op, args);
} }
}; };
...@@ -498,7 +498,7 @@ struct find_triadd ...@@ -498,7 +498,7 @@ struct find_triadd
.bind("input"))); .bind("input")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto input_ins = r.instructions["input"]; auto input_ins = r.instructions["input"];
...@@ -513,7 +513,7 @@ struct find_triadd ...@@ -513,7 +513,7 @@ struct find_triadd
move_broadcasted_back(args); move_broadcasted_back(args);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_triadd{}, args); m.replace_instruction(ins, hip_triadd{}, args);
} }
}; };
...@@ -525,7 +525,7 @@ struct find_mul_add ...@@ -525,7 +525,7 @@ struct find_mul_add
match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b"))); match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto mul_ins = r.instructions["mul"]; auto mul_ins = r.instructions["mul"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
...@@ -538,7 +538,7 @@ struct find_mul_add ...@@ -538,7 +538,7 @@ struct find_mul_add
args.insert(std::prev(args.end()), b_ins); args.insert(std::prev(args.end()), b_ins);
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add{}, args); m.replace_instruction(ins, hip_mul_add{}, args);
} }
}; };
...@@ -550,7 +550,7 @@ struct find_mul_add_relu ...@@ -550,7 +550,7 @@ struct find_mul_add_relu
match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add"))); match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto mul_add_ins = r.instructions["mul_add"]; auto mul_add_ins = r.instructions["mul_add"];
auto ins = r.result; auto ins = r.result;
...@@ -558,7 +558,7 @@ struct find_mul_add_relu ...@@ -558,7 +558,7 @@ struct find_mul_add_relu
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add_relu{}, args); m.replace_instruction(ins, hip_mul_add_relu{}, args);
} }
}; };
...@@ -587,6 +587,11 @@ struct miopen_fusion ...@@ -587,6 +587,11 @@ struct miopen_fusion
return pack(f(self.ops, "ops")); return pack(f(self.ops, "ops"));
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
value compile(context& ctx, const shape&, std::vector<shape> inputs) value compile(context& ctx, const shape&, std::vector<shape> inputs)
{ {
// Compensate for allocation // Compensate for allocation
...@@ -676,7 +681,7 @@ struct miopen_fusion ...@@ -676,7 +681,7 @@ struct miopen_fusion
struct miopen_conv_bias struct miopen_conv_bias
{ {
op::convolution op; op::convolution op;
fusion f = {}; fusion fp = {};
fusion::op_t conv = {}; fusion::op_t conv = {};
fusion::op_t bias = {}; fusion::op_t bias = {};
...@@ -700,19 +705,19 @@ struct miopen_conv_bias ...@@ -700,19 +705,19 @@ struct miopen_conv_bias
float beta = 0; float beta = 0;
miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit()); miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit()); miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
return f.execute(ctx, fargs, args[0], args[4]); return fp.execute(ctx, fargs, args[0], args[4]);
} }
void finalize(context& ctx, const shape&, const std::vector<shape>& inputs) void finalize(context& ctx, const shape&, const std::vector<shape>& inputs)
{ {
f = fusion(inputs[0]); fp = fusion(inputs[0]);
conv = f.create_conv(op, inputs[1]); conv = fp.create_conv(op, inputs[1]);
bias = f.create_bias(inputs[3]); bias = fp.create_bias(inputs[3]);
if(not f.compile(ctx)) if(not fp.compile(ctx))
MIGRAPHX_THROW("Failed to compile fusion plan"); MIGRAPHX_THROW("Failed to compile fusion plan");
} }
shape get_workspace(context& ctx) { return f.get_workspace(ctx); } shape get_workspace(context& ctx) { return fp.get_workspace(ctx); }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
...@@ -723,7 +728,7 @@ MIGRAPHX_REGISTER_OP(miopen_conv_bias) ...@@ -723,7 +728,7 @@ MIGRAPHX_REGISTER_OP(miopen_conv_bias)
struct miopen_conv_bias_relu struct miopen_conv_bias_relu
{ {
op::convolution op; op::convolution op;
fusion f = {}; fusion fp = {};
fusion::op_t conv = {}; fusion::op_t conv = {};
fusion::op_t bias = {}; fusion::op_t bias = {};
fusion::op_t relu = {}; fusion::op_t relu = {};
...@@ -749,18 +754,18 @@ struct miopen_conv_bias_relu ...@@ -749,18 +754,18 @@ struct miopen_conv_bias_relu
miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit()); miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit()); miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0); miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0);
return f.execute(ctx, fargs, args[0], args[4]); return fp.execute(ctx, fargs, args[0], args[4]);
} }
void finalize(context& ctx, const shape&, const std::vector<shape>& inputs) void finalize(context& ctx, const shape&, const std::vector<shape>& inputs)
{ {
f = fusion(inputs[0]); fp = fusion(inputs[0]);
conv = f.create_conv(op, inputs[1]); conv = fp.create_conv(op, inputs[1]);
bias = f.create_bias(inputs[3]); bias = fp.create_bias(inputs[3]);
relu = f.create_relu(); relu = fp.create_relu();
f.compile(ctx); fp.compile(ctx);
} }
shape get_workspace(context& ctx) { return f.get_workspace(ctx); } shape get_workspace(context& ctx) { return fp.get_workspace(ctx); }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
...@@ -778,7 +783,7 @@ auto conv_bias(Ms... ms) ...@@ -778,7 +783,7 @@ auto conv_bias(Ms... ms)
} }
template <class Op> template <class Op>
void apply_conv_bias(context& ctx, module& p, match::matcher_result r) void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
{ {
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"]; auto bias_ins = r.instructions["bias"];
...@@ -793,7 +798,7 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r) ...@@ -793,7 +798,7 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
// TODO: Insert ws allocation // TODO: Insert ws allocation
auto ws = cb.get_workspace(ctx); auto ws = cb.get_workspace(ctx);
(void)ws; (void)ws;
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
inline auto precompile_name(std::string s) // NOLINT inline auto precompile_name(std::string s) // NOLINT
...@@ -824,9 +829,9 @@ struct find_conv_bias ...@@ -824,9 +829,9 @@ struct find_conv_bias
match::output(match::name(std::unordered_set<std::string>{"gpu::relu"})))); match::output(match::name(std::unordered_set<std::string>{"gpu::relu"}))));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
apply_conv_bias<miopen_conv_bias>(*ctx, p, std::move(r)); apply_conv_bias<miopen_conv_bias>(*ctx, m, r);
} }
}; };
...@@ -835,9 +840,9 @@ struct find_conv_bias_relu ...@@ -835,9 +840,9 @@ struct find_conv_bias_relu
context* ctx = nullptr; context* ctx = nullptr;
auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); } auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r)); apply_conv_bias<miopen_conv_bias_relu>(*ctx, m, r);
} }
}; };
...@@ -852,7 +857,7 @@ struct find_conv_pointwise ...@@ -852,7 +857,7 @@ struct find_conv_pointwise
fusable_conv(match::used_once()).bind("conv"))); fusable_conv(match::used_once()).bind("conv")));
} }
void apply(module& m, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"]; auto bias_ins = r.instructions["bias"];
...@@ -870,7 +875,6 @@ struct find_conv_pointwise ...@@ -870,7 +875,6 @@ struct find_conv_pointwise
{ {
if(i.name()[0] == '@') if(i.name()[0] == '@')
continue; continue;
auto inputs = to_shapes(i.inputs());
op.ops.push_back({{i.get_operator()}}); op.ops.push_back({{i.get_operator()}});
} }
std::vector<instruction_ref> inputs = {input_ins, weights_ins, bias_ins, alloc_ins}; std::vector<instruction_ref> inputs = {input_ins, weights_ins, bias_ins, alloc_ins};
...@@ -891,7 +895,7 @@ struct find_gemm_add ...@@ -891,7 +895,7 @@ struct find_gemm_add
match::name("gpu::gemm")(match::nargs(3)).bind("gemm"))); match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto gemm_ins = r.instructions["gemm"]; auto gemm_ins = r.instructions["gemm"];
...@@ -903,26 +907,68 @@ struct find_gemm_add ...@@ -903,26 +907,68 @@ struct find_gemm_add
if(not float_equal(gemm.beta, 0)) if(not float_equal(gemm.beta, 0))
return; return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) {
return not i->get_shape().standard();
}))
return;
auto inputs = gemm_ins->inputs(); auto inputs = gemm_ins->inputs();
inputs.pop_back(); inputs.pop_back();
auto copy_ins = c_ins; auto copy_ins = c_ins;
// Insert copy // Insert copy
if(ins == p.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty()) if(ins == m.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty())
{ {
copy_ins = p.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back()); copy_ins = m.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back());
} }
inputs.push_back(copy_ins); inputs.push_back(copy_ins);
inputs.push_back(copy_ins); inputs.push_back(copy_ins);
gemm.beta = 1; gemm.beta = 1;
p.replace_instruction(ins, gemm, inputs); m.replace_instruction(ins, gemm, inputs);
}
};
auto pointwise_name(const std::string& s)
{
return precompile_name("pointwise")(match::make_basic_pred_matcher([=](auto ins) {
module_ref pm = ins->module_inputs().front();
auto n = std::count_if(pm->begin(), pm->end(), [&](auto& i) { return i.name() == s; });
if(n != 1)
return false;
return std::all_of(pm->begin(), pm->end(), [&](auto& i) {
return starts_with(i.name(), "@") or i.name() == s;
});
}));
}
struct find_gemm_pointwise
{
auto matcher() const
{
return pointwise_name("add")(
match::nargs(3),
match::all_of[match::inputs()](match::standard_shape()),
match::either_arg(0, 1)(match::used_once().bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto c_ins = r.instructions["c"];
auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());
// Already fused gemm
if(not float_equal(gemm.beta, 0))
return;
auto inputs = gemm_ins->inputs();
inputs.pop_back();
inputs.push_back(c_ins);
inputs.push_back(ins->inputs().back());
gemm.beta = 1;
m.replace_instruction(ins, gemm, inputs);
} }
}; };
...@@ -933,22 +979,22 @@ struct find_commutative_broadcast ...@@ -933,22 +979,22 @@ struct find_commutative_broadcast
return match::name("gpu::add", "gpu::mul")(match::arg(1)(match::broadcast_shape())); return match::name("gpu::add", "gpu::mul")(match::arg(1)(match::broadcast_shape()));
} }
void apply(module& p, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto args = ins->inputs(); auto args = ins->inputs();
move_broadcasted_back(args); move_broadcasted_back(args);
p.replace_instruction(ins, ins->get_operator(), args); m.replace_instruction(ins, ins->get_operator(), args);
} }
}; };
void fuse_ops::apply(module& p) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(p, find_gelu{}, find_gelu_new{fast_math}); match::find_matches(m, find_gelu{}, find_gelu_new{fast_math});
run_passes(p, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(p, find_triadd{}); match::find_matches(m, find_triadd{});
match::find_matches(p, match::find_matches(m,
find_layernorm{}, find_layernorm{},
find_conv_pointwise{ctx}, find_conv_pointwise{ctx},
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
...@@ -961,8 +1007,12 @@ void fuse_ops::apply(module& p) const ...@@ -961,8 +1007,12 @@ void fuse_ops::apply(module& p) const
find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}}, find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}, find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
find_add_clip{}); find_add_clip{});
run_passes(p, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(p, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{}); match::find_matches(m,
find_triadd_layernorm{},
find_gemm_add{},
find_gemm_pointwise{},
find_commutative_broadcast{});
} }
} // namespace gpu } // namespace gpu
......
#include <rocblas.h> #include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -27,6 +28,22 @@ rocblas_datatype get_type(shape::type_t type) ...@@ -27,6 +28,22 @@ rocblas_datatype get_type(shape::type_t type)
MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!"); MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
} }
void blas_shape(const shape& s)
{
if(s.lens().size() < 2)
return;
if(std::none_of(s.strides().end() - 2, s.strides().end(), [&](auto i) { return i == 1; }))
MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1");
if(s.lens().size() < 3)
return;
shape batch_shape{s.type(),
{s.lens().begin(), s.lens().end() - 2},
{s.strides().begin(), s.strides().end() - 2}};
auto batch_shapes = reduce_dims({batch_shape});
if(batch_shapes.front().lens().size() != 1)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
}
template <class R, class... Ts, class... Us> template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs) R rocblas_invoke(R (*f)(Ts...), Us... xs)
{ {
...@@ -36,16 +53,29 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs) ...@@ -36,16 +53,29 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs)
return f(xs..., nullptr, nullptr); return f(xs..., nullptr, nullptr);
} }
static bool is_transposed(const shape& s)
{
if(not s.transposed())
return false;
return s.strides().back() != 1;
}
static rocblas_int get_batch_stride(const argument& a)
{
return a.get_shape().strides()[a.get_shape().strides().size() - 3];
}
template <class T> template <class T>
void gemm_impl(context& ctx, void gemm_impl(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
T alpha, T alpha,
T beta, T beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
bool transa = args[0].get_shape().transposed(); bool transa = is_transposed(args[0].get_shape());
bool transb = args[1].get_shape().transposed(); bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1; auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2; auto dim_0 = n_dim - 2;
...@@ -65,6 +95,11 @@ void gemm_impl(context& ctx, ...@@ -65,6 +95,11 @@ void gemm_impl(context& ctx,
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
} }
auto compute_type = output_type; auto compute_type = output_type;
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 #if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag = rocblas_gemm_flags flag =
...@@ -77,8 +112,19 @@ void gemm_impl(context& ctx, ...@@ -77,8 +112,19 @@ void gemm_impl(context& ctx,
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha); auto alpha_r = as(alpha);
auto beta_r = as(beta); auto beta_r = as(beta);
// use void pointer to select different data type if using fp32 mode
void* alpha_v = &alpha_r;
void* beta_v = &beta_r;
if(compute_fp32)
{
alpha_v = &alpha;
beta_v = &beta;
}
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
...@@ -104,14 +150,14 @@ void gemm_impl(context& ctx, ...@@ -104,14 +150,14 @@ void gemm_impl(context& ctx,
n, n,
m, m,
k, k,
&alpha_r, alpha_v,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
to_pointer(args.at(0)), to_pointer(args.at(0)),
arg_type, arg_type,
lda, lda,
&beta_r, beta_v,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
...@@ -125,6 +171,9 @@ void gemm_impl(context& ctx, ...@@ -125,6 +171,9 @@ void gemm_impl(context& ctx,
} }
else else
{ {
auto a_stride = get_batch_stride(args[0]);
auto b_stride = get_batch_stride(args[1]);
auto c_stride = get_batch_stride(args[2]);
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -132,24 +181,24 @@ void gemm_impl(context& ctx, ...@@ -132,24 +181,24 @@ void gemm_impl(context& ctx,
n, n,
m, m,
k, k,
&alpha_r, alpha_v,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
k * n, b_stride,
to_pointer(args.at(0)), to_pointer(args.at(0)),
arg_type, arg_type,
lda, lda,
m * k, a_stride,
&beta_r, beta_v,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
m * n, c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
m * n, c_stride,
num_matrices, num_matrices,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
...@@ -164,9 +213,10 @@ void gemm(context& ctx, ...@@ -164,9 +213,10 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta, float beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
void gemm(context& ctx, void gemm(context& ctx,
...@@ -174,9 +224,10 @@ void gemm(context& ctx, ...@@ -174,9 +224,10 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta, int32_t beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
} // namespace gpu } // namespace gpu
......
...@@ -27,6 +27,15 @@ using hip_host_ptr = MIGRAPHX_MANAGE_PTR(void, hipHostUnregister); ...@@ -27,6 +27,15 @@ using hip_host_ptr = MIGRAPHX_MANAGE_PTR(void, hipHostUnregister);
std::string hip_error(int error) { return hipGetErrorString(static_cast<hipError_t>(error)); } std::string hip_error(int error) { return hipGetErrorString(static_cast<hipError_t>(error)); }
bool is_device_ptr(const void* ptr)
{
hipPointerAttribute_t attr;
auto status = hipPointerGetAttributes(&attr, ptr);
if(status != hipSuccess)
return false;
return attr.memoryType == hipMemoryTypeDevice;
}
std::size_t get_available_gpu_memory() std::size_t get_available_gpu_memory()
{ {
size_t free; size_t free;
...@@ -50,8 +59,8 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false) ...@@ -50,8 +59,8 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
{ {
if(sz > get_available_gpu_memory()) if(sz > get_available_gpu_memory())
MIGRAPHX_THROW("Memory not available to allocate buffer: " + std::to_string(sz)); MIGRAPHX_THROW("Memory not available to allocate buffer: " + std::to_string(sz));
void* result; void* result = nullptr;
auto status = host ? hipHostMalloc(&result, sz) : hipMalloc(&result, sz); auto status = host ? hipHostMalloc(&result, sz) : hipMalloc(&result, sz);
if(status != hipSuccess) if(status != hipSuccess)
{ {
if(host) if(host)
...@@ -59,6 +68,7 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false) ...@@ -59,6 +68,7 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
else else
return allocate_gpu(sz, true); return allocate_gpu(sz, true);
} }
assert(result != nullptr);
return hip_ptr{result}; return hip_ptr{result};
} }
...@@ -75,6 +85,8 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz) ...@@ -75,6 +85,8 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz)
{ {
gpu_sync(); gpu_sync();
std::vector<T> result(sz); std::vector<T> result(sz);
assert(not is_device_ptr(result.data()));
assert(is_device_ptr(x));
auto status = hipMemcpy(result.data(), x, sz * sizeof(T), hipMemcpyDeviceToHost); auto status = hipMemcpy(result.data(), x, sz * sizeof(T), hipMemcpyDeviceToHost);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Copy from gpu failed: " + hip_error(status)); // NOLINT MIGRAPHX_THROW("Copy from gpu failed: " + hip_error(status)); // NOLINT
...@@ -85,6 +97,8 @@ hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false) ...@@ -85,6 +97,8 @@ hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false)
{ {
gpu_sync(); gpu_sync();
auto result = allocate_gpu(sz, host); auto result = allocate_gpu(sz, host);
assert(is_device_ptr(result.get()));
assert(not is_device_ptr(x));
auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice); auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Copy to gpu failed: " + hip_error(status)); MIGRAPHX_THROW("Copy to gpu failed: " + hip_error(status));
...@@ -109,10 +123,9 @@ argument register_on_gpu(const argument& arg) ...@@ -109,10 +123,9 @@ argument register_on_gpu(const argument& arg)
{ {
auto arg_shared = arg.share(); auto arg_shared = arg.share();
auto p = share(register_on_gpu(arg_shared.data(), arg_shared.get_shape().bytes())); auto p = share(register_on_gpu(arg_shared.data(), arg_shared.get_shape().bytes()));
return {arg_shared.get_shape(), return {arg_shared.get_shape(), [p, a = std::move(arg_shared)]() mutable {
[ p, a = std::move(arg_shared) ]() mutable {return get_device_ptr(p.get()); return get_device_ptr(p.get());
} }}; // namespace gpu
}; // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
argument to_gpu(const argument& arg, bool host) argument to_gpu(const argument& arg, bool host)
......
...@@ -11,7 +11,7 @@ struct module; ...@@ -11,7 +11,7 @@ struct module;
namespace gpu { namespace gpu {
std::vector<stream_race> analyze_streams(const module& p); std::vector<stream_race> analyze_streams(const module& m);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -34,6 +34,10 @@ struct code_object_op ...@@ -34,6 +34,10 @@ struct code_object_op
f(self.output, "output")); f(self.output, "output"));
} }
value attributes() const { return {{"group", group()}}; }
std::string group() const { return "gpu::code_object::" + symbol_name; }
std::string name() const { return "gpu::code_object"; } std::string name() const { return "gpu::code_object"; }
shape compute_shape(std::vector<shape> inputs) const; shape compute_shape(std::vector<shape> inputs) const;
argument argument
......
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#include <migraphx/config.hpp>
#include <string>
#include <unordered_map>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct shape;
namespace gpu {
namespace gen {
struct vectorize
{
std::size_t size = 1;
std::size_t axis = 0;
static vectorize elements(std::size_t axis, const std::vector<shape>& inputs);
std::string str() const;
};
struct preload
{
std::vector<bool> args = {};
static preload broadcasts(std::size_t axis, const std::vector<shape>& inputs);
bool is_preloading() const;
std::string str() const;
};
std::size_t find_fast_axis(const std::vector<shape>& inputs);
std::string make_transformer_args(std::vector<std::string> transformers);
template <class... Ts>
std::string make_transformer_args(Ts... xs)
{
return make_transformer_args({xs.str()...});
}
} // namespace gen
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
...@@ -17,8 +17,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -17,8 +17,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
std::string enum_params(std::size_t count, std::string param); std::string enum_params(std::size_t count, std::string param);
std::size_t compute_global(std::size_t n, std::size_t local = 1024);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -8,6 +8,8 @@ namespace migraphx { ...@@ -8,6 +8,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct context;
struct hip_compile_options struct hip_compile_options
{ {
std::size_t global; std::size_t global;
...@@ -17,10 +19,35 @@ struct hip_compile_options ...@@ -17,10 +19,35 @@ struct hip_compile_options
std::string kernel_name = "kernel"; std::string kernel_name = "kernel";
std::string params = ""; std::string params = "";
std::vector<shape> virtual_inputs = {}; std::vector<shape> virtual_inputs = {};
/**
* @brief Set the launch parameters but allow v to override the values
*
* @param v A value class which can have a "global" and/or "local" keys to override the default
* global and local
* @param compute_global A function used to compute the global based on the local
* @param default_local The defaul local to use if its missing from the v parameter
*/
void set_launch_params(const value& v,
const std::function<std::size_t(std::size_t local)>& compute_global,
std::size_t default_local = 1024);
void
set_launch_params(const value& v, std::size_t default_global, std::size_t default_local = 1024)
{
set_launch_params(
v, [=](auto) { return default_global; }, default_local);
}
}; };
/// Compute global for n elements, but max out on target-specific upper limit
std::function<std::size_t(std::size_t local)>
compute_global_for(context& ctx, std::size_t n, std::size_t over = 1);
operation compile_hip_code_object(const std::string& content, hip_compile_options options); operation compile_hip_code_object(const std::string& content, hip_compile_options options);
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size = 1024);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct context;
operation compile_pointwise(context& ctx,
const std::vector<shape>& inputs,
const std::string& lambda,
const std::string& preamble = "");
operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, module m);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP
#ifndef MIGRAPHX_GUARD_GPU_COMPILER_HPP
#define MIGRAPHX_GUARD_GPU_COMPILER_HPP
#include <migraphx/config.hpp>
#include <migraphx/auto_register.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
using compiler_replace = std::function<void(module& m, instruction_ref ins)>;
using compiler_compile = std::function<compiler_replace(context&, instruction_ref, operation)>;
using compiler_compile_op =
std::function<operation(context&, const std::vector<shape>& inputs, const value&)>;
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop);
bool has_compiler_for(const std::string& name);
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op);
operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v);
template <class T>
void register_compiler()
{
T c;
for(auto&& name : c.names())
{
register_compiler(
name,
[=](auto&&... xs) { return c.compile(std::forward<decltype(xs)>(xs)...); },
[=](auto&&... xs) { return c.compile_op(std::forward<decltype(xs)>(xs)...); });
}
}
struct register_compiler_action
{
template <class T>
static void apply()
{
register_compiler<T>();
}
};
template <class T>
using auto_register_compiler = auto_register<register_compiler_action, T>;
template <class Derived>
struct compiler : auto_register_compiler<Derived>
{
auto replace(const operation& op) const
{
return
[=](module& m, instruction_ref ins) { m.replace_instruction(ins, op, ins->inputs()); };
}
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILER_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment