Unverified Commit 70a8e05a authored by NVS Abhilash's avatar NVS Abhilash Committed by GitHub
Browse files

fix: atrous_rates for deeplabv3_mobilenet_v3_large (fixes #7956) (#8019)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 7e2050f1
from functools import partial from functools import partial
from typing import Any, List, Optional from typing import Any, Optional, Sequence
import torch import torch
from torch import nn from torch import nn
...@@ -46,9 +46,9 @@ class DeepLabV3(_SimpleSegmentationModel): ...@@ -46,9 +46,9 @@ class DeepLabV3(_SimpleSegmentationModel):
class DeepLabHead(nn.Sequential): class DeepLabHead(nn.Sequential):
def __init__(self, in_channels: int, num_classes: int) -> None: def __init__(self, in_channels: int, num_classes: int, atrous_rates: Sequence[int] = (12, 24, 36)) -> None:
super().__init__( super().__init__(
ASPP(in_channels, [12, 24, 36]), ASPP(in_channels, atrous_rates),
nn.Conv2d(256, 256, 3, padding=1, bias=False), nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256), nn.BatchNorm2d(256),
nn.ReLU(), nn.ReLU(),
...@@ -83,7 +83,7 @@ class ASPPPooling(nn.Sequential): ...@@ -83,7 +83,7 @@ class ASPPPooling(nn.Sequential):
class ASPP(nn.Module): class ASPP(nn.Module):
def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None: def __init__(self, in_channels: int, atrous_rates: Sequence[int], out_channels: int = 256) -> None:
super().__init__() super().__init__()
modules = [] modules = []
modules.append( modules.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment