Commit 1f085a0e authored by Karan Dwivedi's avatar Karan Dwivedi Committed by Soumith Chintala
Browse files

Add num_classes (#128)

parent 74d04d2c
......@@ -19,7 +19,7 @@ model_urls = {
class VGG(nn.Module):
def __init__(self, features):
def __init__(self, features, num_classes=1000):
super(VGG, self).__init__()
self.features = features
self.classifier = nn.Sequential(
......@@ -29,7 +29,7 @@ class VGG(nn.Module):
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 1000),
nn.Linear(4096, num_classes),
)
self._initialize_weights()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment