Unverified Commit da01c234 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Develop/experiments (#59)



* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>

* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000

* Integrate 1d tensor parallel in Colossal-AI (#39)

* fixed 1D and 2D convergence (#38)

* optimized 2D operations

* fixed 1D ViT convergence problem

* Feature/ddp (#49)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* support torch ddp

* fix loss accumulation

* add log for ddp

* change seed

* modify timing hook
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* Feature/pipeline (#40)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* optimize communication of pipeline parallel

* fix grad clip for pipeline
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)

* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset

* update api for better usability (#58)

update api for better usability
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>
Co-authored-by: default avatarpuck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
Co-authored-by: default avatarアマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
parent eb2f8b1f
TACC: Starting up job 3497142
TACC: Starting parallel tasks...
warning: variables which starts with __, is a module or class declaration are omitted
process rank 0 is bound to device 0
distributed environment is initialzied
model is created
warning: variables which starts with __, is a module or class declaration are omitted
process rank 2 is bound to device 2
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 3 is bound to device 3
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
training and testing dataloaders are created
loss is created
optimizer is created
start training
warning: variables which starts with __, is a module or class declaration are omitted
process rank 1 is bound to device 1
Files already downloaded and verified
Files already downloaded and verified
epoch: 0, train loss: 1.9320369898056498
epoch: 1, train loss: 1.6352128605453335
epoch: 1, eval loss: 1.5123237550258637, correct: 4542, total: 10000, acc = 0.45419999957084656
epoch: 2, train loss: 1.4457968728882926
epoch: 3, train loss: 1.3382204977833494
epoch: 3, eval loss: 1.2539702713489533, correct: 5451, total: 10000, acc = 0.5450999736785889
epoch: 4, train loss: 1.2739947474732691
epoch: 5, train loss: 1.2285400483073021
epoch: 5, eval loss: 1.1386113047599793, correct: 5908, total: 10000, acc = 0.5907999873161316
epoch: 6, train loss: 1.1903334296479517
epoch: 7, train loss: 1.1711674235305007
epoch: 7, eval loss: 1.1258068561553956, correct: 5967, total: 10000, acc = 0.5967000126838684
epoch: 8, train loss: 1.1419668745021432
epoch: 9, train loss: 1.1143895728247506
epoch: 9, eval loss: 1.040754759311676, correct: 6224, total: 10000, acc = 0.6223999857902527
epoch: 10, train loss: 1.1041023871120141
epoch: 11, train loss: 1.089750115968743
epoch: 11, eval loss: 1.0472844064235687, correct: 6265, total: 10000, acc = 0.6265000104904175
epoch: 12, train loss: 1.064698440687997
epoch: 13, train loss: 1.038266262229608
epoch: 13, eval loss: 1.0117274671792984, correct: 6415, total: 10000, acc = 0.6414999961853027
epoch: 14, train loss: 1.029945282303557
epoch: 15, train loss: 1.0171620620756734
epoch: 15, eval loss: 0.9712629705667496, correct: 6519, total: 10000, acc = 0.6518999934196472
epoch: 16, train loss: 0.9928132119227429
epoch: 17, train loss: 0.9921575498824217
epoch: 17, eval loss: 0.9429782271385193, correct: 6641, total: 10000, acc = 0.6640999913215637
epoch: 18, train loss: 0.9607366293060536
epoch: 19, train loss: 0.9427766927650997
epoch: 19, eval loss: 0.9346068739891052, correct: 6623, total: 10000, acc = 0.6622999906539917
epoch: 20, train loss: 0.9219280481338501
epoch: 21, train loss: 0.8945026689646195
epoch: 21, eval loss: 0.8710516095161438, correct: 6909, total: 10000, acc = 0.6908999681472778
epoch: 22, train loss: 0.8807675826306246
epoch: 23, train loss: 0.851514169756247
epoch: 23, eval loss: 0.8239740908145905, correct: 7052, total: 10000, acc = 0.7051999568939209
epoch: 24, train loss: 0.8388774534877466
epoch: 25, train loss: 0.8265813291072845
epoch: 25, eval loss: 0.8102335959672928, correct: 7137, total: 10000, acc = 0.713699996471405
epoch: 26, train loss: 0.8057564490911912
epoch: 27, train loss: 0.7816558753957554
epoch: 27, eval loss: 0.7648743063211441, correct: 7292, total: 10000, acc = 0.729200005531311
epoch: 28, train loss: 0.766656969883004
epoch: 29, train loss: 0.7515677390049915
epoch: 29, eval loss: 0.7517296761274338, correct: 7360, total: 10000, acc = 0.7360000014305115
epoch: 30, train loss: 0.7300611174836451
epoch: 31, train loss: 0.7038229193006244
epoch: 31, eval loss: 0.7385401755571366, correct: 7375, total: 10000, acc = 0.7374999523162842
epoch: 32, train loss: 0.6928578931458143
epoch: 33, train loss: 0.672958068093475
epoch: 33, eval loss: 0.6915913820266724, correct: 7596, total: 10000, acc = 0.7595999836921692
epoch: 34, train loss: 0.6505378533382805
epoch: 35, train loss: 0.6292881539889744
epoch: 35, eval loss: 0.7068031072616577, correct: 7567, total: 10000, acc = 0.7566999793052673
epoch: 36, train loss: 0.6092992303322773
epoch: 37, train loss: 0.5922880838720166
epoch: 37, eval loss: 0.6735526144504547, correct: 7662, total: 10000, acc = 0.7662000060081482
epoch: 38, train loss: 0.5777627850065425
epoch: 39, train loss: 0.562178050376931
epoch: 39, eval loss: 0.6323211371898652, correct: 7799, total: 10000, acc = 0.7798999547958374
epoch: 40, train loss: 0.5385949274106901
epoch: 41, train loss: 0.5233490755971597
epoch: 41, eval loss: 0.6360922038555146, correct: 7806, total: 10000, acc = 0.7805999517440796
epoch: 42, train loss: 0.50960702373057
epoch: 43, train loss: 0.48859657985823496
epoch: 43, eval loss: 0.607847985625267, correct: 7914, total: 10000, acc = 0.7913999557495117
epoch: 44, train loss: 0.47382923291654006
epoch: 45, train loss: 0.45052725380780745
epoch: 45, eval loss: 0.5986941397190094, correct: 8012, total: 10000, acc = 0.8011999726295471
epoch: 46, train loss: 0.43711013392526277
epoch: 47, train loss: 0.42507915229213483
epoch: 47, eval loss: 0.5871582478284836, correct: 8002, total: 10000, acc = 0.8001999855041504
epoch: 48, train loss: 0.40591827947266246
epoch: 49, train loss: 0.3911267008100237
epoch: 49, eval loss: 0.5832945287227631, correct: 8047, total: 10000, acc = 0.8046999573707581
epoch: 50, train loss: 0.3770884950550235
epoch: 51, train loss: 0.3587312725733738
epoch: 51, eval loss: 0.5942261666059494, correct: 8073, total: 10000, acc = 0.8072999715805054
epoch: 52, train loss: 0.34132662324272856
epoch: 53, train loss: 0.3267737687850485
epoch: 53, eval loss: 0.5920912757515907, correct: 8118, total: 10000, acc = 0.8118000030517578
epoch: 54, train loss: 0.3116904997399875
epoch: 55, train loss: 0.30321489380938665
epoch: 55, eval loss: 0.5957943320274353, correct: 8082, total: 10000, acc = 0.8082000017166138
epoch: 56, train loss: 0.2874147834218278
epoch: 57, train loss: 0.27991348140093747
epoch: 57, eval loss: 0.5895262002944947, correct: 8165, total: 10000, acc = 0.8165000081062317
epoch: 58, train loss: 0.274563160173747
epoch: 59, train loss: 0.2600744918596988
epoch: 59, eval loss: 0.5934095367789268, correct: 8150, total: 10000, acc = 0.8149999976158142
finish training
TACC: Starting up job 3498509
TACC: Starting parallel tasks...
warning: variables which starts with __, is a module or class declaration are omitted
process rank 0 is bound to device 0
distributed environment is initialzied
model is created
Files already downloaded and verified
Files already downloaded and verified
training and testing dataloaders are created
loss is created
optimizer is created
start training
warning: variables which starts with __, is a module or class declaration are omitted
process rank 2 is bound to device 2
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 3 is bound to device 3
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 1 is bound to device 1
Files already downloaded and verified
Files already downloaded and verified
epoch: 0, train loss: 2.107759721425115
epoch: 1, train loss: 1.8388929500871776
epoch: 1, eval loss: 1.7622965753078461, correct: 3535, total: 10000, acc = 0.35349997878074646
epoch: 2, train loss: 1.7141443588295762
epoch: 3, train loss: 1.6003259931291853
epoch: 3, eval loss: 1.608506625890732, correct: 4263, total: 10000, acc = 0.4262999892234802
epoch: 4, train loss: 1.5016733225511045
epoch: 5, train loss: 1.4050611877927974
epoch: 5, eval loss: 1.386299443244934, correct: 4984, total: 10000, acc = 0.4983999729156494
epoch: 6, train loss: 1.3264902623332278
epoch: 7, train loss: 1.2681689250225923
epoch: 7, eval loss: 1.3251740992069245, correct: 5295, total: 10000, acc = 0.5295000076293945
epoch: 8, train loss: 1.2236176984650748
epoch: 9, train loss: 1.172800781775494
epoch: 9, eval loss: 1.1429427027702332, correct: 5966, total: 10000, acc = 0.5965999960899353
epoch: 10, train loss: 1.1335287532027887
epoch: 11, train loss: 1.0974334563527788
epoch: 11, eval loss: 1.1024536848068238, correct: 6107, total: 10000, acc = 0.6107000112533569
epoch: 12, train loss: 1.0638826300903244
epoch: 13, train loss: 1.0406859383291127
epoch: 13, eval loss: 1.0324654281139374, correct: 6282, total: 10000, acc = 0.6281999945640564
epoch: 14, train loss: 1.0157714376644211
epoch: 15, train loss: 0.990898135365272
epoch: 15, eval loss: 0.9790050059556961, correct: 6539, total: 10000, acc = 0.6538999676704407
epoch: 16, train loss: 0.963820260398242
epoch: 17, train loss: 0.9404383374720203
epoch: 17, eval loss: 0.9367435872554779, correct: 6691, total: 10000, acc = 0.6690999865531921
epoch: 18, train loss: 0.9299906589546982
epoch: 19, train loss: 0.9038882474510037
epoch: 19, eval loss: 0.9210823565721512, correct: 6709, total: 10000, acc = 0.6708999872207642
epoch: 20, train loss: 0.8825302799137271
epoch: 21, train loss: 0.8686576388320144
epoch: 21, eval loss: 0.8791542768478393, correct: 6913, total: 10000, acc = 0.6912999749183655
epoch: 22, train loss: 0.8509396040926174
epoch: 23, train loss: 0.8375457452268017
epoch: 23, eval loss: 0.8651147484779358, correct: 6948, total: 10000, acc = 0.6947999596595764
epoch: 24, train loss: 0.8163802222329744
epoch: 25, train loss: 0.8068491317787949
epoch: 25, eval loss: 0.8353333532810211, correct: 7089, total: 10000, acc = 0.708899974822998
epoch: 26, train loss: 0.7894753631280393
epoch: 27, train loss: 0.7779296344640304
epoch: 27, eval loss: 0.8161472469568253, correct: 7143, total: 10000, acc = 0.7142999768257141
epoch: 28, train loss: 0.763744876092794
epoch: 29, train loss: 0.7521962505214068
epoch: 29, eval loss: 0.7903082758188248, correct: 7219, total: 10000, acc = 0.7218999862670898
epoch: 30, train loss: 0.7443178624522929
epoch: 31, train loss: 0.7280340212948468
epoch: 31, eval loss: 0.7877005040645599, correct: 7233, total: 10000, acc = 0.7232999801635742
epoch: 32, train loss: 0.7196985489251663
epoch: 33, train loss: 0.7108793039711154
epoch: 33, eval loss: 0.7838329076766968, correct: 7292, total: 10000, acc = 0.729200005531311
epoch: 34, train loss: 0.6965019471791326
epoch: 35, train loss: 0.6875918537986522
epoch: 35, eval loss: 0.7513678789138794, correct: 7392, total: 10000, acc = 0.7391999959945679
epoch: 36, train loss: 0.6793362346230721
epoch: 37, train loss: 0.6741023343436572
epoch: 37, eval loss: 0.7752945452928544, correct: 7316, total: 10000, acc = 0.7315999865531921
epoch: 38, train loss: 0.6629589072295597
epoch: 39, train loss: 0.6507086388918818
epoch: 39, eval loss: 0.7758691757917404, correct: 7322, total: 10000, acc = 0.7321999669075012
epoch: 40, train loss: 0.6381483582817778
epoch: 41, train loss: 0.6374095179596726
epoch: 41, eval loss: 0.7589699536561966, correct: 7386, total: 10000, acc = 0.738599956035614
epoch: 42, train loss: 0.6251792050137812
epoch: 43, train loss: 0.6148473596086308
epoch: 43, eval loss: 0.7495014071464539, correct: 7478, total: 10000, acc = 0.7477999925613403
epoch: 44, train loss: 0.6119371378908351
epoch: 45, train loss: 0.6012086509441843
epoch: 45, eval loss: 0.725347763299942, correct: 7515, total: 10000, acc = 0.7515000104904175
epoch: 46, train loss: 0.597867566103838
epoch: 47, train loss: 0.5913592832429069
epoch: 47, eval loss: 0.7254288077354432, correct: 7529, total: 10000, acc = 0.7529000043869019
epoch: 48, train loss: 0.5801522807807339
epoch: 49, train loss: 0.575563525666996
epoch: 49, eval loss: 0.7291093468666077, correct: 7533, total: 10000, acc = 0.7532999515533447
epoch: 50, train loss: 0.573031121674849
epoch: 51, train loss: 0.5667383588698446
epoch: 51, eval loss: 0.7240727603435516, correct: 7570, total: 10000, acc = 0.7569999694824219
epoch: 52, train loss: 0.5578772419569443
epoch: 53, train loss: 0.5526659309255834
epoch: 53, eval loss: 0.7226850330829621, correct: 7576, total: 10000, acc = 0.7576000094413757
epoch: 54, train loss: 0.5473246245968099
epoch: 55, train loss: 0.5443006860358375
epoch: 55, eval loss: 0.720612645149231, correct: 7596, total: 10000, acc = 0.7595999836921692
epoch: 56, train loss: 0.5361242987671677
epoch: 57, train loss: 0.5323515981435776
epoch: 57, eval loss: 0.7203025311231613, correct: 7580, total: 10000, acc = 0.7579999566078186
epoch: 58, train loss: 0.5297852404871766
epoch: 59, train loss: 0.5288004583241989
epoch: 59, eval loss: 0.7189624041318894, correct: 7605, total: 10000, acc = 0.7604999542236328
finish training
TACC: Starting up job 3496458
TACC: Starting parallel tasks...
warning: variables which starts with __, is a module or class declaration are omitted
process rank 0 is bound to device 0
distributed environment is initialzied
model is created
warning: variables which starts with __, is a module or class declaration are omitted
process rank 3 is bound to device 3
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 2 is bound to device 2
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
training and testing dataloaders are created
loss is created
warning: variables which starts with __, is a module or class declaration are omitted
process rank 7 is bound to device 3
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 6 is bound to device 2
Files already downloaded and verified
Files already downloaded and verified
optimizer is created
start training
warning: variables which starts with __, is a module or class declaration are omitted
process rank 4 is bound to device 0
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 5 is bound to device 1
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 1 is bound to device 1
Files already downloaded and verified
Files already downloaded and verified
epoch: 0, train loss: 1.936693473738067
epoch: 1, train loss: 1.627108974116189
epoch: 1, eval loss: 1.5279120564460755, correct: 4576, total: 10000, acc = 0.4575999975204468
epoch: 2, train loss: 1.438910031805233
epoch: 3, train loss: 1.3184991053172521
epoch: 3, eval loss: 1.3557079970836639, correct: 5129, total: 10000, acc = 0.5128999948501587
epoch: 4, train loss: 1.271946340191121
epoch: 5, train loss: 1.2340542175331894
epoch: 5, eval loss: 1.207822185754776, correct: 5703, total: 10000, acc = 0.5702999830245972
epoch: 6, train loss: 1.187913371592152
epoch: 7, train loss: 1.154962458172623
epoch: 7, eval loss: 1.0685692846775054, correct: 6100, total: 10000, acc = 0.6100000143051147
epoch: 8, train loss: 1.1158924905621275
epoch: 9, train loss: 1.0909727805731249
epoch: 9, eval loss: 1.0345157146453858, correct: 6328, total: 10000, acc = 0.6327999830245972
epoch: 10, train loss: 1.0725988399009316
epoch: 11, train loss: 1.0453423085261364
epoch: 11, eval loss: 0.9778846323490142, correct: 6543, total: 10000, acc = 0.6542999744415283
epoch: 12, train loss: 1.0397504823548454
epoch: 13, train loss: 1.011059400986652
epoch: 13, eval loss: 0.9668682873249054, correct: 6446, total: 10000, acc = 0.644599974155426
epoch: 14, train loss: 0.9938353963044225
epoch: 15, train loss: 0.9691349967401854
epoch: 15, eval loss: 0.9465512812137604, correct: 6657, total: 10000, acc = 0.6656999588012695
epoch: 16, train loss: 0.9470896617490419
epoch: 17, train loss: 0.927201622602891
epoch: 17, eval loss: 0.8875106543302536, correct: 6837, total: 10000, acc = 0.6836999654769897
epoch: 18, train loss: 0.8975223132542202
epoch: 19, train loss: 0.8810242603019792
epoch: 19, eval loss: 0.8688296616077423, correct: 6832, total: 10000, acc = 0.6832000017166138
epoch: 20, train loss: 0.8482622784011218
epoch: 21, train loss: 0.8266285700457436
epoch: 21, eval loss: 0.7801274597644806, correct: 7205, total: 10000, acc = 0.7204999923706055
epoch: 22, train loss: 0.8038581859092323
epoch: 23, train loss: 0.7879118153027126
epoch: 23, eval loss: 0.7779350578784943, correct: 7203, total: 10000, acc = 0.7202999591827393
epoch: 24, train loss: 0.7542270896386127
epoch: 25, train loss: 0.7369782894241567
epoch: 25, eval loss: 0.7534965008497239, correct: 7362, total: 10000, acc = 0.7361999750137329
epoch: 26, train loss: 0.7095995545387268
epoch: 27, train loss: 0.6873777825005201
epoch: 27, eval loss: 0.7344318777322769, correct: 7381, total: 10000, acc = 0.738099992275238
epoch: 28, train loss: 0.6713967414534822
epoch: 29, train loss: 0.650338428969286
epoch: 29, eval loss: 0.677948921918869, correct: 7653, total: 10000, acc = 0.7652999758720398
epoch: 30, train loss: 0.6301205882004329
epoch: 31, train loss: 0.5990057824825754
epoch: 31, eval loss: 0.6719370454549789, correct: 7643, total: 10000, acc = 0.7642999887466431
epoch: 32, train loss: 0.590088236696866
epoch: 33, train loss: 0.5689327443132595
epoch: 33, eval loss: 0.6191721886396409, correct: 7807, total: 10000, acc = 0.7806999683380127
epoch: 34, train loss: 0.5426055670392756
epoch: 35, train loss: 0.5270413601276825
epoch: 35, eval loss: 0.6150132775306701, correct: 7879, total: 10000, acc = 0.7878999710083008
epoch: 36, train loss: 0.5215025428606539
epoch: 37, train loss: 0.4952395400222467
epoch: 37, eval loss: 0.628344652056694, correct: 7868, total: 10000, acc = 0.786799967288971
epoch: 38, train loss: 0.47989121687655545
epoch: 39, train loss: 0.46510300618045186
epoch: 39, eval loss: 0.5977057978510857, correct: 7944, total: 10000, acc = 0.7943999767303467
epoch: 40, train loss: 0.4441945254802704
epoch: 41, train loss: 0.4285763985648447
epoch: 41, eval loss: 0.5695438250899315, correct: 8023, total: 10000, acc = 0.802299976348877
epoch: 42, train loss: 0.41337763776584546
epoch: 43, train loss: 0.3940146170100387
epoch: 43, eval loss: 0.5688270673155784, correct: 8091, total: 10000, acc = 0.8090999722480774
epoch: 44, train loss: 0.37741332303504554
epoch: 45, train loss: 0.36565779605690313
epoch: 45, eval loss: 0.5831407308578491, correct: 8104, total: 10000, acc = 0.8104000091552734
epoch: 46, train loss: 0.3468657017362361
epoch: 47, train loss: 0.32949359198005834
epoch: 47, eval loss: 0.5751512110233307, correct: 8097, total: 10000, acc = 0.8096999526023865
epoch: 48, train loss: 0.3140165246262842
epoch: 49, train loss: 0.29480520498995877
epoch: 49, eval loss: 0.5712087765336037, correct: 8184, total: 10000, acc = 0.818399965763092
epoch: 50, train loss: 0.2766021394303867
epoch: 51, train loss: 0.26527753776433516
epoch: 51, eval loss: 0.5643855139613152, correct: 8218, total: 10000, acc = 0.8217999935150146
epoch: 52, train loss: 0.2525861115784061
epoch: 53, train loss: 0.23714738658496312
epoch: 53, eval loss: 0.5732526823878288, correct: 8249, total: 10000, acc = 0.8248999714851379
epoch: 54, train loss: 0.2238179413335664
epoch: 55, train loss: 0.2119908875652722
epoch: 55, eval loss: 0.5957901775836945, correct: 8261, total: 10000, acc = 0.8260999917984009
epoch: 56, train loss: 0.19989302222217833
epoch: 57, train loss: 0.1875186789096618
epoch: 57, eval loss: 0.5905491337180138, correct: 8290, total: 10000, acc = 0.8289999961853027
epoch: 58, train loss: 0.18436841180129926
epoch: 59, train loss: 0.17459663231762088
epoch: 59, eval loss: 0.589044263958931, correct: 8313, total: 10000, acc = 0.8312999606132507
finish training
TACC: Starting up job 3498327
TACC: Starting parallel tasks...
warning: variables which starts with __, is a module or class declaration are omitted
process rank 0 is bound to device 0
distributed environment is initialzied
model is created
Files already downloaded and verified
Files already downloaded and verified
training and testing dataloaders are created
loss is created
optimizer is created
start training
warning: variables which starts with __, is a module or class declaration are omitted
process rank 2 is bound to device 2
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 3 is bound to device 3
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 4 is bound to device 0
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 5 is bound to device 1
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 7 is bound to device 3
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 6 is bound to device 2
Files already downloaded and verified
Files already downloaded and verified
warning: variables which starts with __, is a module or class declaration are omitted
process rank 1 is bound to device 1
Files already downloaded and verified
Files already downloaded and verified
epoch: 0, train loss: 2.1005014667705613
epoch: 1, train loss: 1.8539113086097094
epoch: 1, eval loss: 1.7973519027233125, correct: 3362, total: 10000, acc = 0.3361999988555908
epoch: 2, train loss: 1.7149482040989155
epoch: 3, train loss: 1.5927067617980801
epoch: 3, eval loss: 1.5848429083824158, correct: 4344, total: 10000, acc = 0.4343999922275543
epoch: 4, train loss: 1.4912729798531046
epoch: 5, train loss: 1.3957378158763962
epoch: 5, eval loss: 1.4951884388923644, correct: 4841, total: 10000, acc = 0.48409998416900635
epoch: 6, train loss: 1.3090402642074896
epoch: 7, train loss: 1.2566283296565621
epoch: 7, eval loss: 1.2464738070964814, correct: 5562, total: 10000, acc = 0.5561999678611755
epoch: 8, train loss: 1.2084139476017075
epoch: 9, train loss: 1.1706127719003327
epoch: 9, eval loss: 1.162048089504242, correct: 5876, total: 10000, acc = 0.5875999927520752
epoch: 10, train loss: 1.120817175933293
epoch: 11, train loss: 1.084984731309268
epoch: 11, eval loss: 1.0764922022819519, correct: 6155, total: 10000, acc = 0.6154999732971191
epoch: 12, train loss: 1.0559214432628787
epoch: 13, train loss: 1.0261321286765896
epoch: 13, eval loss: 1.0338306188583375, correct: 6334, total: 10000, acc = 0.6333999633789062
epoch: 14, train loss: 0.992842432187528
epoch: 15, train loss: 0.9660871296512837
epoch: 15, eval loss: 1.0059030145406722, correct: 6458, total: 10000, acc = 0.645799994468689
epoch: 16, train loss: 0.9467733100968965
epoch: 17, train loss: 0.9243187673237859
epoch: 17, eval loss: 0.9469569176435471, correct: 6610, total: 10000, acc = 0.6609999537467957
epoch: 18, train loss: 0.9059403721167116
epoch: 19, train loss: 0.8819177935318071
epoch: 19, eval loss: 0.9196836709976196, correct: 6727, total: 10000, acc = 0.6726999878883362
epoch: 20, train loss: 0.8721987532109631
epoch: 21, train loss: 0.8469706013494608
epoch: 21, eval loss: 0.8634845405817032, correct: 6976, total: 10000, acc = 0.6976000070571899
epoch: 22, train loss: 0.8352831839298716
epoch: 23, train loss: 0.8124590455269327
epoch: 23, eval loss: 0.8418784946203232, correct: 7034, total: 10000, acc = 0.7033999562263489
epoch: 24, train loss: 0.7961219853284408
epoch: 25, train loss: 0.7883704268202489
epoch: 25, eval loss: 0.8191130340099335, correct: 7116, total: 10000, acc = 0.7116000056266785
epoch: 26, train loss: 0.7733409623710477
epoch: 27, train loss: 0.7561956893424598
epoch: 27, eval loss: 0.8028618812561035, correct: 7200, total: 10000, acc = 0.7199999690055847
epoch: 28, train loss: 0.7479740460308231
epoch: 29, train loss: 0.7343520899208225
epoch: 29, eval loss: 0.7829996794462204, correct: 7256, total: 10000, acc = 0.725600004196167
epoch: 30, train loss: 0.7244430549290716
epoch: 31, train loss: 0.7121965617549663
epoch: 31, eval loss: 0.765428164601326, correct: 7299, total: 10000, acc = 0.7299000024795532
epoch: 32, train loss: 0.6988190838268825
epoch: 33, train loss: 0.6908610359746583
epoch: 33, eval loss: 0.7602580636739731, correct: 7395, total: 10000, acc = 0.7394999861717224
epoch: 34, train loss: 0.6785666395206841
epoch: 35, train loss: 0.6664504153387887
epoch: 35, eval loss: 0.7671193510293961, correct: 7345, total: 10000, acc = 0.734499990940094
epoch: 36, train loss: 0.6639333245705585
epoch: 37, train loss: 0.6509425913800999
epoch: 37, eval loss: 0.7612941324710846, correct: 7382, total: 10000, acc = 0.7382000088691711
epoch: 38, train loss: 0.6416311720196082
epoch: 39, train loss: 0.6312643265237614
epoch: 39, eval loss: 0.7380059510469437, correct: 7496, total: 10000, acc = 0.7495999932289124
epoch: 40, train loss: 0.620578939209179
epoch: 41, train loss: 0.6195461816933691
epoch: 41, eval loss: 0.7172901630401611, correct: 7550, total: 10000, acc = 0.7549999952316284
epoch: 42, train loss: 0.6013389248020795
epoch: 43, train loss: 0.6049416010477104
epoch: 43, eval loss: 0.7145429253578186, correct: 7569, total: 10000, acc = 0.7568999528884888
epoch: 44, train loss: 0.5950779300563189
epoch: 45, train loss: 0.5786038743598121
epoch: 45, eval loss: 0.7171747118234635, correct: 7569, total: 10000, acc = 0.7568999528884888
epoch: 46, train loss: 0.5752052083915594
epoch: 47, train loss: 0.5669339743195748
epoch: 47, eval loss: 0.7040806382894516, correct: 7601, total: 10000, acc = 0.7601000070571899
epoch: 48, train loss: 0.5596802952338238
epoch: 49, train loss: 0.5521421706189915
epoch: 49, eval loss: 0.7221358746290207, correct: 7592, total: 10000, acc = 0.7591999769210815
epoch: 50, train loss: 0.5504364164508119
epoch: 51, train loss: 0.5363630725412952
epoch: 51, eval loss: 0.710089972615242, correct: 7650, total: 10000, acc = 0.7649999856948853
epoch: 52, train loss: 0.5382009008709265
epoch: 53, train loss: 0.5292040118757559
epoch: 53, eval loss: 0.7044323921203614, correct: 7672, total: 10000, acc = 0.7671999931335449
epoch: 54, train loss: 0.5289747638970005
epoch: 55, train loss: 0.5239191630056926
epoch: 55, eval loss: 0.6983724802732467, correct: 7694, total: 10000, acc = 0.7694000005722046
epoch: 56, train loss: 0.5177402243930467
epoch: 57, train loss: 0.5132759012738053
epoch: 57, eval loss: 0.7066506981849671, correct: 7671, total: 10000, acc = 0.7670999765396118
epoch: 58, train loss: 0.5119742675095188
epoch: 59, train loss: 0.5074386891661858
epoch: 59, eval loss: 0.7012903690338135, correct: 7693, total: 10000, acc = 0.7692999839782715
finish training
from pathlib import Path
import pytest
import torch.autograd
import colossalai
from colossalai.builder import build_lr_scheduler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.nn.layer._parallel_utilities import _gather
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2p5d.py')
def eval(engine, test_dataloader):
engine.eval()
accumulated_loss = 0
correct_sum = 0
total_sum = 0
num_steps = len(test_dataloader)
data_iter = iter(test_dataloader)
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
output = _gather(
output[0],
ParallelMode.PARALLEL_2P5D_ROW,
1
)
output = _gather(
output,
ParallelMode.PARALLEL_2P5D_COL,
0,
)
output = _gather(
output,
ParallelMode.PARALLEL_2P5D_DEP,
0,
)
output = torch.argmax(output, dim=-1)
correct = torch.sum(label[0] == output)
correct_sum += correct
total_sum += label[0].size(0)
avg_loss = accumulated_loss / num_steps
return correct_sum, total_sum, avg_loss
def train(engine, train_dataloader, lr_scheduler):
engine.train()
accumulated_loss = 0
num_steps = len(train_dataloader)
data_iter = iter(train_dataloader)
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
avg_loss = accumulated_loss / num_steps
lr_scheduler.step()
return avg_loss
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2p5d_parallel_vision_transformer():
# init dist
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
logger = get_global_dist_logger()
logger.info('start training')
for epoch in range(gpc.config.num_epochs):
train_loss = train(engine, train_dataloader, lr_scheduler)
logger.info(f'epoch {epoch} - train loss: {train_loss}')
if epoch % 2 == 0:
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
logger.info(
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
if __name__ == '__main__':
test_2p5d_parallel_vision_transformer()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import time
from pathlib import Path
import torch
from tqdm import tqdm
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.trainer import Trainer
from colossalai.trainer.metric import Accuracy3D
from colossalai.utils import print_rank_0
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_3d.py')
def _train_epoch(epoch, engine):
logger = get_global_dist_logger()
print_rank_0('[Epoch %d] training start' % (epoch), logger)
engine.train()
train_loss = 0
batch_cnt = 0
num_samples = 0
now = time.time()
epoch_start = now
progress = range(engine._schedule.num_steps)
if gpc.get_global_rank() == 0:
progress = tqdm(progress, desc='[Epoch %d]' % epoch, miniters=1)
for step in progress:
cur_lr = engine.get_lr()
_, targets, loss = engine.step()
batch_size = targets[0].size(0)
train_loss += loss.item()
num_samples += batch_size
batch_cnt += 1
batch_time = time.time() - now
now = time.time()
if gpc.get_global_rank() == 0:
print_features = dict(lr='%g' % cur_lr,
loss='%.3f' % (train_loss / (step + 1)),
throughput='%.3f (images/sec)' %
(batch_size / (batch_time + 1e-12)))
progress.set_postfix(**print_features)
epoch_end = time.time()
epoch_loss = train_loss / batch_cnt
epoch_throughput = num_samples / (epoch_end - epoch_start + 1e-12)
print_rank_0(
'[Epoch %d] Loss: %.3f | Throughput: %.3f (samples/sec)' %
(epoch, epoch_loss, epoch_throughput), logger)
def _eval(epoch, engine):
logger = get_global_dist_logger()
engine.eval()
eval_loss = 0
acc = Accuracy3D(True, ParallelMode.PARALLEL_3D_OUTPUT,
ParallelMode.PARALLEL_3D_WEIGHT)
total = 0
with torch.no_grad():
for _ in range(engine._schedule.num_steps):
outputs, targets, loss = engine.step()
if isinstance(outputs, (list, tuple)):
outputs = outputs[0]
if isinstance(targets, (list, tuple)):
targets = targets[0]
eval_loss += loss.item()
acc.update(outputs, targets)
total += targets.size(0)
print_rank_0(
'[Epoch %d] Evaluation loss: %.3f | Acc: %.3f%%' %
(epoch, eval_loss / engine._schedule.num_steps,
acc.get_accumulated_value() * 100), logger)
def train():
# init dist
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
logger = get_global_dist_logger()
logger.info("Engine is built", ranks=[0])
trainer = Trainer(engine=engine, verbose=True)
logger.info("Trainer is built", ranks=[0])
logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.num_epochs,
hooks_cfg=gpc.config.hooks,
display_progress=True,
test_interval=1)
if __name__ == '__main__':
train()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from pathlib import Path
import pytest
import torch
from colossalai.builder import build_model
from colossalai.context import Config
CONFIG_PATH = Path(__file__).parent.joinpath('configs/vanilla_vit.py')
@pytest.mark.cpu
def test_with_vanilla_vit_config():
config = Config.from_file(CONFIG_PATH)
model = build_model(config.model)
model.build_from_cfg()
img = torch.randn(1, 3, config.IMG_SIZE, config.IMG_SIZE)
out = model(img)
loss = out.mean()
loss.backward()
if __name__ == '__main__':
test_with_vanilla_vit_config()
import os
from pathlib import Path
BATCH_SIZE = 128
IMG_SIZE = 32
num_epochs = 200
# resnet 50
model = dict(
type='VanillaResNet',
block_type='ResNetBottleneck',
layers=[3, 4, 6, 3],
num_cls=10
)
train_data = dict(
dataset=dict(
type='CIFAR10Dataset',
root=Path(os.environ['DATA']),
transform_pipeline=[
dict(type='Resize', size=IMG_SIZE),
dict(type='RandomCrop', size=IMG_SIZE, padding=4),
dict(type='RandomHorizontalFlip'),
dict(type='ToTensor'),
dict(type='Normalize',
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]),
]
),
dataloader=dict(
batch_size=BATCH_SIZE,
pin_memory=True,
num_workers=4,
shuffle=True
)
)
test_data = dict(
dataset=dict(
type='CIFAR10Dataset',
root=Path(os.environ['DATA']),
train=False,
transform_pipeline=[
dict(type='Resize', size=IMG_SIZE),
dict(type='ToTensor'),
dict(type='Normalize',
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]
),
]
),
dataloader=dict(
batch_size=BATCH_SIZE,
pin_memory=True,
num_workers=4,
shuffle=True
)
)
optimizer = dict(
type='SGD',
lr=0.2,
momentum=0.9,
weight_decay=5e-4
)
loss = dict(
type='CrossEntropyLoss',
)
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None),
)
hooks = [
dict(type='LogMetricByEpochHook'),
......@@ -88,4 +17,3 @@ hooks = [
),
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
]
#!/usr/bin/env sh
test_file=$1
config_file=$2
python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500 --config $config_file
python $test_file --rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
......@@ -13,7 +13,7 @@ from colossalai.communication import (recv_backward, recv_forward,
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import init_dist, parse_args
from colossalai.logging import get_global_dist_logger
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
BATCH_SIZE = 32
......@@ -128,7 +128,7 @@ def test_main():
world_size = args.world_size
init_dist(CONFIG)
logger = get_global_dist_logger()
logger = get_dist_logger()
rank = gpc.get_global_rank()
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
up_ranks = gpc.get_ranks_in_group(ParallelMode.PIPELINE_PREV)
......
......@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
from colossalai.builder import build_dataset, ModelInitializer
from colossalai.core import global_context
from colossalai.initialize import init_dist
from colossalai.logging import get_global_dist_logger
from colossalai.logging import get_dist_logger
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
......@@ -17,7 +17,7 @@ CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
@pytest.mark.dist
def test_partition():
init_dist(CONFIG_PATH)
logger = get_global_dist_logger()
logger = get_dist_logger()
logger.info('finished initialization')
# build model
......
......@@ -8,7 +8,7 @@ import pytest
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import initialize
from colossalai.logging import get_global_dist_logger
from colossalai.logging import get_dist_logger
NUM_BATCH = 128
......@@ -24,7 +24,7 @@ CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
@pytest.mark.dist
def test_schedule():
engine, train_dataloader, test_dataloader = initialize(CONFIG_PATH)
logger = get_global_dist_logger()
logger = get_dist_logger()
model = engine.model
optimizer = engine.optimizer
......
import colossalai
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.trainer import Trainer
def test_trainer():
engine, train_dataloader, test_dataloader = colossalai.initialize()
logger = get_global_dist_logger()
logger.info("engine is built", ranks=[0])
trainer = Trainer(engine=engine,
verbose=True)
logger.info("trainer is built", ranks=[0])
logger.info("start training", ranks=[0])
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
hooks_cfg=gpc.config.hooks,
epochs=gpc.config.num_epochs,
display_progress=False,
test_interval=5
)
if __name__ == '__main__':
test_trainer()
import colossalai
import os
from colossalai.amp.amp_type import AMP_TYPE
import torch.nn as nn
from pathlib import Path
from torchvision import transforms
from torch.optim import Adam
from colossalai.initialize import get_default_parser
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer
from colossalai.utils import get_dataloader
from torchvision.models import resnet18
from torchvision.datasets import CIFAR10
BATCH_SIZE = 128
IMG_SIZE = 32
NUM_EPOCHS = 200
CONFIG = dict(
# Config
fp16=dict(
mode=AMP_TYPE.TORCH
)
)
def test_trainer():
parser = get_default_parser()
args = parser.parse_args()
colossalai.launch(
config=CONFIG,
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend
)
# build model
model = resnet18(num_classes=10)
# build dataloaders
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
test_dataset = CIFAR10(
root=Path(os.environ['DATA']),
train=False,
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
num_workers=1,
pin_memory=True,
drop_last=True)
test_dataloader = get_dataloader(dataset=test_dataset,
batch_size=BATCH_SIZE,
num_workers=1,
pin_memory=True,
drop_last=True)
# build optimizer
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader
)
logger = get_dist_logger()
logger.info("engine is built", ranks=[0])
trainer = Trainer(engine=engine,
logger=logger)
logger.info("trainer is built", ranks=[0])
logger.info("start training", ranks=[0])
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=NUM_EPOCHS,
max_steps=100,
display_progress=True,
test_interval=5
)
if __name__ == '__main__':
test_trainer()
import colossalai
import os
import torch
from colossalai.amp.amp_type import AMP_TYPE
from colossalai.context.parallel_mode import ParallelMode
import torch.nn as nn
from pathlib import Path
from torchvision import transforms
from torch.optim import Adam
from colossalai.initialize import get_default_parser
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer
from colossalai.utils import get_dataloader
from colossalai.engine.schedule import PipelineSchedule
from torchvision.models import resnet18
from torchvision.datasets import CIFAR10
BATCH_SIZE = 32
IMG_SIZE = 32
NUM_EPOCHS = 200
CONFIG = dict(
parallel=dict(
pipeline=2,
),
# Config
fp16=dict(
mode=AMP_TYPE.TORCH
)
)
def test_trainer():
parser = get_default_parser()
args = parser.parse_args()
colossalai.launch(
config=CONFIG,
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend
)
# build model
model = resnet18(num_classes=10)
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
model = nn.Sequential(
model.conv1,
model.bn1,
model.relu,
model.maxpool,
model.layer1,
model.layer2
)
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
from functools import partial
class Flatten(nn.Module):
def forward(self, x):
return torch.flatten(x, 1)
model = nn.Sequential(
model.layer3,
model.layer4,
model.avgpool,
Flatten(),
model.fc
)
# build dataloaders
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
test_dataset = CIFAR10(
root=Path(os.environ['DATA']),
train=False,
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
num_workers=1,
pin_memory=True,
drop_last=True)
test_dataloader = get_dataloader(dataset=test_dataset,
batch_size=BATCH_SIZE,
num_workers=1,
pin_memory=True,
drop_last=True)
# build optimizer
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader
)
logger = get_dist_logger()
logger.info("engine is built", ranks=[0])
pipe_schedule = PipelineSchedule(num_microbatches=4)
trainer = Trainer(engine=engine,
schedule=pipe_schedule,
logger=logger)
logger.info("trainer is built", ranks=[0])
logger.info("start training", ranks=[0])
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=NUM_EPOCHS,
max_steps=100,
display_progress=True,
test_interval=5
)
if __name__ == '__main__':
test_trainer()
import colossalai
import os
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from functools import partial
from pathlib import Path
from torchvision import transforms
from torch.optim import Adam
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import report_memory_usage, get_dataloader
from colossalai.initialize import get_default_parser
from torchvision.models import resnet18
from torchvision.datasets import CIFAR10
# Config
BATCH_SIZE = 16
IMG_SIZE = 224
NUM_CLASSES = 10
CONFIG = dict(
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
),
clip_grad_norm=1.0,
gradient_accumulation=4
)
def run_no_pipeline(rank, world_size):
# init dist env
colossalai.launch(
config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=29500,
backend='nccl'
)
# build model
model = resnet18(num_classes=10)
# build dataloaders
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
pin_memory=True,
drop_last=True)
# build optimizer
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader
)
logger = get_dist_logger()
rank = torch.distributed.get_rank()
param_track = []
grad_track = []
next(model.parameters()).retain_grad()
engine.train()
step = 0
for img, label in train_dataloader:
engine.zero_grad()
img = img.cuda()
label = label.cuda()
output = engine(img)
loss = engine.criterion(output, label)
engine.backward(loss)
engine.step()
# check
param_track.append(next(model.parameters())[0].clone())
grad_track.append(next(model.parameters()).grad[0].clone())
step += 1
if step == CONFIG['gradient_accumulation']:
break
assert not torch.all(grad_track[0] == grad_track[-1]), 'grad should be different in different iterations'
assert torch.all(param_track[0] == param_track[1]) and not torch.all(param_track[0] == param_track[-1]), \
'param should be the same in the first few iterations and only changed in the last iteration'
gpc.destroy()
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_engine():
func = partial(run_no_pipeline, world_size=4)
mp.spawn(func, nprocs=4)
if __name__ == '__main__':
test_engine()
......@@ -2,90 +2,3 @@
# -*- encoding: utf-8 -*-
import os
from pathlib import Path
BATCH_SIZE = 128
IMG_SIZE = 224
NUM_CLS = 1000
# resnet 18
model = dict(
type='VanillaResNet',
block_type='ResNetBottleneck',
layers=[3, 4, 6, 3],
num_cls=NUM_CLS
)
train_data = dict(
dataset=dict(
type='CIFAR10Dataset',
root=Path(os.environ['DATA']),
transform_pipeline=[
dict(type='RandomResizedCrop', size=IMG_SIZE),
dict(type='RandomHorizontalFlip'),
dict(type='ToTensor'),
dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
),
dataloader=dict(
batch_size=64,
pin_memory=True,
num_workers=4,
sampler=dict(
type='DataParallelSampler',
shuffle=True,
)
)
)
test_data = dict(
dataset=dict(
type='CIFAR10Dataset',
root=Path(os.environ['DATA']),
train=False,
transform_pipeline=[
dict(type='Resize', size=(IMG_SIZE, IMG_SIZE)),
dict(type='ToTensor'),
dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
),
dataloader=dict(
batch_size=BATCH_SIZE,
pin_memory=True,
num_workers=4,
)
)
dist_initializer = [
dict(type='DataParallelInitializer'),
]
parallelization = dict(
pipeline=1,
tensor=1,
sequence=-1
)
optimizer = dict(
type='Adam',
lr=0.01
)
loss = dict(
type='CrossEntropyLoss'
)
trainer = dict(
max_epochs=5,
max_iters=1000
)
amp = dict(
fp16=None,
)
level = 2
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os.path as osp
import os
import pytest
import torch
from torch.utils.data import DataLoader
import colossalai
from colossalai.builder import build_dataset, build_loss, build_data_sampler, build_model
from colossalai.core import global_context
from colossalai.engine.gradient_handler import DataParallelGradientHandler
from colossalai.nn.optimizer import ZeroRedundancyOptimizer_Level_1, ZeroRedundancyOptimizer_Level_3, \
ZeroRedundancyOptimizer_Level_2
from colossalai.utils import print_rank_0
DIR_PATH = osp.dirname(osp.abspath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, 'config.py')
def run_dist():
colossalai.init_dist(CONFIG_PATH)
# build resnet model
model = build_model(global_context.config.model)
model.build_from_cfg()
model = model.cuda()
level = global_context.config.level
if level > 1:
model = model.half()
# test init cuda memory
_ = torch.rand(1).cuda()
torch.cuda.synchronize()
max_alloc = torch.cuda.max_memory_allocated()
max_reserved = torch.cuda.max_memory_reserved()
print(f'before run: max_allocation = {max_alloc}, max_reserved = {max_reserved}')
# build dataloader
train_dataset = build_dataset(global_context.config.train_data.dataset)
from pathlib import Path
sampler_cfg = global_context.config.train_data.dataloader.pop('sampler', None)
if sampler_cfg is None:
train_dataloader = DataLoader(dataset=train_dataset, **global_context.config.train_data.dataloader)
else:
sampler = build_data_sampler(sampler_cfg, train_dataset)
train_dataloader = DataLoader(dataset=train_dataset, sampler=sampler,
**global_context.config.train_data.dataloader)
test_dataset = build_dataset(global_context.config.test_data.dataset)
test_dataloader = DataLoader(dataset=test_dataset, **global_context.config.test_data.dataloader)
# build optimizer and loss
# optimizer = build_optimizer(global_context.config.optimizer, model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
if level == 1:
zero_optim = ZeroRedundancyOptimizer_Level_1(init_optimizer=optimizer, verbose=False)
elif level == 2:
zero_optim = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, cpu_offload=True, verbose=False)
elif level == 3:
zero_optim = ZeroRedundancyOptimizer_Level_3(init_optimizer=optimizer,
module=model,
import colossalai
from colossalai.initialize import get_default_parser
from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader
from torchvision import transforms
from torchvision.models import resnet18
from torchvision.datasets import CIFAR10
BATCH_SIZE = 128
IMG_SIZE = 224
NUM_CLS = 1000
CONFIG = dict(
fp16=dict(
mode=None,
),
zero=dict(
# ==============
# level 2 config
# ==============
# level=2,
# cpu_offload=True,
# verbose=False,
# ==============
# level 3 config
# ==============
level=3,
verbose=False,
offload_optimizer_config=dict(
device='cpu',
......@@ -77,70 +49,70 @@ def run_dist():
buffer_size=1e8,
max_in_cpu=1e9
)
),
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
)
loss_fn = build_loss(global_context.config.loss)
gradient_handler = DataParallelGradientHandler(model, zero_optim)
# train
for epoch in range(100):
model.train()
# train
avg_train_loss = 0
train_iter = 0
for idx, (data, label) in enumerate(train_dataloader):
# model = model.half()
data = data[0].cuda()
label = label[0].cuda()
if level > 1:
data = data.half()
output = model(data)
loss = loss_fn(output[0], label)
if level > 1:
zero_optim.backward(loss)
zero_optim.overlapping_partition_gradients_reduce_epilogue()
else:
loss.backward()
gradient_handler.handle_gradient()
zero_optim.step()
zero_optim.zero_grad()
avg_train_loss += loss.detach().cpu().numpy()
train_iter += 1
print_rank_0(f'epoch: {epoch}, train loss: {avg_train_loss / train_iter}')
if epoch % 2 == 0:
model.eval()
avg_eval_loss = 0
correct = 0
total = 0
eval_iters = 0
def run_dist():
parser = get_default_parser()
args = parser.parse_args()
colossalai.launch(config=CONFIG,
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend)
# build model
model = resnet18(num_classes=10)
# build dataloader# build dataloaders
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
num_workers=1,
pin_memory=True,
drop_last=True)
for idx, (data, label) in enumerate(test_dataloader):
with torch.no_grad():
data = data[0].cuda()
label = label[0].cuda()
# build optimizer and loss
# optimizer = build_optimizer(global_context.config.optimizer, model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
if level > 1:
data = data.half()
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
output = model(data)
loss = loss_fn(output[0], label)
# train
model.train()
for idx, (data, label) in enumerate(train_dataloader):
engine.zero_grad()
data = data.cuda()
label = label.cuda()
avg_eval_loss += loss.detach().cpu().numpy()
preds = torch.argmax(output[0], dim=1)
total += data.size(0)
correct += sum(preds == label)
eval_iters += 1
output = engine(data)
loss = engine.criterion(output, label)
print_rank_0(f'epoch: {epoch}, eval loss: {avg_eval_loss / eval_iters}, acc: {correct / total}')
engine.backward(loss)
engine.step()
break
@pytest.mark.skip("This test should be invoked manually using the script provided")
......
#!/bin/bash
test_file="test_zero.py"
python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
python $test_file --rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment