"...text-generation-inference.git" did not exist on "e68509add7ac04f8d6b4b75fab6fa65f47c2a76c"
Unverified Commit 03e25734 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Update URL and add progress option (#1043)

parent 3254560b
......@@ -8,10 +8,10 @@ __all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']
_MODEL_URLS = {
"mnasnet0_5":
"https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet0.5_top1_67.592-7c6cb539b9.pth",
"https://download.pytorch.org/models/mnasnet0.5_top1_67.592-7c6cb539b9.pth",
"mnasnet0_75": None,
"mnasnet1_0":
"https://github.com/1e100/mnasnet_trainer/releases/download/v0.1/mnasnet1.0_top1_73.512-f206786ef8.pth",
"https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
"mnasnet1_3": None
}
......@@ -143,41 +143,41 @@ class MNASNet(torch.nn.Module):
nn.init.zeros_(m.bias)
def _load_pretrained(model_name, model):
def _load_pretrained(model_name, model, progress):
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
raise ValueError(
"No checkpoint is available for model type {}".format(model_name))
checkpoint_url = _MODEL_URLS[model_name]
model.load_state_dict(load_state_dict_from_url(checkpoint_url))
model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress))
def mnasnet0_5(pretrained=False, **kwargs):
def mnasnet0_5(pretrained=False, progress=True, **kwargs):
""" MNASNet with depth multiplier of 0.5. """
model = MNASNet(0.5, **kwargs)
if pretrained:
_load_pretrained("mnasnet0_5", model)
_load_pretrained("mnasnet0_5", model, progress)
return model
def mnasnet0_75(pretrained=False, **kwargs):
def mnasnet0_75(pretrained=False, progress=True, **kwargs):
""" MNASNet with depth multiplier of 0.75. """
model = MNASNet(0.75, **kwargs)
if pretrained:
_load_pretrained("mnasnet0_75", model)
_load_pretrained("mnasnet0_75", model, progress)
return model
def mnasnet1_0(pretrained=False, **kwargs):
def mnasnet1_0(pretrained=False, progress=True, **kwargs):
""" MNASNet with depth multiplier of 1.0. """
model = MNASNet(1.0, **kwargs)
if pretrained:
_load_pretrained("mnasnet1_0", model)
_load_pretrained("mnasnet1_0", model, progress)
return model
def mnasnet1_3(pretrained=False, **kwargs):
def mnasnet1_3(pretrained=False, progress=True, **kwargs):
""" MNASNet with depth multiplier of 1.3. """
model = MNASNet(1.3, **kwargs)
if pretrained:
_load_pretrained("mnasnet1_3", model)
_load_pretrained("mnasnet1_3", model, progress)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment