"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3fe3bc0642cf6ebfa1a815367afd0dc57675ecc7"
Commit 4dc711bc authored by henrymai's avatar henrymai Committed by mcarilli
Browse files

prelu belongs in FP16_CASTS (#257)

The main use of these functions (e.g.: `torch.{conv*, prelu}`) is via their `torch.nn`
wrapping layers.

The `torch.nn` layers are what contain the weights and call into these lower level
functions with the weights as a parameter in their `forward()` method.

The `torch.conv*` functions are already in the `FP16_CASTS` list due to amp's philosophy of
casting the parameters rather than the model/layer weights.

Conceptually `torch.prelu` is the same as the `torch.conv*` case, where its weight parameter
is passed in from its wrapper layer `torch.nn.PReLU`.
parent 2c18651b
...@@ -5,8 +5,9 @@ from .. import utils ...@@ -5,8 +5,9 @@ from .. import utils
MODULE = torch MODULE = torch
FP16_FUNCS = [ FP16_FUNCS = [
# Math # Low level functions wrapped by torch.nn layers.
# TODO: why are these in top-level torch namespace? # The wrapper layers contain the weights which are then passed in as a parameter
# to these functions.
'conv1d', 'conv1d',
'conv2d', 'conv2d',
'conv3d', 'conv3d',
...@@ -14,6 +15,7 @@ FP16_FUNCS = [ ...@@ -14,6 +15,7 @@ FP16_FUNCS = [
'conv_transpose2d', 'conv_transpose2d',
'conv_transpose3d', 'conv_transpose3d',
'conv_tbc', 'conv_tbc',
'prelu',
# BLAS # BLAS
'addmm', 'addmm',
...@@ -76,7 +78,6 @@ CASTS = [ ...@@ -76,7 +78,6 @@ CASTS = [
'addcmul', 'addcmul',
'atan2', 'atan2',
'cross', 'cross',
'prelu',
# Element-wise _or_ tensor-wise math # Element-wise _or_ tensor-wise math
'add', 'add',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment