Unverified Commit 11bd2eaa authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Port Multi-weight support from prototype to main (#5618)

* Moving basefiles outside of prototype and porting Alexnet, ConvNext, Densenet and EfficientNet.

* Porting googlenet

* Porting inception

* Porting mnasnet

* Porting mobilenetv2

* Porting mobilenetv3

* Porting regnet

* Porting resnet

* Porting shufflenetv2

* Porting squeezenet

* Porting vgg

* Porting vit

* Fix docstrings

* Fixing imports

* Adding missing import

* Fix mobilenet imports

* Fix tests

* Fix prototype tests

* Exclude get_weight from models on test

* Fix init files

* Porting googlenet

* Porting inception

* porting mobilenetv2

* porting mobilenetv3

* porting resnet

* porting shufflenetv2

* Fix test and linter

* Fixing docs.

* Porting Detection models (#5617)

* fix inits

* fix docs

* Port faster_rcnn

* Port fcos

* Port keypoint_rcnn

* Port mask_rcnn

* Port retinanet

* Port ssd

* Port ssdlite

* Fix linter

* Fixing tests

* Fi...
parent 375e4ab2
from .mobilenet import *
from .resnet import *
from .googlenet import *
from .inception import *
from .mobilenet import *
from .resnet import *
from .shufflenetv2 import *
from .mobilenetv2 import QuantizableMobileNetV2, mobilenet_v2, __all__ as mv2_all
from .mobilenetv3 import QuantizableMobileNetV3, mobilenet_v3_large, __all__ as mv3_all
from .mobilenetv2 import * # noqa: F401, F403
from .mobilenetv3 import * # noqa: F401, F403
from .mobilenetv2 import __all__ as mv2_all
from .mobilenetv3 import __all__ as mv3_all
__all__ = mv2_all + mv3_all
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from .fcn import *
from .deeplabv3 import *
from .fcn import *
from .lraspp import *
......@@ -4,7 +4,6 @@ from typing import Optional, Dict
from torch import nn, Tensor
from torch.nn import functional as F
from ..._internally_replaced_utils import load_state_dict_from_url
from ...utils import _log_api_usage_once
......@@ -36,10 +35,3 @@ class _SimpleSegmentationModel(nn.Module):
result["aux"] = x
return result
def _load_weights(arch: str, model: nn.Module, model_url: Optional[str], progress: bool) -> None:
if model_url is None:
raise ValueError(f"No checkpoint is available for {arch}")
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment