feat: support force downcast after FastRMSNorm multiply for Gemma (#1658)
This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is used in the Gemma model. References https://github.com/huggingface/transformers/pull/29402 and https://github.com/huggingface/transformers/pull/29729 Setting `force_downcast_after=True` will perform the `hidden_states * weight` multiplication in f32 and then downcast to half. This differs slightly from the current implementation which first casts the `hidden_states` to a half and then multiples.
Showing
Please register or sign in to comment