Commit 4dc711bc authored by henrymai's avatar henrymai Committed by mcarilli
Browse files

prelu belongs in FP16_CASTS (#257)

The main use of these functions (e.g.: `torch.{conv*, prelu}`) is via their `torch.nn`
wrapping layers.

The `torch.nn` layers are what contain the weights and call into these lower level
functions with the weights as a parameter in their `forward()` method.

The `torch.conv*` functions are already in the `FP16_CASTS` list due to amp's philosophy of
casting the parameters rather than the model/layer weights.

Conceptually `torch.prelu` is the same as the `torch.conv*` case, where its weight parameter
is passed in from its wrapper layer `torch.nn.PReLU`.
parent 2c18651b
......@@ -5,8 +5,9 @@ from .. import utils
MODULE = torch
FP16_FUNCS = [
# Math
# TODO: why are these in top-level torch namespace?
# Low level functions wrapped by torch.nn layers.
# The wrapper layers contain the weights which are then passed in as a parameter
# to these functions.
'conv1d',
'conv2d',
'conv3d',
......@@ -14,6 +15,7 @@ FP16_FUNCS = [
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc',
'prelu',
# BLAS
'addmm',
......@@ -76,7 +78,6 @@ CASTS = [
'addcmul',
'atan2',
'cross',
'prelu',
# Element-wise _or_ tensor-wise math
'add',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment