"vscode:/vscode.git/clone" did not exist on "0404891ee40f6def945f93e8664de7b041ee7537"
Commit 41b18fd8 authored by zhe chen's avatar zhe chen
Browse files

Use pre-commit to reformat code


Use pre-commit to reformat code
parent ff20ea39
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmdet.models import HEADS
from torch.distributions.categorical import Categorical from torch.distributions.categorical import Categorical
from mmdet.models import HEADS
from .detgen_utils.causal_trans import (CausalTransformerDecoder, from .detgen_utils.causal_trans import (CausalTransformerDecoder,
CausalTransformerDecoderLayer) CausalTransformerDecoderLayer)
from .detgen_utils.utils import (dequantize_verts, generate_square_subsequent_mask, from .detgen_utils.utils import (generate_square_subsequent_mask, top_k_logits,
quantize_verts, top_k_logits, top_p_logits) top_p_logits)
from mmcv.runner import force_fp32, auto_fp16
@HEADS.register_module(force=True) @HEADS.register_module(force=True)
class PolylineGenerator(nn.Module): class PolylineGenerator(nn.Module):
...@@ -63,7 +63,7 @@ class PolylineGenerator(nn.Module): ...@@ -63,7 +63,7 @@ class PolylineGenerator(nn.Module):
self.fp16_enabled = False self.fp16_enabled = False
self.coord_dim = coord_dim # if we use xyz else 2 when we use xy self.coord_dim = coord_dim # if we use xyz else 2 when we use xy
self.kp_coord_dim = coord_dim if coord_dim==2 else 2 # XXX self.kp_coord_dim = coord_dim if coord_dim == 2 else 2 # XXX
self.register_buffer('canvas_size', torch.tensor(canvas_size)) self.register_buffer('canvas_size', torch.tensor(canvas_size))
# initialize the model # initialize the model
...@@ -126,7 +126,7 @@ class PolylineGenerator(nn.Module): ...@@ -126,7 +126,7 @@ class PolylineGenerator(nn.Module):
# Discrete vertex value embeddings # Discrete vertex value embeddings
vert_embeddings = self.vertex_embed(bbox) vert_embeddings = self.vertex_embed(bbox)
return vert_embeddings + (bbox_embedding+coord_embeddings)[None] return vert_embeddings + (bbox_embedding + coord_embeddings)[None]
def _prepare_context(self, batch, context): def _prepare_context(self, batch, context):
"""Prepare class label and vertex context.""" """Prepare class label and vertex context."""
...@@ -169,7 +169,7 @@ class PolylineGenerator(nn.Module): ...@@ -169,7 +169,7 @@ class PolylineGenerator(nn.Module):
def _embed_inputs(self, seqs, condition_embedding=None): def _embed_inputs(self, seqs, condition_embedding=None):
"""Embeds face sequences and adds within and between face positions. """Embeds face sequences and adds within and between face positions.
Args: Args:
seq: B, seqlen=vlen*3, seq: B, seqlen=vlen*3,
condition_embedding: B, [c,xs,ys,xe,ye](5), h condition_embedding: B, [c,xs,ys,xe,ye](5), h
Returns: Returns:
embeddings: B, seqlen, h embeddings: B, seqlen, h
...@@ -189,7 +189,7 @@ class PolylineGenerator(nn.Module): ...@@ -189,7 +189,7 @@ class PolylineGenerator(nn.Module):
# Aggregate embeddings # Aggregate embeddings
embeddings = vert_embeddings + \ embeddings = vert_embeddings + \
(coord_embeddings+pos_embeddings)[None] (coord_embeddings + pos_embeddings)[None]
embeddings = torch.cat([condition_embedding, embeddings], dim=1) embeddings = torch.cat([condition_embedding, embeddings], dim=1)
return embeddings return embeddings
...@@ -210,7 +210,7 @@ class PolylineGenerator(nn.Module): ...@@ -210,7 +210,7 @@ class PolylineGenerator(nn.Module):
return self.forward_train(batch, **kwargs) return self.forward_train(batch, **kwargs)
else: else:
return self.inference(batch, **kwargs) return self.inference(batch, **kwargs)
def sperate_forward(self, batch, context, **kwargs): def sperate_forward(self, batch, context, **kwargs):
polyline_length = batch['polyline_masks'].sum(-1) polyline_length = batch['polyline_masks'].sum(-1)
...@@ -218,24 +218,22 @@ class PolylineGenerator(nn.Module): ...@@ -218,24 +218,22 @@ class PolylineGenerator(nn.Module):
sizes = [size, polyline_length.max()] sizes = [size, polyline_length.max()]
polyline_logits = [] polyline_logits = []
for c_idx, size in zip([c1,c2], sizes): for c_idx, size in zip([c1, c2], sizes):
new_batch = assign_batch(batch, c_idx, size)
new_batch = assign_batch(batch,c_idx, size) _poly_logits = self._forward_train(new_batch, context, **kwargs)
_poly_logits = self._forward_train(new_batch,context,**kwargs)
polyline_logits.append(_poly_logits) polyline_logits.append(_poly_logits)
# maybe imporve the speed # maybe imporve the speed
for i, (_poly_logits, size) in enumerate(zip(polyline_logits, sizes)): for i, (_poly_logits, size) in enumerate(zip(polyline_logits, sizes)):
if size < sizes[1]: if size < sizes[1]:
_poly_logits = F.pad(_poly_logits, (0,0,0,sizes[1]-size), "constant", 0) _poly_logits = F.pad(_poly_logits, (0, 0, 0, sizes[1] - size), 'constant', 0)
polyline_logits[i] = _poly_logits polyline_logits[i] = _poly_logits
polyline_logits = torch.cat(polyline_logits,0) polyline_logits = torch.cat(polyline_logits, 0)
polyline_logits = polyline_logits[revert_idx] polyline_logits = polyline_logits[revert_idx]
cat_dist = Categorical(logits=polyline_logits) cat_dist = Categorical(logits=polyline_logits)
return {'polylines':cat_dist} return {'polylines': cat_dist}
def forward_train(self, batch: dict, context: dict, **kwargs): def forward_train(self, batch: dict, context: dict, **kwargs):
""" """
...@@ -247,7 +245,7 @@ class PolylineGenerator(nn.Module): ...@@ -247,7 +245,7 @@ class PolylineGenerator(nn.Module):
if False: if False:
polyline_logits = self._forward_train(batch, context, **kwargs) polyline_logits = self._forward_train(batch, context, **kwargs)
cat_dist = Categorical(logits=polyline_logits) cat_dist = Categorical(logits=polyline_logits)
return {'polylines':cat_dist} return {'polylines': cat_dist}
else: else:
return self.sperate_forward(batch, context, **kwargs) return self.sperate_forward(batch, context, **kwargs)
...@@ -260,7 +258,7 @@ class PolylineGenerator(nn.Module): ...@@ -260,7 +258,7 @@ class PolylineGenerator(nn.Module):
# we use the gt vertices # we use the gt vertices
global_context, seq_context = self._prepare_context( global_context, seq_context = self._prepare_context(
batch, context) batch, context)
logits = self.body( logits = self.body(
# Last element not used for preds # Last element not used for preds
batch['polylines'][:, :-1], batch['polylines'][:, :-1],
...@@ -271,7 +269,7 @@ class PolylineGenerator(nn.Module): ...@@ -271,7 +269,7 @@ class PolylineGenerator(nn.Module):
return logits return logits
@force_fp32(apply_to=('global_context_embedding','sequential_context_embeddings','cache')) @force_fp32(apply_to=('global_context_embedding', 'sequential_context_embeddings', 'cache'))
def body(self, def body(self,
seqs, seqs,
global_context_embedding=None, global_context_embedding=None,
...@@ -303,7 +301,7 @@ class PolylineGenerator(nn.Module): ...@@ -303,7 +301,7 @@ class PolylineGenerator(nn.Module):
if is_training: if is_training:
causal_msk = generate_square_subsequent_mask( causal_msk = generate_square_subsequent_mask(
decoder_inputs.shape[0], condition_len=condition_len, device=decoder_inputs.device) decoder_inputs.shape[0], condition_len=condition_len, device=decoder_inputs.device)
decoder_outputs, cache = self.decoder( decoder_outputs, cache = self.decoder(
tgt=decoder_inputs, tgt=decoder_inputs,
cache=cache, cache=cache,
...@@ -314,28 +312,27 @@ class PolylineGenerator(nn.Module): ...@@ -314,28 +312,27 @@ class PolylineGenerator(nn.Module):
decoder_outputs = decoder_outputs.transpose(0, 1) decoder_outputs = decoder_outputs.transpose(0, 1)
# since we only need the predict seq # since we only need the predict seq
decoder_outputs = decoder_outputs[:, condition_len-1:] decoder_outputs = decoder_outputs[:, condition_len - 1:]
# Get logits and optionally process for sampling # Get logits and optionally process for sampling
logits = self._project_to_logits(decoder_outputs) logits = self._project_to_logits(decoder_outputs)
# y mask # y mask
_vert_mask = torch.arange(logits.shape[-1], device=logits.device) _vert_mask = torch.arange(logits.shape[-1], device=logits.device)
vertices_mask_y = (_vert_mask < self.canvas_size[1]+1) vertices_mask_y = (_vert_mask < self.canvas_size[1] + 1)
vertices_mask_y[0] = False # y position doesn't have stop sign vertices_mask_y[0] = False # y position doesn't have stop sign
logits[:, 1::self.coord_dim] = logits[:, 1::self.coord_dim] * \ logits[:, 1::self.coord_dim] = logits[:, 1::self.coord_dim] * \
vertices_mask_y - ~vertices_mask_y*1e9 vertices_mask_y - ~vertices_mask_y * 1e9
if self.coord_dim > 2: if self.coord_dim > 2:
# z mask # z mask
_vert_mask = torch.arange(logits.shape[-1], device=logits.device) _vert_mask = torch.arange(logits.shape[-1], device=logits.device)
vertices_mask_z = (_vert_mask < self.canvas_size[2]+1) vertices_mask_z = (_vert_mask < self.canvas_size[2] + 1)
vertices_mask_z[0] = False # y position doesn't have stop sign vertices_mask_z[0] = False # y position doesn't have stop sign
logits[:, 2::self.coord_dim] = logits[:, 2::self.coord_dim] * \ logits[:, 2::self.coord_dim] = logits[:, 2::self.coord_dim] * \
vertices_mask_z - ~vertices_mask_z*1e9 vertices_mask_z - ~vertices_mask_z * 1e9
logits = logits/temperature logits = logits / temperature
logits = top_k_logits(logits, top_k) logits = top_k_logits(logits, top_k)
logits = top_p_logits(logits, top_p) logits = top_p_logits(logits, top_p)
if return_logits: if return_logits:
...@@ -350,9 +347,9 @@ class PolylineGenerator(nn.Module): ...@@ -350,9 +347,9 @@ class PolylineGenerator(nn.Module):
weight = gt['polyline_weights'] weight = gt['polyline_weights']
mask = gt['polyline_masks'] mask = gt['polyline_masks']
loss = -torch.sum( loss = -torch.sum(
pred['polylines'].log_prob(gt['polylines']) * mask * weight)/weight.sum() pred['polylines'].log_prob(gt['polylines']) * mask * weight) / weight.sum()
return {'seq': loss} return {'seq': loss}
...@@ -395,16 +392,15 @@ class PolylineGenerator(nn.Module): ...@@ -395,16 +392,15 @@ class PolylineGenerator(nn.Module):
samples = torch.empty( samples = torch.empty(
[batch_size, 0], dtype=torch.int32, device=device) [batch_size, 0], dtype=torch.int32, device=device)
max_sample_length = max_sample_length or self.max_seq_length max_sample_length = max_sample_length or self.max_seq_length
seq_len = max_sample_length*self.coord_dim+1 seq_len = max_sample_length * self.coord_dim + 1
cache = None cache = None
decoded_tokens = \ decoded_tokens = \
torch.zeros((batch_size,seq_len), torch.zeros((batch_size, seq_len),
device=device,dtype=torch.long) device=device, dtype=torch.long)
remain_idx = torch.arange(batch_size, device=device) remain_idx = torch.arange(batch_size, device=device)
for i in range(seq_len): for i in range(seq_len):
# While-loop body for autoregression calculation. # While-loop body for autoregression calculation.
pred_dist, cache = self.body( pred_dist, cache = self.body(
samples, samples,
...@@ -417,22 +413,22 @@ class PolylineGenerator(nn.Module): ...@@ -417,22 +413,22 @@ class PolylineGenerator(nn.Module):
is_training=False) is_training=False)
samples = pred_dist.sample() samples = pred_dist.sample()
decoded_tokens[remain_idx,i] = samples[:,-1] decoded_tokens[remain_idx, i] = samples[:, -1]
# Stopping conditions for autoregressive calculation. # Stopping conditions for autoregressive calculation.
if not (decoded_tokens[:,:i+1] != 0).all(-1).any(): if not (decoded_tokens[:, :i + 1] != 0).all(-1).any():
break break
# update state, check the new position is zero. # update state, check the new position is zero.
valid_idx = (samples[:,-1] != 0).nonzero(as_tuple=True)[0] valid_idx = (samples[:, -1] != 0).nonzero(as_tuple=True)[0]
remain_idx = remain_idx[valid_idx] remain_idx = remain_idx[valid_idx]
cache = cache[:,:,valid_idx] cache = cache[:, :, valid_idx]
global_context = global_context[valid_idx] global_context = global_context[valid_idx]
seq_context = seq_context[valid_idx] seq_context = seq_context[valid_idx]
samples = samples[valid_idx] samples = samples[valid_idx]
# decoded_tokens = torch.cat(decoded_tokens,dim=1) # decoded_tokens = torch.cat(decoded_tokens,dim=1)
decoded_tokens = decoded_tokens[:,:i+1] decoded_tokens = decoded_tokens[:, :i + 1]
outputs = self.post_process(decoded_tokens, seq_len, outputs = self.post_process(decoded_tokens, seq_len,
device, only_return_complete) device, only_return_complete)
...@@ -455,7 +451,7 @@ class PolylineGenerator(nn.Module): ...@@ -455,7 +451,7 @@ class PolylineGenerator(nn.Module):
_polyline_mask = torch.arange(sample_seq_length)[None].to(device) _polyline_mask = torch.arange(sample_seq_length)[None].to(device)
# Get largest stopping point for incomplete samples. # Get largest stopping point for incomplete samples.
valid_polyline_len = torch.full_like(polyline[:,0], sample_seq_length) valid_polyline_len = torch.full_like(polyline[:, 0], sample_seq_length)
zero_inds = (polyline == 0).type(torch.int32).argmax(-1) zero_inds = (polyline == 0).type(torch.int32).argmax(-1)
# Real length # Real length
...@@ -463,7 +459,7 @@ class PolylineGenerator(nn.Module): ...@@ -463,7 +459,7 @@ class PolylineGenerator(nn.Module):
polyline_mask = _polyline_mask < valid_polyline_len[:, None] polyline_mask = _polyline_mask < valid_polyline_len[:, None]
# Mask faces beyond stopping token with zeros # Mask faces beyond stopping token with zeros
polyline = polyline*polyline_mask polyline = polyline * polyline_mask
# Pad to maximum size with zeros # Pad to maximum size with zeros
pad_size = max_seq_len - sample_seq_length pad_size = max_seq_len - sample_seq_length
...@@ -486,27 +482,27 @@ class PolylineGenerator(nn.Module): ...@@ -486,27 +482,27 @@ class PolylineGenerator(nn.Module):
} }
return outputs return outputs
def find_best_sperate_plan(idx,array):
def find_best_sperate_plan(idx, array):
h = array[-1] - array[idx] h = array[-1] - array[idx]
w = idx w = idx
cost = h*w cost = h * w
return cost return cost
def get_chunk_idx(polyline_length): def get_chunk_idx(polyline_length):
_polyline_length, polyline_length_idx = torch.sort(polyline_length) _polyline_length, polyline_length_idx = torch.sort(polyline_length)
costs = [] costs = []
for i in range(len(_polyline_length)): for i in range(len(_polyline_length)):
cost = find_best_sperate_plan(i, _polyline_length)
cost = find_best_sperate_plan(i,_polyline_length)
costs.append(cost) costs.append(cost)
seperate_point = torch.stack(costs).argmax() seperate_point = torch.stack(costs).argmax()
chunk1 = polyline_length_idx[:seperate_point+1] chunk1 = polyline_length_idx[:seperate_point + 1]
chunk2 = polyline_length_idx[seperate_point+1:] chunk2 = polyline_length_idx[seperate_point + 1:]
revert_idx = torch.argsort(polyline_length_idx) revert_idx = torch.argsort(polyline_length_idx)
return chunk1, chunk2, revert_idx, _polyline_length[seperate_point] return chunk1, chunk2, revert_idx, _polyline_length[seperate_point]
...@@ -517,9 +513,9 @@ def assign_bev(feat, idx): ...@@ -517,9 +513,9 @@ def assign_bev(feat, idx):
def assign_batch(batch, idx, size): def assign_batch(batch, idx, size):
new_batch = {} new_batch = {}
for k,v in batch.items(): for k, v in batch.items():
new_batch[k] = v[idx] new_batch[k] = v[idx]
if new_batch[k].ndim > 1: if new_batch[k].ndim > 1:
new_batch[k] = new_batch[k][:,:size] new_batch[k] = new_batch[k][:, :size]
return new_batch return new_batch
import mmcv
import torch import torch
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weighted_loss
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet.models.losses import l1_loss
from mmdet.models.losses.utils import weighted_loss
import mmcv
from mmdet.models.builder import LOSSES
@mmcv.jit(derivate=True, coderize=True) @mmcv.jit(derivate=True, coderize=True)
...@@ -62,7 +60,7 @@ class LinesLoss(nn.Module): ...@@ -62,7 +60,7 @@ class LinesLoss(nn.Module):
target (torch.Tensor): The learning target of the prediction. target (torch.Tensor): The learning target of the prediction.
shape: [bs, ...] shape: [bs, ...]
weight (torch.Tensor, optional): The weight of loss for each weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None. prediction. Defaults to None.
it's useful when the predictions are not all valid. it's useful when the predictions are not all valid.
avg_factor (int, optional): Average factor that is used to average avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None. the loss. Defaults to None.
...@@ -77,7 +75,7 @@ class LinesLoss(nn.Module): ...@@ -77,7 +75,7 @@ class LinesLoss(nn.Module):
loss = smooth_l1_loss( loss = smooth_l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor, beta=self.beta) pred, target, weight, reduction=reduction, avg_factor=avg_factor, beta=self.beta)
return loss*self.loss_weight return loss * self.loss_weight
@mmcv.jit(derivate=True, coderize=True) @mmcv.jit(derivate=True, coderize=True)
...@@ -123,7 +121,8 @@ class MasksLoss(nn.Module): ...@@ -123,7 +121,8 @@ class MasksLoss(nn.Module):
loss = bce(pred, target, weight, reduction=reduction, loss = bce(pred, target, weight, reduction=reduction,
avg_factor=avg_factor) avg_factor=avg_factor)
return loss*self.loss_weight return loss * self.loss_weight
@mmcv.jit(derivate=True, coderize=True) @mmcv.jit(derivate=True, coderize=True)
@weighted_loss @weighted_loss
...@@ -165,6 +164,6 @@ class LenLoss(nn.Module): ...@@ -165,6 +164,6 @@ class LenLoss(nn.Module):
reduction_override if reduction_override else self.reduction) reduction_override if reduction_override else self.reduction)
loss = ce(pred, target, weight, reduction=reduction, loss = ce(pred, target, weight, reduction=reduction,
avg_factor=avg_factor) avg_factor=avg_factor)
return loss*self.loss_weight return loss * self.loss_weight
\ No newline at end of file
from abc import ABCMeta, abstractmethod from abc import ABCMeta
import torch.nn as nn import torch.nn as nn
from mmcv.runner import auto_fp16
from mmcv.utils import print_log from mmcv.utils import print_log
from mmdet.utils import get_root_logger
from mmdet3d.models.builder import DETECTORS from mmdet3d.models.builder import DETECTORS
from mmdet.utils import get_root_logger
MAPPERS = DETECTORS MAPPERS = DETECTORS
class BaseMapper(nn.Module, metaclass=ABCMeta): class BaseMapper(nn.Module, metaclass=ABCMeta):
"""Base class for mappers.""" """Base class for mappers."""
...@@ -40,7 +39,7 @@ class BaseMapper(nn.Module, metaclass=ABCMeta): ...@@ -40,7 +39,7 @@ class BaseMapper(nn.Module, metaclass=ABCMeta):
return ((hasattr(self, 'roi_head') and self.roi_head.with_mask) return ((hasattr(self, 'roi_head') and self.roi_head.with_mask)
or (hasattr(self, 'mask_head') and self.mask_head is not None)) or (hasattr(self, 'mask_head') and self.mask_head is not None))
#@abstractmethod # @abstractmethod
def extract_feat(self, imgs): def extract_feat(self, imgs):
"""Extract features from images.""" """Extract features from images."""
pass pass
...@@ -48,11 +47,11 @@ class BaseMapper(nn.Module, metaclass=ABCMeta): ...@@ -48,11 +47,11 @@ class BaseMapper(nn.Module, metaclass=ABCMeta):
def forward_train(self, *args, **kwargs): def forward_train(self, *args, **kwargs):
pass pass
#@abstractmethod # @abstractmethod
def simple_test(self, img, img_metas, **kwargs): def simple_test(self, img, img_metas, **kwargs):
pass pass
#@abstractmethod # @abstractmethod
def aug_test(self, imgs, img_metas, **kwargs): def aug_test(self, imgs, img_metas, **kwargs):
"""Test function with test time augmentation.""" """Test function with test time augmentation."""
pass pass
...@@ -88,7 +87,7 @@ class BaseMapper(nn.Module, metaclass=ABCMeta): ...@@ -88,7 +87,7 @@ class BaseMapper(nn.Module, metaclass=ABCMeta):
should be double nested (i.e. List[Tensor], List[List[dict]]), with should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations. the outer list indicating test time augmentations.
""" """
if return_loss: if return_loss:
return self.forward_train(*args, **kwargs) return self.forward_train(*args, **kwargs)
else: else:
...@@ -123,7 +122,7 @@ class BaseMapper(nn.Module, metaclass=ABCMeta): ...@@ -123,7 +122,7 @@ class BaseMapper(nn.Module, metaclass=ABCMeta):
averaging the logs. averaging the logs.
""" """
loss, log_vars, num_samples = self(**data_dict) loss, log_vars, num_samples = self(**data_dict)
outputs = dict( outputs = dict(
loss=loss, log_vars=log_vars, num_samples=num_samples) loss=loss, log_vars=log_vars, num_samples=num_samples)
...@@ -137,7 +136,7 @@ class BaseMapper(nn.Module, metaclass=ABCMeta): ...@@ -137,7 +136,7 @@ class BaseMapper(nn.Module, metaclass=ABCMeta):
not implemented with this method, but an evaluation hook. not implemented with this method, but an evaluation hook.
""" """
loss, log_vars, num_samples = self(**data) loss, log_vars, num_samples = self(**data)
outputs = dict( outputs = dict(
loss=loss, log_vars=log_vars, num_samples=num_samples) loss=loss, log_vars=log_vars, num_samples=num_samples)
...@@ -146,4 +145,4 @@ class BaseMapper(nn.Module, metaclass=ABCMeta): ...@@ -146,4 +145,4 @@ class BaseMapper(nn.Module, metaclass=ABCMeta):
def show_result(self, def show_result(self,
**kwargs): **kwargs):
img = None img = None
return img return img
\ No newline at end of file
import mmcv
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F from mmdet3d.models.builder import build_backbone, build_head, build_neck
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from torchvision.models.resnet import resnet18, resnet50 from torchvision.models.resnet import resnet18
from mmdet3d.models.builder import (build_backbone, build_head, from .base_mapper import MAPPERS, BaseMapper
build_neck)
from .base_mapper import BaseMapper, MAPPERS
@MAPPERS.register_module() @MAPPERS.register_module()
...@@ -31,12 +27,11 @@ class VectorMapNet(BaseMapper): ...@@ -31,12 +27,11 @@ class VectorMapNet(BaseMapper):
model_name=None, **kwargs): model_name=None, **kwargs):
super(VectorMapNet, self).__init__() super(VectorMapNet, self).__init__()
# Attribute
#Attribute
self.model_name = model_name self.model_name = model_name
self.last_epoch = None self.last_epoch = None
self.only_det = only_det self.only_det = only_det
self.backbone = build_backbone(backbone_cfg) self.backbone = build_backbone(backbone_cfg)
if neck_cfg is not None: if neck_cfg is not None:
...@@ -53,17 +48,16 @@ class VectorMapNet(BaseMapper): ...@@ -53,17 +48,16 @@ class VectorMapNet(BaseMapper):
nn.BatchNorm2d(64), nn.BatchNorm2d(64),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1, nn.MaxPool2d(kernel_size=3, stride=2, padding=1,
dilation=1, ceil_mode=False), dilation=1, ceil_mode=False),
trunk.layer1, trunk.layer1,
nn.Conv2d(64, 128, kernel_size=1, bias=False), nn.Conv2d(64, 128, kernel_size=1, bias=False),
) )
# BEV # BEV
if hasattr(self.backbone,'bev_w'): if hasattr(self.backbone, 'bev_w'):
self.bev_w = self.backbone.bev_w self.bev_w = self.backbone.bev_w
self.bev_h = self.backbone.bev_h self.bev_h = self.backbone.bev_h
self.head = build_head(head_cfg) self.head = build_head(head_cfg)
def multiscale_neck(self, bev_embedding): def multiscale_neck(self, bev_embedding):
...@@ -79,12 +73,12 @@ class VectorMapNet(BaseMapper): ...@@ -79,12 +73,12 @@ class VectorMapNet(BaseMapper):
img: torch.Tensor of shape [B, N, 3, H, W] img: torch.Tensor of shape [B, N, 3, H, W]
N: number of cams N: number of cams
vectors: list[list[Tuple(lines, length, label)]] vectors: list[list[Tuple(lines, length, label)]]
- lines: np.array of shape [num_points, 2]. - lines: np.array of shape [num_points, 2].
- length: int - length: int
- label: int - label: int
len(vectors) = batch_size len(vectors) = batch_size
len(vectors[_b]) = num of lines in sample _b len(vectors[_b]) = num of lines in sample _b
img_metas: img_metas:
img_metas['lidar2img']: [B, N, 4, 4] img_metas['lidar2img']: [B, N, 4, 4]
Out: Out:
loss, log_vars, num_sample loss, log_vars, num_sample
...@@ -92,12 +86,12 @@ class VectorMapNet(BaseMapper): ...@@ -92,12 +86,12 @@ class VectorMapNet(BaseMapper):
# prepare labels and images # prepare labels and images
batch, img, img_metas, valid_idx, points = self.batch_data( batch, img, img_metas, valid_idx, points = self.batch_data(
polys, img, img_metas, img.device, points) polys, img, img_metas, img.device, points)
# corner cases use hard code to prevent code fail # corner cases use hard code to prevent code fail
if self.last_epoch is None: if self.last_epoch is None:
self.last_epoch = [batch, img, img_metas, valid_idx, points] self.last_epoch = [batch, img, img_metas, valid_idx, points]
if len(valid_idx)==0: if len(valid_idx) == 0:
batch, img, img_metas, valid_idx, points = self.last_epoch batch, img, img_metas, valid_idx, points = self.last_epoch
else: else:
del self.last_epoch del self.last_epoch
...@@ -110,15 +104,15 @@ class VectorMapNet(BaseMapper): ...@@ -110,15 +104,15 @@ class VectorMapNet(BaseMapper):
# Neck # Neck
bev_feats = self.neck(_bev_feats) bev_feats = self.neck(_bev_feats)
preds_dict, losses_dict = \ preds_dict, losses_dict = \
self.head(batch, self.head(batch,
context={ context={
'bev_embeddings': bev_feats, 'bev_embeddings': bev_feats,
'batch_input_shape': _bev_feats.shape[2:], 'batch_input_shape': _bev_feats.shape[2:],
'img_shape': img_shape, 'img_shape': img_shape,
'raw_bev_embeddings': _bev_feats}, 'raw_bev_embeddings': _bev_feats},
only_det=self.only_det) only_det=self.only_det)
# format outputs # format outputs
loss = 0 loss = 0
...@@ -150,13 +144,13 @@ class VectorMapNet(BaseMapper): ...@@ -150,13 +144,13 @@ class VectorMapNet(BaseMapper):
bev_feats = self.neck(_bev_feats) bev_feats = self.neck(_bev_feats)
context = {'bev_embeddings': bev_feats, context = {'bev_embeddings': bev_feats,
'batch_input_shape': _bev_feats.shape[2:], 'batch_input_shape': _bev_feats.shape[2:],
'img_shape': img_shape, # XXX 'img_shape': img_shape, # XXX
'raw_bev_embeddings': _bev_feats} 'raw_bev_embeddings': _bev_feats}
preds_dict = self.head(batch={}, preds_dict = self.head(batch={},
context=context, context=context,
condition_on_det=True, condition_on_det=True,
gt_condition=False, gt_condition=False,
only_det=self.only_det) only_det=self.only_det)
...@@ -173,7 +167,7 @@ class VectorMapNet(BaseMapper): ...@@ -173,7 +167,7 @@ class VectorMapNet(BaseMapper):
valid_idx = [i for i in range(len(polys)) if len(polys[i])] valid_idx = [i for i in range(len(polys)) if len(polys[i])]
imgs = imgs[valid_idx] imgs = imgs[valid_idx]
img_metas = [img_metas[i] for i in valid_idx] img_metas = [img_metas[i] for i in valid_idx]
polys = [polys[i] for i in valid_idx] polys = [polys[i] for i in valid_idx]
if points is not None: if points is not None:
...@@ -184,16 +178,16 @@ class VectorMapNet(BaseMapper): ...@@ -184,16 +178,16 @@ class VectorMapNet(BaseMapper):
return None, None, None, valid_idx, None return None, None, None, valid_idx, None
batch = {} batch = {}
batch['det'] = format_det(polys,device) batch['det'] = format_det(polys, device)
batch['gen'] = format_gen(polys,device) batch['gen'] = format_gen(polys, device)
return batch, imgs, img_metas, valid_idx, points return batch, imgs, img_metas, valid_idx, points
def batch_points(self, points): def batch_points(self, points):
pad_points = pad_sequence(points, batch_first=True) pad_points = pad_sequence(points, batch_first=True)
points_mask = torch.zeros_like(pad_points[:,:,0]).bool() points_mask = torch.zeros_like(pad_points[:, :, 0]).bool()
for i in range(len(points)): for i in range(len(points)):
valid_num = points[i].shape[0] valid_num = points[i].shape[0]
points_mask[i][:valid_num] = True points_mask[i][:valid_num] = True
...@@ -202,41 +196,38 @@ class VectorMapNet(BaseMapper): ...@@ -202,41 +196,38 @@ class VectorMapNet(BaseMapper):
def format_det(polys, device): def format_det(polys, device):
batch = { batch = {
'class_label':[], 'class_label': [],
'batch_idx':[], 'batch_idx': [],
'bbox': [], 'bbox': [],
} }
for batch_idx, poly in enumerate(polys): for batch_idx, poly in enumerate(polys):
keypoint_label = torch.from_numpy(poly['det_label']).to(device) keypoint_label = torch.from_numpy(poly['det_label']).to(device)
keypoint = torch.from_numpy(poly['keypoint']).to(device) keypoint = torch.from_numpy(poly['keypoint']).to(device)
batch['class_label'].append(keypoint_label) batch['class_label'].append(keypoint_label)
batch['bbox'].append(keypoint) batch['bbox'].append(keypoint)
return batch return batch
def format_gen(polys,device):
def format_gen(polys, device):
line_cls = [] line_cls = []
polylines, polyline_masks, polyline_weights = [], [], [] polylines, polyline_masks, polyline_weights = [], [], []
bbox, line_cls, line_bs_idx = [], [], [] bbox, line_cls, line_bs_idx = [], [], []
for batch_idx, poly in enumerate(polys): for batch_idx, poly in enumerate(polys):
# convert to cuda tensor # convert to cuda tensor
for k in poly.keys(): for k in poly.keys():
if isinstance(poly[k],np.ndarray): if isinstance(poly[k], np.ndarray):
poly[k] = torch.from_numpy(poly[k]).to(device) poly[k] = torch.from_numpy(poly[k]).to(device)
else: else:
poly[k] = [torch.from_numpy(v).to(device) for v in poly[k]] poly[k] = [torch.from_numpy(v).to(device) for v in poly[k]]
line_cls += poly['gen_label'] line_cls += poly['gen_label']
line_bs_idx += [batch_idx]*len(poly['gen_label']) line_bs_idx += [batch_idx] * len(poly['gen_label'])
# condition # condition
bbox += poly['qkeypoint'] bbox += poly['qkeypoint']
...@@ -257,5 +248,5 @@ def format_gen(polys,device): ...@@ -257,5 +248,5 @@ def format_gen(polys,device):
batch['polylines'] = pad_sequence(polylines, batch_first=True) batch['polylines'] = pad_sequence(polylines, batch_first=True)
batch['polyline_masks'] = pad_sequence(polyline_masks, batch_first=True) batch['polyline_masks'] = pad_sequence(polyline_masks, batch_first=True)
batch['polyline_weights'] = pad_sequence(polyline_weights, batch_first=True) batch['polyline_weights'] = pad_sequence(polyline_weights, batch_first=True)
return batch return batch
\ No newline at end of file
from .deformable_transformer import DeformableDetrTransformer_, DeformableDetrTransformerDecoder_
from .base_transformer import PlaceHolderEncoder
\ No newline at end of file
import numpy as np
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F from mmcv.cnn.bricks.registry import TRANSFORMER_LAYER_SEQUENCE
from mmcv.cnn import xavier_init, constant_init
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import (MultiScaleDeformableAttention,
TransformerLayerSequence,
build_transformer_layer_sequence)
from mmcv.runner.base_module import BaseModule
from mmdet.models.utils.builder import TRANSFORMER
@TRANSFORMER_LAYER_SEQUENCE.register_module() @TRANSFORMER_LAYER_SEQUENCE.register_module()
class PlaceHolderEncoder(nn.Module): class PlaceHolderEncoder(nn.Module):
...@@ -21,5 +10,4 @@ class PlaceHolderEncoder(nn.Module): ...@@ -21,5 +10,4 @@ class PlaceHolderEncoder(nn.Module):
self.embed_dims = embed_dims self.embed_dims = embed_dims
def forward(self, *args, query=None, **kwargs): def forward(self, *args, query=None, **kwargs):
return query
return query
\ No newline at end of file
...@@ -4,18 +4,12 @@ import warnings ...@@ -4,18 +4,12 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer, xavier_init from mmcv.cnn import xavier_init
from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER, from mmcv.cnn.bricks.registry import TRANSFORMER_LAYER_SEQUENCE
TRANSFORMER_LAYER_SEQUENCE) from mmcv.cnn.bricks.transformer import TransformerLayerSequence
from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
TransformerLayerSequence,
build_transformer_layer_sequence)
from mmcv.runner.base_module import BaseModule
from torch.nn.init import normal_
from mmdet.models.utils.builder import TRANSFORMER from mmdet.models.utils.builder import TRANSFORMER
from mmdet.models.utils.transformer import Transformer from mmdet.models.utils.transformer import Transformer
from torch.nn.init import normal_
try: try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
...@@ -27,6 +21,7 @@ except ImportError: ...@@ -27,6 +21,7 @@ except ImportError:
from .fp16_dattn import MultiScaleDeformableAttentionFp16 from .fp16_dattn import MultiScaleDeformableAttentionFp16
def inverse_sigmoid(x, eps=1e-5): def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid. """Inverse function of sigmoid.
Args: Args:
...@@ -44,6 +39,7 @@ def inverse_sigmoid(x, eps=1e-5): ...@@ -44,6 +39,7 @@ def inverse_sigmoid(x, eps=1e-5):
x2 = (1 - x).clamp(min=eps) x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2) return torch.log(x1 / x2)
@TRANSFORMER_LAYER_SEQUENCE.register_module() @TRANSFORMER_LAYER_SEQUENCE.register_module()
class DeformableDetrTransformerDecoder_(TransformerLayerSequence): class DeformableDetrTransformerDecoder_(TransformerLayerSequence):
"""Implements the decoder in DETR transformer. """Implements the decoder in DETR transformer.
...@@ -53,8 +49,8 @@ class DeformableDetrTransformerDecoder_(TransformerLayerSequence): ...@@ -53,8 +49,8 @@ class DeformableDetrTransformerDecoder_(TransformerLayerSequence):
`LN`. `LN`.
""" """
def __init__(self, *args, def __init__(self, *args,
return_intermediate=False, coord_dim=2, kp_coord_dim=2, **kwargs): return_intermediate=False, coord_dim=2, kp_coord_dim=2, **kwargs):
super(DeformableDetrTransformerDecoder_, self).__init__(*args, **kwargs) super(DeformableDetrTransformerDecoder_, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate self.return_intermediate = return_intermediate
...@@ -94,25 +90,26 @@ class DeformableDetrTransformerDecoder_(TransformerLayerSequence): ...@@ -94,25 +90,26 @@ class DeformableDetrTransformerDecoder_(TransformerLayerSequence):
for lid, layer in enumerate(self.layers): for lid, layer in enumerate(self.layers):
reference_points_input = \ reference_points_input = \
reference_points[:, :, None,:self.kp_coord_dim] * \ reference_points[:, :, None, :self.kp_coord_dim] * \
valid_ratios[:, None,:,:self.kp_coord_dim] valid_ratios[:, None, :, :self.kp_coord_dim]
# if reference_points.shape[-1] == 3 and self.kp_coord_dim==2: # if reference_points.shape[-1] == 3 and self.kp_coord_dim==2:
output = layer( output = layer(
output, output,
*args, *args,
reference_points=reference_points_input[...,:self.kp_coord_dim], reference_points=reference_points_input[..., :self.kp_coord_dim],
**kwargs) **kwargs)
output = output.permute(1, 0, 2) output = output.permute(1, 0, 2)
if reg_branches is not None: if reg_branches is not None:
tmp = reg_branches[lid](output) tmp = reg_branches[lid](output)
new_reference_points = tmp new_reference_points = tmp
new_reference_points[..., :self.kp_coord_dim] = tmp[ new_reference_points[..., :self.kp_coord_dim] = tmp[
..., :self.kp_coord_dim] + inverse_sigmoid(reference_points) ..., :self.kp_coord_dim] + inverse_sigmoid(
reference_points)
new_reference_points = new_reference_points.sigmoid() new_reference_points = new_reference_points.sigmoid()
if reference_points.shape[-1] == 3 and self.kp_coord_dim==2: if reference_points.shape[-1] == 3 and self.kp_coord_dim == 2:
reference_points[...,-1] = tmp[...,-1].sigmoid().detach() reference_points[..., -1] = tmp[..., -1].sigmoid().detach()
reference_points[...,:self.coord_dim] = new_reference_points.detach() reference_points[..., :self.coord_dim] = new_reference_points.detach()
output = output.permute(1, 0, 2) output = output.permute(1, 0, 2)
if self.return_intermediate: if self.return_intermediate:
...@@ -174,7 +171,7 @@ class DeformableDetrTransformer_(Transformer): ...@@ -174,7 +171,7 @@ class DeformableDetrTransformer_(Transformer):
for m in self.modules(): for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention): if isinstance(m, MultiScaleDeformableAttention):
m.init_weights() m.init_weights()
elif isinstance(m,MultiScaleDeformableAttentionFp16): elif isinstance(m, MultiScaleDeformableAttentionFp16):
m.init_weights() m.init_weights()
if not self.as_two_stage: if not self.as_two_stage:
xavier_init(self.reference_points_embed, distribution='uniform', bias=0.) xavier_init(self.reference_points_embed, distribution='uniform', bias=0.)
...@@ -204,9 +201,9 @@ class DeformableDetrTransformer_(Transformer): ...@@ -204,9 +201,9 @@ class DeformableDetrTransformer_(Transformer):
torch.linspace( torch.linspace(
0.5, W - 0.5, W, dtype=torch.float32, device=device)) 0.5, W - 0.5, W, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / ( ref_y = ref_y.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 1] * H) valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(-1)[None] / ( ref_x = ref_x.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 0] * W) valid_ratios[:, None, lvl, 0] * W)
ref = torch.stack((ref_x, ref_y), -1) ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref) reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1) reference_points = torch.cat(reference_points_list, 1)
...@@ -231,7 +228,7 @@ class DeformableDetrTransformer_(Transformer): ...@@ -231,7 +228,7 @@ class DeformableDetrTransformer_(Transformer):
scale = 2 * math.pi scale = 2 * math.pi
dim_t = torch.arange( dim_t = torch.arange(
num_pos_feats, dtype=torch.float32, device=proposals.device) num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
# N, L, 4 # N, L, 4
proposals = proposals.sigmoid() * scale proposals = proposals.sigmoid() * scale
# N, L, 4, 128 # N, L, 4, 128
...@@ -317,7 +314,7 @@ class DeformableDetrTransformer_(Transformer): ...@@ -317,7 +314,7 @@ class DeformableDetrTransformer_(Transformer):
spatial_shapes = torch.as_tensor( spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=feat_flatten.device) spatial_shapes, dtype=torch.long, device=feat_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros( level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) (1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack( valid_ratios = torch.stack(
[self.get_valid_ratio(m) for m in mlvl_masks], 1) [self.get_valid_ratio(m) for m in mlvl_masks], 1)
...@@ -343,7 +340,7 @@ class DeformableDetrTransformer_(Transformer): ...@@ -343,7 +340,7 @@ class DeformableDetrTransformer_(Transformer):
memory = feat_flatten.permute(1, 0, 2) memory = feat_flatten.permute(1, 0, 2)
bs, _, c = memory.shape bs, _, c = memory.shape
query_pos, query = torch.split(query_embed, c, dim=-1) query_pos, query = torch.split(query_embed, c, dim=-1)
reference_points = self.reference_points_embed(query_pos).sigmoid() reference_points = self.reference_points_embed(query_pos).sigmoid()
init_reference_out = reference_points init_reference_out = reference_points
...@@ -366,4 +363,4 @@ class DeformableDetrTransformer_(Transformer): ...@@ -366,4 +363,4 @@ class DeformableDetrTransformer_(Transformer):
**kwargs) **kwargs)
inter_references_out = inter_references inter_references_out = inter_references
return inter_states, init_reference_out, inter_references_out return inter_states, init_reference_out, inter_references_out
\ No newline at end of file
from turtle import forward
import warnings import warnings
try: try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
except ImportError: except ImportError:
...@@ -7,12 +7,6 @@ except ImportError: ...@@ -7,12 +7,6 @@ except ImportError:
'`MultiScaleDeformableAttention` in MMCV has been moved to ' '`MultiScaleDeformableAttention` in MMCV has been moved to '
'`mmcv.ops.multi_scale_deform_attn`, please update your MMCV') '`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
from mmcv.runner import force_fp32, auto_fp16
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.cnn.bricks.transformer import build_attention
import math import math
import warnings import warnings
...@@ -20,13 +14,15 @@ import warnings ...@@ -20,13 +14,15 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd.function import Function, once_differentiable
from mmcv import deprecated_api_warning from mmcv import deprecated_api_warning
from mmcv.cnn import constant_init, xavier_init from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner import BaseModule from mmcv.cnn.bricks.transformer import build_attention
from mmcv.runner import force_fp32
from mmcv.runner.base_module import BaseModule
from mmcv.utils import ext_loader from mmcv.utils import ext_loader
from torch.autograd.function import Function, once_differentiable
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
'_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
...@@ -35,16 +31,15 @@ from torch.cuda.amp import custom_bwd, custom_fwd ...@@ -35,16 +31,15 @@ from torch.cuda.amp import custom_bwd, custom_fwd
@ATTENTION.register_module() @ATTENTION.register_module()
class MultiScaleDeformableAttentionFp16(BaseModule): class MultiScaleDeformableAttentionFp16(BaseModule):
def __init__(self, attn_cfg=None,init_cfg=None,**kwarg): def __init__(self, attn_cfg=None, init_cfg=None, **kwarg):
super(MultiScaleDeformableAttentionFp16,self).__init__(init_cfg) super(MultiScaleDeformableAttentionFp16, self).__init__(init_cfg)
# import ipdb; ipdb.set_trace() # import ipdb; ipdb.set_trace()
self.deformable_attention = build_attention(attn_cfg) self.deformable_attention = build_attention(attn_cfg)
self.deformable_attention.init_weights() self.deformable_attention.init_weights()
self.fp16_enabled = False self.fp16_enabled = False
@force_fp32(apply_to=('query', 'key', 'value', 'query_pos', 'reference_points', 'identity'))
@force_fp32(apply_to=('query', 'key', 'value', 'query_pos', 'reference_points','identity'))
def forward(self, query, def forward(self, query,
key=None, key=None,
value=None, value=None,
...@@ -57,15 +52,14 @@ class MultiScaleDeformableAttentionFp16(BaseModule): ...@@ -57,15 +52,14 @@ class MultiScaleDeformableAttentionFp16(BaseModule):
**kwargs): **kwargs):
# import ipdb; ipdb.set_trace() # import ipdb; ipdb.set_trace()
return self.deformable_attention(query, return self.deformable_attention(query,
key=key, key=key,
value=value, value=value,
identity=identity, identity=identity,
query_pos=query_pos, query_pos=query_pos,
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
reference_points=reference_points, reference_points=reference_points,
spatial_shapes=spatial_shapes, spatial_shapes=spatial_shapes,
level_start_index=level_start_index,**kwargs) level_start_index=level_start_index, **kwargs)
class MultiScaleDeformableAttnFunctionFp32(Function): class MultiScaleDeformableAttnFunctionFp32(Function):
...@@ -118,8 +112,8 @@ class MultiScaleDeformableAttnFunctionFp32(Function): ...@@ -118,8 +112,8 @@ class MultiScaleDeformableAttnFunctionFp32(Function):
Tuple[Tensor]: Gradient Tuple[Tensor]: Gradient
of input tensors in forward. of input tensors in forward.
""" """
value, value_spatial_shapes, value_level_start_index,\ value, value_spatial_shapes, value_level_start_index, \
sampling_locations, attention_weights = ctx.saved_tensors sampling_locations, attention_weights = ctx.saved_tensors
grad_value = torch.zeros_like(value) grad_value = torch.zeros_like(value)
grad_sampling_loc = torch.zeros_like(sampling_locations) grad_sampling_loc = torch.zeros_like(sampling_locations)
grad_attn_weight = torch.zeros_like(attention_weights) grad_attn_weight = torch.zeros_like(attention_weights)
...@@ -137,7 +131,7 @@ class MultiScaleDeformableAttnFunctionFp32(Function): ...@@ -137,7 +131,7 @@ class MultiScaleDeformableAttnFunctionFp32(Function):
im2col_step=ctx.im2col_step) im2col_step=ctx.im2col_step)
return grad_value, None, None, \ return grad_value, None, None, \
grad_sampling_loc, grad_attn_weight, None grad_sampling_loc, grad_attn_weight, None
def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
...@@ -161,7 +155,7 @@ def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, ...@@ -161,7 +155,7 @@ def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
""" """
bs, _, num_heads, embed_dims = value.shape bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ =\ _, num_queries, num_heads, num_levels, num_points, _ = \
sampling_locations.shape sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
dim=1) dim=1)
...@@ -178,7 +172,7 @@ def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, ...@@ -178,7 +172,7 @@ def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
# bs, num_heads, num_queries, num_points, 2 -> # bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2 # bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, sampling_grid_l_ = sampling_grids[:, :, :,
level].transpose(1, 2).flatten(0, 1) level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points # bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = F.grid_sample( sampling_value_l_ = F.grid_sample(
value_l_, value_l_,
...@@ -281,8 +275,8 @@ class MultiScaleDeformableAttentionFP32(BaseModule): ...@@ -281,8 +275,8 @@ class MultiScaleDeformableAttentionFP32(BaseModule):
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view( grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1, self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1) 2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points): for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1 grid_init[:, :, i, :] *= i + 1
...@@ -373,13 +367,13 @@ class MultiScaleDeformableAttentionFP32(BaseModule): ...@@ -373,13 +367,13 @@ class MultiScaleDeformableAttentionFP32(BaseModule):
offset_normalizer = torch.stack( offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \ sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \ + sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :] / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4: elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \ sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \ + sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \ * reference_points[:, :, None, :, None, 2:] \
* 0.5 * 0.5
else: else:
raise ValueError( raise ValueError(
f'Last dim of reference_points must be' f'Last dim of reference_points must be'
...@@ -399,4 +393,4 @@ class MultiScaleDeformableAttentionFP32(BaseModule): ...@@ -399,4 +393,4 @@ class MultiScaleDeformableAttentionFP32(BaseModule):
# (num_query, bs ,embed_dims) # (num_query, bs ,embed_dims)
output = output.permute(1, 0, 2) output = output.permute(1, 0, 2)
return self.dropout(output) + identity return self.dropout(output) + identity
\ No newline at end of file
import sys
import os import os
sys.path.append(os.path.abspath('.')) import sys
from src.datasets.evaluation.vector_eval import VectorEvaluate
sys.path.append(os.path.abspath('.'))
import argparse import argparse
from src.datasets.evaluation.vector_eval import VectorEvaluate
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Evaluate a submission file') description='Evaluate a submission file')
parser.add_argument('submission',
help='submission file in pickle or json format to be evaluated')
parser.add_argument('gt', parser.add_argument('submission',
help='gt annotation file') help='submission file in pickle or json format to be evaluated')
parser.add_argument('gt',
help='gt annotation file')
args = parser.parse_args() args = parser.parse_args()
return args return args
def main(args): def main(args):
evaluator = VectorEvaluate(args.gt, n_workers=0) evaluator = VectorEvaluate(args.gt, n_workers=0)
results = evaluator.evaluate(args.submission) results = evaluator.evaluate(args.submission)
print(results) print(results)
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
main(args) main(args)
...@@ -9,7 +9,6 @@ import torch ...@@ -9,7 +9,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from mmcv.image import tensor2imgs from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info from mmcv.runner import get_dist_info
from mmdet.core import encode_mask_results from mmdet.core import encode_mask_results
...@@ -120,7 +119,7 @@ def collect_results_cpu(result_part, size, tmpdir=None): ...@@ -120,7 +119,7 @@ def collect_results_cpu(result_part, size, tmpdir=None):
if tmpdir is None: if tmpdir is None:
MAX_LEN = 512 MAX_LEN = 512
# 32 is whitespace # 32 is whitespace
dir_tensor = torch.full((MAX_LEN, ), dir_tensor = torch.full((MAX_LEN,),
32, 32,
dtype=torch.uint8, dtype=torch.uint8,
device='cuda') device='cuda')
......
...@@ -8,7 +8,6 @@ from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, ...@@ -8,7 +8,6 @@ from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer, Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner) build_runner)
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
from mmdet.core import DistEvalHook, EvalHook from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import (build_dataloader, build_dataset, from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor) replace_ImageToTensor)
......
import argparse import argparse
import mmcv
import os import os
import os.path as osp import os.path as osp
import torch import torch
import warnings from mmcv import Config
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model) wrap_fp16_model)
from mmdet3d.apis import single_gpu_test from mmdet3d.apis import single_gpu_test
from mmdet3d.datasets import build_dataloader, build_dataset from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_model from mmdet3d.models import build_model
from mmdet.datasets import replace_ImageToTensor
from mmdet_test import multi_gpu_test from mmdet_test import multi_gpu_test
from mmdet_train import set_random_seed from mmdet_train import set_random_seed
from mmdet.datasets import replace_ImageToTensor
def parse_args(): def parse_args():
...@@ -29,13 +27,13 @@ def parse_args(): ...@@ -29,13 +27,13 @@ def parse_args():
'--fuse-conv-bn', '--fuse-conv-bn',
action='store_true', action='store_true',
help='Whether to fuse conv and bn, this will slightly increase' help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed') 'the inference speed')
parser.add_argument( parser.add_argument(
'--format-only', '--format-only',
action='store_true', action='store_true',
help='Format the output results without perform evaluation. It is' help='Format the output results without perform evaluation. It is'
'useful when you want to format the result to a specific format and ' 'useful when you want to format the result to a specific format and '
'submit it to the test server') 'submit it to the test server')
parser.add_argument( parser.add_argument(
'--eval', '--eval',
action='store_true', action='store_true',
...@@ -47,7 +45,7 @@ def parse_args(): ...@@ -47,7 +45,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--tmpdir', '--tmpdir',
help='tmp directory used for collecting results from multiple ' help='tmp directory used for collecting results from multiple '
'workers, available when gpu-collect is not specified') 'workers, available when gpu-collect is not specified')
parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument( parser.add_argument(
'--deterministic', '--deterministic',
...@@ -74,8 +72,8 @@ def main(): ...@@ -74,8 +72,8 @@ def main():
if (args.eval and args.format_only) or (not args.eval and not args.format_only): if (args.eval and args.format_only) or (not args.eval and not args.format_only):
raise ValueError('Please specify exactly one operation (eval/format) ' raise ValueError('Please specify exactly one operation (eval/format) '
'with the argument "--eval" or "--format-only"') 'with the argument "--eval" or "--format-only"')
if args.eval and args.split == 'test': if args.eval and args.split == 'test':
raise ValueError('Cannot evaluate on test set') raise ValueError('Cannot evaluate on test set')
...@@ -90,7 +88,7 @@ def main(): ...@@ -90,7 +88,7 @@ def main():
# import modules from plguin/xx, registry will be updated # import modules from plguin/xx, registry will be updated
import sys import sys
sys.path.append(os.path.abspath('.')) sys.path.append(os.path.abspath('.'))
if hasattr(cfg, 'plugin'): if hasattr(cfg, 'plugin'):
if cfg.plugin: if cfg.plugin:
import importlib import importlib
...@@ -106,11 +104,11 @@ def main(): ...@@ -106,11 +104,11 @@ def main():
plg_lib = importlib.import_module(_module_path) plg_lib = importlib.import_module(_module_path)
plugin_dirs = cfg.plugin_dir plugin_dirs = cfg.plugin_dir
if not isinstance(plugin_dirs,list): if not isinstance(plugin_dirs, list):
plugin_dirs = [plugin_dirs,] plugin_dirs = [plugin_dirs, ]
for plugin_dir in plugin_dirs: for plugin_dir in plugin_dirs:
import_path(plugin_dir) import_path(plugin_dir)
else: else:
# import dir is the dirpath for the config file # import dir is the dirpath for the config file
_module_dir = os.path.dirname(args.config) _module_dir = os.path.dirname(args.config)
...@@ -151,10 +149,10 @@ def main(): ...@@ -151,10 +149,10 @@ def main():
elif cfg.get('work_dir', None) is None: elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None # use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs', cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0]) osp.splitext(osp.basename(args.config))[0])
cfg_data_dict.work_dir = cfg.work_dir cfg_data_dict.work_dir = cfg.work_dir
print('work_dir: ',cfg.work_dir) print('work_dir: ', cfg.work_dir)
dataset = build_dataset(cfg_data_dict) dataset = build_dataset(cfg_data_dict)
data_loader = build_dataloader( data_loader = build_dataloader(
dataset, dataset,
...@@ -181,7 +179,7 @@ def main(): ...@@ -181,7 +179,7 @@ def main():
device_ids=[torch.cuda.current_device()], device_ids=[torch.cuda.current_device()],
broadcast_buffers=False) broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir, outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect) args.gpu_collect)
rank, _ = get_dist_info() rank, _ = get_dist_info()
if rank == 0: if rank == 0:
......
...@@ -2,26 +2,25 @@ from __future__ import division ...@@ -2,26 +2,25 @@ from __future__ import division
import argparse import argparse
import copy import copy
import mmcv
import os import os
import time import time
import torch
import warnings import warnings
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from os import path as osp from os import path as osp
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmdet import __version__ as mmdet_version from mmdet import __version__ as mmdet_version
from mmdet3d import __version__ as mmdet3d_version from mmdet3d import __version__ as mmdet3d_version
from mmdet3d.apis import train_model from mmdet3d.apis import train_model
from mmdet3d.datasets import build_dataset from mmdet3d.datasets import build_dataset
# from builder import build_model
from mmdet3d.models import build_model
from mmdet3d.utils import collect_env, get_root_logger from mmdet3d.utils import collect_env, get_root_logger
from mmseg import __version__ as mmseg_version
# warper # warper
from mmdet_train import set_random_seed from mmdet_train import set_random_seed
# from builder import build_model from mmseg import __version__ as mmseg_version
from mmdet3d.models import build_model
def parse_args(): def parse_args():
...@@ -39,13 +38,13 @@ def parse_args(): ...@@ -39,13 +38,13 @@ def parse_args():
'--gpus', '--gpus',
type=int, type=int,
help='number of gpus to use ' help='number of gpus to use '
'(only applicable to non-distributed training)') '(only applicable to non-distributed training)')
group_gpus.add_argument( group_gpus.add_argument(
'--gpu-ids', '--gpu-ids',
type=int, type=int,
nargs='+', nargs='+',
help='ids of gpus to use ' help='ids of gpus to use '
'(only applicable to non-distributed training)') '(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument( parser.add_argument(
'--deterministic', '--deterministic',
...@@ -56,18 +55,18 @@ def parse_args(): ...@@ -56,18 +55,18 @@ def parse_args():
nargs='+', nargs='+',
action=DictAction, action=DictAction,
help='override some settings in the used config, the key-value pair ' help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file (deprecate), ' 'in xxx=yyy format will be merged into config file (deprecate), '
'change to --cfg-options instead.') 'change to --cfg-options instead.')
parser.add_argument( parser.add_argument(
'--cfg-options', '--cfg-options',
nargs='+', nargs='+',
action=DictAction, action=DictAction,
help='override some settings in the used config, the key-value pair ' help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to ' 'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space ' 'Note that the quotation marks are necessary and that no white space '
'is allowed.') 'is allowed.')
parser.add_argument( parser.add_argument(
'--launcher', '--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
...@@ -110,7 +109,7 @@ def main(): ...@@ -110,7 +109,7 @@ def main():
# import modules, registry will be updated # import modules, registry will be updated
import sys import sys
sys.path.append(os.path.abspath('.')) sys.path.append(os.path.abspath('.'))
if hasattr(cfg, 'plugin'): if hasattr(cfg, 'plugin'):
if cfg.plugin: if cfg.plugin:
import importlib import importlib
...@@ -126,11 +125,11 @@ def main(): ...@@ -126,11 +125,11 @@ def main():
plg_lib = importlib.import_module(_module_path) plg_lib = importlib.import_module(_module_path)
plugin_dirs = cfg.plugin_dir plugin_dirs = cfg.plugin_dir
if not isinstance(plugin_dirs,list): if not isinstance(plugin_dirs, list):
plugin_dirs = [plugin_dirs,] plugin_dirs = [plugin_dirs, ]
for plugin_dir in plugin_dirs: for plugin_dir in plugin_dirs:
import_path(plugin_dir) import_path(plugin_dir)
else: else:
# import dir is the dirpath for the config file # import dir is the dirpath for the config file
_module_dir = os.path.dirname(args.config) _module_dir = os.path.dirname(args.config)
......
<div id="top" align="center"> <div id="top" align="center">
# InternImage for CVPR 2023 Workshop on End-to-End Autonomous Driving # InternImage for CVPR 2023 Workshop on End-to-End Autonomous Driving
</div>
</div>
## 1. InternImage-based Baseline for CVPR23 Occupancy Prediction Challenge ## 1. InternImage-based Baseline for CVPR23 Occupancy Prediction Challenge
We achieve an improvement of 1.44 in MIOU baseline by leveraging the InterImage-based model. We achieve an improvement of 1.44 in MIOU baseline by leveraging the InterImage-based model.
model name|weight| mIoU | others | barrier | bicycle | bus | car | construction_vehicle | motorcycle | pedestrian | traffic_cone | trailer | truck | driveable_surface | other_flat | sidewalk | terrain | manmade | vegetation | | model name | weight | mIoU | others | barrier | bicycle | bus | car | construction_vehicle | motorcycle | pedestrian | traffic_cone | trailer | truck | driveable_surface | other_flat | sidewalk | terrain | manmade | vegetation |
----|:----------:| :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :----------------------: | :---: | :------: | :------: | | ---------------------- | :---------------------------------------------------------------------------------------------------: | :---: | :----: | :-----: | :-----: | :---: | :---: | :------------------: | :--------: | :--------: | :----------: | :-----: | :---: | :---------------: | :--------: | :------: | :-----: | :-----: | :--------: |
bevformer_intern-s_occ|[Google Drive](https://drive.google.com/file/d/1LV9K8hrskKf51xY1wbqTKzK7WZmVXEV_/view?usp=sharing)| 25.11 | 6.93 | 35.57 | 10.40 | 35.97 | 41.23 | 13.72 | 20.30 | 21.10 | 18.34 | 19.18 | 28.64 | 49.82 | 30.74 | 31.00 | 27.44 | 19.29 | 17.29 | | bevformer_intern-s_occ | [Google Drive](https://drive.google.com/file/d/1LV9K8hrskKf51xY1wbqTKzK7WZmVXEV_/view?usp=sharing) | 25.11 | 6.93 | 35.57 | 10.40 | 35.97 | 41.23 | 13.72 | 20.30 | 21.10 | 18.34 | 19.18 | 28.64 | 49.82 | 30.74 | 31.00 | 27.44 | 19.29 | 17.29 |
bevformer_base_occ|[Google Drive](https://drive.google.com/file/d/1NyoiosafAmne1qiABeNOPXR-P-y0i7_I/view?usp=share_link)| 23.67 | 5.03 | 38.79 | 9.98 | 34.41 | 41.09 | 13.24 | 16.50 | 18.15 | 17.83 | 18.66 | 27.70 | 48.95 | 27.73 | 29.08 | 25.38 | 15.41 | 14.46 | | bevformer_base_occ | [Google Drive](https://drive.google.com/file/d/1NyoiosafAmne1qiABeNOPXR-P-y0i7_I/view?usp=share_link) | 23.67 | 5.03 | 38.79 | 9.98 | 34.41 | 41.09 | 13.24 | 16.50 | 18.15 | 17.83 | 18.66 | 27.70 | 48.95 | 27.73 | 29.08 | 25.38 | 15.41 | 14.46 |
### Get Started ### Get Started
please refer to [README.md](./occupancy_prediction/README.md)
please refer to [README.md](./occupancy_prediction/README.md)
## 2. InternImage-based Baseline for Online HD Map Construction Challenge For Autonomous Driving ## 2. InternImage-based Baseline for Online HD Map Construction Challenge For Autonomous Driving
By incorporating the InterImage-based model, we observe an enhancement of 6.56 in mAP baseline. By incorporating the InterImage-based model, we observe an enhancement of 6.56 in mAP baseline.
model name|weight|$\mathrm{mAP}$ | $\mathrm{AP}_{pc}$ | $\mathrm{AP}_{div}$ | $\mathrm{AP}_{bound}$ | | model name | weight | $\\mathrm{mAP}$ | $\\mathrm{AP}\_{pc}$ | $\\mathrm{AP}\_{div}$ | $\\mathrm{AP}\_{bound}$ |
----|:----------:| :--: | :--: | :--: | :--: | | ------------------- | :---------------------------------------------------------------------------------------------------------------: | :-------------: | :------------------: | :-------------------: | :---------------------: |
vectormapnet_intern|[Checkpoint](https://github.com/OpenGVLab/InternImage/releases/download/track_model/vectormapnet_internimage.pth) | 49.35 | 45.05 | 56.78 | 46.22 | | vectormapnet_intern | [Checkpoint](https://github.com/OpenGVLab/InternImage/releases/download/track_model/vectormapnet_internimage.pth) | 49.35 | 45.05 | 56.78 | 46.22 |
vectormapnet_base|[Google Drive](https://drive.google.com/file/d/16D1CMinwA8PG1sd9PV9_WtHzcBohvO-D/view) | 42.79 | 37.22 | 50.47 | 40.68 | | vectormapnet_base | [Google Drive](https://drive.google.com/file/d/16D1CMinwA8PG1sd9PV9_WtHzcBohvO-D/view) | 42.79 | 37.22 | 50.47 | 40.68 |
### Get Started ### Get Started
please refer to [README.md](Online-HD-Map-Construction/README.md)
please refer to [README.md](Online-HD-Map-Construction/README.md)
## 3. InternImage-based Baseline for CVPR23 OpenLane-V2 Challenge ## 3. InternImage-based Baseline for CVPR23 OpenLane-V2 Challenge
Through the implementation of the InterImage-based model, we achieve an advancement of 0.009 in F-score baseline.
Through the implementation of the InterImage-based model, we achieve an advancement of 0.009 in F-score baseline.
| | OpenLane-V2 Score | DET<sub>l</sub> | DET<sub>t</sub> | TOP<sub>ll</sub> | TOP<sub>lt</sub> | F-Score | | | OpenLane-V2 Score | DET<sub>l</sub> | DET<sub>t</sub> | TOP<sub>ll</sub> | TOP<sub>lt</sub> | F-Score |
|-------------|-------------------|-----------------|-----------------|------------------|------------------|---------| | ----------- | ----------------- | --------------- | --------------- | ---------------- | ---------------- | ------- |
| base r50 | 0.292 | 0.183 | 0.457 | 0.022 | 0.143 | 0.215 | | base r50 | 0.292 | 0.183 | 0.457 | 0.022 | 0.143 | 0.215 |
| InternImage | 0.325 | 0.194 | 0.537 | 0.02 | 0.17 | 0.224 | | InternImage | 0.325 | 0.194 | 0.537 | 0.02 | 0.17 | 0.224 |
### Get Started ### Get Started
please refer to [README.md](./openlane-v2/README.md)
please refer to [README.md](./openlane-v2/README.md)
...@@ -5,4 +5,4 @@ authors: ...@@ -5,4 +5,4 @@ authors:
title: "OpenOccupancy: 3D Occupancy Benchmark for Scene Perception in Autonomous Driving" title: "OpenOccupancy: 3D Occupancy Benchmark for Scene Perception in Autonomous Driving"
date-released: 2023-02-10 date-released: 2023-02-10
url: "https://github.com/CVPR2023-Occupancy-Prediction-Challenge/CVPR2023-Occupancy-Prediction-Challenge" url: "https://github.com/CVPR2023-Occupancy-Prediction-Challenge/CVPR2023-Occupancy-Prediction-Challenge"
license: Apache-2.0 license: Apache-2.0
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment