"src/diffusers/pipelines/lumina2/pipeline_lumina2.py" did not exist on "75a636da4882771ca8834b804f767daa9394ffa8"
Commit c0936f8f authored by Khalique's avatar Khalique
Browse files

made changes according to feedback

parent 7e2cf4e8
......@@ -36,7 +36,7 @@ struct tf_parser
std::unordered_map<std::string, op_func> ops;
std::vector<size_t> parse_axes(attribute_map& attributes, const std::string& s) const
std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s) const
{
auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes;
......@@ -45,7 +45,7 @@ struct tf_parser
{
for(size_t& axis : axes)
{
parse_axis(axis);
axis = parse_axis(axis);
}
}
return axes;
......@@ -57,27 +57,27 @@ struct tf_parser
std::vector<T> new_data(prev_data.size());
for(size_t i = 0; i < new_data.size(); i++)
{
auto new_idx = i;
parse_axis(new_idx);
auto new_idx = parse_axis(i);
new_data.at(new_idx) = prev_data.at(i);
}
prev_data = new_data;
}
template <class T>
void parse_axis(T& dim) const
T parse_axis(const T& dim) const
{
if(is_nhwc)
{
switch(dim)
{
case 0: dim = 0; break;
case 1: dim = 2; break;
case 2: dim = 3; break;
case 3: dim = 1; break;
default: break;
case 0: return 0;
case 1: return 2;
case 2: return 3;
case 3: return 1;
default: return T{dim};
}
}
return T{dim};
}
std::vector<int64_t> get_axes(size_t num_axes) const
......@@ -224,8 +224,7 @@ struct tf_parser
{
// get index for axis within args
size_t axis_idx = attributes.at("N").i();
size_t axis = args[axis_idx]->eval().at<int64_t>();
parse_axis(axis);
size_t axis = parse_axis(args[axis_idx]->eval().at<int64_t>());
op::concat op{axis};
// return only first N arguments (assuming last index is the axis value)
return prog.add_instruction(
......@@ -263,7 +262,6 @@ struct tf_parser
{
std::vector<size_t> padding;
copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
reorder_data(padding);
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
......@@ -369,7 +367,7 @@ struct tf_parser
op::reshape op;
if(args.size() != 2)
MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
literal s = args[1]->get_literal();
auto s = args[1]->eval();
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args[0]);
}
......@@ -674,7 +672,12 @@ struct tf_parser
{
std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
return literal{{shape::half_type, dims}, data_uint16};
std::vector<half> data_half;
std::transform(data_uint16.begin(),
data_uint16.end(),
std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return literal{{shape::half_type, dims}, data_half};
}
case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment