"megatron/vscode:/vscode.git/clone" did not exist on "1c4c360fdec75da2f9fe27e10046bcbe1e1e7e5b"
Commit 69f40c98 authored by yhcao6's avatar yhcao6
Browse files

Merge branch 'pytorch-1.0' of https://github.com/open-mmlab/mmdetection into pytorch-1.0

parents f3a939fa e9cb6fab
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
- Linux (tested on Ubuntu 16.04 and CentOS 7.2) - Linux (tested on Ubuntu 16.04 and CentOS 7.2)
- Python 3.4+ - Python 3.4+
- PyTorch 0.4.1 - PyTorch 1.0
- Cython - Cython
- [mmcv](https://github.com/open-mmlab/mmcv) - [mmcv](https://github.com/open-mmlab/mmcv) >= 0.2.2
### Install mmdetection ### Install mmdetection
a. Install PyTorch 0.4.1 and torchvision following the [official instructions](https://pytorch.org/). a. Install PyTorch 1.0 and torchvision following the [official instructions](https://pytorch.org/).
b. Clone the mmdetection repository. b. Clone the mmdetection repository.
......
...@@ -34,13 +34,21 @@ def sigmoid_focal_loss(pred, ...@@ -34,13 +34,21 @@ def sigmoid_focal_loss(pred,
weight, weight,
gamma=2.0, gamma=2.0,
alpha=0.25, alpha=0.25,
reduction='elementwise_mean'): reduction='mean'):
pred_sigmoid = pred.sigmoid() pred_sigmoid = pred.sigmoid()
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma) weight = weight * pt.pow(gamma)
return F.binary_cross_entropy_with_logits( loss = F.binary_cross_entropy_with_logits(
pred, target, weight, reduction=reduction) pred, target, reduction='none') * weight
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weighted_sigmoid_focal_loss(pred, def weighted_sigmoid_focal_loss(pred,
...@@ -62,22 +70,22 @@ def mask_cross_entropy(pred, target, label): ...@@ -62,22 +70,22 @@ def mask_cross_entropy(pred, target, label):
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1) pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits( return F.binary_cross_entropy_with_logits(
pred_slice, target, reduction='elementwise_mean')[None] pred_slice, target, reduction='mean')[None]
def smooth_l1_loss(pred, target, beta=1.0, reduction='elementwise_mean'): def smooth_l1_loss(pred, target, beta=1.0, reduction='mean'):
assert beta > 0 assert beta > 0
assert pred.size() == target.size() and target.numel() > 0 assert pred.size() == target.size() and target.numel() > 0
diff = torch.abs(pred - target) diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta, loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta) diff - 0.5 * beta)
reduction = F._Reduction.get_enum(reduction) reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2 # none: 0, mean:1, sum: 2
if reduction == 0: if reduction_enum == 0:
return loss return loss
elif reduction == 1: elif reduction_enum == 1:
return loss.sum() / pred.numel() return loss.sum() / pred.numel()
elif reduction == 2: elif reduction_enum == 2:
return loss.sum() return loss.sum()
......
from torch.autograd import Function, Variable from torch.autograd import Function
from .. import roi_align_cuda from .. import roi_align_cuda
...@@ -49,11 +49,11 @@ class RoIAlignFunction(Function): ...@@ -49,11 +49,11 @@ class RoIAlignFunction(Function):
grad_input = grad_rois = None grad_input = grad_rois = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_input = Variable( grad_input = rois.new_zeros(batch_size, num_channels, data_height,
rois.new(batch_size, num_channels, data_height, data_width) data_width)
.zero_()) roi_align_cuda.backward(grad_output.contiguous(), rois, out_h,
roi_align_cuda.backward(grad_output, rois, out_h, out_w, out_w, spatial_scale, sample_num,
spatial_scale, sample_num, grad_input) grad_input)
return grad_input, grad_rois, None, None, None return grad_input, grad_rois, None, None, None
......
#include <torch/torch.h> #include <torch/extension.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)
#define CUDA_1D_KERNEL_LOOP(i, n) \ #define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x) i += blockDim.x * gridDim.x)
...@@ -144,12 +142,7 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois, ...@@ -144,12 +142,7 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
sample_num, channels, height, width, pooled_height, sample_num, channels, height, width, pooled_height,
pooled_width, top_data); pooled_width, top_data);
})); }));
cudaError_t err = cudaGetLastError(); THCudaCheck(cudaGetLastError());
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
return 1; return 1;
} }
...@@ -280,8 +273,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, ...@@ -280,8 +273,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
at::Tensor bottom_grad) { at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_height * pooled_width * channels; const int output_size = num_rois * pooled_height * pooled_width * channels;
// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES(
top_grad.type(), "ROIAlignLaucherBackward", ([&] { top_grad.type(), "ROIAlignLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>(); const scalar_t *top_diff = top_grad.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>(); const scalar_t *rois_data = rois.data<scalar_t>();
...@@ -297,11 +289,6 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, ...@@ -297,11 +289,6 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
channels, height, width, pooled_height, pooled_width, channels, height, width, pooled_height, pooled_width,
bottom_diff); bottom_diff);
})); }));
cudaError_t err = cudaGetLastError(); THCudaCheck(cudaGetLastError());
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
return 1; return 1;
} }
...@@ -24,9 +24,8 @@ class RoIPoolFunction(Function): ...@@ -24,9 +24,8 @@ class RoIPoolFunction(Function):
num_channels = features.size(1) num_channels = features.size(1)
num_rois = rois.size(0) num_rois = rois.size(0)
out_size = (num_rois, num_channels, out_h, out_w) out_size = (num_rois, num_channels, out_h, out_w)
output = features.new_zeros(*out_size) output = features.new_zeros(out_size)
argmax = features.new_zeros(out_size, dtype=torch.int)
argmax = features.new_zeros(*out_size, dtype=torch.int)
roi_pool_cuda.forward(features, rois, out_h, out_w, spatial_scale, roi_pool_cuda.forward(features, rois, out_h, out_w, spatial_scale,
output, argmax) output, argmax)
ctx.spatial_scale = spatial_scale ctx.spatial_scale = spatial_scale
...@@ -46,9 +45,9 @@ class RoIPoolFunction(Function): ...@@ -46,9 +45,9 @@ class RoIPoolFunction(Function):
grad_input = grad_rois = None grad_input = grad_rois = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_input = grad_output.new(feature_size).zero_() grad_input = grad_output.new_zeros(feature_size)
roi_pool_cuda.backward(grad_output, rois, argmax, spatial_scale, roi_pool_cuda.backward(grad_output.contiguous(), rois, argmax,
grad_input) spatial_scale, grad_input)
return grad_input, grad_rois, None, None return grad_input, grad_rois, None, None
......
#include <torch/torch.h> #include <torch/extension.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)
#define CUDA_1D_KERNEL_LOOP(i, n) \ #define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x) i += blockDim.x * gridDim.x)
...@@ -100,11 +98,7 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois, ...@@ -100,11 +98,7 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois,
channels, height, width, pooled_h, pooled_w, top_data, channels, height, width, pooled_h, pooled_w, top_data,
argmax_data); argmax_data);
})); }));
cudaError_t err = cudaGetLastError(); THCudaCheck(cudaGetLastError());
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
return 1; return 1;
} }
...@@ -139,8 +133,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, ...@@ -139,8 +133,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const int pooled_w, at::Tensor bottom_grad) { const int pooled_w, at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_h * pooled_w * channels; const int output_size = num_rois * pooled_h * pooled_w * channels;
// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES(
top_grad.type(), "ROIPoolLaucherBackward", ([&] { top_grad.type(), "ROIPoolLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>(); const scalar_t *top_diff = top_grad.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>(); const scalar_t *rois_data = rois.data<scalar_t>();
...@@ -158,11 +151,6 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, ...@@ -158,11 +151,6 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
scalar_t(spatial_scale), channels, height, width, pooled_h, scalar_t(spatial_scale), channels, height, width, pooled_h,
pooled_w, bottom_diff); pooled_w, bottom_diff);
})); }));
cudaError_t err = cudaGetLastError(); THCudaCheck(cudaGetLastError());
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
return 1; return 1;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment