prelu belongs in FP16_CASTS (#257)
The main use of these functions (e.g.: `torch.{conv*, prelu}`) is via their `torch.nn`
wrapping layers.
The `torch.nn` layers are what contain the weights and call into these lower level
functions with the weights as a parameter in their `forward()` method.
The `torch.conv*` functions are already in the `FP16_CASTS` list due to amp's philosophy of
casting the parameters rather than the model/layer weights.
Conceptually `torch.prelu` is the same as the `torch.conv*` case, where its weight parameter
is passed in from its wrapper layer `torch.nn.PReLU`.
Showing
Please register or sign in to comment