[JAX] Fix missing axes parameters in TransformerLayer and the wrong shape of...
[JAX] Fix missing axes parameters in TransformerLayer and the wrong shape of bias in LayerNormMLP (#196)
Fixed missing axes and wrong shape of bias in LayerNormMLP
Signed-off-by:
Ming Huang <mingh@nvidia.com>
Showing
Please register or sign in to comment