Unverified Commit fbd69f10 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Make StochasticDepth FX-compatible (#4373)

parent 72d650ae
...@@ -39,7 +39,7 @@ class TestFxFeatureExtraction: ...@@ -39,7 +39,7 @@ class TestFxFeatureExtraction:
'num_classes': 1, 'num_classes': 1,
'pretrained': False 'pretrained': False
} }
leaf_modules = [torchvision.ops.StochasticDepth] leaf_modules = []
def _create_feature_extractor(self, *args, **kwargs): def _create_feature_extractor(self, *args, **kwargs):
""" """
......
import torch import torch
import torch.fx
from torch import nn, Tensor from torch import nn, Tensor
...@@ -37,6 +38,9 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) ...@@ -37,6 +38,9 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True)
return input * noise return input * noise
torch.fx.wrap('stochastic_depth')
class StochasticDepth(nn.Module): class StochasticDepth(nn.Module):
""" """
See :func:`stochastic_depth`. See :func:`stochastic_depth`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment