"vscode:/vscode.git/clone" did not exist on "658cdab084dccef1ac45adebdaddc6ce5c3b6e7c"
Unverified Commit edc7be5c authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Add compute_method for the experimental custom op (#1194)

Adds compute_method for the experimental custom ops.
Adds a test for the same using HIP APIs.
Depends on #1183
Solves #1101
parent f5760e21
...@@ -236,6 +236,11 @@ void print_program(const program& p) { std::cout << p << std::endl; } ...@@ -236,6 +236,11 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; } void print_module(const module& m) { std::cout << m << std::endl; }
migraphx::instruction_ref add_allocation(module& m, const migraphx::shape& s)
{
return m.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}}), {});
}
struct experimental_custom_op struct experimental_custom_op
{ {
std::string name; std::string name;
...@@ -260,7 +265,12 @@ struct custom_operation ...@@ -260,7 +265,12 @@ struct custom_operation
return op.compute_shape(std::move(inputs)); return op.compute_shape(std::move(inputs));
} }
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); } // TODO: Compute method with module_args
argument
compute(migraphx::context ctx, migraphx::shape output_shape, std::vector<argument> inputs) const
{
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
}
}; };
template <class CustomOp> template <class CustomOp>
...@@ -577,6 +587,24 @@ struct migraphx_experimental_custom_op ...@@ -577,6 +587,24 @@ struct migraphx_experimental_custom_op
manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete> manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete>
object_ptr = nullptr; object_ptr = nullptr;
migraphx::experimental_custom_op xobject; migraphx::experimental_custom_op xobject;
migraphx_experimental_custom_op_compute compute_f = nullptr;
migraphx::argument compute(migraphx::context ctx,
migraphx::shape output,
std::vector<migraphx::argument> inputs) const
{
std::remove_pointer_t<migraphx_argument_t> out;
if(compute_f == nullptr)
throw std::runtime_error("compute function is missing.");
auto api_error_result = compute_f(&out,
object_ptr.data,
object_cast<migraphx_context_t>(&(ctx)),
object_cast<migraphx_shape_t>(&(output)),
object_cast<migraphx_arguments_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute.");
return (&out)->object;
}
migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr; migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr;
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{ {
...@@ -1141,6 +1169,21 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou ...@@ -1141,6 +1169,21 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t s)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(s == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer");
*out = allocate<migraphx_instruction_t>(
migraphx::add_allocation((module->object), (s->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program) extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{ {
auto api_error_result = migraphx::try_([&] { destroy((program)); }); auto api_error_result = migraphx::try_([&] { destroy((program)); });
...@@ -1772,6 +1815,14 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi ...@@ -1772,6 +1815,14 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_compute input)
{
auto api_error_result = migraphx::try_([&] { (obj)->compute_f = (input); });
return api_error_result;
}
extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape( extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input) migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input)
{ {
......
...@@ -129,6 +129,12 @@ typedef const struct migraphx_context* const_migraphx_context_t; ...@@ -129,6 +129,12 @@ typedef const struct migraphx_context* const_migraphx_context_t;
typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t; typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t;
typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t; typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t;
typedef migraphx_status (*migraphx_experimental_custom_op_compute)(migraphx_argument_t out,
void* obj,
migraphx_context_t ctx,
migraphx_shape_t output,
migraphx_arguments_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out, typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
void* obj, void* obj,
migraphx_shapes_t inputs); migraphx_shapes_t inputs);
...@@ -295,6 +301,10 @@ migraphx_status migraphx_module_add_return(migraphx_instruction_t* out, ...@@ -295,6 +301,10 @@ migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
migraphx_instructions_t args); migraphx_instructions_t args);
migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t s);
migraphx_status migraphx_program_destroy(migraphx_program_t program); migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output, migraphx_status migraphx_program_assign_to(migraphx_program_t output,
...@@ -477,6 +487,10 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi ...@@ -477,6 +487,10 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
migraphx_experimental_custom_op_delete d, migraphx_experimental_custom_op_delete d,
const char* name); const char* name);
migraphx_status
migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_compute input);
migraphx_status migraphx_experimental_custom_op_set_compute_shape( migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input); migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
......
...@@ -401,11 +401,14 @@ struct interface_base : Base ...@@ -401,11 +401,14 @@ struct interface_base : Base
return x; return x;
} }
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
template <class T> template <class T>
auto auto_convert_param(rank<1>, T x) -> decltype(as_handle<T>{x}) auto auto_convert_param(rank<1>, T x) -> decltype(as_handle<T>{x})
{ {
return as_handle<T>{x}; return as_handle<T>{x};
} }
#pragma GCC diagnostic pop
template <class T> template <class T>
auto auto_convert_param(rank<2>, T x) -> decltype(as_handle<T>{x, borrow{}}) auto auto_convert_param(rank<2>, T x) -> decltype(as_handle<T>{x, borrow{}})
...@@ -565,6 +568,14 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -565,6 +568,14 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return pout; return pout;
} }
template <typename T>
std::vector<T> as_vector() const
{
size_t vector_len = this->get_shape().bytes() / sizeof(T);
T* buffer_ptr = reinterpret_cast<T*>(this->data());
return {buffer_ptr, buffer_ptr + vector_len};
}
/// Generate an argument using random data /// Generate an argument using random data
static argument generate(shape ps, size_t pseed = 0) static argument generate(shape ps, size_t pseed = 0)
{ {
...@@ -802,13 +813,20 @@ struct module ...@@ -802,13 +813,20 @@ struct module
return instruction(ret_ins, own{}); return instruction(ret_ins, own{});
} }
instruction add_allocation(const migraphx::shape& s)
{
migraphx_instruction_t ret_ins;
call(&migraphx_module_add_allocation, &ret_ins, mm.get(), s.get_handle_ptr());
return instruction(ret_ins, own{});
}
migraphx_module_t get_handle_ptr() const { return mm.get(); } migraphx_module_t get_handle_ptr() const { return mm.get(); }
private: private:
std::shared_ptr<migraphx_module> mm; std::shared_ptr<migraphx_module> mm;
}; };
struct context struct context : handle_lookup<context, migraphx_context>
{ {
context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {} context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {}
...@@ -1178,6 +1196,7 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op ...@@ -1178,6 +1196,7 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
struct experimental_custom_op_base struct experimental_custom_op_base
{ {
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual argument compute(context ctx, shape output, arguments inputs) const = 0;
virtual shape compute_shape(shapes inputs) const = 0; virtual shape compute_shape(shapes inputs) const = 0;
virtual ~experimental_custom_op_base() = default; virtual ~experimental_custom_op_base() = default;
}; };
...@@ -1189,6 +1208,7 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental ...@@ -1189,6 +1208,7 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
{ {
this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str()); this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape); MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape);
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute);
} }
void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); } void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); }
......
...@@ -244,6 +244,10 @@ def module(h): ...@@ -244,6 +244,10 @@ def module(h):
h.method('add_return', h.method('add_return',
api.params(args='std::vector<migraphx::instruction_ref>'), api.params(args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref') returns='migraphx::instruction_ref')
h.method('add_allocation',
api.params(s='const migraphx::shape&'),
invoke='migraphx::add_allocation($@)',
returns='migraphx::instruction_ref')
@auto_handle() @auto_handle()
...@@ -436,6 +440,11 @@ def context(h): ...@@ -436,6 +440,11 @@ def context(h):
'migraphx::experimental_custom_op') 'migraphx::experimental_custom_op')
def experimental_custom_op(h): def experimental_custom_op(h):
h.constructor('create', api.params(name='const char*')) h.constructor('create', api.params(name='const char*'))
h.virtual('compute',
api.params(ctx='migraphx::context',
output='migraphx::shape',
inputs='std::vector<migraphx::argument>'),
returns='migraphx::argument')
h.virtual('compute_shape', h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'), api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape') returns='migraphx::shape')
......
...@@ -34,16 +34,18 @@ endfunction() ...@@ -34,16 +34,18 @@ endfunction()
add_api_test(array_base test_array_base.cpp ${TEST_ONNX_DIR}) add_api_test(array_base test_array_base.cpp ${TEST_ONNX_DIR})
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR}) add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR}) add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR}) add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR}) add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR})
# GPU-based tests # GPU-based tests
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR}) add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_gpu migraphx_gpu) target_link_libraries(test_api_gpu migraphx_gpu)
add_api_test(custom_op_gpu test_custom_op_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_custom_op_gpu migraphx_gpu)
endif() endif()
...@@ -21,26 +21,66 @@ ...@@ -21,26 +21,66 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <algorithm>
#include <cmath>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> #include <migraphx/migraphx.hpp>
#include "test.hpp" #include "test.hpp"
struct simple_custom_op final : migraphx::experimental_custom_op_base struct sigmoid_custom_op final : migraphx::experimental_custom_op_base
{ {
virtual std::string name() const override { return "simple_custom_op"; } virtual std::string name() const override { return "sigmoid_custom_op"; }
virtual migraphx::argument
compute(migraphx::context, migraphx::shape, migraphx::arguments inputs) const override
{
auto* output_ptr = reinterpret_cast<float*>(inputs[1].data());
auto input_vec = inputs[0].as_vector<float>();
std::transform(input_vec.begin(), input_vec.end(), output_ptr, [](auto x) {
return 1.f / (1.f + std::exp(-x));
});
return inputs[1];
}
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{ {
return inputs.front(); CHECK(inputs.size() == 2);
CHECK(inputs[0].lengths().size() == 1);
CHECK(inputs[0].type() == migraphx_shape_float_type);
CHECK(bool{inputs[0] == inputs[1]});
return inputs.back();
} }
}; };
TEST_CASE(register_custom_op) TEST_CASE(register_custom_op)
{ {
simple_custom_op simple_op; sigmoid_custom_op sigmoid_op;
migraphx::register_experimental_custom_op(simple_op); migraphx::register_experimental_custom_op(sigmoid_op);
auto op = migraphx::operation("sigmoid_custom_op");
EXPECT(op.name() == "sigmoid_custom_op");
}
auto op = migraphx::operation("simple_custom_op"); TEST_CASE(run_sigmoid_custom_op)
EXPECT(op.name() == "simple_custom_op"); {
migraphx::program p;
migraphx::shape s{migraphx_shape_float_type, {12}};
migraphx::module m = p.get_main_module();
auto x = m.add_parameter("x", s);
auto alloc = m.add_allocation(s);
auto custom_kernel = m.add_instruction(migraphx::operation("sigmoid_custom_op"), {x, alloc});
p.compile(migraphx::target("ref"));
// run program
migraphx::program_parameters pp;
auto param_shapes = p.get_parameter_shapes();
migraphx::argument input_arg = migraphx::argument::generate(param_shapes["x"]);
pp.add("x", input_arg);
auto results = p.eval(pp);
auto result = results[0];
auto expected_result = input_arg.as_vector<float>();
std::transform(expected_result.begin(),
expected_result.end(),
expected_result.begin(),
[](auto y) { return 1.f / (1.f + std::exp(-y)); });
EXPECT(bool{result == migraphx::argument(s, expected_result.data())});
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
#define MIGRAPHX_HIP_ASSERT(x) (EXPECT(x == hipSuccess))
struct simple_custom_op final : migraphx::experimental_custom_op_base
{
virtual std::string name() const override { return "simple_custom_op"; }
virtual migraphx::argument
compute(migraphx::context ctx, migraphx::shape, migraphx::arguments inputs) const override
{
// sets first half size_bytes of the input 0, and rest of the half bytes are copied.
int* h_output = nullptr;
auto* d_output = reinterpret_cast<int*>(inputs[0].data());
auto input_bytes = inputs[0].get_shape().bytes();
auto* output_ptr = inputs[1].data();
auto copy_bytes = input_bytes / 2;
MIGRAPHX_HIP_ASSERT(hipSetDevice(0));
MIGRAPHX_HIP_ASSERT(hipHostMalloc(&h_output, input_bytes));
MIGRAPHX_HIP_ASSERT(hipMemcpyAsync(
h_output, d_output, input_bytes, hipMemcpyDeviceToHost, ctx.get_queue<hipStream_t>()));
MIGRAPHX_HIP_ASSERT(hipDeviceSynchronize());
MIGRAPHX_HIP_ASSERT(hipMemset(h_output, 0, copy_bytes));
MIGRAPHX_HIP_ASSERT(hipDeviceSynchronize());
MIGRAPHX_HIP_ASSERT(hipMemcpy(output_ptr, h_output, input_bytes, hipMemcpyHostToDevice));
MIGRAPHX_HIP_ASSERT(hipDeviceSynchronize());
MIGRAPHX_HIP_ASSERT(hipHostFree(h_output));
return inputs[1];
}
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
return inputs.back();
}
};
TEST_CASE(run_simple_custom_op)
{
simple_custom_op simple_op;
migraphx::register_experimental_custom_op(simple_op);
migraphx::program p;
migraphx::shape s{migraphx_shape_int32_type, {4, 3}};
migraphx::module m = p.get_main_module();
auto x = m.add_parameter("x", s);
auto neg = m.add_instruction(migraphx::operation("neg"), x);
auto alloc = m.add_allocation(s);
auto custom_kernel = m.add_instruction(migraphx::operation("simple_custom_op"), {neg, alloc});
auto relu = m.add_instruction(migraphx::operation("relu"), custom_kernel);
m.add_return({relu});
migraphx::compile_options options;
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp;
std::vector<int> x_data(12, -3);
pp.add("x", migraphx::argument(s, x_data.data()));
auto results = p.eval(pp);
auto result = results[0];
auto result_vec = result.as_vector<int>();
std::vector<int> expected_result(12, 0);
std::fill(expected_result.begin() + 6, expected_result.end(), 3);
EXPECT(bool{result == migraphx::argument(s, expected_result.data())});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -92,13 +92,7 @@ TEST_CASE(if_pl_test) ...@@ -92,13 +92,7 @@ TEST_CASE(if_pl_test)
auto outputs = p.eval(pp); auto outputs = p.eval(pp);
auto output = outputs[0]; auto output = outputs[0];
auto lens = output.get_shape().lengths(); return output.as_vector<float>();
auto elem_num =
std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>());
float* data_ptr = reinterpret_cast<float*>(output.data());
std::vector<float> ret(data_ptr, data_ptr + elem_num);
return ret;
}; };
// then branch // then branch
...@@ -141,18 +135,11 @@ TEST_CASE(loop_test) ...@@ -141,18 +135,11 @@ TEST_CASE(loop_test)
auto outputs = p.eval(pp); auto outputs = p.eval(pp);
auto output = outputs[0]; auto output = outputs[0];
auto lens = output.get_shape().lengths();
auto elem_num =
std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>());
float* data_ptr = reinterpret_cast<float*>(output.data());
std::vector<std::vector<float>> ret; std::vector<std::vector<float>> ret;
ret.push_back({data_ptr, data_ptr + elem_num}); ret.push_back(output.as_vector<float>());
output = outputs[1]; output = outputs[1];
lens = output.get_shape().lengths(); ret.push_back(output.as_vector<float>());
elem_num = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<std::size_t>());
data_ptr = reinterpret_cast<float*>(output.data());
ret.push_back({data_ptr, data_ptr + elem_num});
return ret; return ret;
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> #include <migraphx/migraphx.hpp>
#include "test.hpp" #include "test.hpp"
......
...@@ -236,6 +236,11 @@ void print_program(const program& p) { std::cout << p << std::endl; } ...@@ -236,6 +236,11 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; } void print_module(const module& m) { std::cout << m << std::endl; }
migraphx::instruction_ref add_allocation(module& m, const migraphx::shape& s)
{
return m.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}}), {});
}
struct experimental_custom_op struct experimental_custom_op
{ {
std::string name; std::string name;
...@@ -260,7 +265,12 @@ struct custom_operation ...@@ -260,7 +265,12 @@ struct custom_operation
return op.compute_shape(std::move(inputs)); return op.compute_shape(std::move(inputs));
} }
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); } // TODO: Compute method with module_args
argument
compute(migraphx::context ctx, migraphx::shape output_shape, std::vector<argument> inputs) const
{
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
}
}; };
template <class CustomOp> template <class CustomOp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment