Commit f63f2a97 authored by yl-1993's avatar yl-1993
Browse files

fix flake8 error

parent 8332ddbc
...@@ -12,9 +12,7 @@ class AlexNet(nn.Module): ...@@ -12,9 +12,7 @@ class AlexNet(nn.Module):
num_classes (int): number of classes for classification. num_classes (int): number of classes for classification.
""" """
def __init__(self, num_classes=-1):
def __init__(self,
num_classes=-1):
super(AlexNet, self).__init__() super(AlexNet, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.features = nn.Sequential( self.features = nn.Sequential(
......
import logging import logging
import math
import torch.nn as nn import torch.nn as nn
...@@ -118,7 +117,10 @@ class VGG(nn.Module): ...@@ -118,7 +117,10 @@ class VGG(nn.Module):
elif pretrained is None: elif pretrained is None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(
m.weight,
mode='fan_out',
nonlinearity='relu')
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
...@@ -131,6 +133,7 @@ class VGG(nn.Module): ...@@ -131,6 +133,7 @@ class VGG(nn.Module):
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x):
outs = []
for i, layer_name in enumerate(self.vgg_layers): for i, layer_name in enumerate(self.vgg_layers):
vgg_layer = getattr(self, layer_name) vgg_layer = getattr(self, layer_name)
x = vgg_layer(x) x = vgg_layer(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment