Commit afe88104 authored by lishj6's avatar lishj6 🏸
Browse files

init0905

parent a48c4071
from typing import List, Optional, Tuple, Union
import warnings
import copy
import numpy as np
import cv2
import torch
import torch.nn as nn
from mmcv.utils import build_from_cfg
from mmcv.cnn import Linear, bias_init_with_prob
from mmcv.runner import BaseModule, force_fp32
from mmcv.cnn.bricks.registry import (
ATTENTION,
PLUGIN_LAYERS,
POSITIONAL_ENCODING,
FEEDFORWARD_NETWORK,
NORM_LAYERS,
)
from mmdet.core import reduce_mean
from mmdet.models import HEADS
from mmdet.core.bbox.builder import BBOX_SAMPLERS, BBOX_CODERS
from mmdet.models import build_loss
from projects.mmdet3d_plugin.datasets.utils import box3d_to_corners
from projects.mmdet3d_plugin.core.box3d import *
from ..attention import gen_sineembed_for_position
from ..blocks import linear_relu_ln
from ..instance_bank import topk
@HEADS.register_module()
class MotionPlanningHead(BaseModule):
def __init__(
self,
fut_ts=12,
fut_mode=6,
ego_fut_ts=6,
ego_fut_mode=3,
motion_anchor=None,
plan_anchor=None,
embed_dims=256,
decouple_attn=False,
instance_queue=None,
operation_order=None,
temp_graph_model=None,
graph_model=None,
cross_graph_model=None,
norm_layer=None,
ffn=None,
refine_layer=None,
motion_sampler=None,
motion_loss_cls=None,
motion_loss_reg=None,
planning_sampler=None,
plan_loss_cls=None,
plan_loss_reg=None,
plan_loss_status=None,
motion_decoder=None,
planning_decoder=None,
num_det=50,
num_map=10,
):
super(MotionPlanningHead, self).__init__()
self.fut_ts = fut_ts
self.fut_mode = fut_mode
self.ego_fut_ts = ego_fut_ts
self.ego_fut_mode = ego_fut_mode
self.decouple_attn = decouple_attn
self.operation_order = operation_order
# =========== build modules ===========
def build(cfg, registry):
if cfg is None:
return None
return build_from_cfg(cfg, registry)
self.instance_queue = build(instance_queue, PLUGIN_LAYERS)
self.motion_sampler = build(motion_sampler, BBOX_SAMPLERS)
self.planning_sampler = build(planning_sampler, BBOX_SAMPLERS)
self.motion_decoder = build(motion_decoder, BBOX_CODERS)
self.planning_decoder = build(planning_decoder, BBOX_CODERS)
self.op_config_map = {
"temp_gnn": [temp_graph_model, ATTENTION],
"gnn": [graph_model, ATTENTION],
"cross_gnn": [cross_graph_model, ATTENTION],
"norm": [norm_layer, NORM_LAYERS],
"ffn": [ffn, FEEDFORWARD_NETWORK],
"refine": [refine_layer, PLUGIN_LAYERS],
}
self.layers = nn.ModuleList(
[
build(*self.op_config_map.get(op, [None, None]))
for op in self.operation_order
]
)
self.embed_dims = embed_dims
if self.decouple_attn:
self.fc_before = nn.Linear(
self.embed_dims, self.embed_dims * 2, bias=False
)
self.fc_after = nn.Linear(
self.embed_dims * 2, self.embed_dims, bias=False
)
else:
self.fc_before = nn.Identity()
self.fc_after = nn.Identity()
self.motion_loss_cls = build_loss(motion_loss_cls)
self.motion_loss_reg = build_loss(motion_loss_reg)
self.plan_loss_cls = build_loss(plan_loss_cls)
self.plan_loss_reg = build_loss(plan_loss_reg)
self.plan_loss_status = build_loss(plan_loss_status)
# motion init
motion_anchor = np.load(motion_anchor)
self.motion_anchor = nn.Parameter(
torch.tensor(motion_anchor, dtype=torch.float32),
requires_grad=False,
)
self.motion_anchor_encoder = nn.Sequential(
*linear_relu_ln(embed_dims, 1, 1),
Linear(embed_dims, embed_dims),
)
# plan anchor init
plan_anchor = np.load(plan_anchor)
self.plan_anchor = nn.Parameter(
torch.tensor(plan_anchor, dtype=torch.float32),
requires_grad=False,
)
self.plan_anchor_encoder = nn.Sequential(
*linear_relu_ln(embed_dims, 1, 1),
Linear(embed_dims, embed_dims),
)
self.num_det = num_det
self.num_map = num_map
def init_weights(self):
for i, op in enumerate(self.operation_order):
if self.layers[i] is None:
continue
elif op != "refine":
for p in self.layers[i].parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if hasattr(m, "init_weight"):
m.init_weight()
def get_motion_anchor(
self,
classification,
prediction,
):
cls_ids = classification.argmax(dim=-1)
motion_anchor = self.motion_anchor[cls_ids]
prediction = prediction.detach()
return self._agent2lidar(motion_anchor, prediction)
def _agent2lidar(self, trajs, boxes):
yaw = torch.atan2(boxes[..., SIN_YAW], boxes[..., COS_YAW])
cos_yaw = torch.cos(yaw)
sin_yaw = torch.sin(yaw)
rot_mat_T = torch.stack(
[
torch.stack([cos_yaw, sin_yaw]),
torch.stack([-sin_yaw, cos_yaw]),
]
)
trajs_lidar = torch.einsum('abcij,jkab->abcik', trajs, rot_mat_T)
return trajs_lidar
def graph_model(
self,
index,
query,
key=None,
value=None,
query_pos=None,
key_pos=None,
**kwargs,
):
if self.decouple_attn:
query = torch.cat([query, query_pos], dim=-1)
if key is not None:
key = torch.cat([key, key_pos], dim=-1)
query_pos, key_pos = None, None
if value is not None:
value = self.fc_before(value)
return self.fc_after(
self.layers[index](
query,
key,
value,
query_pos=query_pos,
key_pos=key_pos,
**kwargs,
)
)
def forward(
self,
det_output,
map_output,
feature_maps,
metas,
anchor_encoder,
mask,
anchor_handler,
):
# =========== det/map feature/anchor ===========
instance_feature = det_output["instance_feature"]
anchor_embed = det_output["anchor_embed"]
det_classification = det_output["classification"][-1].sigmoid()
det_anchors = det_output["prediction"][-1]
det_confidence = det_classification.max(dim=-1).values
_, (instance_feature_selected, anchor_embed_selected) = topk(
det_confidence, self.num_det, instance_feature, anchor_embed
)
map_instance_feature = map_output["instance_feature"]
map_anchor_embed = map_output["anchor_embed"]
map_classification = map_output["classification"][-1].sigmoid()
map_anchors = map_output["prediction"][-1]
map_confidence = map_classification.max(dim=-1).values
_, (map_instance_feature_selected, map_anchor_embed_selected) = topk(
map_confidence, self.num_map, map_instance_feature, map_anchor_embed
)
# =========== get ego/temporal feature/anchor ===========
bs, num_anchor, dim = instance_feature.shape
(
ego_feature,
ego_anchor,
temp_instance_feature,
temp_anchor,
temp_mask,
) = self.instance_queue.get(
det_output,
feature_maps,
metas,
bs,
mask,
anchor_handler,
)
ego_anchor_embed = anchor_encoder(ego_anchor)
temp_anchor_embed = anchor_encoder(temp_anchor)
temp_instance_feature = temp_instance_feature.flatten(0, 1)
temp_anchor_embed = temp_anchor_embed.flatten(0, 1)
temp_mask = temp_mask.flatten(0, 1)
# =========== mode anchor init ===========
motion_anchor = self.get_motion_anchor(det_classification, det_anchors)
plan_anchor = torch.tile(
self.plan_anchor[None], (bs, 1, 1, 1, 1)
)
# =========== mode query init ===========
motion_mode_query = self.motion_anchor_encoder(gen_sineembed_for_position(motion_anchor[..., -1, :]))
plan_pos = gen_sineembed_for_position(plan_anchor[..., -1, :])
plan_mode_query = self.plan_anchor_encoder(plan_pos).flatten(1, 2).unsqueeze(1)
# =========== cat instance and ego ===========
instance_feature_selected = torch.cat([instance_feature_selected, ego_feature], dim=1)
anchor_embed_selected = torch.cat([anchor_embed_selected, ego_anchor_embed], dim=1)
instance_feature = torch.cat([instance_feature, ego_feature], dim=1)
anchor_embed = torch.cat([anchor_embed, ego_anchor_embed], dim=1)
# =================== forward the layers ====================
motion_classification = []
motion_prediction = []
planning_classification = []
planning_prediction = []
planning_status = []
for i, op in enumerate(self.operation_order):
if self.layers[i] is None:
continue
elif op == "temp_gnn":
instance_feature = self.graph_model(
i,
instance_feature.flatten(0, 1).unsqueeze(1),
temp_instance_feature,
temp_instance_feature,
query_pos=anchor_embed.flatten(0, 1).unsqueeze(1),
key_pos=temp_anchor_embed,
key_padding_mask=temp_mask,
)
instance_feature = instance_feature.reshape(bs, num_anchor + 1, dim)
elif op == "gnn":
instance_feature = self.graph_model(
i,
instance_feature,
instance_feature_selected,
instance_feature_selected,
query_pos=anchor_embed,
key_pos=anchor_embed_selected,
)
elif op == "norm" or op == "ffn":
instance_feature = self.layers[i](instance_feature)
elif op == "cross_gnn":
instance_feature = self.layers[i](
instance_feature,
key=map_instance_feature_selected,
query_pos=anchor_embed,
key_pos=map_anchor_embed_selected,
)
elif op == "refine":
motion_query = motion_mode_query + (instance_feature + anchor_embed)[:, :num_anchor].unsqueeze(2)
plan_query = plan_mode_query + (instance_feature + anchor_embed)[:, num_anchor:].unsqueeze(2)
(
motion_cls,
motion_reg,
plan_cls,
plan_reg,
plan_status,
) = self.layers[i](
motion_query,
plan_query,
instance_feature[:, num_anchor:],
anchor_embed[:, num_anchor:],
)
motion_classification.append(motion_cls)
motion_prediction.append(motion_reg)
planning_classification.append(plan_cls)
planning_prediction.append(plan_reg)
planning_status.append(plan_status)
self.instance_queue.cache_motion(instance_feature[:, :num_anchor], det_output, metas)
self.instance_queue.cache_planning(instance_feature[:, num_anchor:], plan_status)
motion_output = {
"classification": motion_classification,
"prediction": motion_prediction,
"period": self.instance_queue.period,
"anchor_queue": self.instance_queue.anchor_queue,
}
planning_output = {
"classification": planning_classification,
"prediction": planning_prediction,
"status": planning_status,
"period": self.instance_queue.ego_period,
"anchor_queue": self.instance_queue.ego_anchor_queue,
}
return motion_output, planning_output
def loss(self,
motion_model_outs,
planning_model_outs,
data,
motion_loss_cache
):
loss = {}
motion_loss = self.loss_motion(motion_model_outs, data, motion_loss_cache)
loss.update(motion_loss)
planning_loss = self.loss_planning(planning_model_outs, data)
loss.update(planning_loss)
return loss
@force_fp32(apply_to=("model_outs"))
def loss_motion(self, model_outs, data, motion_loss_cache):
cls_scores = model_outs["classification"]
reg_preds = model_outs["prediction"]
output = {}
for decoder_idx, (cls, reg) in enumerate(
zip(cls_scores, reg_preds)
):
(
cls_target,
cls_weight,
reg_pred,
reg_target,
reg_weight,
num_pos
) = self.motion_sampler.sample(
reg,
data["gt_agent_fut_trajs"],
data["gt_agent_fut_masks"],
motion_loss_cache,
)
num_pos = max(reduce_mean(num_pos), 1.0)
cls = cls.flatten(end_dim=1)
cls_target = cls_target.flatten(end_dim=1)
cls_weight = cls_weight.flatten(end_dim=1)
cls_loss = self.motion_loss_cls(cls, cls_target, weight=cls_weight, avg_factor=num_pos)
reg_weight = reg_weight.flatten(end_dim=1)
reg_pred = reg_pred.flatten(end_dim=1)
reg_target = reg_target.flatten(end_dim=1)
reg_weight = reg_weight.unsqueeze(-1)
reg_pred = reg_pred.cumsum(dim=-2)
reg_target = reg_target.cumsum(dim=-2)
reg_loss = self.motion_loss_reg(
reg_pred, reg_target, weight=reg_weight, avg_factor=num_pos
)
output.update(
{
f"motion_loss_cls_{decoder_idx}": cls_loss,
f"motion_loss_reg_{decoder_idx}": reg_loss,
}
)
return output
@force_fp32(apply_to=("model_outs"))
def loss_planning(self, model_outs, data):
cls_scores = model_outs["classification"]
reg_preds = model_outs["prediction"]
status_preds = model_outs["status"]
output = {}
for decoder_idx, (cls, reg, status) in enumerate(
zip(cls_scores, reg_preds, status_preds)
):
(
cls,
cls_target,
cls_weight,
reg_pred,
reg_target,
reg_weight,
) = self.planning_sampler.sample(
cls,
reg,
data['gt_ego_fut_trajs'],
data['gt_ego_fut_masks'],
data,
)
cls = cls.flatten(end_dim=1)
cls_target = cls_target.flatten(end_dim=1)
cls_weight = cls_weight.flatten(end_dim=1)
cls_loss = self.plan_loss_cls(cls, cls_target, weight=cls_weight)
reg_weight = reg_weight.flatten(end_dim=1)
reg_pred = reg_pred.flatten(end_dim=1)
reg_target = reg_target.flatten(end_dim=1)
reg_weight = reg_weight.unsqueeze(-1)
reg_loss = self.plan_loss_reg(
reg_pred, reg_target, weight=reg_weight
)
status_loss = self.plan_loss_status(status.squeeze(1), data['ego_status'])
output.update(
{
f"planning_loss_cls_{decoder_idx}": cls_loss,
f"planning_loss_reg_{decoder_idx}": reg_loss,
f"planning_loss_status_{decoder_idx}": status_loss,
}
)
return output
@force_fp32(apply_to=("model_outs"))
def post_process(
self,
det_output,
motion_output,
planning_output,
data,
):
motion_result = self.motion_decoder.decode(
det_output["classification"],
det_output["prediction"],
det_output.get("instance_id"),
det_output.get("quality"),
motion_output,
)
planning_result = self.planning_decoder.decode(
det_output,
motion_output,
planning_output,
data,
)
return motion_result, planning_result
\ No newline at end of file
import torch
from mmdet.core.bbox.builder import BBOX_SAMPLERS
__all__ = ["MotionTarget", "PlanningTarget"]
def get_cls_target(
reg_preds,
reg_target,
reg_weight,
):
bs, num_pred, mode, ts, d = reg_preds.shape
reg_preds_cum = reg_preds.cumsum(dim=-2)
reg_target_cum = reg_target.cumsum(dim=-2)
dist = torch.linalg.norm(reg_target_cum.unsqueeze(2) - reg_preds_cum, dim=-1)
dist = dist * reg_weight.unsqueeze(2)
dist = dist.mean(dim=-1)
mode_idx = torch.argmin(dist, dim=-1)
return mode_idx
def get_best_reg(
reg_preds,
reg_target,
reg_weight,
):
bs, num_pred, mode, ts, d = reg_preds.shape
reg_preds_cum = reg_preds.cumsum(dim=-2)
reg_target_cum = reg_target.cumsum(dim=-2)
dist = torch.linalg.norm(reg_target_cum.unsqueeze(2) - reg_preds_cum, dim=-1)
dist = dist * reg_weight.unsqueeze(2)
dist = dist.mean(dim=-1)
mode_idx = torch.argmin(dist, dim=-1)
mode_idx = mode_idx[..., None, None, None].repeat(1, 1, 1, ts, d)
best_reg = torch.gather(reg_preds, 2, mode_idx).squeeze(2)
return best_reg
@BBOX_SAMPLERS.register_module()
class MotionTarget():
def __init__(
self,
):
super(MotionTarget, self).__init__()
def sample(
self,
reg_pred,
gt_reg_target,
gt_reg_mask,
motion_loss_cache,
):
bs, num_anchor, mode, ts, d = reg_pred.shape
reg_target = reg_pred.new_zeros((bs, num_anchor, ts, d))
reg_weight = reg_pred.new_zeros((bs, num_anchor, ts))
indices = motion_loss_cache['indices']
num_pos = reg_pred.new_tensor([0])
for i, (pred_idx, target_idx) in enumerate(indices):
if len(gt_reg_target[i]) == 0:
continue
reg_target[i, pred_idx] = gt_reg_target[i][target_idx]
reg_weight[i, pred_idx] = gt_reg_mask[i][target_idx]
num_pos += len(pred_idx)
cls_target = get_cls_target(reg_pred, reg_target, reg_weight)
cls_weight = reg_weight.any(dim=-1)
best_reg = get_best_reg(reg_pred, reg_target, reg_weight)
return cls_target, cls_weight, best_reg, reg_target, reg_weight, num_pos
@BBOX_SAMPLERS.register_module()
class PlanningTarget():
def __init__(
self,
ego_fut_ts,
ego_fut_mode,
):
super(PlanningTarget, self).__init__()
self.ego_fut_ts = ego_fut_ts
self.ego_fut_mode = ego_fut_mode
def sample(
self,
cls_pred,
reg_pred,
gt_reg_target,
gt_reg_mask,
data,
):
gt_reg_target = gt_reg_target.unsqueeze(1)
gt_reg_mask = gt_reg_mask.unsqueeze(1)
bs = reg_pred.shape[0]
bs_indices = torch.arange(bs, device=reg_pred.device)
cmd = data['gt_ego_fut_cmd'].argmax(dim=-1)
cls_pred = cls_pred.reshape(bs, 3, 1, self.ego_fut_mode)
reg_pred = reg_pred.reshape(bs, 3, 1, self.ego_fut_mode, self.ego_fut_ts, 2)
cls_pred = cls_pred[bs_indices, cmd]
reg_pred = reg_pred[bs_indices, cmd]
cls_target = get_cls_target(reg_pred, gt_reg_target, gt_reg_mask)
cls_weight = gt_reg_mask.any(dim=-1)
best_reg = get_best_reg(reg_pred, gt_reg_target, gt_reg_mask)
return cls_pred, cls_target, cls_weight, best_reg, gt_reg_target, gt_reg_mask
from inspect import signature
import torch
from mmcv.runner import force_fp32, auto_fp16
from mmcv.utils import build_from_cfg
from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
from mmdet.models import (
DETECTORS,
BaseDetector,
build_backbone,
build_head,
build_neck,
)
from .grid_mask import GridMask
try:
from ..ops import feature_maps_format
DAF_VALID = True
except:
DAF_VALID = False
__all__ = ["SparseDrive"]
@DETECTORS.register_module()
class SparseDrive(BaseDetector):
def __init__(
self,
img_backbone,
head,
img_neck=None,
init_cfg=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
use_grid_mask=True,
use_deformable_func=False,
depth_branch=None,
):
super(SparseDrive, self).__init__(init_cfg=init_cfg)
if pretrained is not None:
backbone.pretrained = pretrained
self.img_backbone = build_backbone(img_backbone)
if img_neck is not None:
self.img_neck = build_neck(img_neck)
self.head = build_head(head)
self.use_grid_mask = use_grid_mask
if use_deformable_func:
assert DAF_VALID, "deformable_aggregation needs to be set up."
self.use_deformable_func = use_deformable_func
if depth_branch is not None:
self.depth_branch = build_from_cfg(depth_branch, PLUGIN_LAYERS)
else:
self.depth_branch = None
if use_grid_mask:
self.grid_mask = GridMask(
True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7
)
@auto_fp16(apply_to=("img",), out_fp32=True)
def extract_feat(self, img, return_depth=False, metas=None):
bs = img.shape[0]
if img.dim() == 5: # multi-view
num_cams = img.shape[1]
img = img.flatten(end_dim=1)
else:
num_cams = 1
if self.use_grid_mask:
img = self.grid_mask(img)
img = img.to(memory_format=torch.channels_last)
if "metas" in signature(self.img_backbone.forward).parameters:
feature_maps = self.img_backbone(img, num_cams, metas=metas)
else:
feature_maps = self.img_backbone(img)
if self.img_neck is not None:
feature_maps = list(self.img_neck(feature_maps))
for i, feat in enumerate(feature_maps):
feature_maps[i] = torch.reshape(
feat, (bs, num_cams) + feat.shape[1:]
)
if return_depth and self.depth_branch is not None:
depths = self.depth_branch(feature_maps, metas.get("focal"))
else:
depths = None
if self.use_deformable_func:
feature_maps = feature_maps_format(feature_maps)
if return_depth:
return feature_maps, depths
return feature_maps
@force_fp32(apply_to=("img",))
def forward(self, img, **data):
if self.training:
return self.forward_train(img, **data)
else:
return self.forward_test(img, **data)
def forward_train(self, img, **data):
feature_maps, depths = self.extract_feat(img, True, data)
model_outs = self.head(feature_maps, data)
output = self.head.loss(model_outs, data)
if depths is not None and "gt_depth" in data:
output["loss_dense_depth"] = self.depth_branch.loss(
depths, data["gt_depth"]
)
return output
def forward_test(self, img, **data):
if isinstance(img, list):
return self.aug_test(img, **data)
else:
return self.simple_test(img, **data)
def simple_test(self, img, **data):
feature_maps = self.extract_feat(img)
model_outs = self.head(feature_maps, data)
results = self.head.post_process(model_outs, data)
output = [{"img_bbox": result} for result in results]
return output
def aug_test(self, img, **data):
# fake test time augmentation
for key in data.keys():
if isinstance(data[key], list):
data[key] = data[key][0]
return self.simple_test(img[0], **data)
from typing import List, Optional, Tuple, Union
import warnings
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner import BaseModule
from mmdet.models import HEADS
from mmdet.models import build_head
@HEADS.register_module()
class SparseDriveHead(BaseModule):
def __init__(
self,
task_config: dict,
det_head = dict,
map_head = dict,
motion_plan_head = dict,
init_cfg=None,
**kwargs,
):
super(SparseDriveHead, self).__init__(init_cfg)
self.task_config = task_config
if self.task_config['with_det']:
self.det_head = build_head(det_head)
if self.task_config['with_map']:
self.map_head = build_head(map_head)
if self.task_config['with_motion_plan']:
self.motion_plan_head = build_head(motion_plan_head)
def init_weights(self):
if self.task_config['with_det']:
self.det_head.init_weights()
if self.task_config['with_map']:
self.map_head.init_weights()
if self.task_config['with_motion_plan']:
self.motion_plan_head.init_weights()
def forward(
self,
feature_maps: Union[torch.Tensor, List],
metas: dict,
):
if self.task_config['with_det']:
det_output = self.det_head(feature_maps, metas)
else:
det_output = None
if self.task_config['with_map']:
map_output = self.map_head(feature_maps, metas)
else:
map_output = None
if self.task_config['with_motion_plan']:
motion_output, planning_output = self.motion_plan_head(
det_output,
map_output,
feature_maps,
metas,
self.det_head.anchor_encoder,
self.det_head.instance_bank.mask,
self.det_head.instance_bank.anchor_handler,
)
else:
motion_output, planning_output = None, None
return det_output, map_output, motion_output, planning_output
def loss(self, model_outs, data):
det_output, map_output, motion_output, planning_output = model_outs
losses = dict()
if self.task_config['with_det']:
loss_det = self.det_head.loss(det_output, data)
losses.update(loss_det)
if self.task_config['with_map']:
loss_map = self.map_head.loss(map_output, data)
losses.update(loss_map)
if self.task_config['with_motion_plan']:
motion_loss_cache = dict(
indices=self.det_head.sampler.indices,
)
loss_motion = self.motion_plan_head.loss(
motion_output,
planning_output,
data,
motion_loss_cache
)
losses.update(loss_motion)
return losses
def post_process(self, model_outs, data):
det_output, map_output, motion_output, planning_output = model_outs
if self.task_config['with_det']:
det_result = self.det_head.post_process(det_output)
batch_size = len(det_result)
if self.task_config['with_map']:
map_result= self.map_head.post_process(map_output)
batch_size = len(map_result)
if self.task_config['with_motion_plan']:
motion_result, planning_result = self.motion_plan_head.post_process(
det_output,
motion_output,
planning_output,
data,
)
results = [dict()] * batch_size
for i in range(batch_size):
if self.task_config['with_det']:
results[i].update(det_result[i])
if self.task_config['with_map']:
results[i].update(map_result[i])
if self.task_config['with_motion_plan']:
results[i].update(motion_result[i])
results[i].update(planning_result[i])
return results
from inspect import signature
import torch
from mmcv.runner import force_fp32, auto_fp16
from mmcv.utils import build_from_cfg
from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
from mmdet.models import (
DETECTORS,
BaseDetector,
build_backbone,
build_head,
build_neck,
)
from .grid_mask import GridMask
try:
from ..ops import feature_maps_format
DAF_VALID = True
except:
DAF_VALID = False
__all__ = ["SparseDrive"]
@DETECTORS.register_module()
class SparseDrive(BaseDetector):
def __init__(
self,
img_backbone,
head,
img_neck=None,
init_cfg=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
use_grid_mask=True,
use_deformable_func=False,
depth_branch=None,
):
super(SparseDrive, self).__init__(init_cfg=init_cfg)
if pretrained is not None:
backbone.pretrained = pretrained
self.img_backbone = build_backbone(img_backbone)
if img_neck is not None:
self.img_neck = build_neck(img_neck)
self.head = build_head(head)
self.use_grid_mask = use_grid_mask
if use_deformable_func:
assert DAF_VALID, "deformable_aggregation needs to be set up."
self.use_deformable_func = use_deformable_func
if depth_branch is not None:
self.depth_branch = build_from_cfg(depth_branch, PLUGIN_LAYERS)
else:
self.depth_branch = None
if use_grid_mask:
self.grid_mask = GridMask(
True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7
)
@auto_fp16(apply_to=("img",), out_fp32=True)
def extract_feat(self, img, return_depth=False, metas=None):
bs = img.shape[0]
if img.dim() == 5: # multi-view
num_cams = img.shape[1]
img = img.flatten(end_dim=1)
else:
num_cams = 1
if self.use_grid_mask:
img = self.grid_mask(img)
if "metas" in signature(self.img_backbone.forward).parameters:
feature_maps = self.img_backbone(img, num_cams, metas=metas)
else:
feature_maps = self.img_backbone(img)
if self.img_neck is not None:
feature_maps = list(self.img_neck(feature_maps))
for i, feat in enumerate(feature_maps):
feature_maps[i] = torch.reshape(
feat, (bs, num_cams) + feat.shape[1:]
)
if return_depth and self.depth_branch is not None:
depths = self.depth_branch(feature_maps, metas.get("focal"))
else:
depths = None
if self.use_deformable_func:
feature_maps = feature_maps_format(feature_maps)
if return_depth:
return feature_maps, depths
return feature_maps
@force_fp32(apply_to=("img",))
def forward(self, img, **data):
if self.training:
return self.forward_train(img, **data)
else:
return self.forward_test(img, **data)
def forward_train(self, img, **data):
feature_maps, depths = self.extract_feat(img, True, data)
model_outs = self.head(feature_maps, data)
output = self.head.loss(model_outs, data)
if depths is not None and "gt_depth" in data:
output["loss_dense_depth"] = self.depth_branch.loss(
depths, data["gt_depth"]
)
return output
def forward_test(self, img, **data):
if isinstance(img, list):
return self.aug_test(img, **data)
else:
return self.simple_test(img, **data)
def simple_test(self, img, **data):
feature_maps = self.extract_feat(img)
model_outs = self.head(feature_maps, data)
results = self.head.post_process(model_outs, data)
output = [{"img_bbox": result} for result in results]
return output
def aug_test(self, img, **data):
# fake test time augmentation
for key in data.keys():
if isinstance(data[key], list):
data[key] = data[key][0]
return self.simple_test(img[0], **data)
import torch
from .deformable_aggregation import DeformableAggregationFunction
def deformable_aggregation_function(
feature_maps,
spatial_shape,
scale_start_index,
sampling_location,
weights,
):
return DeformableAggregationFunction.apply(
feature_maps,
spatial_shape,
scale_start_index,
sampling_location,
weights,
)
@torch.compile(mode="max-autotune-no-cudagraphs")
def feature_maps_format(feature_maps, inverse=False):
if inverse:
col_feats, spatial_shape, scale_start_index = feature_maps
num_cams, num_levels = spatial_shape.shape[:2]
split_size = spatial_shape[..., 0] * spatial_shape[..., 1]
split_size = split_size.cpu().numpy().tolist()
idx = 0
cam_split = [1]
cam_split_size = [sum(split_size[0])]
for i in range(num_cams - 1):
if not torch.all(spatial_shape[i] == spatial_shape[i + 1]):
cam_split.append(0)
cam_split_size.append(0)
cam_split[-1] += 1
cam_split_size[-1] += sum(split_size[i + 1])
mc_feat = [
x.unflatten(1, (cam_split[i], -1))
for i, x in enumerate(col_feats.split(cam_split_size, dim=1))
]
spatial_shape = spatial_shape.cpu().numpy().tolist()
mc_ms_feat = []
shape_index = 0
for i, feat in enumerate(mc_feat):
feat = list(feat.split(split_size[shape_index], dim=2))
for j, f in enumerate(feat):
feat[j] = f.unflatten(2, spatial_shape[shape_index][j])
feat[j] = feat[j].permute(0, 1, 4, 2, 3)
mc_ms_feat.append(feat)
shape_index += cam_split[i]
return mc_ms_feat
if isinstance(feature_maps[0], (list, tuple)):
formated = [feature_maps_format(x) for x in feature_maps]
col_feats = torch.cat([x[0] for x in formated], dim=1)
spatial_shape = torch.cat([x[1] for x in formated], dim=0)
scale_start_index = torch.cat([x[2] for x in formated], dim=0)
return [col_feats, spatial_shape, scale_start_index]
bs, num_cams = feature_maps[0].shape[:2]
spatial_shape = []
col_feats = []
for i, feat in enumerate(feature_maps):
spatial_shape.append(feat.shape[-2:])
col_feats.append(
torch.reshape(feat, (bs, num_cams, feat.shape[2], -1))
)
col_feats = torch.cat(col_feats, dim=-1).permute(0, 1, 3, 2).flatten(1, 2)
spatial_shape = [spatial_shape] * num_cams
spatial_shape = torch.tensor(
spatial_shape,
dtype=torch.int64,
device=col_feats.device,
)
scale_start_index = spatial_shape[..., 0] * spatial_shape[..., 1]
scale_start_index = scale_start_index.flatten().cumsum(dim=0)
scale_start_index = torch.cat(
[torch.tensor([0]).to(scale_start_index), scale_start_index[:-1]]
)
scale_start_index = scale_start_index.reshape(num_cams, -1)
feature_maps = [
col_feats,
spatial_shape,
scale_start_index,
]
return feature_maps
import torch
from torch.autograd.function import Function, once_differentiable
from . import deformable_aggregation_ext
class DeformableAggregationFunction(Function):
@staticmethod
def forward(
ctx,
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
):
# output: [bs, num_pts, num_embeds]
mc_ms_feat = mc_ms_feat.contiguous().float()
spatial_shape = spatial_shape.contiguous().int()
scale_start_index = scale_start_index.contiguous().int()
sampling_location = sampling_location.contiguous().float()
weights = weights.contiguous().float()
output = deformable_aggregation_ext.deformable_aggregation_forward(
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
)
ctx.save_for_backward(
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
(
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
) = ctx.saved_tensors
mc_ms_feat = mc_ms_feat.contiguous().float()
spatial_shape = spatial_shape.contiguous().int()
scale_start_index = scale_start_index.contiguous().int()
sampling_location = sampling_location.contiguous().float()
weights = weights.contiguous().float()
grad_mc_ms_feat = torch.zeros_like(mc_ms_feat)
grad_sampling_location = torch.zeros_like(sampling_location)
grad_weights = torch.zeros_like(weights)
deformable_aggregation_ext.deformable_aggregation_backward(
mc_ms_feat,
spatial_shape,
scale_start_index,
sampling_location,
weights,
grad_output.contiguous(),
grad_mc_ms_feat,
grad_sampling_location,
grad_weights,
)
return (
grad_mc_ms_feat,
None,
None,
grad_sampling_location,
grad_weights,
)
import os
import torch
from setuptools import setup
from torch.utils.cpp_extension import (
BuildExtension,
CppExtension,
CUDAExtension,
)
def make_cuda_ext(
name,
module,
sources,
sources_cuda=[],
extra_args=[],
extra_include_path=[],
):
define_macros = []
extra_compile_args = {"cxx": [] + extra_args}
if 1:
define_macros += [("WITH_CUDA", None)]
extension = CUDAExtension
extra_compile_args["nvcc"] = extra_args + [
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
sources += sources_cuda
else:
print("Compiling {} without CUDA".format(name))
extension = CppExtension
return extension(
name="{}.{}".format(module, name),
sources=[os.path.join(*module.split("."), p) for p in sources],
include_dirs=extra_include_path,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
if __name__ == "__main__":
setup(
name="deformable_aggregation_ext",
ext_modules=[
make_cuda_ext(
"deformable_aggregation_ext",
module=".",
sources=[
f"src/deformable_aggregation.cpp",
f"src/deformable_aggregation_cuda.cu",
],
),
],
cmdclass={"build_ext": BuildExtension},
)
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
void deformable_aggregation(
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
);
/* feat: bs, num_feat, c */
/* _spatial_shape: cam, scale, 2 */
/* _scale_start_index: cam, scale */
/* _sampling_location: bs, anchor, pts, cam, 2 */
/* _weights: bs, anchor, pts, cam, scale, group */
/* output: bs, anchor, c */
/* kernel: bs, anchor, pts, c */
at::Tensor deformable_aggregation_forward(
const at::Tensor &_mc_ms_feat,
const at::Tensor &_spatial_shape,
const at::Tensor &_scale_start_index,
const at::Tensor &_sampling_location,
const at::Tensor &_weights
) {
at::DeviceGuard guard(_mc_ms_feat.device());
const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat));
int batch_size = _mc_ms_feat.size(0);
int num_feat = _mc_ms_feat.size(1);
int num_embeds = _mc_ms_feat.size(2);
int num_cams = _spatial_shape.size(0);
int num_scale = _spatial_shape.size(1);
int num_anchors = _sampling_location.size(1);
int num_pts = _sampling_location.size(2);
int num_groups = _weights.size(5);
const float* mc_ms_feat = _mc_ms_feat.data_ptr<float>();
const int* spatial_shape = _spatial_shape.data_ptr<int>();
const int* scale_start_index = _scale_start_index.data_ptr<int>();
const float* sampling_location = _sampling_location.data_ptr<float>();
const float* weights = _weights.data_ptr<float>();
auto output = at::zeros({batch_size, num_anchors, num_embeds}, _mc_ms_feat.options());
deformable_aggregation(
output.data_ptr<float>(),
mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
return output;
}
void deformable_aggregation_grad(
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
);
void deformable_aggregation_backward(
const at::Tensor &_mc_ms_feat,
const at::Tensor &_spatial_shape,
const at::Tensor &_scale_start_index,
const at::Tensor &_sampling_location,
const at::Tensor &_weights,
const at::Tensor &_grad_output,
at::Tensor &_grad_mc_ms_feat,
at::Tensor &_grad_sampling_location,
at::Tensor &_grad_weights
) {
at::DeviceGuard guard(_mc_ms_feat.device());
const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat));
int batch_size = _mc_ms_feat.size(0);
int num_feat = _mc_ms_feat.size(1);
int num_embeds = _mc_ms_feat.size(2);
int num_cams = _spatial_shape.size(0);
int num_scale = _spatial_shape.size(1);
int num_anchors = _sampling_location.size(1);
int num_pts = _sampling_location.size(2);
int num_groups = _weights.size(5);
const float* mc_ms_feat = _mc_ms_feat.data_ptr<float>();
const int* spatial_shape = _spatial_shape.data_ptr<int>();
const int* scale_start_index = _scale_start_index.data_ptr<int>();
const float* sampling_location = _sampling_location.data_ptr<float>();
const float* weights = _weights.data_ptr<float>();
const float* grad_output = _grad_output.data_ptr<float>();
float* grad_mc_ms_feat = _grad_mc_ms_feat.data_ptr<float>();
float* grad_sampling_location = _grad_sampling_location.data_ptr<float>();
float* grad_weights = _grad_weights.data_ptr<float>();
deformable_aggregation_grad(
mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights,
grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"deformable_aggregation_forward",
&deformable_aggregation_forward,
"deformable_aggregation_forward"
);
m.def(
"deformable_aggregation_backward",
&deformable_aggregation_backward,
"deformable_aggregation_backward"
);
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <THC/THCAtomics.cuh>
#include <iostream>
#include <stdlib.h>
__device__ float bilinear_sampling(
const float *&bottom_data, const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr
) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
struct float2_t{
float a;
float b;
};
__forceinline__ __device__
float2_t warp_reduce_sum(float2_t val, int max = 32) {
for (int offset = max; offset > 0; offset >>= 1) {
val.a += __shfl_down(val.a, offset);
val.b += __shfl_down(val.b, offset);
}
return val;
}
template <int blocksize>
__forceinline__ __device__
float2_t block_reduce_sum(float2_t val, float2_t* shared) {
const int lid = threadIdx.x % 64;
const int wid = threadIdx.x / 64;
constexpr int share_size = blocksize / 64;
val = warp_reduce_sum(val);
if constexpr (blocksize == 64) return val;
if (lid == 0 && wid < share_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0 && lid < share_size) {
val = shared[lid];
val = warp_reduce_sum(val, share_size / 2);
}
return val;
}
template <int blocksize>
__device__ void bilinear_sampling_grad_sp(
const float *&bottom_data, const float &weight,
const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr,
const float &grad_output,
float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights,
float2_t* s_data) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float top_grad_mc_ms_feat = grad_output * weight;
float grad_h_weight = 0, grad_w_weight = 0;
const int valid1 = (h_low >= 0 && w_low >= 0);
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
float v1 = valid1 ? bottom_data[ptr1] : 0.0f;
if (valid1) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#endif
}
const int valid2 = (h_low >= 0 && w_high <= width - 1);
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
float v2 = valid2 ? bottom_data[ptr2] : 0.0f;
if (valid2) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#endif
}
const int valid3 = (h_high <= height - 1 && w_low >= 0);
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
float v3 = valid3 ? bottom_data[ptr3] : 0.0f;
if (valid3) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#endif
}
const int valid4 = (h_high <= height - 1 && w_high <= width - 1);
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
float v4 = valid4 ? bottom_data[ptr4] : 0.0f;
if (valid4) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#endif
}
grad_h_weight += (-hw * v1) + (-lw * v2) + ( hw * v3) + ( lw * v4);
grad_w_weight += (-hh * v1) + ( hh * v2) + (-lh * v3) + ( lh * v4);
float2_t spl;
spl.a = width * grad_w_weight * top_grad_mc_ms_feat;
spl.b = height * grad_h_weight * top_grad_mc_ms_feat;
spl = block_reduce_sum<blocksize>(spl, s_data);
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
float wei = grad_output * val;
for (int offset=16; offset>=1; offset >>= 1) {
wei += __shfl_down(wei, offset);
}
#ifdef __gfx936__
// __builtin_amdgcn_global_atomic_fadd_f32(grad_weights, grad_output * val);
if (threadIdx.x % 32 == 0) {
// __builtin_amdgcn_global_atomic_fadd_f32(grad_weights, wei);
*grad_weights += wei;
}
if (threadIdx.x ==0) {
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location, spl.a);
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location + 1, spl.b);
}
#else
atomicAdd(grad_weights, grad_output * val);
atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#endif
}
__device__ void bilinear_sampling_grad(
const float *&bottom_data, const float &weight,
const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr,
const float &grad_output,
float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float top_grad_mc_ms_feat = grad_output * weight;
float grad_h_weight = 0, grad_w_weight = 0;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
// atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#endif
}
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
// atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#endif
}
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
// atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#endif
}
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
// atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#endif
}
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
// atomicAdd(grad_weights, grad_output * val);
// atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
// atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_weights, grad_output * val);
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#else
atomicAdd(grad_weights, grad_output * val);
atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#endif
}
__global__ void deformable_aggregation_kernel(
const int64_t num_kernels,
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const float weight = *(weights + idx / (num_embeds / num_groups));
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
// atomicAdd(
// output + anchor_index * num_embeds + channel_index,
// bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight
// );
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(output + anchor_index * num_embeds + channel_index, bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight);
#else
atomicAdd(output + anchor_index * num_embeds + channel_index, bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight);
#endif
}
template <int blocksize>
__global__ void deformable_aggregation_grad_kernel_sp(
const int64_t num_kernels,
const float* mc_ms_feat, // [bs, anchor, pts, cam, scale, channel]
const int* spatial_shape, // [cam, scale, 2]
const int* scale_start_index, // [cam, scale]
const float* sample_location, // [bs, anchor, pts, cam, 2(y, x)]
const float* weights, // [bs, anchor, cam, scale, group]
const float* grad_output, // [bs, anchor, c]
float* grad_mc_ms_feat, // same as feat
float* grad_sampling_location, // same as sampling location
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
extern __shared__ float2_t s_data[];
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const int weights_ptr = idx / (num_embeds / num_groups);
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
const float grad = grad_output[anchor_index*num_embeds + channel_index];
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
/* atomicAdd( */
/* output + anchor_index * num_embeds + channel_index, */
/* bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */
/* ); */
const float weight = weights[weights_ptr];
float *grad_weights_ptr = grad_weights + weights_ptr;
float *grad_location_ptr = grad_sampling_location + loc_offset;
bilinear_sampling_grad_sp<blocksize>(
mc_ms_feat, weight, h, w, num_embeds, h_im, w_im,
value_offset,
grad,
grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr,
s_data
);
}
__global__ void deformable_aggregation_grad_kernel(
const int64_t num_kernels,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const int weights_ptr = idx / (num_embeds / num_groups);
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
const float grad = grad_output[anchor_index*num_embeds + channel_index];
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
/* atomicAdd( */
/* output + anchor_index * num_embeds + channel_index, */
/* bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */
/* ); */
const float weight = weights[weights_ptr];
float *grad_weights_ptr = grad_weights + weights_ptr;
float *grad_location_ptr = grad_sampling_location + loc_offset;
bilinear_sampling_grad(
mc_ms_feat, weight, h, w, num_embeds, h_im, w_im,
value_offset,
grad,
grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr
);
}
void deformable_aggregation(
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
const int64_t num_kernels = (int64_t)batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
deformable_aggregation_kernel
<<<(int)ceil(((double)num_kernels/128)), 128>>>(
num_kernels, output,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
void deformable_aggregation_grad(
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
const int64_t num_kernels =(int64_t)batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
if (num_embeds != 256 || ((num_embeds / num_groups) != 32)) {
deformable_aggregation_grad_kernel
<<<(int)ceil(((double)num_kernels/128)), 128>>>(
num_kernels,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
} else {
int blk_dim = 256;
deformable_aggregation_grad_kernel_sp<256>
<<<(int)ceil(((double)num_kernels/blk_dim)), blk_dim, blk_dim * 2 * sizeof(float)>>>(
num_kernels,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <ATen/ATen.h>
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <THH/THHAtomics.cuh>
#include <iostream>
#include <stdlib.h>
__device__ float bilinear_sampling(
const float *&bottom_data, const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr
) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
}
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
}
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
}
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
}
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
struct float2_t{
float a;
float b;
};
__forceinline__ __device__
float2_t warp_reduce_sum(float2_t val, int max = 32) {
for (int offset = max; offset > 0; offset >>= 1) {
val.a += __shfl_down(val.a, offset);
val.b += __shfl_down(val.b, offset);
}
return val;
}
template <int blocksize>
__forceinline__ __device__
float2_t block_reduce_sum(float2_t val, float2_t* shared) {
const int lid = threadIdx.x % 64;
const int wid = threadIdx.x / 64;
constexpr int share_size = blocksize / 64;
val = warp_reduce_sum(val);
if constexpr (blocksize == 64) return val;
if (lid == 0 && wid < share_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0 && lid < share_size) {
val = shared[lid];
val = warp_reduce_sum(val, share_size / 2);
}
return val;
}
template <int blocksize>
__device__ void bilinear_sampling_grad_sp(
const float *&bottom_data, const float &weight,
const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr,
const float &grad_output,
float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights,
float2_t* s_data) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float top_grad_mc_ms_feat = grad_output * weight;
float grad_h_weight = 0, grad_w_weight = 0;
const int valid1 = (h_low >= 0 && w_low >= 0);
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
float v1 = valid1 ? bottom_data[ptr1] : 0.0f;
if (valid1) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#endif
}
const int valid2 = (h_low >= 0 && w_high <= width - 1);
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
float v2 = valid2 ? bottom_data[ptr2] : 0.0f;
if (valid2) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#endif
}
const int valid3 = (h_high <= height - 1 && w_low >= 0);
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
float v3 = valid3 ? bottom_data[ptr3] : 0.0f;
if (valid3) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#endif
}
const int valid4 = (h_high <= height - 1 && w_high <= width - 1);
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
float v4 = valid4 ? bottom_data[ptr4] : 0.0f;
if (valid4) {
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#endif
}
grad_h_weight += (-hw * v1) + (-lw * v2) + ( hw * v3) + ( lw * v4);
grad_w_weight += (-hh * v1) + ( hh * v2) + (-lh * v3) + ( lh * v4);
float2_t spl;
spl.a = width * grad_w_weight * top_grad_mc_ms_feat;
spl.b = height * grad_h_weight * top_grad_mc_ms_feat;
spl = block_reduce_sum<blocksize>(spl, s_data);
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
float wei = grad_output * val;
for (int offset=16; offset>=1; offset >>= 1) {
wei += __shfl_down(wei, offset);
}
#ifdef __gfx936__
// __builtin_amdgcn_global_atomic_fadd_f32(grad_weights, grad_output * val);
if (threadIdx.x % 32 == 0) {
// __builtin_amdgcn_global_atomic_fadd_f32(grad_weights, wei);
*grad_weights += wei;
}
if (threadIdx.x ==0) {
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location, spl.a);
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location + 1, spl.b);
}
#else
atomicAdd(grad_weights, grad_output * val);
atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#endif
}
__device__ void bilinear_sampling_grad(
const float *&bottom_data, const float &weight,
const int &height, const int &width,
const int &num_embeds, const float &h_im, const float &w_im,
const int &base_ptr,
const float &grad_output,
float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights) {
const int h_low = floorf(h_im);
const int w_low = floorf(w_im);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h_im - h_low;
const float lw = w_im - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = num_embeds;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float top_grad_mc_ms_feat = grad_output * weight;
float grad_h_weight = 0, grad_w_weight = 0;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
// atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
#endif
}
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
// atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
#endif
}
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
// atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
#endif
}
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
// atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#else
atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
#endif
}
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
// atomicAdd(grad_weights, grad_output * val);
// atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
// atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(grad_weights, grad_output * val);
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
__builtin_amdgcn_global_atomic_fadd_f32(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#else
atomicAdd(grad_weights, grad_output * val);
atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
#endif
}
__global__ void deformable_aggregation_kernel(
const int64_t num_kernels,
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const float weight = *(weights + idx / (num_embeds / num_groups));
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
// atomicAdd(
// output + anchor_index * num_embeds + channel_index,
// bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight
// );
#ifdef __gfx936__
__builtin_amdgcn_global_atomic_fadd_f32(output + anchor_index * num_embeds + channel_index, bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight);
#else
atomicAdd(output + anchor_index * num_embeds + channel_index, bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight);
#endif
}
template <int blocksize>
__global__ void deformable_aggregation_grad_kernel_sp(
const int64_t num_kernels,
const float* mc_ms_feat, // [bs, anchor, pts, cam, scale, channel]
const int* spatial_shape, // [cam, scale, 2]
const int* scale_start_index, // [cam, scale]
const float* sample_location, // [bs, anchor, pts, cam, 2(y, x)]
const float* weights, // [bs, anchor, cam, scale, group]
const float* grad_output, // [bs, anchor, c]
float* grad_mc_ms_feat, // same as feat
float* grad_sampling_location, // same as sampling location
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
extern __shared__ float2_t s_data[];
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const int weights_ptr = idx / (num_embeds / num_groups);
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
const float grad = grad_output[anchor_index*num_embeds + channel_index];
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
/* atomicAdd( */
/* output + anchor_index * num_embeds + channel_index, */
/* bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */
/* ); */
const float weight = weights[weights_ptr];
float *grad_weights_ptr = grad_weights + weights_ptr;
float *grad_location_ptr = grad_sampling_location + loc_offset;
bilinear_sampling_grad_sp<blocksize>(
mc_ms_feat, weight, h, w, num_embeds, h_im, w_im,
value_offset,
grad,
grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr,
s_data
);
}
__global__ void deformable_aggregation_grad_kernel(
const int64_t num_kernels,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_kernels) return;
const int weights_ptr = idx / (num_embeds / num_groups);
const int channel_index = idx % num_embeds;
idx /= num_embeds;
const int scale_index = idx % num_scale;
idx /= num_scale;
const int cam_index = idx % num_cams;
idx /= num_cams;
const int pts_index = idx % num_pts;
idx /= num_pts;
int anchor_index = idx % num_anchors;
idx /= num_anchors;
const int batch_index = idx % batch_size;
idx /= batch_size;
anchor_index = batch_index * num_anchors + anchor_index;
const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
const float loc_w = sample_location[loc_offset];
if (loc_w <= 0 || loc_w >= 1) return;
const float loc_h = sample_location[loc_offset + 1];
if (loc_h <= 0 || loc_h >= 1) return;
const float grad = grad_output[anchor_index*num_embeds + channel_index];
int cam_scale_index = cam_index * num_scale + scale_index;
const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
cam_scale_index = cam_scale_index << 1;
const int h = spatial_shape[cam_scale_index];
const int w = spatial_shape[cam_scale_index + 1];
const float h_im = loc_h * h - 0.5;
const float w_im = loc_w * w - 0.5;
/* atomicAdd( */
/* output + anchor_index * num_embeds + channel_index, */
/* bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */
/* ); */
const float weight = weights[weights_ptr];
float *grad_weights_ptr = grad_weights + weights_ptr;
float *grad_location_ptr = grad_sampling_location + loc_offset;
bilinear_sampling_grad(
mc_ms_feat, weight, h, w, num_embeds, h_im, w_im,
value_offset,
grad,
grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr
);
}
void deformable_aggregation(
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
const int64_t num_kernels = (int64_t)batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
hipLaunchKernelGGL(( deformable_aggregation_kernel)
, dim3((int)ceil(((double)num_kernels/128))), dim3(128), 0, 0,
num_kernels, output,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
void deformable_aggregation_grad(
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
) {
const int64_t num_kernels =(int64_t)batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
if (num_embeds != 256 || ((num_embeds / num_groups) != 32)) {
hipLaunchKernelGGL(( deformable_aggregation_grad_kernel)
, dim3((int)ceil(((double)num_kernels/128))), dim3(128), 0, 0,
num_kernels,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
} else {
int blk_dim = 256;
hipLaunchKernelGGL(( deformable_aggregation_grad_kernel_sp<256>)
, dim3((int)ceil(((double)num_kernels/blk_dim))), dim3(blk_dim), blk_dim * 2 * sizeof(float), 0,
num_kernels,
mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <torch/extension.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
void deformable_aggregation(
float* output,
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
);
/* feat: bs, num_feat, c */
/* _spatial_shape: cam, scale, 2 */
/* _scale_start_index: cam, scale */
/* _sampling_location: bs, anchor, pts, cam, 2 */
/* _weights: bs, anchor, pts, cam, scale, group */
/* output: bs, anchor, c */
/* kernel: bs, anchor, pts, c */
at::Tensor deformable_aggregation_forward(
const at::Tensor &_mc_ms_feat,
const at::Tensor &_spatial_shape,
const at::Tensor &_scale_start_index,
const at::Tensor &_sampling_location,
const at::Tensor &_weights
) {
at::DeviceGuard guard(_mc_ms_feat.device());
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(_mc_ms_feat));
int batch_size = _mc_ms_feat.size(0);
int num_feat = _mc_ms_feat.size(1);
int num_embeds = _mc_ms_feat.size(2);
int num_cams = _spatial_shape.size(0);
int num_scale = _spatial_shape.size(1);
int num_anchors = _sampling_location.size(1);
int num_pts = _sampling_location.size(2);
int num_groups = _weights.size(5);
const float* mc_ms_feat = _mc_ms_feat.data_ptr<float>();
const int* spatial_shape = _spatial_shape.data_ptr<int>();
const int* scale_start_index = _scale_start_index.data_ptr<int>();
const float* sampling_location = _sampling_location.data_ptr<float>();
const float* weights = _weights.data_ptr<float>();
auto output = at::zeros({batch_size, num_anchors, num_embeds}, _mc_ms_feat.options());
deformable_aggregation(
output.data_ptr<float>(),
mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
return output;
}
void deformable_aggregation_grad(
const float* mc_ms_feat,
const int* spatial_shape,
const int* scale_start_index,
const float* sample_location,
const float* weights,
const float* grad_output,
float* grad_mc_ms_feat,
float* grad_sampling_location,
float* grad_weights,
int batch_size,
int num_cams,
int num_feat,
int num_embeds,
int num_scale,
int num_anchors,
int num_pts,
int num_groups
);
void deformable_aggregation_backward(
const at::Tensor &_mc_ms_feat,
const at::Tensor &_spatial_shape,
const at::Tensor &_scale_start_index,
const at::Tensor &_sampling_location,
const at::Tensor &_weights,
const at::Tensor &_grad_output,
at::Tensor &_grad_mc_ms_feat,
at::Tensor &_grad_sampling_location,
at::Tensor &_grad_weights
) {
at::DeviceGuard guard(_mc_ms_feat.device());
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(_mc_ms_feat));
int batch_size = _mc_ms_feat.size(0);
int num_feat = _mc_ms_feat.size(1);
int num_embeds = _mc_ms_feat.size(2);
int num_cams = _spatial_shape.size(0);
int num_scale = _spatial_shape.size(1);
int num_anchors = _sampling_location.size(1);
int num_pts = _sampling_location.size(2);
int num_groups = _weights.size(5);
const float* mc_ms_feat = _mc_ms_feat.data_ptr<float>();
const int* spatial_shape = _spatial_shape.data_ptr<int>();
const int* scale_start_index = _scale_start_index.data_ptr<int>();
const float* sampling_location = _sampling_location.data_ptr<float>();
const float* weights = _weights.data_ptr<float>();
const float* grad_output = _grad_output.data_ptr<float>();
float* grad_mc_ms_feat = _grad_mc_ms_feat.data_ptr<float>();
float* grad_sampling_location = _grad_sampling_location.data_ptr<float>();
float* grad_weights = _grad_weights.data_ptr<float>();
deformable_aggregation_grad(
mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights,
grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"deformable_aggregation_forward",
&deformable_aggregation_forward,
"deformable_aggregation_forward"
);
m.def(
"deformable_aggregation_backward",
&deformable_aggregation_backward,
"deformable_aggregation_backward"
);
}
numpy==1.23.5
mmcv_full==1.7.1
mmdet==2.28.2
urllib3==1.26.16
pyquaternion==0.9.9
nuscenes-devkit==1.1.10
yapf==0.33.0
tensorboard==2.14.0
motmetrics==1.1.3
pandas==1.1.5
flash-attn==2.3.2
opencv-python==4.8.1.78
prettytable==3.7.0
scikit-learn==1.3.0
numpy==1.23.5
mmdet==2.28.2
urllib3==1.26.16
pyquaternion
nuscenes-devkit
yapf
tensorboard
motmetrics
opencv-python
prettytable
scikit-learn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment