"src/vscode:/vscode.git/clone" did not exist on "f8d3e73e7661baee8ab706d3e057a02a322143d5"
Commit 4ac2919f authored by Paul's avatar Paul
Browse files

Ensure standard shape for triadd layernorm

parent 6d34b90f
...@@ -255,6 +255,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm) ...@@ -255,6 +255,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm> struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm>
{ {
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return inputs[0];
}
// Empty finalize to skip dimension reduction // Empty finalize to skip dimension reduction
void finalize(context&, const shape&, const std::vector<shape>&) {} void finalize(context&, const shape&, const std::vector<shape>&) {}
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment