Unverified Commit 70a8e05a authored by NVS Abhilash's avatar NVS Abhilash Committed by GitHub
Browse files

fix: atrous_rates for deeplabv3_mobilenet_v3_large (fixes #7956) (#8019)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 7e2050f1
from functools import partial
from typing import Any, List, Optional
from typing import Any, Optional, Sequence
import torch
from torch import nn
......@@ -46,9 +46,9 @@ class DeepLabV3(_SimpleSegmentationModel):
class DeepLabHead(nn.Sequential):
def __init__(self, in_channels: int, num_classes: int) -> None:
def __init__(self, in_channels: int, num_classes: int, atrous_rates: Sequence[int] = (12, 24, 36)) -> None:
super().__init__(
ASPP(in_channels, [12, 24, 36]),
ASPP(in_channels, atrous_rates),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
......@@ -83,7 +83,7 @@ class ASPPPooling(nn.Sequential):
class ASPP(nn.Module):
def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
def __init__(self, in_channels: int, atrous_rates: Sequence[int], out_channels: int = 256) -> None:
super().__init__()
modules = []
modules.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment