Commit 870a396b authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 228b665c d309e02f
......@@ -27,6 +27,7 @@
#include <migraphx/algorithm.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS {
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
......@@ -50,25 +52,62 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return s0;
if(s0.size() > s1.size())
s0.swap(s1);
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + migraphx::to_string_range(s0) +
"} and {" + migraphx::to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1)
{
// change both shapes to dynamic_dimension representation
s0 = s0.to_dynamic();
s1 = s1.to_dynamic();
if(s0.ndim() > s1.ndim())
{
std::swap(s0, s1);
}
auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
std::transform(
s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(),
s1.dyn_dims().cbegin() + offset,
out_dims.begin() + offset,
[&](auto a, auto b) {
if(a == b)
{
return a;
}
else if(a == 1 or b == 1)
{
// setting opt to 0, may need to be changed
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0};
}
else
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(s0.dyn_dims()) + "} and {" +
migraphx::to_string_range(s1.dyn_dims()) + "} mismatch!");
}
});
return out_dims;
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
{
assert(not shapes.empty());
assert(
std::none_of(shapes.cbegin(), shapes.cend(), [](auto shape) { return shape.dynamic(); }));
return transform_accumulate(shapes.begin() + 1,
shapes.end(),
shapes.front().lens(),
......@@ -114,20 +153,63 @@ instruction_ref insert_common_op(module& m,
const operation& op,
std::vector<instruction_ref> inputs)
{
auto common = common_shape(to_shapes(inputs));
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().lens() != common.lens())
if(std::any_of(
inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
{
// currently only handles the binary case
if(inputs.size() != 2)
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
MIGRAPHX_THROW("INSERT_COMMON_OP: not handled; " + migraphx::to_string(inputs.size()) +
"inputs, only handle two inputs if any are dynamic shape");
}
if(input->get_shape().type() != common.type())
auto c_type = compute_common_types(to_shapes(inputs));
auto c_dyn_dims =
compute_broadcasted_dyn_dims(inputs[0]->get_shape(), inputs[1]->get_shape());
// following should work for a static or dynamic shape
if(inputs[0]->get_shape().dyn_dims() != c_dyn_dims)
{
input = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input);
inputs[0] = m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[0],
inputs[1]);
}
return input;
});
if(inputs[1]->get_shape().dyn_dims() != c_dyn_dims)
{
inputs[1] = m.insert_instruction(
ins,
make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}),
inputs[1],
inputs[0]);
}
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().type() != c_type)
{
input =
m.insert_instruction(ins, make_op("convert", {{"target_type", c_type}}), input);
}
return input;
});
}
else
{
auto common = common_shape(to_shapes(inputs));
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().lens() != common.lens())
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", common.lens()}}), input);
}
if(input->get_shape().type() != common.type())
{
input = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input);
}
return input;
});
}
return m.insert_instruction(ins, op, inputs);
}
......
......@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate]
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and
i->name().front() != '@' and
not contains({"undefined", "identity", "allocate"}, i->name()))
not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and
not i->is_undefined())
continue;
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited;
......
......@@ -74,9 +74,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
auto x_main_module_19 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 18));
auto x_main_module_20 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,"
"4],use_dynamic_same_auto_pad:0}"),
migraphx::make_json_op(
"convolution",
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}"),
x_0,
x_main_module_19);
auto x_main_module_21 = mmain->add_instruction(
......@@ -90,9 +90,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_23);
auto x_main_module_25 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"),
migraphx::make_json_op(
"convolution",
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}"),
x_main_module_24,
x_main_module_17);
auto x_main_module_26 = mmain->add_instruction(
......@@ -106,9 +106,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_28);
auto x_main_module_30 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"),
migraphx::make_json_op(
"convolution",
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"),
x_main_module_29,
x_main_module_15);
auto x_main_module_31 = mmain->add_instruction(
......@@ -117,9 +117,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31);
auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32);
auto x_main_module_34 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"),
migraphx::make_json_op(
"convolution",
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"),
x_main_module_33,
x_main_module_13);
auto x_main_module_35 = mmain->add_instruction(
......@@ -128,9 +128,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35);
auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36);
auto x_main_module_38 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"),
migraphx::make_json_op(
"convolution",
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"),
x_main_module_37,
x_main_module_11);
auto x_main_module_39 = mmain->add_instruction(
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -44,7 +44,6 @@
#include <migraphx/propagate_constant.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/register_target.hpp>
......@@ -110,8 +109,12 @@ struct loader
ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true));
ap(output_type,
{"--cpp"},
ap.help("Print out the program as cpp program."),
ap.help("Print out the program as C++ program."),
ap.set_value("cpp"));
ap(output_type,
{"--python", "--py"},
ap.help("Print out the program as python program."),
ap.set_value("py"));
ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json"));
ap(output_type,
{"--text"},
......@@ -221,7 +224,6 @@ struct loader
{
migraphx::run_passes(*p.get_main_module(),
{
migraphx::rewrite_batchnorm{},
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{},
migraphx::simplify_algebra{},
......@@ -261,7 +263,9 @@ struct loader
type = "binary";
}
if(type == "cpp")
if(type == "py")
p.print_py(*os);
else if(type == "cpp")
p.print_cpp(*os);
else if(type == "graphviz")
p.print_graph(*os, brief);
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -145,7 +145,7 @@ void verify_reduced(program p,
auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n + 1);
mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << std::endl;
std::cout << "Verify: " << n << std::endl;
std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance);
}
......@@ -159,6 +159,7 @@ void verify_reduced_program(const program& p,
{
const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end());
std::cout << "Verify steps: " << n << std::endl;
for(std::size_t i = 0; i < n; i++)
{
verify_reduced(p, i, t, options, quantize, inputs, tolerance);
......
......@@ -42,6 +42,13 @@ static bool try_compute_shape(instruction_ref ins,
try
{
shape new_shape = ins->get_operator().compute_shape(inputs, mods);
// Cannot tell if a dynamic shape will need to be made contiguous
if(new_shape.dynamic())
{
return false;
}
// If the output shape is a standard shape, no need to try its output
if(new_shape.standard())
{
......@@ -133,14 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
}
}
// Perform evaluations in parallel
// Perform static contiguous evaluations in parallel
std::vector<argument> literals(const_instructions.size());
par_for(const_instructions.size(), 1, [&](const auto i) {
auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front();
literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front();
// compute the output contiguous shape from the previous instruction shape
shape computed_shape = c.compute_shape({prev->get_shape()});
const std::vector<argument>& prev_eval = {prev->eval()};
// prev_eval should not be used in make_compute_output_shape() as computed_shape is static
auto co_shape = make_compute_output_shape(pack(c, computed_shape, prev_eval));
literals[i] = c.compute(co_shape, prev_eval);
});
// Replace static contiguous operations with a literal
for(size_t i = 0; i < const_instructions.size(); i++)
{
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
......
......@@ -30,23 +30,31 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
T generic_read_file(const std::string& filename)
T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0)
{
std::ifstream is(filename, std::ios::binary | std::ios::ate);
std::streamsize size = is.tellg();
if(size < 1)
if(nbytes == 0)
{
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes = is.tellg();
if(offset > nbytes)
MIGRAPHX_THROW("offset is larger than file size");
nbytes -= offset;
}
if(nbytes < 1)
MIGRAPHX_THROW("Invalid size for: " + filename);
is.seekg(0, std::ios::beg);
is.seekg(offset, std::ios::beg);
T buffer(size, 0);
if(not is.read(&buffer[0], size))
T buffer(nbytes, 0);
if(not is.read(&buffer[0], nbytes))
MIGRAPHX_THROW("Error reading file: " + filename);
return buffer;
}
std::vector<char> read_buffer(const std::string& filename)
std::vector<char> read_buffer(const std::string& filename, size_t offset, size_t nbytes)
{
return generic_read_file<std::vector<char>>(filename);
return generic_read_file<std::vector<char>>(filename, offset, nbytes);
}
std::string read_string(const std::string& filename)
......
......@@ -39,13 +39,22 @@ static literal get_scalar(instruction_ref ins)
if(ins->name() == "contiguous")
return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape();
if(not(s.elements() == 1 or s.scalar()))
if(s.elements() != 1 && not(s.scalar()))
return {};
if(not ins->can_eval())
return {};
auto e = ins->eval();
literal r{};
e.visit_at([&](auto x) { r = literal{x}; });
// needed for bool as visit_at invokes as() which promotes bool to int8
// Without this we'll break type checks for logical ops that are fused.
if(e.get_shape().type() == shape::bool_type)
{
r = literal{e.at<bool>()};
}
else
{
e.visit_at([&](auto x) { r = literal{x}; });
}
return r;
}
......
......@@ -107,6 +107,7 @@ struct argument : raw_data<argument>
data_t m_data{};
};
std::vector<shape> to_shapes(const std::vector<argument>& args);
void migraphx_to_value(value& v, const argument& a);
void migraphx_from_value(const value& v, argument& a);
......
......@@ -198,7 +198,7 @@ struct check_shapes
*/
const check_shapes& same_ndims() const
{
if(not this->same([](const shape& s) { return s.max_lens().size(); }))
if(not this->same([](const shape& s) { return s.ndim(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this;
}
......
......@@ -36,6 +36,9 @@ struct operation;
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1);
std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, shape s1);
shape common_shape(const std::vector<shape>& shapes);
instruction_ref insert_common_op(module& m,
......
......@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
{
return {};
}
template <class T>
void wait_for_context(T&, any_ptr)
{
}
template <class T>
void finish_on_context(T&, any_ptr)
{
}
#ifdef TYPE_ERASED_DECLARATION
......@@ -78,6 +87,10 @@ struct context
void from_value(const value& v);
// (optional)
any_ptr get_queue();
// (optional)
void wait_for(any_ptr queue);
// (optional)
void finish_on(any_ptr queue);
//
void finish() const;
};
......@@ -165,6 +178,18 @@ struct context
return (*this).private_detail_te_get_handle().get_queue();
}
void wait_for(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait_for(queue);
}
void finish_on(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finish_on(queue);
}
void finish() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -187,6 +212,8 @@ struct context
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0;
virtual void wait_for(any_ptr queue) = 0;
virtual void finish_on(any_ptr queue) = 0;
virtual void finish() const = 0;
};
......@@ -231,6 +258,33 @@ struct context
return get_queue_context(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.wait_for(queue))
{
private_detail_te_self.wait_for(queue);
}
template <class T>
static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue)
{
wait_for_context(private_detail_te_self, queue);
}
template <class T>
static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.finish_on(queue))
{
private_detail_te_self.finish_on(queue);
}
template <class T>
static void
private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue)
{
finish_on_context(private_detail_te_self, queue);
}
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
......@@ -248,7 +302,7 @@ struct context
PrivateDetailTypeErasedT value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
: private_detail_te_value(value)
{
}
......@@ -277,6 +331,18 @@ struct context
return private_detail_te_default_get_queue(char(0), private_detail_te_value);
}
void wait_for(any_ptr queue) override
{
private_detail_te_default_wait_for(char(0), private_detail_te_value, queue);
}
void finish_on(any_ptr queue) override
{
private_detail_te_default_finish_on(char(0), private_detail_te_value, queue);
}
void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value;
......
......@@ -21,36 +21,55 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument concat(hipStream_t stream,
const migraphx::shape&,
std::vector<migraphx::argument> args,
std::vector<std::size_t> offsets)
struct dyn_output
{
auto ninputs = args.size() - 1;
for(std::size_t j = 0; j < ninputs; j++)
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
auto&& arg = args[j];
auto offset = offsets[j];
auto byte_offset = offset * arg.get_shape().type_size();
auto output_shape = shape{
arg.get_shape().type(), arg.get_shape().lens(), args.back().get_shape().strides()};
auto output = argument{output_shape, args.back().data() + byte_offset};
contiguous(stream, output, arg);
return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
if(ins_shape.dynamic())
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
return dyn_output{ins_shape, ins_shape};
});
}
return args.back();
operator shape() const
{
return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
}
};
template <class F>
compute_output_shape<F> make_compute_output_shape(F f)
{
return {f};
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -21,22 +21,21 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_ADD_HPP
#define MIGRAPHX_GUARD_RTGLIB_ADD_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/any_ptr.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_add : binary_device<hip_add, device::add>
struct execution_environment
{
any_ptr queue = any_ptr{};
bool async = false;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#endif /* MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP */
......@@ -31,7 +31,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<char> read_buffer(const std::string& filename);
std::vector<char> read_buffer(const std::string& filename, size_t offset = 0, size_t nbytes = 0);
std::string read_string(const std::string& filename);
void write_buffer(const std::string& filename, const char* buffer, std::size_t size);
......
......@@ -121,6 +121,8 @@ struct instruction
bool can_eval() const;
bool is_undefined() const;
argument eval(bool check_eval = true) const;
void finalize(context& ctx);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
......
......@@ -80,6 +80,7 @@ struct literal : raw_data<literal>
fill(start, end);
}
// Directly copies buffer of x
template <class T, MIGRAPHX_REQUIRES(sizeof(T) == 1)>
literal(const shape& s, T* x) : buffer(make_shared_array<char>(s.bytes())), m_shape(s)
{
......@@ -107,25 +108,15 @@ struct literal : raw_data<literal>
std::shared_ptr<char> buffer;
shape m_shape;
// Keeps the same data ordering as the given container
template <class Iterator>
void fill(Iterator start, Iterator end)
{
assert(std::distance(start, end) == m_shape.elements());
if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
}
else
{
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
it++;
});
});
}
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get()));
std::copy(start, end, output.begin());
});
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment