Unverified Commit c2cb2aab authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

update backend for PyTorch Update (#130)

* update backend

* version

fixes https://github.com/zhanghang1989/PyTorch-Encoding/issues/123
parent a0fe6223
MIT License MIT License
Copyright (c) 2017 Hang Zhang. All rights reserved. Copyright (c) 2017- Hang Zhang. All rights reserved.
Copyright (c) 2018 Amazon.com, Inc. or its affiliates. All rights reserved. Copyright (c) 2018- Amazon.com, Inc. or its affiliates. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal
......
...@@ -83,15 +83,15 @@ Test Pre-trained Model ...@@ -83,15 +83,15 @@ Test Pre-trained Model
<code xml:space="preserve" id="cmd_enc101_ade" style="display: none; text-align: left; white-space: pre-wrap"> <code xml:space="preserve" id="cmd_enc101_ade" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model EncNet --aux --se-loss --backbone resnet101 CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model EncNet --aux --se-loss --backbone resnet101 --base-size 640 --crop-size 576
</code> </code>
<code xml:space="preserve" id="cmd_enc101_voc" style="display: none; text-align: left; white-space: pre-wrap"> <code xml:space="preserve" id="cmd_enc101_voc" style="display: none; text-align: left; white-space: pre-wrap">
# First finetuning COCO dataset pretrained model on augmented set # First finetuning COCO dataset pretrained model on augmented set
# You can also train from scratch on COCO by yourself # You can also train from scratch on COCO by yourself
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset Pascal_aug --model-zoo EncNet_Resnet101_COCO --aux --se-loss --lr 0.001 --syncbn --ngpus 4 --checkname res101 CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset Pascal_aug --model-zoo EncNet_Resnet101_COCO --aux --se-loss --lr 0.001 --syncbn --ngpus 4 --checkname res101 --ft
# Finetuning on original set # Finetuning on original set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset Pascal_voc --model encnet --aux --se-loss --backbone resnet101 --lr 0.0001 --syncbn --ngpus 4 --checkname res101 --resume runs/Pascal_aug/encnet/res101/checkpoint.params CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset Pascal_voc --model encnet --aux --se-loss --backbone resnet101 --lr 0.0001 --syncbn --ngpus 4 --checkname res101 --resume runs/Pascal_aug/encnet/res101/checkpoint.params --ft
</code> </code>
Quick Demo Quick Demo
......
...@@ -22,7 +22,7 @@ Test Pre-trained Model ...@@ -22,7 +22,7 @@ Test Pre-trained Model
cd PyTorch-Encoding/ cd PyTorch-Encoding/
python scripts/prepare_minc.py python scripts/prepare_minc.py
- Download pre-trained model (pre-trained on train-1 split using single training size of 224, with an error rate of :math:`19.70\%` using single crop on test-1 set):: - Download pre-trained model (pre-trained on train-1 split using single training size of 224, with an error rate of :math:`18.96\%` using single crop on test-1 set)::
cd experiments/recognition cd experiments/recognition
python model/download_models.py python model/download_models.py
......
...@@ -10,4 +10,4 @@ ...@@ -10,4 +10,4 @@
"""An optimized PyTorch package with CUDA backend.""" """An optimized PyTorch package with CUDA backend."""
from .version import __version__ from .version import __version__
from . import nn, functions, dilated, parallel, utils, models, datasets, optimizer from . import nn, functions, dilated, parallel, utils, models, datasets
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import os import os
import sys import sys
import random
import numpy as np import numpy as np
from tqdm import tqdm, trange from tqdm import tqdm, trange
from PIL import Image, ImageOps, ImageFilter from PIL import Image, ImageOps, ImageFilter
...@@ -93,7 +94,7 @@ class CitySegmentation(BaseDataset): ...@@ -93,7 +94,7 @@ class CitySegmentation(BaseDataset):
mask = mask.transpose(Image.FLIP_LEFT_RIGHT) mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
crop_size = self.crop_size crop_size = self.crop_size
# random scale (short edge from 480 to 720) # random scale (short edge from 480 to 720)
short_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.5)) short_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0))
w, h = img.size w, h = img.size
if h > w: if h > w:
ow = short_size ow = short_size
......
#include <torch/tensor.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/NativeFunctions.h> #include <ATen/NativeFunctions.h>
...@@ -42,7 +43,8 @@ std::vector<at::Tensor> Non_Max_Suppression_CPU( ...@@ -42,7 +43,8 @@ std::vector<at::Tensor> Non_Max_Suppression_CPU(
auto num_boxes = input.size(1); auto num_boxes = input.size(1);
auto batch_size = input.size(0); auto batch_size = input.size(0);
auto mask = input.type().toScalarType(at::kByte).tensor({batch_size, num_boxes}); auto mask = torch::zeros({batch_size, num_boxes}, input.type().toScalarType(at::kByte));
//auto mask = input.type().toScalarType(at::kByte).tensor({batch_size, num_boxes});
mask.fill_(1); mask.fill_(1);
auto *rawMask = mask.data<unsigned char>(); auto *rawMask = mask.data<unsigned char>();
auto *rawIdx = sorted_inds.data<int64_t>(); auto *rawIdx = sorted_inds.data<int64_t>();
......
#include <torch/tensor.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
//#include <omp.h> //#include <omp.h>
...@@ -404,7 +405,7 @@ at::Tensor ROIAlign_Forward_CPU( ...@@ -404,7 +405,7 @@ at::Tensor ROIAlign_Forward_CPU(
AT_ASSERT(roi_cols == 4 || roi_cols == 5); AT_ASSERT(roi_cols == 4 || roi_cols == 5);
// Output at::Tensor is (num_rois, C, pooled_height, pooled_width) // Output at::Tensor is (num_rois, C, pooled_height, pooled_width)
auto output = input.type().tensor({num_rois, channels, pooled_height, pooled_width}); auto output = torch::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
AT_ASSERT(input.is_contiguous()); AT_ASSERT(input.is_contiguous());
AT_ASSERT(bottom_rois.is_contiguous()); AT_ASSERT(bottom_rois.is_contiguous());
...@@ -451,7 +452,7 @@ at::Tensor ROIAlign_Backward_CPU( ...@@ -451,7 +452,7 @@ at::Tensor ROIAlign_Backward_CPU(
AT_ASSERT(roi_cols == 4 || roi_cols == 5); AT_ASSERT(roi_cols == 4 || roi_cols == 5);
// Output at::Tensor is (num_rois, C, pooled_height, pooled_width) // Output at::Tensor is (num_rois, C, pooled_height, pooled_width)
auto grad_in = bottom_rois.type().tensor({b_size, channels, height, width}).zero_(); auto grad_in = torch::zeros({b_size, channels, height, width}, bottom_rois.options());
AT_ASSERT(bottom_rois.is_contiguous()); AT_ASSERT(bottom_rois.is_contiguous());
......
#include <torch/tensor.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <vector> #include <vector>
...@@ -45,8 +46,8 @@ std::vector<at::Tensor> BatchNorm_Backward_CPU( ...@@ -45,8 +46,8 @@ std::vector<at::Tensor> BatchNorm_Backward_CPU(
std::vector<at::Tensor> Sum_Square_Forward_CPU( std::vector<at::Tensor> Sum_Square_Forward_CPU(
const at::Tensor input) { const at::Tensor input) {
/* outputs */ /* outputs */
at::Tensor sum = input.type().tensor({input.size(1)}).zero_(); at::Tensor sum = torch::zeros({input.size(1)}, input.options());
at::Tensor square = input.type().tensor({input.size(1)}).zero_(); at::Tensor square = torch::zeros({input.size(1)}, input.options());
return {sum, square}; return {sum, square};
} }
......
#include <vector> #include <vector>
#include <torch/tensor.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
...@@ -165,7 +166,7 @@ at::Tensor Aggregate_Forward_CUDA( ...@@ -165,7 +166,7 @@ at::Tensor Aggregate_Forward_CUDA(
const at::Tensor X_, const at::Tensor X_,
const at::Tensor C_) { const at::Tensor C_) {
/* Device tensors */ /* Device tensors */
auto E_ = A_.type().tensor({A_.size(0), C_.size(0), C_.size(1)}).zero_(); auto E_ = torch::zeros({A_.size(0), C_.size(0), C_.size(1)}, A_.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// B, K, D // B, K, D
dim3 blocks(C_.size(1), C_.size(0), X_.size(0)); dim3 blocks(C_.size(1), C_.size(0), X_.size(0));
...@@ -214,7 +215,7 @@ at::Tensor ScaledL2_Forward_CUDA( ...@@ -214,7 +215,7 @@ at::Tensor ScaledL2_Forward_CUDA(
const at::Tensor X_, const at::Tensor X_,
const at::Tensor C_, const at::Tensor C_,
const at::Tensor S_) { const at::Tensor S_) {
auto SL_ = X_.type().tensor({X_.size(0), X_.size(1), C_.size(0)}).zero_(); auto SL_ = torch::zeros({X_.size(0), X_.size(1), C_.size(0)}, X_.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(C_.size(0), X_.size(1), X_.size(0)); dim3 blocks(C_.size(0), X_.size(1), X_.size(0));
dim3 threads(getNumThreads(C_.size(1))); dim3 threads(getNumThreads(C_.size(1)));
......
#include <vector> #include <vector>
#include <torch/tensor.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/Functions.h> #include <ATen/Functions.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
...@@ -239,7 +240,7 @@ at::Tensor Encoding_Dist_Inference_Forward_CUDA( ...@@ -239,7 +240,7 @@ at::Tensor Encoding_Dist_Inference_Forward_CUDA(
const at::Tensor STD_) { const at::Tensor STD_) {
// const at::Tensor S_, // const at::Tensor S_,
// X \in R^{B, N, D}, C \in R^{K, D}, S \in R^K // X \in R^{B, N, D}, C \in R^{K, D}, S \in R^K
auto KD_ = X_.type().tensor({X_.size(0), X_.size(1), C_.size(0)}).zero_(); auto KD_ = torch::zeros({X_.size(0), X_.size(1), C_.size(0)}, X_.options());
// E(x), E(x^2) // E(x), E(x^2)
int N = X_.size(0) * X_.size(1); int N = X_.size(0) * X_.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
...@@ -301,7 +302,7 @@ std::vector<at::Tensor> Encoding_Dist_Forward_CUDA( ...@@ -301,7 +302,7 @@ std::vector<at::Tensor> Encoding_Dist_Forward_CUDA(
double eps) { double eps) {
// const at::Tensor S_, // const at::Tensor S_,
// X \in R^{B, N, D}, C \in R^{K, D}, S \in R^K // X \in R^{B, N, D}, C \in R^{K, D}, S \in R^K
auto KD_ = X_.type().tensor({X_.size(0), X_.size(1), C_.size(0)}).zero_(); auto KD_ = torch::zeros({X_.size(0), X_.size(1), C_.size(0)}, X_.options());
// E(x), E(x^2) // E(x), E(x^2)
int N = X_.size(0) * X_.size(1); int N = X_.size(0) * X_.size(1);
auto SVar_ = (X_.pow(2).sum(0).sum(0).view({1, X_.size(2)}) - auto SVar_ = (X_.pow(2).sum(0).sum(0).view({1, X_.size(2)}) -
...@@ -373,7 +374,7 @@ at::Tensor AggregateV2_Forward_CUDA( ...@@ -373,7 +374,7 @@ at::Tensor AggregateV2_Forward_CUDA(
const at::Tensor C_, const at::Tensor C_,
const at::Tensor STD_) { const at::Tensor STD_) {
/* Device tensors */ /* Device tensors */
auto E_ = A_.type().tensor({A_.size(0), C_.size(0), C_.size(1)}).zero_(); auto E_ = torch::zeros({A_.size(0), C_.size(0), C_.size(1)}, A_.options());
// auto IS_ = 1.0f / (S_ + eps).sqrt(); // auto IS_ = 1.0f / (S_ + eps).sqrt();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// B, K, D // B, K, D
......
#include <torch/tensor.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "ATen/NativeFunctions.h" #include "ATen/NativeFunctions.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
...@@ -75,7 +76,8 @@ std::vector<at::Tensor> Non_Max_Suppression_CUDA( ...@@ -75,7 +76,8 @@ std::vector<at::Tensor> Non_Max_Suppression_CUDA(
auto num_boxes = input.size(1); auto num_boxes = input.size(1);
auto batch_size = input.size(0); auto batch_size = input.size(0);
auto mask = input.type().toScalarType(at::kByte).tensor({batch_size, num_boxes}); //auto mask = input.type().toScalarType(at::kByte).tensor({batch_size, num_boxes});
auto mask = torch::zeros({batch_size, num_boxes}, input.type().toScalarType(at::kByte));
mask.fill_(1); mask.fill_(1);
//need the indices of the boxes sorted by score. //need the indices of the boxes sorted by score.
......
#include <torch/tensor.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
...@@ -367,7 +368,7 @@ at::Tensor ROIAlign_Forward_CUDA( ...@@ -367,7 +368,7 @@ at::Tensor ROIAlign_Forward_CUDA(
auto width = input.size(3); auto width = input.size(3);
// Output Tensor is (num_rois, C, pooled_height, pooled_width) // Output Tensor is (num_rois, C, pooled_height, pooled_width)
auto output = input.type().tensor({proposals, channels, pooled_height, pooled_width}); auto output = torch::zeros({proposals, channels, pooled_height, pooled_width}, input.options());
auto count = output.numel(); auto count = output.numel();
...@@ -414,7 +415,7 @@ at::Tensor ROIAlign_Backward_CUDA( ...@@ -414,7 +415,7 @@ at::Tensor ROIAlign_Backward_CUDA(
// Output Tensor is (num_rois, C, pooled_height, pooled_width) // Output Tensor is (num_rois, C, pooled_height, pooled_width)
// gradient wrt input features // gradient wrt input features
auto grad_in = rois.type().tensor({b_size, channels, height, width}).zero_(); auto grad_in = torch::zeros({b_size, channels, height, width}, rois.options());
auto num_rois = rois.size(0); auto num_rois = rois.size(0);
auto count = grad_output.numel(); auto count = grad_output.numel();
......
#include <vector> #include <vector>
#include <torch/tensor.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
...@@ -244,8 +245,8 @@ std::vector<at::Tensor> BatchNorm_Backward_CUDA( ...@@ -244,8 +245,8 @@ std::vector<at::Tensor> BatchNorm_Backward_CUDA(
std::vector<at::Tensor> Sum_Square_Forward_CUDA( std::vector<at::Tensor> Sum_Square_Forward_CUDA(
const at::Tensor input_) { const at::Tensor input_) {
/* outputs */ /* outputs */
at::Tensor sum_ = input_.type().tensor({input_.size(1)}).zero_(); at::Tensor sum_ = torch::zeros({input_.size(1)}, input_.options());
at::Tensor square_ = input_.type().tensor({input_.size(1)}).zero_(); at::Tensor square_ = torch::zeros({input_.size(1)}, input_.options());
/* cuda utils*/ /* cuda utils*/
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input_.size(1)); dim3 blocks(input_.size(1));
......
...@@ -11,7 +11,7 @@ _model_sha1 = {name: checksum for checksum, name in [ ...@@ -11,7 +11,7 @@ _model_sha1 = {name: checksum for checksum, name in [
('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'), ('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'),
('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'), ('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'),
('2e22611a7f3992ebdee6726af169991bc26d7363', 'deepten_minc'), ('2e22611a7f3992ebdee6726af169991bc26d7363', 'deepten_minc'),
('fc8c0b795abf0133700c2d4265d2f9edab7eb6cc', 'fcn_resnet50_ade'), ('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50_ade'),
('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'), ('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'),
('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'), ('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'),
('075195c5237b778c718fd73ceddfa1376c18dfd0', 'deeplab_resnet50_ade'), ('075195c5237b778c718fd73ceddfa1376c18dfd0', 'deeplab_resnet50_ade'),
......
...@@ -92,7 +92,7 @@ class Options(): ...@@ -92,7 +92,7 @@ class Options():
if args.epochs is None: if args.epochs is None:
epoches = { epoches = {
'coco': 30, 'coco': 30,
'citys': 180, 'citys': 240,
'pascal_voc': 50, 'pascal_voc': 50,
'pascal_aug': 50, 'pascal_aug': 50,
'pcontext': 80, 'pcontext': 80,
...@@ -100,7 +100,7 @@ class Options(): ...@@ -100,7 +100,7 @@ class Options():
} }
args.epochs = epoches[args.dataset.lower()] args.epochs = epoches[args.dataset.lower()]
if args.batch_size is None: if args.batch_size is None:
args.batch_size = 4 * torch.cuda.device_count() args.batch_size = 16
if args.test_batch_size is None: if args.test_batch_size is None:
args.test_batch_size = args.batch_size args.test_batch_size = args.batch_size
if args.lr is None: if args.lr is None:
......
...@@ -18,7 +18,7 @@ import setuptools.command.install ...@@ -18,7 +18,7 @@ import setuptools.command.install
cwd = os.path.dirname(os.path.abspath(__file__)) cwd = os.path.dirname(os.path.abspath(__file__))
version = '0.5.0' version = '0.5.1'
try: try:
sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=cwd).decode('ascii').strip() cwd=cwd).decode('ascii').strip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment