Unverified Commit 7efdf026 authored by Martin Hahner's avatar Martin Hahner Committed by GitHub
Browse files

Fix issues with CaDDN (#542)

* Download DeepLabV3 if not available yet

* Fix dtype issue

Fix RuntimeError: expected backend CPU and dtype Float but got backend CPU and dtype Long
parent aaf9cbeb
...@@ -17,7 +17,7 @@ class FrustumGridGenerator(nn.Module): ...@@ -17,7 +17,7 @@ class FrustumGridGenerator(nn.Module):
""" """
super().__init__() super().__init__()
self.dtype = torch.float32 self.dtype = torch.float32
self.grid_size = torch.as_tensor(grid_size) self.grid_size = torch.as_tensor(grid_size, dtype=self.dtype)
self.pc_range = pc_range self.pc_range = pc_range
self.out_of_bounds_val = -2 self.out_of_bounds_val = -2
self.disc_cfg = disc_cfg self.disc_cfg = disc_cfg
......
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path
from torch import hub
import numpy as np import numpy as np
import torch import torch
...@@ -56,6 +58,15 @@ class DDNTemplate(nn.Module): ...@@ -56,6 +58,15 @@ class DDNTemplate(nn.Module):
# Update weights # Update weights
if self.pretrained_path is not None: if self.pretrained_path is not None:
model_dict = model.state_dict() model_dict = model.state_dict()
# Download pretrained model if not available yet
checkpoint_path = Path(self.pretrained_path)
if not checkpoint_path.exists():
checkpoint = checkpoint_path.name
save_dir = checkpoint_path.parent
save_dir.mkdir(parents=True)
url = f'https://download.pytorch.org/models/{checkpoint}'
hub.load_state_dict_from_url(url, save_dir)
# Get pretrained state dict # Get pretrained state dict
pretrained_dict = torch.load(self.pretrained_path) pretrained_dict = torch.load(self.pretrained_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment