Unverified Commit e3767f8a authored by Kai Zhang's avatar Kai Zhang Committed by GitHub
Browse files

add regnet_y_128gf factory function (#5176)



* add regnet_y_128gf

* fix test

* add expected test file

* update regnet factory function, add to prototype as well

* write torchscript to temp file instead bytesio in model test

* docs

* clear GPU memory

* no_grad

* nit
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent a8f2dedb
......@@ -76,6 +76,7 @@ You can construct a model with random weights by calling its constructor:
regnet_y_8gf = models.regnet_y_8gf()
regnet_y_16gf = models.regnet_y_16gf()
regnet_y_32gf = models.regnet_y_32gf()
regnet_y_128gf = models.regnet_y_128gf()
regnet_x_400mf = models.regnet_x_400mf()
regnet_x_800mf = models.regnet_x_800mf()
regnet_x_1_6gf = models.regnet_x_1_6gf()
......@@ -439,6 +440,7 @@ RegNet
regnet_y_8gf
regnet_y_16gf
regnet_y_32gf
regnet_y_128gf
regnet_x_400mf
regnet_x_800mf
regnet_x_1_6gf
......
......@@ -27,6 +27,7 @@ from torchvision.models.regnet import (
regnet_y_8gf,
regnet_y_16gf,
regnet_y_32gf,
regnet_y_128gf,
regnet_x_400mf,
regnet_x_800mf,
regnet_x_1_6gf,
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
import contextlib
import functools
import io
import operator
import os
import pkgutil
......@@ -8,6 +7,7 @@ import sys
import traceback
import warnings
from collections import OrderedDict
from tempfile import TemporaryDirectory
import pytest
import torch
......@@ -126,16 +126,16 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
def get_export_import_copy(m):
"""Save and load a TorchScript model"""
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
imported = torch.jit.load(buffer)
with TemporaryDirectory() as dir:
path = os.path.join(dir, "script.pt")
m.save(path)
imported = torch.jit.load(path)
return imported
m_import = get_export_import_copy(m)
with freeze_rng_state():
with torch.no_grad(), freeze_rng_state():
results = m(*args)
with freeze_rng_state():
with torch.no_grad(), freeze_rng_state():
results_from_imported = m_import(*args)
tol = 3e-4
torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol)
......@@ -156,10 +156,10 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
sm = torch.jit.script(nn_module)
with freeze_rng_state():
with torch.no_grad(), freeze_rng_state():
eager_out = nn_module(*args)
with freeze_rng_state():
with torch.no_grad(), freeze_rng_state():
script_out = sm(*args)
if unwrapper:
script_out = unwrapper(script_out)
......
......@@ -26,6 +26,7 @@ __all__ = [
"regnet_y_8gf",
"regnet_y_16gf",
"regnet_y_32gf",
"regnet_y_128gf",
"regnet_x_400mf",
"regnet_x_800mf",
"regnet_x_1_6gf",
......@@ -505,6 +506,18 @@ def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any
return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs)
def regnet_y_128gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetY_128GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
NOTE: Pretrained weights are not available for this model.
"""
params = BlockParams.from_init_params(
depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
)
return _regnet("regnet_y_128gf", params, pretrained, progress, **kwargs)
def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetX_400MF architecture from
......
......@@ -20,6 +20,7 @@ __all__ = [
"RegNet_Y_8GF_Weights",
"RegNet_Y_16GF_Weights",
"RegNet_Y_32GF_Weights",
"RegNet_Y_128GF_Weights",
"RegNet_X_400MF_Weights",
"RegNet_X_800MF_Weights",
"RegNet_X_1_6GF_Weights",
......@@ -34,6 +35,7 @@ __all__ = [
"regnet_y_8gf",
"regnet_y_16gf",
"regnet_y_32gf",
"regnet_y_128gf",
"regnet_x_400mf",
"regnet_x_800mf",
"regnet_x_1_6gf",
......@@ -253,6 +255,11 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
default = ImageNet1K_V2
class RegNet_Y_128GF_Weights(WeightsEnum):
# weights are not available yet.
pass
class RegNet_X_400MF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
......@@ -501,6 +508,16 @@ def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress:
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", None))
def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_128GF_Weights.verify(weights)
params = BlockParams.from_init_params(
depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.ImageNet1K_V1))
def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_400MF_Weights.verify(weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment