Commit e0c791ea authored by Paul's avatar Paul
Browse files

Fix incorrect type

parent ee66d06f
......@@ -105,7 +105,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1, 2);
check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0);
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......@@ -131,7 +131,7 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1, 2);
check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0);
}
argument
......@@ -159,7 +159,7 @@ struct hip_copy
std::string name() const { return "hip::copy"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
check_shapes{inputs, *this}.has(2).same_type();
return inputs.at(1);
}
argument compute(context& ctx, const shape&, std::vector<argument> args) const
......
......@@ -51,17 +51,20 @@ struct layernorm_base
}
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
auto s = inputs.at(0);
auto t = s.type();
if (not mods.empty())
t = mods.front()->get_output_shapes().front().type();
if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{
return {s.type(), s.lens()};
return {t, s.lens()};
}
else
{
return s.with_lens(s.lens());
return s.with_lens(t, s.lens());
}
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment