"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "83c6be1666e8ccf9055e8b7813064644f0a1ad69"
Commit 4d053deb authored by Neal Wu's avatar Neal Wu
Browse files

Allow users to pass in num_classes to ResNet

parent 7cb653fd
......@@ -129,9 +129,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
# Running the model
###############################################################################
class Cifar10Model(resnet.Model):
def __init__(self, resnet_size, data_format=None):
"""These are the parameters that work for CIFAR-10 data.
"""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES):
"""These are the parameters that work for CIFAR-10 data."""
if resnet_size % 6 != 2:
raise ValueError('resnet_size must be 6n + 2:', resnet_size)
......@@ -139,7 +139,7 @@ class Cifar10Model(resnet.Model):
super(Cifar10Model, self).__init__(
resnet_size=resnet_size,
num_classes=_NUM_CLASSES,
num_classes=num_classes,
num_filters=16,
kernel_size=3,
conv_stride=1,
......
......@@ -132,9 +132,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
# Running the model
###############################################################################
class ImagenetModel(resnet.Model):
def __init__(self, resnet_size, data_format=None):
"""These are the parameters that work for Imagenet data.
"""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES):
"""These are the parameters that work for Imagenet data."""
# For bigger models, we want to use "bottleneck" layers
if resnet_size < 50:
......@@ -146,7 +146,7 @@ class ImagenetModel(resnet.Model):
super(ImagenetModel, self).__init__(
resnet_size=resnet_size,
num_classes=_NUM_CLASSES,
num_classes=num_classes,
num_filters=64,
kernel_size=7,
conv_stride=2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment