Unverified Commit 0ff00ef6 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Update C/C++ API for dynamic batch (#1712)

Relies on Removed split_single_dyn_dim compile flag #1711
Exposes dynamic_dimension as a opaque object with dynamic_dimensions and optimals
Exposes ONNX dyn_input_dims and default_dyn_dim to run with dynamic batch
Updates api.py to be able to create objects from aggregate initialization (used for dynamic_dimension)
Uses offload copy for now
parent a8ace295
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <migraphx/execution_environment.hpp> #include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
...@@ -145,6 +146,11 @@ void set_default_dim_value(onnx_options& options, size_t value) ...@@ -145,6 +146,11 @@ void set_default_dim_value(onnx_options& options, size_t value)
options.default_dim_value = value; options.default_dim_value = value;
} }
void set_default_dyn_dim_value(onnx_options& options, const shape::dynamic_dimension& dd)
{
options.default_dyn_dim_value = dd;
}
void set_default_loop_iterations(onnx_options& options, int64_t value) void set_default_loop_iterations(onnx_options& options, int64_t value)
{ {
options.max_loop_iterations = value; options.max_loop_iterations = value;
...@@ -161,6 +167,13 @@ void set_input_parameter_shape(onnx_options& options, ...@@ -161,6 +167,13 @@ void set_input_parameter_shape(onnx_options& options,
options.map_input_dims[std::string(name)] = std::move(dims); options.map_input_dims[std::string(name)] = std::move(dims);
} }
void set_dyn_input_parameter_shape(onnx_options& options,
const char* name,
std::vector<shape::dynamic_dimension> dyn_dims)
{
options.map_dyn_input_dims[std::string(name)] = std::move(dyn_dims);
}
void set_input_parameter_shape(tf_options& options, const char* name, std::vector<std::size_t> dims) void set_input_parameter_shape(tf_options& options, const char* name, std::vector<std::size_t> dims)
{ {
options.map_input_dims[std::string(name)] = std::move(dims); options.map_input_dims[std::string(name)] = std::move(dims);
...@@ -187,6 +200,12 @@ std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& ...@@ -187,6 +200,12 @@ std::vector<const char*> get_names(const std::unordered_map<std::string, Value>&
return result; return result;
} }
template <class T>
std::set<T> make_set(const T* x, std::size_t n)
{
return {x, x + n};
}
void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names) void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
{ {
if(names.empty()) if(names.empty())
...@@ -346,6 +365,9 @@ const Target* object_cast(const U* x) ...@@ -346,6 +365,9 @@ const Target* object_cast(const U* x)
template <class T, class... Ts, class Target = std::remove_pointer_t<T>> template <class T, class... Ts, class Target = std::remove_pointer_t<T>>
Target* allocate(Ts&&... xs) Target* allocate(Ts&&... xs)
{ {
if constexpr(std::is_aggregate<Target>{})
return new Target{std::forward<Ts>(xs)...}; // NOLINT
else
return new Target(std::forward<Ts>(xs)...); // NOLINT return new Target(std::forward<Ts>(xs)...); // NOLINT
} }
...@@ -409,6 +431,39 @@ struct manage_generic_ptr ...@@ -409,6 +431,39 @@ struct manage_generic_ptr
D deleter = nullptr; D deleter = nullptr;
}; };
extern "C" struct migraphx_optimals;
struct migraphx_optimals
{
template <class... Ts>
migraphx_optimals(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::set<size_t> object;
};
extern "C" struct migraphx_dynamic_dimension;
struct migraphx_dynamic_dimension
{
template <class... Ts>
migraphx_dynamic_dimension(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::shape::dynamic_dimension object;
};
extern "C" struct migraphx_dynamic_dimensions;
struct migraphx_dynamic_dimensions
{
template <class... Ts>
migraphx_dynamic_dimensions(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
std::vector<migraphx::shape::dynamic_dimension> object;
};
extern "C" struct migraphx_shape; extern "C" struct migraphx_shape;
struct migraphx_shape struct migraphx_shape
{ {
...@@ -736,6 +791,152 @@ struct migraphx_experimental_custom_op ...@@ -736,6 +791,152 @@ struct migraphx_experimental_custom_op
} }
}; };
extern "C" migraphx_status migraphx_optimals_destroy(migraphx_optimals_t optimals)
{
auto api_error_result = migraphx::try_([&] { destroy((optimals)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_optimals_assign_to(migraphx_optimals_t output,
const_migraphx_optimals_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_optimals_create(migraphx_optimals_t* optimals, const size_t* ptr, size_t size)
{
auto api_error_result = migraphx::try_([&] {
*optimals = object_cast<migraphx_optimals_t>(
allocate<std::set<size_t>>(migraphx::make_set<size_t>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_destroy(migraphx_dynamic_dimension_t dynamic_dimension)
{
auto api_error_result = migraphx::try_([&] { destroy((dynamic_dimension)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_assign_to(migraphx_dynamic_dimension_t output,
const_migraphx_dynamic_dimension_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_dynamic_dimension_create_min_max(
migraphx_dynamic_dimension_t* dynamic_dimension, size_t min, size_t max)
{
auto api_error_result = migraphx::try_([&] {
*dynamic_dimension = object_cast<migraphx_dynamic_dimension_t>(
allocate<migraphx::shape::dynamic_dimension>((min), (max)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_create_min_max_optimals(migraphx_dynamic_dimension_t* dynamic_dimension,
size_t min,
size_t max,
migraphx_optimals_t optimals)
{
auto api_error_result = migraphx::try_([&] {
if(optimals == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter optimals: Null pointer");
*dynamic_dimension = object_cast<migraphx_dynamic_dimension_t>(
allocate<migraphx::shape::dynamic_dimension>((min), (max), (optimals->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_is_fixed(bool* out, const_migraphx_dynamic_dimension_t dynamic_dimension)
{
auto api_error_result = migraphx::try_([&] {
if(dynamic_dimension == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter dynamic_dimension: Null pointer");
*out = (dynamic_dimension->object).is_fixed();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimension_equal(bool* out,
const_migraphx_dynamic_dimension_t dynamic_dimension,
const_migraphx_dynamic_dimension_t x)
{
auto api_error_result = migraphx::try_([&] {
if(dynamic_dimension == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter dynamic_dimension: Null pointer");
if(x == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter x: Null pointer");
*out = migraphx::equal((dynamic_dimension->object), (x->object));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_destroy(migraphx_dynamic_dimensions_t dynamic_dimensions)
{
auto api_error_result = migraphx::try_([&] { destroy((dynamic_dimensions)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_assign_to(migraphx_dynamic_dimensions_t output,
const_migraphx_dynamic_dimensions_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_create(migraphx_dynamic_dimensions_t* dynamic_dimensions,
const_migraphx_dynamic_dimension_t* ptr,
size_t size)
{
auto api_error_result = migraphx::try_([&] {
*dynamic_dimensions = object_cast<migraphx_dynamic_dimensions_t>(
allocate<std::vector<migraphx::shape::dynamic_dimension>>(
migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_size(size_t* out, migraphx_dynamic_dimensions_t dynamic_dimensions)
{
auto api_error_result = migraphx::try_([&] {
if(dynamic_dimensions == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter dynamic_dimensions: Null pointer");
*out = (dynamic_dimensions->object).size();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_dynamic_dimensions_get(const_migraphx_dynamic_dimension_t* out,
migraphx_dynamic_dimensions_t dynamic_dimensions,
size_t idx)
{
auto api_error_result = migraphx::try_([&] {
if(dynamic_dimensions == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter dynamic_dimensions: Null pointer");
*out = object_cast<const_migraphx_dynamic_dimension_t>(
&((dynamic_dimensions->object).at((idx))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{ {
auto api_error_result = migraphx::try_([&] { destroy((shape)); }); auto api_error_result = migraphx::try_([&] { destroy((shape)); });
...@@ -794,6 +995,19 @@ extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape, ...@@ -794,6 +995,19 @@ extern "C" migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_create_dynamic(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
migraphx_dynamic_dimensions_t dims)
{
auto api_error_result = migraphx::try_([&] {
if(dims == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer");
*shape = object_cast<migraphx_shape_t>(
allocate<migraphx::shape>((migraphx::to_shape_type(type)), (dims->object)));
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape) migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape)
{ {
...@@ -824,6 +1038,17 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap ...@@ -824,6 +1038,17 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_dyn_dims(migraphx_dynamic_dimensions_t* out,
const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = allocate<migraphx_dynamic_dimensions_t>((shape->object).dyn_dims());
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
const_migraphx_shape_t shape) const_migraphx_shape_t shape)
{ {
...@@ -857,6 +1082,16 @@ extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shap ...@@ -857,6 +1082,16 @@ extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shap
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_ndim(size_t* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).ndim();
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x) migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x)
{ {
...@@ -880,6 +1115,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha ...@@ -880,6 +1115,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_dynamic(bool* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).dynamic();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i) extern "C" migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -915,6 +1160,17 @@ migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t s ...@@ -915,6 +1160,17 @@ migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t s
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_argument_create_empty(migraphx_argument_t* argument,
const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*argument = object_cast<migraphx_argument_t>(allocate<migraphx::argument>((shape->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out, extern "C" migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument) const_migraphx_argument_t argument)
{ {
...@@ -1590,6 +1846,19 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape( ...@@ -1590,6 +1846,19 @@ extern "C" migraphx_status migraphx_onnx_options_set_input_parameter_shape(
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_onnx_options_set_dyn_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, migraphx_dynamic_dimensions_t dims)
{
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dims == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dims: Null pointer");
migraphx::set_dyn_input_parameter_shape((onnx_options->object), (name), (dims->object));
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value) migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, size_t value)
{ {
...@@ -1601,6 +1870,20 @@ migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options ...@@ -1601,6 +1870,20 @@ migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_onnx_options_set_default_dyn_dim_value(migraphx_onnx_options_t onnx_options,
const_migraphx_dynamic_dimension_t dd)
{
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
if(dd == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter dd: Null pointer");
migraphx::set_default_dyn_dim_value((onnx_options->object), (dd->object));
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options, migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value) int64_t value)
......
...@@ -66,6 +66,15 @@ typedef enum ...@@ -66,6 +66,15 @@ typedef enum
} migraphx_shape_datatype_t; } migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
typedef struct migraphx_optimals* migraphx_optimals_t;
typedef const struct migraphx_optimals* const_migraphx_optimals_t;
typedef struct migraphx_dynamic_dimension* migraphx_dynamic_dimension_t;
typedef const struct migraphx_dynamic_dimension* const_migraphx_dynamic_dimension_t;
typedef struct migraphx_dynamic_dimensions* migraphx_dynamic_dimensions_t;
typedef const struct migraphx_dynamic_dimensions* const_migraphx_dynamic_dimensions_t;
typedef struct migraphx_shape* migraphx_shape_t; typedef struct migraphx_shape* migraphx_shape_t;
typedef const struct migraphx_shape* const_migraphx_shape_t; typedef const struct migraphx_shape* const_migraphx_shape_t;
...@@ -157,6 +166,55 @@ typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void ...@@ -157,6 +166,55 @@ typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void
typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input); typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input);
migraphx_status migraphx_optimals_destroy(migraphx_optimals_t optimals);
migraphx_status migraphx_optimals_assign_to(migraphx_optimals_t output,
const_migraphx_optimals_t input);
migraphx_status
migraphx_optimals_create(migraphx_optimals_t* optimals, const size_t* ptr, size_t size);
migraphx_status migraphx_dynamic_dimension_destroy(migraphx_dynamic_dimension_t dynamic_dimension);
migraphx_status migraphx_dynamic_dimension_assign_to(migraphx_dynamic_dimension_t output,
const_migraphx_dynamic_dimension_t input);
migraphx_status migraphx_dynamic_dimension_create_min_max(
migraphx_dynamic_dimension_t* dynamic_dimension, size_t min, size_t max);
migraphx_status
migraphx_dynamic_dimension_create_min_max_optimals(migraphx_dynamic_dimension_t* dynamic_dimension,
size_t min,
size_t max,
migraphx_optimals_t optimals);
migraphx_status
migraphx_dynamic_dimension_is_fixed(bool* out,
const_migraphx_dynamic_dimension_t dynamic_dimension);
migraphx_status
migraphx_dynamic_dimension_equal(bool* out,
const_migraphx_dynamic_dimension_t dynamic_dimension,
const_migraphx_dynamic_dimension_t x);
migraphx_status
migraphx_dynamic_dimensions_destroy(migraphx_dynamic_dimensions_t dynamic_dimensions);
migraphx_status migraphx_dynamic_dimensions_assign_to(migraphx_dynamic_dimensions_t output,
const_migraphx_dynamic_dimensions_t input);
migraphx_status
migraphx_dynamic_dimensions_create(migraphx_dynamic_dimensions_t* dynamic_dimensions,
const_migraphx_dynamic_dimension_t* ptr,
size_t size);
migraphx_status migraphx_dynamic_dimensions_size(size_t* out,
migraphx_dynamic_dimensions_t dynamic_dimensions);
migraphx_status migraphx_dynamic_dimensions_get(const_migraphx_dynamic_dimension_t* out,
migraphx_dynamic_dimensions_t dynamic_dimensions,
size_t idx);
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape); migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input); migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input);
...@@ -176,23 +234,34 @@ migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape, ...@@ -176,23 +234,34 @@ migraphx_status migraphx_shape_create_with_strides(migraphx_shape_t* shape,
migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape, migraphx_status migraphx_shape_create_scalar(migraphx_shape_t* shape,
migraphx_shape_datatype_t type); migraphx_shape_datatype_t type);
migraphx_status migraphx_shape_create_dynamic(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
migraphx_dynamic_dimensions_t dims);
migraphx_status migraphx_status
migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape); migraphx_shape_lengths(const size_t** out, size_t* out_size, const_migraphx_shape_t shape);
migraphx_status migraphx_status
migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape); migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_dyn_dims(migraphx_dynamic_dimensions_t* out,
const_migraphx_shape_t shape);
migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, const_migraphx_shape_t shape); migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_elements(size_t* out, const_migraphx_shape_t shape); migraphx_status migraphx_shape_elements(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape); migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_ndim(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x); migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x);
migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape); migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_dynamic(bool* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i); migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
...@@ -203,6 +272,9 @@ migraphx_status migraphx_argument_assign_to(migraphx_argument_t output, ...@@ -203,6 +272,9 @@ migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
migraphx_status migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer); migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer);
migraphx_status migraphx_argument_create_empty(migraphx_argument_t* argument,
const_migraphx_shape_t shape);
migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out, migraphx_status migraphx_argument_shape(const_migraphx_shape_t* out,
const_migraphx_argument_t argument); const_migraphx_argument_t argument);
...@@ -397,9 +469,16 @@ migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_optio ...@@ -397,9 +469,16 @@ migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_optio
migraphx_status migraphx_onnx_options_set_input_parameter_shape( migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size); migraphx_onnx_options_t onnx_options, const char* name, size_t* dims, size_t dims_size);
migraphx_status migraphx_onnx_options_set_dyn_input_parameter_shape(
migraphx_onnx_options_t onnx_options, const char* name, migraphx_dynamic_dimensions_t dims);
migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options, migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options,
size_t value); size_t value);
migraphx_status
migraphx_onnx_options_set_default_dyn_dim_value(migraphx_onnx_options_t onnx_options,
const_migraphx_dynamic_dimension_t dd);
migraphx_status migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options, migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value); int64_t value);
......
...@@ -571,10 +571,90 @@ using require_interface = ...@@ -571,10 +571,90 @@ using require_interface =
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_CONST_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, const) #define MIGRAPHX_CONST_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, const)
/**
* Container to hold optimal dynamic dimension values.
*/
struct optimals : MIGRAPHX_HANDLE_BASE(optimals)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(optimals)
optimals(std::initializer_list<size_t> init_list)
{
this->make_handle(&migraphx_optimals_create, init_list.begin(), init_list.size());
}
};
/**
* @brief Dynamic dimension object.
* @details minimum, maximum, and optimal dimensions
*/
struct dynamic_dimension : MIGRAPHX_CONST_HANDLE_BASE(dynamic_dimension)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(dynamic_dimension)
dynamic_dimension(size_t min, size_t max)
{
this->make_handle(&migraphx_dynamic_dimension_create_min_max, min, max);
}
dynamic_dimension(size_t min, size_t max, const optimals& opts)
{
this->make_handle(
&migraphx_dynamic_dimension_create_min_max_optimals, min, max, opts.get_handle_ptr());
}
bool is_fixed() const
{
bool result = false;
call(&migraphx_dynamic_dimension_is_fixed, &result, this->get_handle_ptr());
return result;
}
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y)
{
bool pout;
call(&migraphx_dynamic_dimension_equal, &pout, x.get_handle_ptr(), y.get_handle_ptr());
return pout;
}
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y)
{
return not(x == y);
}
};
/**
* Container to hold dynamic_dimension objects.
*/
struct dynamic_dimensions : MIGRAPHX_HANDLE_BASE(dynamic_dimensions)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(dynamic_dimensions)
template <class... Ts>
dynamic_dimensions(Ts... xs)
{
std::array<const_migraphx_dynamic_dimension_t, sizeof...(Ts)> a{xs.get_handle_ptr()...};
this->make_handle(&migraphx_dynamic_dimensions_create, a.data(), a.size());
}
size_t size() const
{
size_t pout;
call(&migraphx_dynamic_dimensions_size, &pout, this->get_handle_ptr());
return pout;
}
dynamic_dimension operator[](size_t pidx) const
{
const_migraphx_dynamic_dimension_t pout;
call(&migraphx_dynamic_dimensions_get, &pout, this->get_handle_ptr(), pidx);
return {pout, this->share_handle()};
}
};
/** /**
* @brief Describe shape of tensor * @brief Describe shape of tensor
* @details A shape consists of a data type, lengths of multi-dimension tensor, and strides * @details A shape consists of a data type, lengths of multi-dimension tensor, and strides
*
*/ */
struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{ {
...@@ -598,6 +678,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -598,6 +678,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
this->make_handle(&migraphx_shape_create, type, plengths.data(), plengths.size()); this->make_handle(&migraphx_shape_create, type, plengths.data(), plengths.size());
} }
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape(migraphx_shape_datatype_t t, std::initializer_list<std::size_t> d)
: shape::shape(t, std::vector<std::size_t>{d.begin(), d.end()})
{
}
shape(migraphx_shape_datatype_t type, shape(migraphx_shape_datatype_t type,
std::vector<size_t> plengths, std::vector<size_t> plengths,
std::vector<size_t> pstrides) std::vector<size_t> pstrides)
...@@ -610,6 +697,11 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -610,6 +697,11 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
pstrides.size()); pstrides.size());
} }
shape(migraphx_shape_datatype_t type, const dynamic_dimensions& dyn_dims)
{
this->make_handle(&migraphx_shape_create_dynamic, type, dyn_dims.get_handle_ptr());
}
std::vector<size_t> lengths() const std::vector<size_t> lengths() const
{ {
const size_t* pout; const size_t* pout;
...@@ -626,6 +718,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -626,6 +718,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return {pout, pout + pout_size}; return {pout, pout + pout_size};
} }
/// Get the dynamic dimensions of the shape
dynamic_dimensions dyn_dims() const
{
migraphx_dynamic_dimensions_t pout;
call(&migraphx_shape_dyn_dims, &pout, this->get_handle_ptr());
return {pout, own{}};
}
migraphx_shape_datatype_t type() const migraphx_shape_datatype_t type() const
{ {
migraphx_shape_datatype_t pout; migraphx_shape_datatype_t pout;
...@@ -654,6 +754,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -654,6 +754,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return result; return result;
} }
/// Is the shape dynamic
bool dynamic() const
{
bool result = false;
call(&migraphx_shape_dynamic, &result, this->get_handle_ptr());
return result;
}
// map element index to space index // map element index to space index
size_t index(size_t i) const size_t index(size_t i) const
{ {
...@@ -687,6 +795,11 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -687,6 +795,11 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); } argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
argument(shape pshape)
{
this->make_handle(&migraphx_argument_create_empty, pshape.get_handle_ptr());
}
argument(shape pshape, void* pbuffer) argument(shape pshape, void* pbuffer)
{ {
this->make_handle(&migraphx_argument_create, pshape.get_handle_ptr(), pbuffer); this->make_handle(&migraphx_argument_create, pshape.get_handle_ptr(), pbuffer);
...@@ -1182,12 +1295,27 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) ...@@ -1182,12 +1295,27 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
dim.size()); dim.size());
} }
void set_dyn_input_parameter_shape(const std::string& name, const dynamic_dimensions& dyn_dims)
{
call(&migraphx_onnx_options_set_dyn_input_parameter_shape,
this->get_handle_ptr(),
name.c_str(),
dyn_dims.get_handle_ptr());
}
/// When there is a dimension parameter, then use this default value /// When there is a dimension parameter, then use this default value
void set_default_dim_value(unsigned int value) void set_default_dim_value(unsigned int value)
{ {
call(&migraphx_onnx_options_set_default_dim_value, this->get_handle_ptr(), value); call(&migraphx_onnx_options_set_default_dim_value, this->get_handle_ptr(), value);
} }
void set_default_dyn_dim_value(const dynamic_dimension& dd)
{
call(&migraphx_onnx_options_set_default_dyn_dim_value,
this->get_handle_ptr(),
dd.get_handle_ptr());
}
/// Set default max iteration number for the loop operator /// Set default max iteration number for the loop operator
void set_default_loop_iterations(int64_t value) void set_default_loop_iterations(int64_t value)
{ {
......
...@@ -45,56 +45,48 @@ def shape_type_wrap(p): ...@@ -45,56 +45,48 @@ def shape_type_wrap(p):
p.read = 'migraphx::to_shape_type(${name})' p.read = 'migraphx::to_shape_type(${name})'
@api.cwrap('migraphx::compile_options') def auto_handle(*args, **kwargs):
def compile_options_type_wrap(p): def with_handle(f):
if p.returns: return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
p.add_param('migraphx_compile_options *') *args, **kwargs)(f)
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_compile_options(${result})']
else:
p.add_param('migraphx_compile_options *')
p.read = '${name} == nullptr ? migraphx::compile_options{} : migraphx::to_compile_options(*${name})'
@api.cwrap('migraphx::file_options')
def file_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_file_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_file_options(${result})']
else:
p.add_param('migraphx_file_options *')
p.read = '${name} == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*${name})'
return with_handle
@api.cwrap('migraphx::onnx_options')
def onnx_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_onnx_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_onnx_options(${result})']
else:
p.add_param('migraphx_onnx_options *')
p.read = '${name} == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*${name})'
@api.handle('migraphx_optimals', 'std::set<size_t>')
def optimals(h):
h.constructor('create',
api.params(ptr='const size_t*', size='size_t'),
fname='migraphx::make_set<size_t>')
@api.cwrap('migraphx::tf_options')
def tf_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_tf_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_tf_options(${result})']
else:
p.add_param('migraphx_tf_options *')
p.read = '${name} == nullptr ? migraphx::tf_options{} : migraphx::to_tf_options(*${name})'
@api.handle('migraphx_dynamic_dimension', 'migraphx::shape::dynamic_dimension')
def dynamic_dimension(h):
h.constructor('create_min_max', api.params(min='size_t', max='size_t'))
h.constructor(
'create_min_max_optimals',
api.params(min='size_t', max='size_t', optimals='std::set<size_t>'))
h.method('is_fixed', returns='bool', const=True)
h.method('equal',
api.params(x='const migraphx::shape::dynamic_dimension&'),
invoke='migraphx::equal($@)',
returns='bool',
const=True)
def auto_handle(*args, **kwargs):
def with_handle(f):
return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
*args, **kwargs)(f)
return with_handle @api.handle('migraphx_dynamic_dimensions',
'std::vector<migraphx::shape::dynamic_dimension>')
def dynamic_dimensions(h):
h.constructor(
'create',
api.params(ptr='const_migraphx_dynamic_dimension_t*', size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>')
h.method('size', returns='size_t')
h.method('get',
api.params(idx='size_t'),
fname='at',
cpp_name='operator[]',
returns='const migraphx::shape::dynamic_dimension&')
@auto_handle() @auto_handle()
...@@ -109,20 +101,29 @@ def shape(h): ...@@ -109,20 +101,29 @@ def shape(h):
lengths='std::vector<size_t>', lengths='std::vector<size_t>',
strides='std::vector<size_t>')) strides='std::vector<size_t>'))
h.constructor('create_scalar', api.params(type='migraphx::shape::type_t')) h.constructor('create_scalar', api.params(type='migraphx::shape::type_t'))
h.constructor(
'create_dynamic',
api.params(type='migraphx::shape::type_t',
dims='std::vector<migraphx::shape::dynamic_dimension>'))
h.method('lengths', h.method('lengths',
fname='lens', fname='lens',
returns='const std::vector<size_t>&', returns='const std::vector<size_t>&',
const=True) const=True)
h.method('strides', returns='const std::vector<size_t>&', const=True) h.method('strides', returns='const std::vector<size_t>&', const=True)
h.method('dyn_dims',
returns='std::vector<migraphx::shape::dynamic_dimension>',
const=True)
h.method('type', returns='migraphx::shape::type_t', const=True) h.method('type', returns='migraphx::shape::type_t', const=True)
h.method('elements', returns='size_t', const=True) h.method('elements', returns='size_t', const=True)
h.method('bytes', returns='size_t', const=True) h.method('bytes', returns='size_t', const=True)
h.method('ndim', returns='size_t', const=True)
h.method('equal', h.method('equal',
api.params(x='const migraphx::shape&'), api.params(x='const migraphx::shape&'),
invoke='migraphx::equal($@)', invoke='migraphx::equal($@)',
returns='bool', returns='bool',
const=True) const=True)
h.method('standard', returns='bool', const=True) h.method('standard', returns='bool', const=True)
h.method('dynamic', returns='bool', const=True)
h.method('index', api.params(i='size_t'), returns='size_t', const=True) h.method('index', api.params(i='size_t'), returns='size_t', const=True)
...@@ -130,6 +131,7 @@ def shape(h): ...@@ -130,6 +131,7 @@ def shape(h):
def argument(h): def argument(h):
h.constructor('create', h.constructor('create',
api.params(shape='const migraphx::shape&', buffer='void*')) api.params(shape='const migraphx::shape&', buffer='void*'))
h.constructor('create_empty', api.params(shape='const migraphx::shape&'))
h.method('shape', h.method('shape',
fname='get_shape', fname='get_shape',
cpp_name='get_shape', cpp_name='get_shape',
...@@ -325,11 +327,22 @@ def onnx_options(h): ...@@ -325,11 +327,22 @@ def onnx_options(h):
api.params(name='const char*', dims='std::vector<size_t>'), api.params(name='const char*', dims='std::vector<size_t>'),
invoke='migraphx::set_input_parameter_shape($@)', invoke='migraphx::set_input_parameter_shape($@)',
) )
h.method(
'set_dyn_input_parameter_shape',
api.params(name='const char*',
dims='std::vector<migraphx::shape::dynamic_dimension>'),
invoke='migraphx::set_dyn_input_parameter_shape($@)',
)
h.method( h.method(
'set_default_dim_value', 'set_default_dim_value',
api.params(value='size_t'), api.params(value='size_t'),
invoke='migraphx::set_default_dim_value($@)', invoke='migraphx::set_default_dim_value($@)',
) )
h.method(
'set_default_dyn_dim_value',
api.params(dd='const migraphx::shape::dynamic_dimension&'),
invoke='migraphx::set_default_dyn_dim_value($@)',
)
h.method( h.method(
'set_default_loop_iterations', 'set_default_loop_iterations',
api.params(value='int64_t'), api.params(value='int64_t'),
......
...@@ -48,6 +48,7 @@ add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR}) ...@@ -48,6 +48,7 @@ add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR}) add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR}) add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR})
add_api_test(dynamic_shape test_dynamic_shape.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(create_dynamic_dimensions)
{
migraphx::dynamic_dimension dd0{1, 4};
EXPECT(not dd0.is_fixed());
migraphx::dynamic_dimension dd1{4, 4};
EXPECT(dd1.is_fixed());
migraphx::optimals opts{1, 2, 4};
migraphx::dynamic_dimension dd2{1, 4, opts};
migraphx::dynamic_dimensions dyn_dims0{dd0, dd1, dd2};
CHECK(bool{dyn_dims0[0] == dd0});
CHECK(bool{dyn_dims0[1] == dd1});
CHECK(bool{dyn_dims0[2] == dd2});
CHECK(bool{dyn_dims0[2] != dd0});
EXPECT(dyn_dims0.size() == 3);
}
TEST_CASE(create_dynamic_shape)
{
migraphx::dynamic_dimensions dyn_dims(migraphx::dynamic_dimension{1, 4},
migraphx::dynamic_dimension{78, 92},
migraphx::dynamic_dimension{1, 4, {1, 4}});
migraphx::shape dyn_shape{migraphx_shape_float_type, dyn_dims};
CHECK(bool{dyn_shape.dynamic()});
CHECK(bool{dyn_shape.dyn_dims()[0] == migraphx::dynamic_dimension{1, 4}});
migraphx::shape static_shape{migraphx_shape_float_type, {3, 8}};
EXPECT(not static_shape.dynamic());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> #include <migraphx/migraphx.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -72,6 +71,105 @@ hip_ptr get_hip_buffer(size_t size) ...@@ -72,6 +71,105 @@ hip_ptr get_hip_buffer(size_t size)
return hip_ptr{ptr}; return hip_ptr{ptr};
} }
// TODO: placeholder until we have a way to copy tuple arguments to/from device through c++ api
// TEST_CASE(dynamic_batch_load_and_run)
//{
// migraphx::onnx_options o_options;
// migraphx::dynamic_dimensions dyn_dims = {{1, 4, {2, 4}}, {3, 3}, {4, 4}, {4, 4}};
// o_options.set_dyn_input_parameter_shape("0", dyn_dims);
// dyn_dims = {{2, 2}, {3, 3}, {3, 3}, {3, 3}};
// o_options.set_dyn_input_parameter_shape("1", dyn_dims);
// auto p = migraphx::parse_onnx("conv_dynamic_batch_test.onnx", o_options);
// migraphx::compile_options c_options;
// c_options.set_split_single_dyn_dim();
// p.compile(migraphx::target("gpu"), c_options);
// auto out_shapes = p.get_output_shapes();
// CHECK(out_shapes.size() == 1);
// EXPECT(out_shapes[0].dynamic());
//
// std::vector<float> a(0.12, 2*3*4*4);
// std::vector<float> c(0.75, 2*3*3*3);
//
// auto param_shapes = p.get_parameter_shapes();
// int batch_size = 2;
// std::unordered_map<std::string, migraphx::argument> arg_map;
//
// arg_map["0"] = migraphx::argument(param_shapes["0"].to_static(batch_size), a.data());
// arg_map["1"] = migraphx::argument(param_shapes["1"].to_static(batch_size), c.data());
//
// migraphx::program_parameters pp;
// std::vector<hip_ptr> buffs;
// std::vector<migraphx::argument> args;
//
// // copy to GPU and create parameter map
// for(auto&& name : param_shapes.names())
// {
// if(arg_map.find(name) != arg_map.end())
// {
// args.push_back(arg_map.at(name));
// }
// else
// {
// migraphx::shape static_shape = param_shapes[name].to_static(batch_size);
// auto output_arg = migraphx::argument(static_shape);
// args.push_back(output_arg);
// }
// buffs.push_back(get_hip_buffer(args.rbegin()->get_shape().bytes()));
// auto err = hipMemcpy(buffs.rbegin()->get(),
// args.rbegin()->data(),
// args.rbegin()->get_shape().bytes(),
// hipMemcpyHostToDevice);
// EXPECT(err == hipSuccess);
// pp.add(name, migraphx::argument(args.rbegin()->get_shape(), buffs.rbegin()->get()));
// }
//
// auto output = p.eval(pp)[0];
//
// // copy output back to host
// auto host_arg = migraphx::argument(output.get_shape());
// auto err = hipMemcpy(
// host_arg.data(), output.data(), output.get_shape().bytes(), hipMemcpyDeviceToHost);
// EXPECT(err == hipSuccess);
//}
TEST_CASE(dynamic_batch_load_and_run_offload)
{
migraphx::onnx_options o_options;
migraphx::dynamic_dimensions dyn_dims = {migraphx::dynamic_dimension{1, 4, {2, 4}},
migraphx::dynamic_dimension{3, 3},
migraphx::dynamic_dimension{4, 4},
migraphx::dynamic_dimension{4, 4}};
o_options.set_dyn_input_parameter_shape("0", dyn_dims);
dyn_dims = {migraphx::dynamic_dimension{2, 2},
migraphx::dynamic_dimension{3, 3},
migraphx::dynamic_dimension{3, 3},
migraphx::dynamic_dimension{3, 3}};
o_options.set_dyn_input_parameter_shape("1", dyn_dims);
auto p = migraphx::parse_onnx("conv_dynamic_batch_test.onnx", o_options);
auto shapes_before = p.get_output_shapes();
migraphx::compile_options c_options;
c_options.set_offload_copy();
p.compile(migraphx::target("gpu"), c_options);
auto out_shapes = p.get_output_shapes();
CHECK(out_shapes.size() == 1);
EXPECT(out_shapes[0].dynamic());
// batch size = 2
std::vector<float> a(2 * 3 * 4 * 4, 0.12);
std::vector<float> c(2 * 3 * 3 * 3, 0.75);
migraphx::program_parameters pp;
auto param_shapes = p.get_parameter_shapes();
pp.add("0",
migraphx::argument(migraphx::shape(migraphx_shape_float_type, {2, 3, 4, 4}), a.data()));
pp.add("1",
migraphx::argument(migraphx::shape(migraphx_shape_float_type, {2, 3, 3, 3}), c.data()));
auto outputs = p.eval(pp);
CHECK(shapes_before.size() == outputs.size());
CHECK(bool{outputs.front().get_shape() ==
migraphx::shape(migraphx_shape_float_type, {2, 1, 3, 3})});
}
TEST_CASE(load_and_run_async) TEST_CASE(load_and_run_async)
{ {
auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
......
...@@ -1132,50 +1132,6 @@ TEST_CASE(conv_dyn_batch_test) ...@@ -1132,50 +1132,6 @@ TEST_CASE(conv_dyn_batch_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol)); EXPECT(migraphx::verify_range(results_vector, sol));
a = {2.71567607, -0.9960829, 0.91671127, 0.28140706, 0.63235772, 0.08077253, 0.80927712,
-0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439, -0.65290606,
0.02345525, -0.33579525, 0.38901961, 1.05473483, -1.31188095, 1.8963089, -0.07265259,
0.947339, 0.41949373, -0.70814759, 0.25892952, 1.07311416, 1.2571274, -0.62318051,
-0.19951548, -0.94232577, -0.29393643, 0.42292568, -0.80230367, 1.40909171, 0.63617158,
0.13900366, 1.09253144, -0.15265895, 1.54781747, 0.72780299, 1.09189606, -0.38068101,
0.97057933, -0.58958799, 1.56188643, 0.21474874, 0.58725154, -1.27097559, -0.03024297,
1.09437096, -0.4897908, 0.34838957, -1.31042492, -1.69069934, 0.86956722, -0.40457946,
0.46691212, 1.29273605, 0.26464137, 0.22073045, -1.02178168, 0.22163901, -1.84387338,
0.75522131, -0.45775682, -0.42241111, -1.50944722, 1.07256448, -1.95876884, -0.28106022,
0.3341668, 2.13129425, -1.14728117, -1.06555498, -0.298444, -0.88322699, -0.65866792,
-2.06007552, 0.01374334, 0.45612028, 0.52715492, 1.01914406, -1.72659791, 0.80650896,
0.16860051, 2.24112225, -0.78620857, 0.36566174, -0.07020134, -0.47976932, -0.68230027,
-0.94711417, -0.54506505, 1.66504931, -0.71860826, 0.61132306};
c = {-0.14601797, -0.13000923, 0.06521662, 0.06178288, -0.11083675, 0.10154136, 0.09990512,
0.06030385, -0.11374587, -0.17523311, -0.14344215, 0.17802463, 0.06300922, -0.15325832,
0.07066704, 0.05166031, 0.00615084, -0.02606523, 0.08083995, -0.17913306, 0.0624622,
0.0735731, -0.04198661, -0.0164391, -0.06374192, 0.16569914, 0.10681538, 0.07370754,
0.02802075, 0.00282027, 0.15104802, -0.11084409, -0.00197773, 0.07924436, 0.03528272,
0.04765259, -0.15896152, 0.07917164, 0.12125669, -0.1154705, -0.11999125, 0.12749968,
-0.06269585, 0.18658121, -0.03944227, 0.0111798, -0.17731084, 0.11789055, -0.09982193,
0.08142821, 0.0729029, 0.11303909, 0.12735154, 0.03885292};
sol = {-0.20817225,
0.87965256,
0.14958936,
-1.24887264,
-0.06540672,
0.20778663,
0.40456355,
-0.99900877};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {1, 3, 4, 4}};
migraphx::parameter_map params1;
params1["X"] = migraphx::argument(input_fixed_shape1, a.data());
params1["W"] = migraphx::argument(weights_shape, c.data());
result = p.eval(params1).back();
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
} }
TEST_CASE(conv_dyn_img_shape_test) TEST_CASE(conv_dyn_img_shape_test)
......
...@@ -764,9 +764,12 @@ const Target* object_cast(const U* x) ...@@ -764,9 +764,12 @@ const Target* object_cast(const U* x)
return reinterpret_cast<const Target*>(x); return reinterpret_cast<const Target*>(x);
} }
template<class T, class... Ts, class Target=std::remove_pointer_t<T>> template <class T, class... Ts, class Target = std::remove_pointer_t<T>>
Target* allocate(Ts&&... xs) Target* allocate(Ts&&... xs)
{ {
if constexpr(std::is_aggregate<Target>{})
return new Target{std::forward<Ts>(xs)...}; // NOLINT
else
return new Target(std::forward<Ts>(xs)...); // NOLINT return new Target(std::forward<Ts>(xs)...); // NOLINT
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <migraphx/execution_environment.hpp> #include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
...@@ -145,6 +146,11 @@ void set_default_dim_value(onnx_options& options, size_t value) ...@@ -145,6 +146,11 @@ void set_default_dim_value(onnx_options& options, size_t value)
options.default_dim_value = value; options.default_dim_value = value;
} }
void set_default_dyn_dim_value(onnx_options& options, const shape::dynamic_dimension& dd)
{
options.default_dyn_dim_value = dd;
}
void set_default_loop_iterations(onnx_options& options, int64_t value) void set_default_loop_iterations(onnx_options& options, int64_t value)
{ {
options.max_loop_iterations = value; options.max_loop_iterations = value;
...@@ -161,6 +167,13 @@ void set_input_parameter_shape(onnx_options& options, ...@@ -161,6 +167,13 @@ void set_input_parameter_shape(onnx_options& options,
options.map_input_dims[std::string(name)] = std::move(dims); options.map_input_dims[std::string(name)] = std::move(dims);
} }
void set_dyn_input_parameter_shape(onnx_options& options,
const char* name,
std::vector<shape::dynamic_dimension> dyn_dims)
{
options.map_dyn_input_dims[std::string(name)] = std::move(dyn_dims);
}
void set_input_parameter_shape(tf_options& options, const char* name, std::vector<std::size_t> dims) void set_input_parameter_shape(tf_options& options, const char* name, std::vector<std::size_t> dims)
{ {
options.map_input_dims[std::string(name)] = std::move(dims); options.map_input_dims[std::string(name)] = std::move(dims);
...@@ -187,6 +200,12 @@ std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& ...@@ -187,6 +200,12 @@ std::vector<const char*> get_names(const std::unordered_map<std::string, Value>&
return result; return result;
} }
template <class T>
std::set<T> make_set(const T* x, std::size_t n)
{
return {x, x + n};
}
void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names) void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
{ {
if(names.empty()) if(names.empty())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment