Unverified Commit e3767f8a authored by Kai Zhang's avatar Kai Zhang Committed by GitHub
Browse files

add regnet_y_128gf factory function (#5176)



* add regnet_y_128gf

* fix test

* add expected test file

* update regnet factory function, add to prototype as well

* write torchscript to temp file instead bytesio in model test

* docs

* clear GPU memory

* no_grad

* nit
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent a8f2dedb
...@@ -76,6 +76,7 @@ You can construct a model with random weights by calling its constructor: ...@@ -76,6 +76,7 @@ You can construct a model with random weights by calling its constructor:
regnet_y_8gf = models.regnet_y_8gf() regnet_y_8gf = models.regnet_y_8gf()
regnet_y_16gf = models.regnet_y_16gf() regnet_y_16gf = models.regnet_y_16gf()
regnet_y_32gf = models.regnet_y_32gf() regnet_y_32gf = models.regnet_y_32gf()
regnet_y_128gf = models.regnet_y_128gf()
regnet_x_400mf = models.regnet_x_400mf() regnet_x_400mf = models.regnet_x_400mf()
regnet_x_800mf = models.regnet_x_800mf() regnet_x_800mf = models.regnet_x_800mf()
regnet_x_1_6gf = models.regnet_x_1_6gf() regnet_x_1_6gf = models.regnet_x_1_6gf()
...@@ -439,6 +440,7 @@ RegNet ...@@ -439,6 +440,7 @@ RegNet
regnet_y_8gf regnet_y_8gf
regnet_y_16gf regnet_y_16gf
regnet_y_32gf regnet_y_32gf
regnet_y_128gf
regnet_x_400mf regnet_x_400mf
regnet_x_800mf regnet_x_800mf
regnet_x_1_6gf regnet_x_1_6gf
......
...@@ -27,6 +27,7 @@ from torchvision.models.regnet import ( ...@@ -27,6 +27,7 @@ from torchvision.models.regnet import (
regnet_y_8gf, regnet_y_8gf,
regnet_y_16gf, regnet_y_16gf,
regnet_y_32gf, regnet_y_32gf,
regnet_y_128gf,
regnet_x_400mf, regnet_x_400mf,
regnet_x_800mf, regnet_x_800mf,
regnet_x_1_6gf, regnet_x_1_6gf,
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
import contextlib import contextlib
import functools import functools
import io
import operator import operator
import os import os
import pkgutil import pkgutil
...@@ -8,6 +7,7 @@ import sys ...@@ -8,6 +7,7 @@ import sys
import traceback import traceback
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from tempfile import TemporaryDirectory
import pytest import pytest
import torch import torch
...@@ -126,16 +126,16 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False): ...@@ -126,16 +126,16 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
def get_export_import_copy(m): def get_export_import_copy(m):
"""Save and load a TorchScript model""" """Save and load a TorchScript model"""
buffer = io.BytesIO() with TemporaryDirectory() as dir:
torch.jit.save(m, buffer) path = os.path.join(dir, "script.pt")
buffer.seek(0) m.save(path)
imported = torch.jit.load(buffer) imported = torch.jit.load(path)
return imported return imported
m_import = get_export_import_copy(m) m_import = get_export_import_copy(m)
with freeze_rng_state(): with torch.no_grad(), freeze_rng_state():
results = m(*args) results = m(*args)
with freeze_rng_state(): with torch.no_grad(), freeze_rng_state():
results_from_imported = m_import(*args) results_from_imported = m_import(*args)
tol = 3e-4 tol = 3e-4
torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol) torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol)
...@@ -156,10 +156,10 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False): ...@@ -156,10 +156,10 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
sm = torch.jit.script(nn_module) sm = torch.jit.script(nn_module)
with freeze_rng_state(): with torch.no_grad(), freeze_rng_state():
eager_out = nn_module(*args) eager_out = nn_module(*args)
with freeze_rng_state(): with torch.no_grad(), freeze_rng_state():
script_out = sm(*args) script_out = sm(*args)
if unwrapper: if unwrapper:
script_out = unwrapper(script_out) script_out = unwrapper(script_out)
......
...@@ -26,6 +26,7 @@ __all__ = [ ...@@ -26,6 +26,7 @@ __all__ = [
"regnet_y_8gf", "regnet_y_8gf",
"regnet_y_16gf", "regnet_y_16gf",
"regnet_y_32gf", "regnet_y_32gf",
"regnet_y_128gf",
"regnet_x_400mf", "regnet_x_400mf",
"regnet_x_800mf", "regnet_x_800mf",
"regnet_x_1_6gf", "regnet_x_1_6gf",
...@@ -505,6 +506,18 @@ def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any ...@@ -505,6 +506,18 @@ def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any
return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs) return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs)
def regnet_y_128gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
"""
Constructs a RegNetY_128GF architecture from
`"Designing Network Design Spaces" <https://arxiv.org/abs/2003.13678>`_.
NOTE: Pretrained weights are not available for this model.
"""
params = BlockParams.from_init_params(
depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
)
return _regnet("regnet_y_128gf", params, pretrained, progress, **kwargs)
def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
""" """
Constructs a RegNetX_400MF architecture from Constructs a RegNetX_400MF architecture from
......
...@@ -20,6 +20,7 @@ __all__ = [ ...@@ -20,6 +20,7 @@ __all__ = [
"RegNet_Y_8GF_Weights", "RegNet_Y_8GF_Weights",
"RegNet_Y_16GF_Weights", "RegNet_Y_16GF_Weights",
"RegNet_Y_32GF_Weights", "RegNet_Y_32GF_Weights",
"RegNet_Y_128GF_Weights",
"RegNet_X_400MF_Weights", "RegNet_X_400MF_Weights",
"RegNet_X_800MF_Weights", "RegNet_X_800MF_Weights",
"RegNet_X_1_6GF_Weights", "RegNet_X_1_6GF_Weights",
...@@ -34,6 +35,7 @@ __all__ = [ ...@@ -34,6 +35,7 @@ __all__ = [
"regnet_y_8gf", "regnet_y_8gf",
"regnet_y_16gf", "regnet_y_16gf",
"regnet_y_32gf", "regnet_y_32gf",
"regnet_y_128gf",
"regnet_x_400mf", "regnet_x_400mf",
"regnet_x_800mf", "regnet_x_800mf",
"regnet_x_1_6gf", "regnet_x_1_6gf",
...@@ -253,6 +255,11 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ...@@ -253,6 +255,11 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
default = ImageNet1K_V2 default = ImageNet1K_V2
class RegNet_Y_128GF_Weights(WeightsEnum):
# weights are not available yet.
pass
class RegNet_X_400MF_Weights(WeightsEnum): class RegNet_X_400MF_Weights(WeightsEnum):
ImageNet1K_V1 = Weights( ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
...@@ -501,6 +508,16 @@ def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: ...@@ -501,6 +508,16 @@ def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress:
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", None))
def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_Y_128GF_Weights.verify(weights)
params = BlockParams.from_init_params(
depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
)
return _regnet(params, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.ImageNet1K_V1)) @handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.ImageNet1K_V1))
def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
weights = RegNet_X_400MF_Weights.verify(weights) weights = RegNet_X_400MF_Weights.verify(weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment