Commit 6a10c7bf authored by unknown's avatar unknown
Browse files

提交Swin-Transformer代码

parents
import os
import json
import torch.utils.data as data
import numpy as np
from PIL import Image
import warnings
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
class IN22KDATASET(data.Dataset):
def __init__(self, root, ann_file='', transform=None, target_transform=None):
super(IN22KDATASET, self).__init__()
self.data_path = root
self.ann_path = os.path.join(self.data_path, ann_file)
self.transform = transform
self.target_transform = target_transform
# id & label: https://github.com/google-research/big_transfer/issues/7
# total: 21843; only 21841 class have images: map 21841->9205; 21842->15027
self.database = json.load(open(self.ann_path))
def _load_image(self, path):
try:
im = Image.open(path)
except:
print("ERROR IMG LOADED: ", path)
random_img = np.random.rand(224, 224, 3) * 255
im = Image.fromarray(np.uint8(random_img))
return im
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
idb = self.database[index]
# images
images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB')
if self.transform is not None:
images = self.transform(images)
# target
target = int(idb[1])
if self.target_transform is not None:
target = self.target_transform(target)
return images, target
def __len__(self):
return len(self.database)
359
368
460
475
486
492
496
514
516
525
547
548
556
563
575
641
648
723
733
765
801
826
852
858
878
896
900
905
908
910
935
946
947
994
999
1003
1005
1010
1027
1029
1048
1055
1064
1065
1069
1075
1079
1081
1085
1088
1093
1106
1143
1144
1145
1147
1168
1171
1178
1187
1190
1197
1205
1216
1223
1230
1236
1241
1245
1257
1259
1260
1267
1268
1269
1271
1272
1273
1277
1303
1344
1349
1355
1357
1384
1388
1391
1427
1429
1432
1437
1450
1461
1462
1474
1502
1503
1512
1552
1555
1577
1584
1587
1589
1599
1615
1616
1681
1692
1701
1716
1729
1757
1759
1764
1777
1786
1822
1841
1842
1848
1850
1856
1860
1861
1864
1876
1897
1898
1910
1913
1918
1922
1928
1932
1935
1947
1951
1953
1970
1977
1979
2001
2017
2067
2081
2087
2112
2128
2135
2147
2174
2175
2176
2177
2178
2181
2183
2184
2187
2189
2190
2191
2192
2193
2197
2202
2203
2206
2208
2209
2211
2212
2213
2214
2215
2216
2217
2219
2222
2223
2224
2225
2226
2227
2228
2229
2230
2236
2238
2240
2241
2242
2243
2244
2245
2247
2248
2249
2250
2251
2252
2255
2256
2257
2262
2263
2264
2265
2266
2268
2270
2271
2272
2273
2275
2276
2279
2280
2281
2282
2285
2289
2292
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2309
2310
2312
2313
2314
2315
2316
2318
2319
2321
2322
2326
2329
2330
2331
2332
2334
2335
2336
2337
2338
2339
2341
2342
2343
2344
2346
2348
2349
2351
2352
2353
2355
2357
2358
2359
2360
2364
2365
2368
2369
2377
2382
2383
2385
2397
2398
2400
2402
2405
2412
2421
2428
2431
2432
2433
2436
2441
2445
2450
2453
2454
2465
2469
2532
2533
2538
2544
2547
2557
2565
2578
2612
2658
2702
2722
2731
2738
2741
2747
2810
2818
2833
2844
2845
2867
2874
2882
2884
2888
2889
3008
3012
3019
3029
3033
3042
3091
3106
3138
3159
3164
3169
3280
3296
3311
3318
3320
3324
3330
3366
3375
3381
3406
3419
3432
3434
3435
3493
3495
3503
3509
3511
3513
3517
3521
3526
3546
3554
3600
3601
3606
3612
3613
3616
3622
3623
3627
3632
3634
3636
3638
3644
3646
3649
3650
3651
3656
3663
3673
3674
3689
3690
3702
3733
3769
3971
3974
4065
4068
4073
4102
4136
4140
4151
4159
4165
4207
4219
4226
4249
4256
4263
4270
4313
4321
4378
4386
4478
4508
4512
4536
4542
4550
4560
4562
4570
4571
4572
4583
4588
4594
4604
4608
4623
4634
4636
4646
4651
4652
4686
4688
4691
4699
4724
4727
4737
4770
4774
4789
4802
4807
4819
4880
4886
4908
4927
4931
4936
4964
4976
4993
5028
5033
5043
5046
5096
5111
5114
5131
5132
5183
5199
5235
5275
5291
5293
5294
5343
5360
5362
5364
5390
5402
5418
5428
5430
5437
5443
5473
5484
5486
5505
5507
5508
5510
5567
5578
5580
5584
5606
5613
5629
5672
5676
5692
5701
5760
5769
5770
5779
5814
5850
5871
5893
5911
5949
5954
6005
6006
6012
6017
6023
6024
6040
6050
6054
6087
6105
6157
6235
6237
6256
6259
6286
6291
6306
6339
6341
6343
6379
6383
6393
6405
6479
6511
6517
6541
6561
6608
6611
6615
6678
6682
6707
6752
6798
6850
6880
6885
6890
6920
6981
7000
7009
7038
7049
7050
7052
7073
7078
7098
7111
7165
7198
7204
7280
7283
7286
7287
7293
7294
7305
7318
7341
7346
7354
7382
7427
7428
7435
7445
7450
7455
7467
7469
7497
7502
7506
7514
7523
7651
7661
7664
7672
7679
7685
7696
7730
7871
7873
7895
7914
7915
7920
7934
7935
7949
8009
8036
8051
8065
8074
8090
8112
8140
8164
8168
8178
8182
8198
8212
8216
8230
8242
8288
8289
8295
8318
8352
8368
8371
8375
8376
8401
8416
8419
8436
8460
8477
8478
8482
8498
8500
8539
8543
8552
8555
8580
8584
8586
8594
8598
8601
8606
8610
8611
8622
8627
8639
8649
8650
8653
8654
8667
8672
8673
8674
8676
8684
8720
8723
8750
8753
8801
8815
8831
8835
8842
8845
8858
8897
8916
8951
8954
8959
8970
8976
8981
8983
8989
8991
8993
9019
9039
9042
9043
9056
9057
9070
9087
9098
9106
9130
9131
9155
9171
9183
9198
9199
9201
9204
9212
9221
9225
9229
9250
9260
9271
9279
9295
9300
9310
9322
9345
9352
9376
9377
9382
9392
9401
9405
9441
9449
9464
9475
9502
9505
9514
9515
9545
9567
9576
9608
9609
9624
9633
9639
9643
9656
9674
9740
9752
9760
9767
9778
9802
9820
9839
9879
9924
9956
9961
9963
9970
9997
10010
10031
10040
10052
10073
10075
10078
10094
10097
10109
10118
10121
10124
10158
10226
10276
10304
10307
10314
10315
10332
10337
10338
10413
10423
10451
10463
10465
10487
10519
10522
10523
10532
10534
10535
10551
10559
10574
10583
10586
10589
10612
10626
10635
10638
10677
10683
10726
10776
10782
10783
10807
10837
10840
10848
10859
10871
10881
10884
10908
10914
10921
10936
10947
10951
10952
10957
10999
11003
11018
11023
11025
11027
11045
11055
11095
11110
11137
5564
11168
11186
11221
11223
11242
11255
11259
11279
11306
11311
11331
11367
11377
11389
11392
11401
11407
11437
11449
11466
11469
11473
11478
11483
11484
11507
11536
11558
11566
11575
11584
11594
11611
11612
11619
11621
11640
11643
11664
11674
11689
11709
11710
11716
11721
11726
11729
11743
11760
11771
11837
11839
11856
11876
11878
11884
11889
11896
11917
11923
11930
11944
11952
11980
11984
12214
12229
12239
12241
12242
12247
12283
12349
12369
12373
12422
12560
12566
12575
12688
12755
12768
12778
12780
12812
12832
12835
12836
12843
12847
12849
12850
12856
12858
12873
12938
12971
13017
13038
13046
13059
13085
13086
13088
13094
13134
13182
13230
13406
13444
13614
13690
13698
13709
13749
13804
13982
14051
14059
14219
14246
14256
14264
14294
14324
14367
14389
14394
14438
14442
14965
15732
16744
18037
18205
18535
18792
19102
20019
20462
21026
21045
21163
21171
21181
21196
21200
21369
21817
\ No newline at end of file
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import torch
class SubsetRandomSampler(torch.utils.data.Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
"""
def __init__(self, indices):
self.epoch = 0
self.indices = indices
def __iter__(self):
return (self.indices[i] for i in torch.randperm(len(self.indices)))
def __len__(self):
return len(self.indices)
def set_epoch(self, epoch):
self.epoch = epoch
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import os
import zipfile
import io
import numpy as np
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def is_zip_path(img_or_path):
"""judge if this is a zip path"""
return '.zip@' in img_or_path
class ZipReader(object):
"""A class to read zipped files"""
zip_bank = dict()
def __init__(self):
super(ZipReader, self).__init__()
@staticmethod
def get_zipfile(path):
zip_bank = ZipReader.zip_bank
if path not in zip_bank:
zfile = zipfile.ZipFile(path, 'r')
zip_bank[path] = zfile
return zip_bank[path]
@staticmethod
def split_zip_style_path(path):
pos_at = path.index('@')
assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
zip_path = path[0: pos_at]
folder_path = path[pos_at + 1:]
folder_path = str.strip(folder_path, '/')
return zip_path, folder_path
@staticmethod
def list_folder(path):
zip_path, folder_path = ZipReader.split_zip_style_path(path)
zfile = ZipReader.get_zipfile(zip_path)
folder_list = []
for file_foler_name in zfile.namelist():
file_foler_name = str.strip(file_foler_name, '/')
if file_foler_name.startswith(folder_path) and \
len(os.path.splitext(file_foler_name)[-1]) == 0 and \
file_foler_name != folder_path:
if len(folder_path) == 0:
folder_list.append(file_foler_name)
else:
folder_list.append(file_foler_name[len(folder_path) + 1:])
return folder_list
@staticmethod
def list_files(path, extension=None):
if extension is None:
extension = ['.*']
zip_path, folder_path = ZipReader.split_zip_style_path(path)
zfile = ZipReader.get_zipfile(zip_path)
file_lists = []
for file_foler_name in zfile.namelist():
file_foler_name = str.strip(file_foler_name, '/')
if file_foler_name.startswith(folder_path) and \
str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
if len(folder_path) == 0:
file_lists.append(file_foler_name)
else:
file_lists.append(file_foler_name[len(folder_path) + 1:])
return file_lists
@staticmethod
def read(path):
zip_path, path_img = ZipReader.split_zip_style_path(path)
zfile = ZipReader.get_zipfile(zip_path)
data = zfile.read(path_img)
return data
@staticmethod
def imread(path):
zip_path, path_img = ZipReader.split_zip_style_path(path)
zfile = ZipReader.get_zipfile(zip_path)
data = zfile.read(path_img)
try:
im = Image.open(io.BytesIO(data))
except:
print("ERROR IMG LOADED: ", path_img)
random_img = np.random.rand(224, 224, 3) * 255
im = Image.fromarray(np.uint8(random_img))
return im
# Swin Transformer for Image Classification
This folder contains the implementation of the Swin Transformer for image classification.
## Model Zoo
Please refer to [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models) for more pre-trained models.
## Usage
### Install
We recommend using the pytorch docker `nvcr>=21.05` by
nvidia: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.
- Clone this repo:
```bash
git clone https://github.com/microsoft/Swin-Transformer.git
cd Swin-Transformer
```
- Create a conda virtual environment and activate it:
```bash
conda create -n swin python=3.7 -y
conda activate swin
```
- Install `CUDA>=10.2` with `cudnn>=7` following
the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
- Install `PyTorch>=1.8.0` and `torchvision>=0.9.0` with `CUDA>=10.2`:
```bash
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.2 -c pytorch
```
- Install `timm==0.4.12`:
```bash
pip install timm==0.4.12
```
- Install other requirements:
```bash
pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 pyyaml scipy
```
- Install fused window process for acceleration, activated by passing `--fused_window_process` in the running script
```bash
cd kernels/window_process
python setup.py install #--user
```
### Data preparation
We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to
load data:
- For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like:
```bash
$ tree data
imagenet
├── train
│ ├── class1
│ │ ├── img1.jpeg
│ │ ├── img2.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img3.jpeg
│ │ └── ...
│ └── ...
└── val
├── class1
│ ├── img4.jpeg
│ ├── img5.jpeg
│ └── ...
├── class2
│ ├── img6.jpeg
│ └── ...
└── ...
```
- To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes
four files:
- `train.zip`, `val.zip`: which store the zipped folder for train and validate splits.
- `train_map.txt`, `val_map.txt`: which store the relative path in the corresponding zip file and ground truth
label. Make sure the data folder looks like this:
```bash
$ tree data
data
└── ImageNet-Zip
├── train_map.txt
├── train.zip
├── val_map.txt
└── val.zip
$ head -n 5 data/ImageNet-Zip/val_map.txt
ILSVRC2012_val_00000001.JPEG 65
ILSVRC2012_val_00000002.JPEG 970
ILSVRC2012_val_00000003.JPEG 230
ILSVRC2012_val_00000004.JPEG 809
ILSVRC2012_val_00000005.JPEG 516
$ head -n 5 data/ImageNet-Zip/train_map.txt
n01440764/n01440764_10026.JPEG 0
n01440764/n01440764_10027.JPEG 0
n01440764/n01440764_10029.JPEG 0
n01440764/n01440764_10040.JPEG 0
n01440764/n01440764_10042.JPEG 0
```
- For ImageNet-22K dataset, make a folder named `fall11_whole` and move all images to labeled sub-folders in this
folder. Then download the train-val split
file ([ILSVRC2011fall_whole_map_train.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_train.txt)
& [ILSVRC2011fall_whole_map_val.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_val.txt))
, and put them in the parent directory of `fall11_whole`. The file structure should look like:
```bash
$ tree imagenet22k/
imagenet22k/
├── ILSVRC2011fall_whole_map_train.txt
├── ILSVRC2011fall_whole_map_val.txt
└── fall11_whole
├── n00004475
├── n00005787
├── n00006024
├── n00006484
└── ...
```
### Evaluation
To evaluate a pre-trained `Swin Transformer` on ImageNet val, run:
```bash
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py --eval \
--cfg <config-file> --resume <checkpoint> --data-path <imagenet-path>
```
For example, to evaluate the `Swin-B` with a single GPU:
```bash
python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \
--cfg configs/swin/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path <imagenet-path>
```
### Training from scratch on ImageNet-1K
To train a `Swin Transformer` on ImageNet from scratch, run:
```bash
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py \
--cfg <config-file> --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
```
**Notes**:
- To use zipped ImageNet instead of folder dataset, add `--zip` to the parameters.
- To cache the dataset in the memory instead of reading from files every time, add `--cache-mode part`, which will
shard the dataset into non-overlapping pieces for different GPUs and only load the corresponding one for each GPU.
- When GPU memory is not enough, you can try the following suggestions:
- Use gradient accumulation by adding `--accumulation-steps <steps>`, set appropriate `<steps>` according to your need.
- Use gradient checkpointing by adding `--use-checkpoint`, e.g., it saves about 60% memory when training `Swin-B`.
Please refer to [this page](https://pytorch.org/docs/stable/checkpoint.html) for more details.
- We recommend using multi-node with more GPUs for training very large models, a tutorial can be found
in [this page](https://pytorch.org/tutorials/intermediate/dist_tuto.html).
- To change config options in general, you can use `--opts KEY1 VALUE1 KEY2 VALUE2`, e.g.,
`--opts TRAIN.EPOCHS 100 TRAIN.WARMUP_EPOCHS 5` will change total epochs to 100 and warm-up epochs to 5.
- For additional options, see [config](config.py) and run `python main.py --help` to get detailed message.
For example, to train `Swin Transformer` with 8 GPU on a single node for 300 epochs, run:
`Swin-T`:
```bash
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
--cfg configs/swin/swin_tiny_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128
```
`Swin-S`:
```bash
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
--cfg configs/swin/swin_small_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128
```
`Swin-B`:
```bash
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
--cfg configs/swin/swin_base_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 64 \
--accumulation-steps 2 [--use-checkpoint]
```
### Pre-training on ImageNet-22K
For example, to pre-train a `Swin-B` model on ImageNet-22K:
```bash
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
--cfg configs/swin/swin_base_patch4_window7_224_22k.yaml --data-path <imagenet22k-path> --batch-size 64 \
--accumulation-steps 8 [--use-checkpoint]
```
### Fine-tuning on higher resolution
For example, to fine-tune a `Swin-B` model pre-trained on 224x224 resolution to 384x384 resolution:
```bashs
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
--cfg configs/swin/swin_base_patch4_window12_384_finetune.yaml --pretrained swin_base_patch4_window7_224.pth \
--data-path <imagenet-path> --batch-size 64 --accumulation-steps 2 [--use-checkpoint]
```
### Fine-tuning from a ImageNet-22K(21K) pre-trained model
For example, to fine-tune a `Swin-B` model pre-trained on ImageNet-22K(21K):
```bashs
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
--cfg configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml --pretrained swin_base_patch4_window7_224_22k.pth \
--data-path <imagenet-path> --batch-size 64 --accumulation-steps 2 [--use-checkpoint]
```
### Throughput
To measure the throughput, run:
```bash
python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py \
--cfg <config-file> --data-path <imagenet-path> --batch-size 64 --throughput --disable_amp
```
## Mixture-of-Experts Support
### Install [Tutel](https://github.com/microsoft/tutel)
```bash
python3 -m pip uninstall tutel -y
python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@main
```
### Training Swin-MoE
For example, to train a `Swin-MoE-S` model with 32 experts on ImageNet-22K with 32 GPUs (4 nodes):
```bash
python -m torch.distributed.launch --nproc_per_node 8 --nnode=4 \
--node_rank=<node-rank> --master_addr=<master-ip> --master_port 12345 main_moe.py \
--cfg configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml --data-path <imagenet22k-path> --batch-size 128
```
### Evaluating Swin-MoE
To evaluate a `Swin-MoE-S` with 32 experts on ImageNet-22K with 32 GPUs (4 nodes):
1. Download the zip file [swin_moe_small_patch4_window12_192_32expert_32gpu_22k.zip](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.zip) which contains the pre-trained models for each rank, and unzip them to the folder "swin_moe_small_patch4_window12_192_32expert_32gpu_22k".
2. Run the following evaluation command, note the checkpoint path should not contain the ".rank\<x\>" suffix.
```bash
python -m torch.distributed.launch --nproc_per_node 8 --nnode=4 \
--node_rank=<node-rank> --master_addr=<master-ip> --master_port 12345 main_moe.py \
--cfg configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml --data-path <imagenet22k-path> --batch-size 128 \
--resume swin_moe_small_patch4_window12_192_32expert_32gpu_22k/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.pth
```
More Swin-MoE models can be found in [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models)
## SimMIM Support
### Evaluating provided models
To evaluate a provided model on ImageNet validation set, run:
```bash
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> main_simmim_ft.py \
--eval --cfg <config-file> --resume <checkpoint> --data-path <imagenet-path>
```
For example, to evaluate the `Swin Base` model on a single GPU, run:
```bash
python -m torch.distributed.launch --nproc_per_node 1 main_simmim_ft.py \
--eval --cfg configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml --resume simmim_finetune__swin_base__img224_window7__800ep.pth --data-path <imagenet-path>
```
### Pre-training with SimMIM
To pre-train models with `SimMIM`, run:
```bash
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> main_simmim_pt.py \
--cfg <config-file> --data-path <imagenet-path>/train [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
```
For example, to pre-train `Swin Base` for 800 epochs on one DGX-2 server, run:
```bash
python -m torch.distributed.launch --nproc_per_node 16 main_simmim_pt.py \
--cfg configs/simmim/simmim_pretrain__swin_base__img192_window6__800ep.yaml --batch-size 128 --data-path <imagenet-path>/train [--output <output-directory> --tag <job-tag>]
```
### Fine-tuning pre-trained models
To fine-tune models pre-trained by `SimMIM`, run:
```bash
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> main_simmim_ft.py \
--cfg <config-file> --data-path <imagenet-path> --pretrained <pretrained-ckpt> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
```
For example, to fine-tune `Swin Base` pre-trained by `SimMIM` on one DGX-2 server, run:
```bash
python -m torch.distributed.launch --nproc_per_node 16 main_simmim_ft.py \
--cfg configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml --batch-size 128 --data-path <imagenet-path> --pretrained <pretrained-ckpt> [--output <output-directory> --tag <job-tag>]
```
\ No newline at end of file
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(name='swin_window_process',
ext_modules=[
CUDAExtension('swin_window_process', [
'swin_window_process.cpp',
'swin_window_process_kernel.cu',
])
],
cmdclass={'build_ext': BuildExtension})
\ No newline at end of file
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/torch.h>
#include <torch/extension.h>
at::Tensor roll_and_window_partition_forward_cuda(
at::Tensor & input,
//at::Tensor & output,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size);
at::Tensor roll_and_window_partition_backward_cuda(
at::Tensor & grad_in,
//at::Tensor & grad_out,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size);
at::Tensor window_merge_and_roll_forward_cuda(
at::Tensor & input,
//at::Tensor & output,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size);
at::Tensor window_merge_and_roll_backward_cuda(
at::Tensor & grad_in,
//at::Tensor & grad_out,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at::Tensor roll_and_window_partition_forward(
at::Tensor & input,
//at::Tensor & output,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size){
CHECK_INPUT(input);
return roll_and_window_partition_forward_cuda(input, B, H, W, C, shift_size, window_size);
}
at::Tensor roll_and_window_partition_backward(
at::Tensor & grad_in,
//at::Tensor & grad_out,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size){
CHECK_INPUT(grad_in);
return roll_and_window_partition_backward_cuda(grad_in, B, H, W, C, shift_size, window_size);
}
at::Tensor window_merge_and_roll_forward(
at::Tensor & input,
//at::Tensor & output,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size){
CHECK_INPUT(input);
return window_merge_and_roll_forward_cuda(input, B, H, W, C, shift_size, window_size);
}
at::Tensor window_merge_and_roll_backward(
at::Tensor & grad_in,
//at::Tensor & grad_out,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size){
CHECK_INPUT(grad_in);
return window_merge_and_roll_backward_cuda(grad_in, B, H, W, C, shift_size, window_size);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("roll_and_window_partition_forward", &roll_and_window_partition_forward, "torch.roll and window_partition.");
m.def("roll_and_window_partition_backward", &roll_and_window_partition_backward, "torch.roll and window_partition.");
m.def("window_merge_and_roll_forward", &window_merge_and_roll_forward, "window merge and torch.roll.");
m.def("window_merge_and_roll_backward", &window_merge_and_roll_backward, "window merge and torch.roll.");
}
\ No newline at end of file
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <stdio.h>
int best_block_dim(int feat_dim){
int best_dim;
if (feat_dim < 384){
best_dim = 64;
}
else{
if (feat_dim < 1024){
best_dim = 128;
}
else{
best_dim = 256;
}
}
return best_dim;
}
template <typename T>
__global__ void roll_and_window_partition_forward_cuda_kernel(
T* input,
T* output,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size,
const int nH,
const int nW){
// start
//bool qual = threadIdx.x < C;
int index = threadIdx.x;
int offset;
for (int i = index; i < C; i += blockDim.x) {
offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
int input_offset = blockIdx.z / (nH * nW) * H * W * C +
(blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y - shift_size + H) % H * W * C +
(blockIdx.z % nW * window_size + blockIdx.x - shift_size + W) % W * C +
i;
output[offset] = (T)(__ldg(input + input_offset));
}
}
template <typename T>
__global__ void roll_and_window_partition_backward_cuda_kernel(
T* grad_in,
T* grad_out,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size,
const int nH,
const int nW){
// start
int index = threadIdx.x;
int offset;
for (int i = index; i < C; i += blockDim.x) {
offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
int input_offset =
(blockIdx.z * nH * nW + (blockIdx.y + shift_size + H) % H / window_size * nW + (blockIdx.x + shift_size + W) % W / window_size) * window_size * window_size * C +
(blockIdx.y + shift_size + H ) % H % window_size * window_size * C +
(blockIdx.x + shift_size + W ) % W % window_size * C +
i;
grad_out[offset] = (T)(__ldg(grad_in + input_offset));
}
}
template <typename T>
__global__ void window_merge_and_roll_forward_cuda_kernel(
T* input,
T* output,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size,
const int nH,
const int nW){
// start
int index = threadIdx.x;
int offset;
for (int i = index; i < C; i += blockDim.x) {
offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
int input_offset =
(blockIdx.z * nH * nW + (blockIdx.y - shift_size + H) % H / window_size * nH + (blockIdx.x - shift_size + W) % W / window_size) * window_size * window_size * C +
(blockIdx.y - shift_size + H) % window_size * window_size * C +
(blockIdx.x - shift_size + W) % window_size * C +
i;
output[offset] = (T)(__ldg(input + input_offset));
}
}
template <typename T>
__global__ void window_merge_and_roll_backward_cuda_kernel(
T* grad_in,
T* grad_out,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size,
const int nH,
const int nW){
// start
int index = threadIdx.x;
int offset;
for (int i = index; i < C; i += blockDim.x) {
offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
int input_offset =
(blockIdx.z / (nH * nW)) * H * W * C +
(blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y + shift_size + H) % H * W * C +
(blockIdx.z % nW * window_size + blockIdx.x + shift_size + W) % W * C +
i;
grad_out[offset] = (T)(__ldg(grad_in + input_offset));
}
}
// input: [B, H, W, C]
// output: [B*nH*nW, window_size, window_size, C]
at::Tensor roll_and_window_partition_forward_cuda(
at::Tensor & input,
//at::Tensor & output,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size){
int nH = H / window_size;
int nW = W / window_size;
dim3 grid(window_size, window_size, B * nH * nW);
//dim3 block((C + 31) / 32 * 32);
int blocknum = best_block_dim(C);
dim3 block(blocknum);
at::Tensor output;
if (input.scalar_type() == torch::kFloat16){
output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true));
}
else{
output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true));
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "roll_and_window_partition_forward_cuda_kernel", ([&] {
roll_and_window_partition_forward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
input.data<scalar_t>(),
output.data<scalar_t>(),
B,
H,
W,
C,
shift_size,
window_size,
nH,
nW);
}));
return output;
}
// grad_in: [B*nH*nW, window_size, window_size, C]
// grad_out: [B, H, W, C]
at::Tensor roll_and_window_partition_backward_cuda(
at::Tensor & grad_in,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size){
int nH = H / window_size;
int nW = W / window_size;
dim3 grid(W, H, B);
//dim3 block((C + 31) / 32 * 32);
int blocknum = best_block_dim(C);
dim3 block(blocknum);
at::Tensor grad_out;
if (grad_in.scalar_type() == torch::kFloat16){
grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false));
}
else{
grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "roll_and_window_partition_backward_cuda_kernel", ([&] {
roll_and_window_partition_backward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
grad_in.data<scalar_t>(),
grad_out.data<scalar_t>(),
B,
H,
W,
C,
shift_size,
window_size,
nH,
nW);
}));
return grad_out;
}
// input: [B*nH*nW, window_size, window_size, C]
// output: [B, H, W, C]
at::Tensor window_merge_and_roll_forward_cuda(
at::Tensor & input,
//at::Tensor & output,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size){
int nH = H / window_size;
int nW = W / window_size;
dim3 grid(W, H, B);
//dim3 block((C + 31) / 32 * 32);
int blocknum = best_block_dim(C);
dim3 block(blocknum);
//generate output tensor inside
at::Tensor output;
if (input.scalar_type() == torch::kFloat16){
output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true));
}
else{
output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true));
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "window_merge_and_roll_forward_cuda_kernel", ([&] {
window_merge_and_roll_forward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
input.data<scalar_t>(),
output.data<scalar_t>(),
B,
H,
W,
C,
shift_size,
window_size,
nH,
nW);
}));
return output;
}
at::Tensor window_merge_and_roll_backward_cuda(
at::Tensor & grad_in,
const int B,
const int H,
const int W,
const int C,
const int shift_size,
const int window_size){
int nH = H / window_size;
int nW = W / window_size;
dim3 grid(window_size, window_size, B * nH * nW);
//dim3 block((C + 31) / 32 * 32);
int blocknum = best_block_dim(C);
dim3 block(blocknum);
at::Tensor grad_out;
if (grad_in.scalar_type() == torch::kFloat16){
grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false));
}
else{
grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "window_merge_and_roll_backward_cuda_kernel", ([&] {
window_merge_and_roll_backward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
grad_in.data<scalar_t>(),
grad_out.data<scalar_t>(),
B,
H,
W,
C,
shift_size,
window_size,
nH,
nW);
}));
return grad_out;
}
\ No newline at end of file
# --------------------------------------------------------
# Fused kernel for window process for SwinTransformer
# Copyright (c) 2022 Nvidia
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import torch
import swin_window_process
import random
import time
import unittest
class WindowProcess(torch.autograd.Function):
@staticmethod
def forward(ctx, input, B, H, W, C, shift_size, window_size):
output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size)
ctx.B = B
ctx.H = H
ctx.W = W
ctx.C = C
ctx.shift_size = shift_size
ctx.window_size = window_size
return output
@staticmethod
def backward(ctx, grad_in):
B = ctx.B
H = ctx.H
W = ctx.W
C = ctx.C
shift_size = ctx.shift_size
window_size = ctx.window_size
grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size)
return grad_out, None, None, None, None, None, None, None
class WindowProcessReverse(torch.autograd.Function):
@staticmethod
def forward(ctx, input, B, H, W, C, shift_size, window_size):
output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size)
ctx.B = B
ctx.H = H
ctx.W = W
ctx.C = C
ctx.shift_size = shift_size
ctx.window_size = window_size
return output
@staticmethod
def backward(ctx, grad_in):
B = ctx.B
H = ctx.H
W = ctx.W
C = ctx.C
shift_size = ctx.shift_size
window_size = ctx.window_size
grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size)
return grad_out, None, None, None, None, None, None, None
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
def pyt_forward(x, shift_size, window_size):
# x in shape(B, H, W, C)
# cyclic shift
if shift_size > 0:
shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, window_size)
return x_windows
def reverse_pyt_forward(attn_windows, shift_size, window_size, H, W):
# x in shape(B*nH*nW, window_size, window_size, C)
shifted_x = window_reverse(attn_windows, window_size, H, W)
if shift_size > 0:
x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))
else:
x = shifted_x
return x
def copy_one_tensor(input, requires_grad=True):
input1 = input.clone().detach().requires_grad_(requires_grad).cuda()
return input1
class Test_WindowProcess(unittest.TestCase):
def setUp(self):
self.B = 192
self.H = 56
self.W = 56
self.C = 96
self.shift_size = 2
self.window_size = 7
self.nH = self.H // self.window_size
self.nW = self.W // self.window_size
def test_roll_and_window_partition_forward(self, dtype=torch.float32):
input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
input1 = copy_one_tensor(input, True)
input2 = copy_one_tensor(input, True)
with torch.no_grad():
# ori
expected = pyt_forward(input1, self.shift_size, self.window_size)
# fused kernel
fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)
self.assertTrue(torch.equal(expected, fused_output))
#self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
def test_roll_and_window_partition_backward(self, dtype=torch.float32):
input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
d_loss_tensor = torch.randn((self.B*self.nW*self.nH, self.window_size, self.window_size, self.C), dtype=dtype).cuda()
input1 = copy_one_tensor(input, True)
input2 = copy_one_tensor(input, True)
# ori
expected = pyt_forward(input1, self.shift_size, self.window_size)
expected.backward(d_loss_tensor)
# fused kernel
fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)
fused_output.backward(d_loss_tensor)
self.assertTrue(torch.equal(expected, fused_output))
#self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
def test_window_merge_and_roll_forward(self, dtype=torch.float32):
input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
input1 = copy_one_tensor(input, True)
input2 = copy_one_tensor(input, True)
with torch.no_grad():
# ori
expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
# fused kernel
fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
self.assertTrue(torch.equal(expected, fused_output))
#self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
def test_window_merge_and_roll_backward(self, dtype=torch.float32):
input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
input1 = copy_one_tensor(input, True)
input2 = copy_one_tensor(input, True)
# ori
expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
expected.backward(d_loss_tensor)
# fused kernel
fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
fused_output.backward(d_loss_tensor)
self.assertTrue(torch.equal(expected, fused_output))
#self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
def test_forward_backward_speed(self, dtype=torch.float32, times=1000):
input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
input1 = copy_one_tensor(input, True)
input2 = copy_one_tensor(input, True)
# SwinTransformer official
def run_pyt(t=1000):
for _ in range(t):
expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
expected.backward(d_loss_tensor)
# my op
def run_fusedop(t=1000):
for _ in range(t):
fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
fused_output.backward(d_loss_tensor)
torch.cuda.synchronize()
t1 = time.time()
run_pyt(t=times)
torch.cuda.synchronize()
t2 = time.time()
run_fusedop(t=times)
torch.cuda.synchronize()
t3 = time.time()
self.assertTrue((t3 - t2) < (t2 - t1))
print('Run {} times'.format(times))
print('Original time cost: {}'.format(t2 - t1))
print('Fused op time cost: {}'.format(t3 - t2))
def test_roll_and_window_partition_forward_fp16(self, dtype=torch.float16):
self.test_roll_and_window_partition_forward(dtype=dtype)
def test_roll_and_window_partition_backward_fp16(self, dtype=torch.float16):
self.test_roll_and_window_partition_backward(dtype=dtype)
def test_window_merge_and_roll_forward_fp16(self, dtype=torch.float16):
self.test_window_merge_and_roll_forward(dtype=dtype)
def test_window_merge_and_roll_backward_fp16(self, dtype=torch.float16):
self.test_window_merge_and_roll_backward(dtype=dtype)
def test_forward_backward_speed_fp16(self, dtype=torch.float16, times=1000):
self.test_forward_backward_speed(dtype=dtype, times=times)
if __name__ == '__main__':
print('Pass only two tensors are exactly the same (using torch.equal).\n')
torch.manual_seed(0)
unittest.main(verbosity=2)
# --------------------------------------------------------
# Fused kernel for window process for SwinTransformer
# Copyright (c) 2022 Nvidia
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import torch
import swin_window_process
class WindowProcess(torch.autograd.Function):
@staticmethod
def forward(ctx, input, B, H, W, C, shift_size, window_size):
output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size)
ctx.B = B
ctx.H = H
ctx.W = W
ctx.C = C
ctx.shift_size = shift_size
ctx.window_size = window_size
return output
@staticmethod
def backward(ctx, grad_in):
B = ctx.B
H = ctx.H
W = ctx.W
C = ctx.C
shift_size = ctx.shift_size
window_size = ctx.window_size
grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size)
return grad_out, None, None, None, None, None, None, None
class WindowProcessReverse(torch.autograd.Function):
@staticmethod
def forward(ctx, input, B, H, W, C, shift_size, window_size):
output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size)
ctx.B = B
ctx.H = H
ctx.W = W
ctx.C = C
ctx.shift_size = shift_size
ctx.window_size = window_size
return output
@staticmethod
def backward(ctx, grad_in):
B = ctx.B
H = ctx.H
W = ctx.W
C = ctx.C
shift_size = ctx.shift_size
window_size = ctx.window_size
#grad_out = ctx.saved_tensors[0]
#grad_out = torch.zeros((B, H, W, C), dtype=dtype).cuda()
grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size)
return grad_out, None, None, None, None, None, None, None
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import os
import sys
import logging
import functools
from termcolor import colored
@functools.lru_cache()
def create_logger(output_dir, dist_rank=0, name=''):
# create logger
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
# create formatter
fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
# create console handlers for master process
if dist_rank == 0:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(
logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
logger.addHandler(console_handler)
# create file handlers
file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
logger.addHandler(file_handler)
return logger
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from .build import build_model
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment