"test/vscode:/vscode.git/clone" did not exist on "8cc29b3773eec3f3af1301b42ccecaf0e8a4861c"
Commit 4ac2919f authored by Paul's avatar Paul
Browse files

Ensure standard shape for triadd layernorm

parent 6d34b90f
......@@ -255,6 +255,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm>
{
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return inputs[0];
}
// Empty finalize to skip dimension reduction
void finalize(context&, const shape&, const std::vector<shape>&) {}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment