-
Ming-Xu Huang authored
[JAX] Fix missing axes parameters in TransformerLayer and the wrong shape of bias in LayerNormMLP (#196) Fixed missing axes and wrong shape of bias in LayerNormMLP Signed-off-by:Ming Huang <mingh@nvidia.com>
22ccf9b1
[JAX] Fix missing axes parameters in TransformerLayer and the wrong shape of bias in LayerNormMLP (#196)
Fixed missing axes and wrong shape of bias in LayerNormMLP
Signed-off-by:
Ming Huang <mingh@nvidia.com>