[PyTorch] Fix bug in FP8 cast in LayerNormLinear/LayerNormMLP (#738)
Perform FP8 cast on gathered layernorm output in LayerNormLinear
Signed-off-by:
Tim Moon <tmoon@nvidia.com>
Showing
Please register or sign in to comment
Perform FP8 cast on gathered layernorm output in LayerNormLinear
Signed-off-by:
Tim Moon <tmoon@nvidia.com>