Unverified Commit 68ca6f21 authored by Jiahang Xu's avatar Jiahang Xu Committed by GitHub
Browse files

Fix bug Issue4592 (#4614)

parent 358ea2eb
...@@ -4,8 +4,8 @@ import warnings ...@@ -4,8 +4,8 @@ import warnings
import torch import torch
import torch.nn as torch_nn import torch.nn as torch_nn
from torchvision.models.utils import load_state_dict_from_url
import torch.nn.functional as F import torch.nn.functional as F
from nni.retiarii import model_wrapper
import sys import sys
from pathlib import Path from pathlib import Path
...@@ -111,7 +111,7 @@ def _get_depths(depths, alpha): ...@@ -111,7 +111,7 @@ def _get_depths(depths, alpha):
rather than down. """ rather than down. """
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
@model_wrapper
class MNASNet(nn.Module): class MNASNet(nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model. implements the B1 variant of the model.
...@@ -180,7 +180,7 @@ class MNASNet(nn.Module): ...@@ -180,7 +180,7 @@ class MNASNet(nn.Module):
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
] ]
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), self.classifier = nn.Sequential(nn.Dropout(p=dropout),
nn.Linear(1280, num_classes)) nn.Linear(1280, num_classes))
self._initialize_weights() self._initialize_weights()
#self.for_test = 10 #self.for_test = 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment