"include/vscode:/vscode.git/clone" did not exist on "0e5c264c3e954c34483d8a50d9b622d5d455a160"
Unverified Commit 68ca6f21 authored by Jiahang Xu's avatar Jiahang Xu Committed by GitHub
Browse files

Fix bug Issue4592 (#4614)

parent 358ea2eb
......@@ -4,8 +4,8 @@ import warnings
import torch
import torch.nn as torch_nn
from torchvision.models.utils import load_state_dict_from_url
import torch.nn.functional as F
from nni.retiarii import model_wrapper
import sys
from pathlib import Path
......@@ -111,7 +111,7 @@ def _get_depths(depths, alpha):
rather than down. """
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
@model_wrapper
class MNASNet(nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model.
......@@ -180,7 +180,7 @@ class MNASNet(nn.Module):
nn.ReLU(inplace=True),
]
self.layers = nn.Sequential(*layers)
self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True),
self.classifier = nn.Sequential(nn.Dropout(p=dropout),
nn.Linear(1280, num_classes))
self._initialize_weights()
#self.for_test = 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment