Unverified Commit 7efdf026 authored by Martin Hahner's avatar Martin Hahner Committed by GitHub
Browse files

Fix issues with CaDDN (#542)

* Download DeepLabV3 if not available yet

* Fix dtype issue

Fix RuntimeError: expected backend CPU and dtype Float but got backend CPU and dtype Long
parent aaf9cbeb
......@@ -17,7 +17,7 @@ class FrustumGridGenerator(nn.Module):
"""
super().__init__()
self.dtype = torch.float32
self.grid_size = torch.as_tensor(grid_size)
self.grid_size = torch.as_tensor(grid_size, dtype=self.dtype)
self.pc_range = pc_range
self.out_of_bounds_val = -2
self.disc_cfg = disc_cfg
......
from collections import OrderedDict
from pathlib import Path
from torch import hub
import numpy as np
import torch
......@@ -57,6 +59,15 @@ class DDNTemplate(nn.Module):
if self.pretrained_path is not None:
model_dict = model.state_dict()
# Download pretrained model if not available yet
checkpoint_path = Path(self.pretrained_path)
if not checkpoint_path.exists():
checkpoint = checkpoint_path.name
save_dir = checkpoint_path.parent
save_dir.mkdir(parents=True)
url = f'https://download.pytorch.org/models/{checkpoint}'
hub.load_state_dict_from_url(url, save_dir)
# Get pretrained state dict
pretrained_dict = torch.load(self.pretrained_path)
pretrained_dict = self.filter_pretrained_dict(model_dict=model_dict,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment