"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "8e2b2bf362375bca63f5d9c60dfe7075ae63746d"
Unverified Commit ff81e2b1 authored by Ailing's avatar Ailing Committed by GitHub
Browse files

Add device to torch.tensor. (#1979)

parent 4d77d3fa
...@@ -220,10 +220,10 @@ class BoxCoder(object): ...@@ -220,10 +220,10 @@ class BoxCoder(object):
pred_w = torch.exp(dw) * widths[:, None] pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None] pred_h = torch.exp(dh) * heights[:, None]
pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1) pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
return pred_boxes return pred_boxes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment