Commit a416dbc6 authored by umangyadav's avatar umangyadav
Browse files

use sharing

parent dd219d6e
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <limits>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
...@@ -37,15 +38,57 @@ void raw_data_to_value(value& v, const RawData& rd) ...@@ -37,15 +38,57 @@ void raw_data_to_value(value& v, const RawData& rd)
if(rd.get_shape().type() == shape::tuple_type) if(rd.get_shape().type() == shape::tuple_type)
result["sub"] = migraphx::to_value(rd.get_sub_objects()); result["sub"] = migraphx::to_value(rd.get_sub_objects());
else if(not rd.empty()) else if(not rd.empty())
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes()); {
size_t binary_size = rd.get_shape().bytes();
size_t partition_length = std::numeric_limits<uint32_t>::max();
if(binary_size > partition_length)
{
size_t array_size = 1 + ((binary_size - 1) / partition_length);
std::vector<migraphx::value> v_array(array_size);
for(size_t i = 0; i < array_size; ++i)
{
size_t chunk_size =
(i == (array_size - 1)) ? (binary_size % partition_length) : partition_length;
v_array[i] =
migraphx::value::binary{(rd.data() + (i * partition_length)), chunk_size};
}
result["data"] = migraphx::value(v_array);
}
else
{
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes());
}
}
v = result; v = result;
} }
void migraphx_to_value(value& v, const literal& l) { raw_data_to_value(v, l); } void migraphx_to_value(value& v, const literal& l) { raw_data_to_value(v, l); }
void migraphx_from_value(const value& v, literal& l) void migraphx_from_value(const value& v, literal& l)
{ {
auto s = migraphx::from_value<shape>(v.at("shape")); auto s = migraphx::from_value<shape>(v.at("shape"));
l = literal(s, v.at("data").get_binary().data()); size_t binary_size = s.bytes();
size_t partition_length = std::numeric_limits<uint32_t>::max();
if(binary_size <= partition_length)
{
l = literal(s, v.at("data").get_binary().data());
}
else
{
assert(v.is_array());
size_t array_size = 1 + ((binary_size - 1) / partition_length);
assert(array_size == v.size());
std::vector<uint8_t> binary_array(binary_size);
size_t read_size = 0;
for(size_t i = 0; i < array_size; ++i)
{
binary_array.insert(binary_array.end(),
v.at(i).get_binary().data(),
v.at(i).get_binary().data() + v.at(i).get_binary().size());
read_size += v.at(i).get_binary().size();
}
assert(read_size == binary_size);
l = literal(s, binary_array.data());
}
} }
void migraphx_to_value(value& v, const argument& a) { raw_data_to_value(v, a); } void migraphx_to_value(value& v, const argument& a) { raw_data_to_value(v, a); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment