Unverified Commit 7220dd18 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add code object custom op (#744)



* Add code object op

* Formattting

* Add more value tests

* Formatting

* Fix from_value conversion from binary

* Formatting

* Dont use offload copy

* Remove iostream header
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent b90fcb18
...@@ -174,6 +174,12 @@ void replace(Range&& r, const T& old, const T& new_x) ...@@ -174,6 +174,12 @@ void replace(Range&& r, const T& old, const T& new_x)
std::replace(r.begin(), r.end(), old, new_x); std::replace(r.begin(), r.end(), old, new_x);
} }
template <class R1, class R2>
bool equal(R1&& r1, R2&& r2)
{
return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end());
}
template <class R> template <class R>
using range_value = std::decay_t<decltype(*std::declval<R>().begin())>; using range_value = std::decay_t<decltype(*std::declval<R>().begin())>;
......
...@@ -130,8 +130,25 @@ auto from_value_impl(rank<1>, const value& v, T& x) ...@@ -130,8 +130,25 @@ auto from_value_impl(rank<1>, const value& v, T& x)
x.insert(x.end(), from_value<typename T::value_type>(e)); x.insert(x.end(), from_value<typename T::value_type>(e));
} }
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<typename T::value_type>{})>
auto from_value_impl(rank<2>, const value& v, T& x)
-> decltype(x.insert(x.end(), *x.begin()), void())
{
x.clear();
if(v.is_binary())
{
for(auto&& e : v.get_binary())
x.insert(x.end(), e);
}
else
{
for(auto&& e : v)
x.insert(x.end(), from_value<typename T::value_type>(e));
}
}
template <class T> template <class T>
auto from_value_impl(rank<2>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void()) auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void())
{ {
x.clear(); x.clear();
for(auto&& e : v) for(auto&& e : v)
...@@ -139,7 +156,7 @@ auto from_value_impl(rank<2>, const value& v, T& x) -> decltype(x.insert(*x.begi ...@@ -139,7 +156,7 @@ auto from_value_impl(rank<2>, const value& v, T& x) -> decltype(x.insert(*x.begi
} }
template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})> template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
void from_value_impl(rank<3>, const value& v, T& x) void from_value_impl(rank<4>, const value& v, T& x)
{ {
reflect_each(x, [&](auto& y, const std::string& name) { reflect_each(x, [&](auto& y, const std::string& name) {
using type = std::decay_t<decltype(y)>; using type = std::decay_t<decltype(y)>;
...@@ -149,27 +166,27 @@ void from_value_impl(rank<3>, const value& v, T& x) ...@@ -149,27 +166,27 @@ void from_value_impl(rank<3>, const value& v, T& x)
} }
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})>
void from_value_impl(rank<4>, const value& v, T& x) void from_value_impl(rank<5>, const value& v, T& x)
{ {
x = v.to<T>(); x = v.to<T>();
} }
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
void from_value_impl(rank<5>, const value& v, T& x) void from_value_impl(rank<6>, const value& v, T& x)
{ {
x = v.to<T>(); x = v.to<T>();
} }
inline void from_value_impl(rank<6>, const value& v, std::string& x) { x = v.to<std::string>(); } inline void from_value_impl(rank<7>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T> template <class T>
auto from_value_impl(rank<7>, const value& v, T& x) -> decltype(x.from_value(v), void()) auto from_value_impl(rank<8>, const value& v, T& x) -> decltype(x.from_value(v), void())
{ {
x.from_value(v); x.from_value(v);
} }
template <class T> template <class T>
auto from_value_impl(rank<8>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void()) auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{ {
migraphx_from_value(v, x); migraphx_from_value(v, x);
} }
...@@ -185,7 +202,7 @@ value to_value(const T& x) ...@@ -185,7 +202,7 @@ value to_value(const T& x)
template <class T> template <class T>
void from_value(const value& v, T& x) void from_value(const value& v, T& x)
{ {
detail::from_value_impl(rank<8>{}, v, x); detail::from_value_impl(rank<9>{}, v, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
File mode changed from 100644 to 100755
...@@ -103,6 +103,7 @@ add_library(migraphx_gpu ...@@ -103,6 +103,7 @@ add_library(migraphx_gpu
allocation_model.cpp allocation_model.cpp
argmax.cpp argmax.cpp
argmin.cpp argmin.cpp
code_object_op.cpp
eliminate_workspace.cpp eliminate_workspace.cpp
fuse_ops.cpp fuse_ops.cpp
hip.cpp hip.cpp
......
#include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_REGISTER_OP(code_object_op);
shape code_object_op::compute_shape(std::vector<shape> inputs) const
{
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](const shape& s) {
return s.normalize_standard();
});
auto einputs = expected_inputs;
std::transform(einputs.begin(), einputs.end(), einputs.begin(), [](const shape& s) {
return s.normalize_standard();
});
if(einputs != inputs)
MIGRAPHX_THROW("Input shapes have changed: [" + to_string_range(einputs) + "] -> [" +
to_string_range(inputs) + "]");
return output;
}
argument
code_object_op::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
std::vector<void*> kargs(args.size());
std::transform(
args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); });
k.launch(ctx.get_stream().get(), global, local, std::move(kargs));
return args.back();
}
void code_object_op::finalize(context&, const shape&, const std::vector<shape>&)
{
assert(not code_object.empty());
k = kernel(code_object, symbol_name);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CODE_OBJECT_OP_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CODE_OBJECT_OP_HPP
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/gpu/kernel.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct code_object_op
{
value::binary code_object;
std::string symbol_name;
std::size_t global;
std::size_t local;
std::vector<shape> expected_inputs;
shape output;
kernel k;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.code_object, "code_object"),
f(self.symbol_name, "symbol_name"),
f(self.global, "global"),
f(self.local, "local"),
f(self.expected_inputs, "expected_inputs"),
f(self.output, "output"));
}
std::string name() const { return "gpu::code_object"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void finalize(context&, const shape&, const std::vector<shape>&);
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
friend std::ostream& operator<<(std::ostream& os, const code_object_op& op)
{
os << op.name() << "[";
os << "code_object=" << op.code_object.size() << ",";
os << "symbol_name=" << op.symbol_name << ",";
os << "global=" << op.global << ",";
os << "local=" << op.local << ",";
return os;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -17,16 +17,28 @@ struct kernel_impl; ...@@ -17,16 +17,28 @@ struct kernel_impl;
struct kernel struct kernel
{ {
kernel() = default; kernel() = default;
kernel(const std::vector<char>& image, const std::string& name); kernel(const char* image, const std::string& name);
template <class T, MIGRAPHX_REQUIRES(sizeof(T) == 1)>
kernel(const std::vector<T>& image, const std::string& name)
: kernel(reinterpret_cast<const char*>(image.data()), name)
{
}
void launch(hipStream_t stream,
std::size_t global,
std::size_t local,
const std::vector<kernel_argument>& args) const;
void launch(hipStream_t stream, void launch(hipStream_t stream,
std::size_t global, std::size_t global,
std::size_t local, std::size_t local,
const std::vector<kernel_argument>& args); std::vector<void*> args) const;
auto launch(hipStream_t stream, std::size_t global, std::size_t local) auto launch(hipStream_t stream, std::size_t global, std::size_t local) const
{ {
return [=](auto&&... xs) { launch(stream, global, local, {xs...}); }; return [=](auto&&... xs) {
launch(stream, global, local, std::vector<kernel_argument>{xs...});
};
} }
private: private:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/gpu/pack_args.hpp> #include <migraphx/gpu/pack_args.hpp>
#include <cassert>
// extern declare the function since hip/hip_ext.h header is broken // extern declare the function since hip/hip_ext.h header is broken
extern hipError_t hipExtModuleLaunchKernel(hipFunction_t, // NOLINT extern hipError_t hipExtModuleLaunchKernel(hipFunction_t, // NOLINT
...@@ -33,18 +34,17 @@ struct kernel_impl ...@@ -33,18 +34,17 @@ struct kernel_impl
hipFunction_t fun = nullptr; hipFunction_t fun = nullptr;
}; };
hip_module_ptr load_module(const std::vector<char>& image) hip_module_ptr load_module(const char* image)
{ {
hipModule_t raw_m; hipModule_t raw_m;
auto status = hipModuleLoadData(&raw_m, image.data()); auto status = hipModuleLoadData(&raw_m, image);
hip_module_ptr m{raw_m}; hip_module_ptr m{raw_m};
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to load module: " + hip_error(status)); MIGRAPHX_THROW("Failed to load module: " + hip_error(status));
return m; return m;
} }
kernel::kernel(const std::vector<char>& image, const std::string& name) kernel::kernel(const char* image, const std::string& name) : impl(std::make_shared<kernel_impl>())
: impl(std::make_shared<kernel_impl>())
{ {
impl->module = load_module(image); impl->module = load_module(image);
auto status = hipModuleGetFunction(&impl->fun, impl->module.get(), name.c_str()); auto status = hipModuleGetFunction(&impl->fun, impl->module.get(), name.c_str());
...@@ -52,42 +52,56 @@ kernel::kernel(const std::vector<char>& image, const std::string& name) ...@@ -52,42 +52,56 @@ kernel::kernel(const std::vector<char>& image, const std::string& name)
MIGRAPHX_THROW("Failed to get function: " + name + ": " + hip_error(status)); MIGRAPHX_THROW("Failed to get function: " + name + ": " + hip_error(status));
} }
void kernel::launch(hipStream_t stream, void launch_kernel(hipFunction_t fun,
hipStream_t stream,
std::size_t global, std::size_t global,
std::size_t local, std::size_t local,
const std::vector<kernel_argument>& args) void* kernargs,
std::size_t size)
{ {
std::vector<char> kernargs = pack_args(args);
std::size_t size = kernargs.size();
void* config[] = { void* config[] = {
// HIP_LAUNCH_PARAM_* are macros that do horrible things // HIP_LAUNCH_PARAM_* are macros that do horrible things
#ifdef MIGRAPHX_USE_CLANG_TIDY #ifdef MIGRAPHX_USE_CLANG_TIDY
nullptr, kernargs.data(), nullptr, &size, nullptr nullptr, kernargs, nullptr, &size, nullptr
#else #else
HIP_LAUNCH_PARAM_BUFFER_POINTER, HIP_LAUNCH_PARAM_BUFFER_POINTER,
kernargs.data(), kernargs,
HIP_LAUNCH_PARAM_BUFFER_SIZE, HIP_LAUNCH_PARAM_BUFFER_SIZE,
&size, &size,
HIP_LAUNCH_PARAM_END HIP_LAUNCH_PARAM_END
#endif #endif
}; };
auto status = hipExtModuleLaunchKernel(impl->fun, auto status = hipExtModuleLaunchKernel(
global, fun, global, 1, 1, local, 1, 1, 0, stream, nullptr, reinterpret_cast<void**>(&config));
1,
1,
local,
1,
1,
0,
stream,
nullptr,
reinterpret_cast<void**>(&config));
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to launch kernel: " + hip_error(status)); MIGRAPHX_THROW("Failed to launch kernel: " + hip_error(status));
} }
void kernel::launch(hipStream_t stream,
std::size_t global,
std::size_t local,
std::vector<void*> args) const
{
assert(impl != nullptr);
void* kernargs = args.data();
std::size_t size = args.size() * sizeof(void*);
launch_kernel(impl->fun, stream, global, local, kernargs, size);
}
void kernel::launch(hipStream_t stream,
std::size_t global,
std::size_t local,
const std::vector<kernel_argument>& args) const
{
assert(impl != nullptr);
std::vector<char> kernargs = pack_args(args);
std::size_t size = kernargs.size();
launch_kernel(impl->fun, stream, global, local, kernargs.data(), size);
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#include <test.hpp> #include <test.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/program.hpp>
#include <migraphx/gpu/kernel.hpp> #include <migraphx/gpu/kernel.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/compile_hip.hpp> #include <migraphx/gpu/compile_hip.hpp>
...@@ -21,6 +25,23 @@ int main() {} ...@@ -21,6 +25,23 @@ int main() {}
)migraphx"; )migraphx";
// NOLINTNEXTLINE
const std::string add_2s_binary = R"migraphx(
#include <hip/hip_runtime.h>
extern "C" {
__global__ void add_2(std::int32_t* x, std::int32_t* y)
{
int num = threadIdx.x + blockDim.x * blockIdx.x;
y[num] = x[num] + 2;
}
}
int main() {}
)migraphx";
migraphx::gpu::src_file make_src_file(const std::string& name, const std::string& content) migraphx::gpu::src_file make_src_file(const std::string& name, const std::string& content)
{ {
return {name, std::make_pair(content.data(), content.data() + content.size())}; return {name, std::make_pair(content.data(), content.data() + content.size())};
...@@ -52,4 +73,37 @@ TEST_CASE(simple_compile_hip) ...@@ -52,4 +73,37 @@ TEST_CASE(simple_compile_hip)
EXPECT(migraphx::all_of(data, [](auto x) { return x == 2; })); EXPECT(migraphx::all_of(data, [](auto x) { return x == 2; }));
} }
TEST_CASE(code_object_hip)
{
auto binaries = migraphx::gpu::compile_hip_src(
{make_src_file("main.cpp", add_2s_binary)}, "", get_device_name());
EXPECT(binaries.size() == 1);
migraphx::shape input{migraphx::shape::int32_type, {5}};
std::vector<migraphx::shape> expected_inputs = {input, input};
auto co = migraphx::make_op("gpu::code_object",
{{"code_object", migraphx::value::binary{binaries.front()}},
{"symbol_name", "add_2"},
{"global", input.elements()},
{"local", 1024},
{"expected_inputs", migraphx::to_value(expected_inputs)},
{"output", migraphx::to_value(input)}});
migraphx::program p;
auto* mm = p.get_main_module();
auto input_literal = migraphx::generate_literal(input);
auto output_literal = migraphx::transform(input_literal, [](auto x) { return x + 2; });
auto x = mm->add_literal(input_literal);
auto y = mm->add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(input)}}));
mm->add_instruction(co, x, y);
migraphx::compile_options options;
p.compile(migraphx::gpu::target{}, options);
auto result = migraphx::gpu::from_gpu(p.eval({}).front());
EXPECT(result == output_literal.get_argument());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <test.hpp> #include <test.hpp>
#include <numeric>
struct empty_type struct empty_type
{ {
}; };
...@@ -100,4 +102,15 @@ TEST_CASE(serialize_empty_struct) ...@@ -100,4 +102,15 @@ TEST_CASE(serialize_empty_struct)
EXPECT(v.at("a").to<int>() == 1); EXPECT(v.at("a").to<int>() == 1);
} }
TEST_CASE(from_value_binary)
{
std::vector<std::uint8_t> data(10);
std::iota(data.begin(), data.end(), 0);
migraphx::value v = migraphx::value::binary{data};
auto out = migraphx::from_value<migraphx::value::binary>(v);
EXPECT(out == data);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp> #include <test.hpp>
enum class enum_type enum class enum_type
...@@ -770,8 +771,31 @@ TEST_CASE(value_binary) ...@@ -770,8 +771,31 @@ TEST_CASE(value_binary)
std::iota(data.begin(), data.end(), 0); std::iota(data.begin(), data.end(), 0);
v = migraphx::value::binary{data}; v = migraphx::value::binary{data};
EXPECT(v.is_binary()); EXPECT(v.is_binary());
EXPECT(v.get_binary().size() == data.size());
EXPECT(v.get_binary() == data); EXPECT(v.get_binary() == data);
EXPECT(v.get_key().empty()); EXPECT(v.get_key().empty());
} }
TEST_CASE(value_binary_object)
{
std::vector<std::uint8_t> data(20);
std::iota(data.begin(), data.end(), 0);
migraphx::value v = {{"data", migraphx::value::binary{data}}};
EXPECT(v["data"].is_binary());
EXPECT(v["data"].get_binary().size() == data.size());
EXPECT(v["data"].get_binary() == data);
}
TEST_CASE(value_binary_object_conv)
{
std::vector<std::int8_t> data(20);
std::iota(data.begin(), data.end(), 0);
migraphx::value v = {{"data", migraphx::value::binary{data}}};
EXPECT(v["data"].is_binary());
EXPECT(v["data"].get_binary().size() == data.size());
EXPECT(migraphx::equal(v["data"].get_binary(), data));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment