Commit 1f085a0e authored by Karan Dwivedi's avatar Karan Dwivedi Committed by Soumith Chintala
Browse files

Add num_classes (#128)

parent 74d04d2c
...@@ -19,7 +19,7 @@ model_urls = { ...@@ -19,7 +19,7 @@ model_urls = {
class VGG(nn.Module): class VGG(nn.Module):
def __init__(self, features): def __init__(self, features, num_classes=1000):
super(VGG, self).__init__() super(VGG, self).__init__()
self.features = features self.features = features
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
...@@ -29,7 +29,7 @@ class VGG(nn.Module): ...@@ -29,7 +29,7 @@ class VGG(nn.Module):
nn.Linear(4096, 4096), nn.Linear(4096, 4096),
nn.ReLU(True), nn.ReLU(True),
nn.Dropout(), nn.Dropout(),
nn.Linear(4096, 1000), nn.Linear(4096, num_classes),
) )
self._initialize_weights() self._initialize_weights()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment