Commit 995d68e4 authored by umangyadav's avatar umangyadav
Browse files

value for basic data types

parent 86061b4d
...@@ -441,6 +441,17 @@ struct migraphx_module ...@@ -441,6 +441,17 @@ struct migraphx_module
migraphx::module object; migraphx::module object;
}; };
extern "C" struct migraphx_value;
struct migraphx_value
{
template <class... Ts>
migraphx_value(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{
}
migraphx::value object;
};
extern "C" struct migraphx_program; extern "C" struct migraphx_program;
struct migraphx_program struct migraphx_program
{ {
...@@ -1012,7 +1023,7 @@ migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, siz ...@@ -1012,7 +1023,7 @@ migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, siz
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_module_create(migraphx_module_t* module, char* name) extern "C" migraphx_status migraphx_module_create(migraphx_module_t* module, const char* name)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
if(name == nullptr) if(name == nullptr)
...@@ -1118,6 +1129,289 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou ...@@ -1118,6 +1129,289 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_value_destroy(migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] { destroy((value)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_value_assign_to(migraphx_value_t output,
const_migraphx_value_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_value_create(migraphx_value_t* value)
{
auto api_error_result = migraphx::try_(
[&] { *value = object_cast<migraphx_value_t>(allocate<migraphx::value>()); });
return api_error_result;
}
extern "C" migraphx_status migraphx_value_is_null(bool* out, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).is_null();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_value_get_key(char* out, size_t out_size, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
auto&& api_result = (value->object).get_key();
auto* it = std::copy_n(api_result.begin(), std::min(api_result.size(), out_size - 1), out);
*it = '\0';
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_create_int64(migraphx_value_t* value, int64_t i)
{
auto api_error_result = migraphx::try_(
[&] { *value = object_cast<migraphx_value_t>(allocate<migraphx::value>((i))); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_value_create_int64_with_key(migraphx_value_t* value, const char* pkey, int64_t i)
{
auto api_error_result = migraphx::try_(
[&] { *value = object_cast<migraphx_value_t>(allocate<migraphx::value>((pkey), (i))); });
return api_error_result;
}
extern "C" migraphx_status migraphx_value_if_int64(const int64_t** out,
const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).if_int64();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_is_int64(bool* out, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).is_int64();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_get_int64(int64_t* out, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).get_int64();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_create_uint64(migraphx_value_t* value, uint64_t i)
{
auto api_error_result = migraphx::try_(
[&] { *value = object_cast<migraphx_value_t>(allocate<migraphx::value>((i))); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_value_create_uint64_with_key(migraphx_value_t* value, const char* pkey, uint64_t i)
{
auto api_error_result = migraphx::try_(
[&] { *value = object_cast<migraphx_value_t>(allocate<migraphx::value>((pkey), (i))); });
return api_error_result;
}
extern "C" migraphx_status migraphx_value_if_uint64(const uint64_t** out,
const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).if_uint64();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_is_uint64(bool* out, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).is_uint64();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_get_uint64(uint64_t* out, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).get_uint64();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_create_float(migraphx_value_t* value, double i)
{
auto api_error_result = migraphx::try_(
[&] { *value = object_cast<migraphx_value_t>(allocate<migraphx::value>((i))); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_value_create_float_with_key(migraphx_value_t* value, const char* pkey, double i)
{
auto api_error_result = migraphx::try_(
[&] { *value = object_cast<migraphx_value_t>(allocate<migraphx::value>((pkey), (i))); });
return api_error_result;
}
extern "C" migraphx_status migraphx_value_if_float(const double** out, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).if_float();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_is_float(bool* out, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).is_float();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_get_float(double* out, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).get_float();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_create_string(migraphx_value_t* value, const char* i)
{
auto api_error_result = migraphx::try_(
[&] { *value = object_cast<migraphx_value_t>(allocate<migraphx::value>((i))); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_value_create_string_with_key(migraphx_value_t* value, const char* pkey, const char* i)
{
auto api_error_result = migraphx::try_(
[&] { *value = object_cast<migraphx_value_t>(allocate<migraphx::value>((pkey), (i))); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_value_if_string(char* out, size_t out_size, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
auto&& api_result = (value->object).if_string();
auto* it =
std::copy_n(api_result->begin(), std::min(api_result->size(), out_size - 1), out);
*it = '\0';
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_is_string(bool* out, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).is_string();
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_value_get_string(char* out, size_t out_size, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(out == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter out: Null pointer");
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
auto&& api_result = (value->object).get_string();
auto* it = std::copy_n(api_result.begin(), std::min(api_result.size(), out_size - 1), out);
*it = '\0';
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_create_bool(migraphx_value_t* value, bool i)
{
auto api_error_result = migraphx::try_(
[&] { *value = object_cast<migraphx_value_t>(allocate<migraphx::value>((i))); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_value_create_bool_with_key(migraphx_value_t* value, const char* pkey, bool i)
{
auto api_error_result = migraphx::try_(
[&] { *value = object_cast<migraphx_value_t>(allocate<migraphx::value>((pkey), (i))); });
return api_error_result;
}
extern "C" migraphx_status migraphx_value_if_bool(const bool** out, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).if_bool();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_is_bool(bool* out, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).is_bool();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_value_get_bool(bool* out, const_migraphx_value_t value)
{
auto api_error_result = migraphx::try_([&] {
if(value == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter value: Null pointer");
*out = (value->object).get_bool();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program) extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{ {
auto api_error_result = migraphx::try_([&] { destroy((program)); }); auto api_error_result = migraphx::try_([&] { destroy((program)); });
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPHX_GUARD_C_API_MIGRAPHX_H #define MIGRAPHX_GUARD_C_API_MIGRAPHX_H
#include <stdlib.h> #include <stdlib.h>
#include <cstdint>
// Add new types here // Add new types here
// clang-format off // clang-format off
...@@ -76,6 +77,9 @@ typedef const struct migraphx_modules* const_migraphx_modules_t; ...@@ -76,6 +77,9 @@ typedef const struct migraphx_modules* const_migraphx_modules_t;
typedef struct migraphx_module* migraphx_module_t; typedef struct migraphx_module* migraphx_module_t;
typedef const struct migraphx_module* const_migraphx_module_t; typedef const struct migraphx_module* const_migraphx_module_t;
typedef struct migraphx_value* migraphx_value_t;
typedef const struct migraphx_value* const_migraphx_value_t;
typedef struct migraphx_program* migraphx_program_t; typedef struct migraphx_program* migraphx_program_t;
typedef const struct migraphx_program* const_migraphx_program_t; typedef const struct migraphx_program* const_migraphx_program_t;
...@@ -243,7 +247,7 @@ migraphx_status migraphx_modules_assign_to(migraphx_modules_t output, ...@@ -243,7 +247,7 @@ migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
migraphx_status migraphx_status
migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size); migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size);
migraphx_status migraphx_module_create(migraphx_module_t* module, char* name); migraphx_status migraphx_module_create(migraphx_module_t* module, const char* name);
migraphx_status migraphx_module_print(const_migraphx_module_t module); migraphx_status migraphx_module_print(const_migraphx_module_t module);
...@@ -272,6 +276,71 @@ migraphx_status migraphx_module_add_return(migraphx_instruction_t* out, ...@@ -272,6 +276,71 @@ migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
migraphx_instructions_t args); migraphx_instructions_t args);
migraphx_status migraphx_value_destroy(migraphx_value_t value);
migraphx_status migraphx_value_assign_to(migraphx_value_t output, const_migraphx_value_t input);
migraphx_status migraphx_value_create(migraphx_value_t* value);
migraphx_status migraphx_value_is_null(bool* out, const_migraphx_value_t value);
migraphx_status migraphx_value_get_key(char* out, size_t out_size, const_migraphx_value_t value);
migraphx_status migraphx_value_create_int64(migraphx_value_t* value, int64_t i);
migraphx_status
migraphx_value_create_int64_with_key(migraphx_value_t* value, const char* pkey, int64_t i);
migraphx_status migraphx_value_if_int64(const int64_t** out, const_migraphx_value_t value);
migraphx_status migraphx_value_is_int64(bool* out, const_migraphx_value_t value);
migraphx_status migraphx_value_get_int64(int64_t* out, const_migraphx_value_t value);
migraphx_status migraphx_value_create_uint64(migraphx_value_t* value, uint64_t i);
migraphx_status
migraphx_value_create_uint64_with_key(migraphx_value_t* value, const char* pkey, uint64_t i);
migraphx_status migraphx_value_if_uint64(const uint64_t** out, const_migraphx_value_t value);
migraphx_status migraphx_value_is_uint64(bool* out, const_migraphx_value_t value);
migraphx_status migraphx_value_get_uint64(uint64_t* out, const_migraphx_value_t value);
migraphx_status migraphx_value_create_float(migraphx_value_t* value, double i);
migraphx_status
migraphx_value_create_float_with_key(migraphx_value_t* value, const char* pkey, double i);
migraphx_status migraphx_value_if_float(const double** out, const_migraphx_value_t value);
migraphx_status migraphx_value_is_float(bool* out, const_migraphx_value_t value);
migraphx_status migraphx_value_get_float(double* out, const_migraphx_value_t value);
migraphx_status migraphx_value_create_string(migraphx_value_t* value, const char* i);
migraphx_status
migraphx_value_create_string_with_key(migraphx_value_t* value, const char* pkey, const char* i);
migraphx_status migraphx_value_if_string(char* out, size_t out_size, const_migraphx_value_t value);
migraphx_status migraphx_value_is_string(bool* out, const_migraphx_value_t value);
migraphx_status migraphx_value_get_string(char* out, size_t out_size, const_migraphx_value_t value);
migraphx_status migraphx_value_create_bool(migraphx_value_t* value, bool i);
migraphx_status
migraphx_value_create_bool_with_key(migraphx_value_t* value, const char* pkey, bool i);
migraphx_status migraphx_value_if_bool(const bool** out, const_migraphx_value_t value);
migraphx_status migraphx_value_is_bool(bool* out, const_migraphx_value_t value);
migraphx_status migraphx_value_get_bool(bool* out, const_migraphx_value_t value);
migraphx_status migraphx_program_destroy(migraphx_program_t program); migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output, migraphx_status migraphx_program_assign_to(migraphx_program_t output,
......
...@@ -832,6 +832,90 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) ...@@ -832,6 +832,90 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
} }
}; };
struct value : MIGRAPHX_HANDLE_BASE(value)
{
value() { this->make_handle(&migraphx_value_create); }
value(migraphx_value* p, own) { this->set_handle(p, own{}); }
value(migraphx_value* p, borrow) { this->set_handle(p, borrow{}); }
bool is_null() const
{
bool result;
call(&migraphx_value_is_null, &result, this->get_handle_ptr());
return result;
}
std::string get_key() const
{
std::array<char, 1024> str_array;
call(&migraphx_value_get_key, str_array.data(), 1024, this->get_handle_ptr());
return {str_array.data()};
}
value(std::string i) { this->make_handle(&migraphx_value_create_string, i.data()); }
value(const std::string& pkey, std::string i)
{
this->make_handle(&migraphx_value_create_string_with_key, pkey.data(), i.data());
}
bool is_string() const
{
bool result;
call(&migraphx_value_is_string, &result, this->get_handle_ptr());
return result;
}
std::string get_string() const
{
std::array<char, 1024> str_array;
call(&migraphx_value_get_string, str_array.data(), 1024, this->get_handle_ptr());
return {str_array.data()};
}
// TODO(umang): need to return pointer to make it consistent across all data types, or make all
// data type return value instead of pointers
// TODO(umang): need to check if size of 1024 holds for serialization
std::string if_string() const
{
std::array<char, 1024> str_array;
call(&migraphx_value_if_string, str_array.data(), 1024, this->get_handle_ptr());
return {str_array.data()};
}
#define MIGRAPHX_VISIT_VALUE_TYPES(m) \
m(int64, std::int64_t) m(uint64, std::uint64_t) m(float, double) m(bool, bool)
#define MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS(vt, cpp_type) \
value(cpp_type i) { this->make_handle(&migraphx_value_create_##vt, i); } \
value(const std::string& pkey, cpp_type i) \
{ \
this->make_handle(&migraphx_value_create_##vt##_with_key, pkey.data(), i); \
} \
bool is_##vt() const \
{ \
bool result; \
call(&migraphx_value_is_##vt, &result, this->get_handle_ptr()); \
return result; \
} \
cpp_type get_##vt() const \
{ \
cpp_type get_value; \
call(&migraphx_value_get_##vt, &get_value, this->get_handle_ptr()); \
return get_value; \
} \
const cpp_type* if_##vt() const \
{ \
const cpp_type* ret_value = nullptr; \
call(&migraphx_value_if_##vt, &ret_value, this->get_handle_ptr()); \
return ret_value; \
}
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_DEFINE_METHODS)
};
/// A program represents the all computation graphs to be compiled and executed /// A program represents the all computation graphs to be compiled and executed
struct program : MIGRAPHX_HANDLE_BASE(program) struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
......
...@@ -223,6 +223,27 @@ def module(h): ...@@ -223,6 +223,27 @@ def module(h):
returns='migraphx::instruction_ref') returns='migraphx::instruction_ref')
@auto_handle()
def value(h):
h.constructor('create')
h.method('is_null', returns='bool', const=True)
h.method('get_key', returns='const std::string', const=True)
cpp_types = ['int64_t', 'uint64_t', 'double', 'std::string', 'bool']
vt = ['int64', 'uint64', 'float', 'string', 'bool']
for vt, cpp_type in zip(vt, cpp_types):
if (vt == 'string'):
h.constructor('create_' + vt, api.params(i='const char*'))
h.constructor('create_' + vt + '_with_key',
api.params(pkey='const char*', i='const char*'))
else:
h.constructor('create_' + vt, api.params(i=cpp_type))
h.constructor('create_' + vt + '_with_key',
api.params(pkey='const char*', i=cpp_type))
h.method('if_' + vt, returns='const ' + cpp_type + '*', const=True)
h.method('is_' + vt, returns='bool', const=True)
h.method('get_' + vt, returns=cpp_type, const=True)
@auto_handle() @auto_handle()
def program(h): def program(h):
h.constructor('create') h.constructor('create')
......
...@@ -15,6 +15,7 @@ add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR}) ...@@ -15,6 +15,7 @@ add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR}) add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR}) add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR})
add_api_test(value test_value.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
......
#include <cstdint>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(value_default_construct)
{
migraphx::value v;
EXPECT(v.is_null());
EXPECT(v.get_key().empty());
}
TEST_CASE(value_construct_int1)
{
EXPECT(migraphx::value(int64_t{1}).is_int64());
migraphx::value v(int64_t{1});
EXPECT(v.is_int64());
EXPECT(v.get_int64() == 1);
EXPECT(*v.if_int64() == 1);
EXPECT(v.get_key().empty());
}
TEST_CASE(value_construct_int2)
{
migraphx::value v = int64_t{1};
EXPECT(v.is_int64());
EXPECT(v.get_int64() == 1);
EXPECT(*v.if_int64() == 1);
EXPECT(v.get_key().empty());
}
TEST_CASE(value_construct_string)
{
migraphx::value v = std::string{"one"};
EXPECT(v.is_string());
EXPECT(v.get_string() == "one");
EXPECT(v.if_string() == "one");
EXPECT(v.get_key().empty());
}
TEST_CASE(value_construct_key_string_literal_pair)
{
// Use parens instead {} to construct to test the key-pair constructor
migraphx::value v("key", std::string{"one"});
EXPECT(v.is_string());
EXPECT(v.get_string() == "one");
EXPECT(v.if_string() == "one");
EXPECT(v.get_key() == "key");
}
TEST_CASE(value_construct_float)
{
migraphx::value v = 1.0;
EXPECT(v.is_float());
// TODO: add float_equal method
EXPECT(v.get_key().empty());
}
TEST_CASE(value_construct_bool)
{
migraphx::value v = true;
EXPECT(v.is_bool());
EXPECT(v.get_bool() == true);
EXPECT(v.get_key().empty());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -918,7 +918,10 @@ def vector_c_wrap(p: Parameter) -> None: ...@@ -918,7 +918,10 @@ def vector_c_wrap(p: Parameter) -> None:
@cwrap('std::string') @cwrap('std::string')
def string_c_wrap(p: Parameter) -> None: def string_c_wrap(p: Parameter) -> None:
t = Type('char*') if p.type.is_const and not p.returns:
t = Type('const char*')
else:
t = Type('char*')
if p.returns: if p.returns:
if p.type.is_reference(): if p.type.is_reference():
p.add_param(t.add_pointer()) p.add_param(t.add_pointer())
...@@ -936,6 +939,11 @@ def string_c_wrap(p: Parameter) -> None: ...@@ -936,6 +939,11 @@ def string_c_wrap(p: Parameter) -> None:
p.virtual_read = ['${name}.c_str()'] p.virtual_read = ['${name}.c_str()']
if p.type.is_reference(): if p.type.is_reference():
p.write = ['*${name} = ${result}.c_str()'] p.write = ['*${name} = ${result}.c_str()']
elif p.type.is_pointer():
p.write = [
'auto* it = std::copy_n(${result}->begin(), std::min(${result}->size(), ${name}_size - 1), ${name});'
'*it = \'\\0\''
]
else: else:
p.write = [ p.write = [
'auto* it = std::copy_n(${result}.begin(), std::min(${result}.size(), ${name}_size - 1), ${name});' 'auto* it = std::copy_n(${result}.begin(), std::min(${result}.size(), ${name}_size - 1), ${name});'
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPHX_GUARD_C_API_MIGRAPHX_H #define MIGRAPHX_GUARD_C_API_MIGRAPHX_H
#include <stdlib.h> #include <stdlib.h>
#include <cstdint>
// Add new types here // Add new types here
// clang-format off // clang-format off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment