• Reed Wanderman-Milne's avatar
    Simply LayerNorm mixed precision logic. · 0257b276
    Reed Wanderman-Milne authored
    Instead of needing to ensure variables are float32, casting inputs to float32, etc, instead dtype="float32" is passed to the layer constructor, which will do all that logic automatically.
    
    The only difference is the output of LayerNorm is now float32 instead of float16, so an extra cast is needed elsewhere.
    
    PiperOrigin-RevId: 273833286
    0257b276
transformer.py 22.6 KB