Commit e0c791ea authored by Paul's avatar Paul
Browse files

Fix incorrect type

parent ee66d06f
...@@ -105,7 +105,7 @@ struct hip_copy_to_gpu ...@@ -105,7 +105,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; } std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2); check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -131,7 +131,7 @@ struct hip_copy_from_gpu ...@@ -131,7 +131,7 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; } std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2); check_shapes{inputs, *this}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument argument
...@@ -159,7 +159,7 @@ struct hip_copy ...@@ -159,7 +159,7 @@ struct hip_copy
std::string name() const { return "hip::copy"; } std::string name() const { return "hip::copy"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2).same_type();
return inputs.at(1); return inputs.at(1);
} }
argument compute(context& ctx, const shape&, std::vector<argument> args) const argument compute(context& ctx, const shape&, std::vector<argument> args) const
......
...@@ -51,17 +51,20 @@ struct layernorm_base ...@@ -51,17 +51,20 @@ struct layernorm_base
} }
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
auto s = inputs.at(0); auto s = inputs.at(0);
auto t = s.type();
if (not mods.empty())
t = mods.front()->get_output_shapes().front().type();
if(s.scalar()) if(s.scalar())
{ {
return s; return s;
} }
else if(s.broadcasted()) else if(s.broadcasted())
{ {
return {s.type(), s.lens()}; return {t, s.lens()};
} }
else else
{ {
return s.with_lens(s.lens()); return s.with_lens(t, s.lens());
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment