Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
8b3c0f35
Commit
8b3c0f35
authored
Nov 10, 2021
by
Tim Dettmers
Browse files
Added adagrad with tests (no clipping).
parent
22b2877c
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
105 additions
and
2 deletions
+105
-2
bitsandbytes/functional.py
bitsandbytes/functional.py
+2
-0
bitsandbytes/optim/__init__.py
bitsandbytes/optim/__init__.py
+1
-0
bitsandbytes/optim/adagrad.py
bitsandbytes/optim/adagrad.py
+57
-0
csrc/kernels.cu
csrc/kernels.cu
+22
-0
csrc/ops.cu
csrc/ops.cu
+8
-0
csrc/ops.cuh
csrc/ops.cuh
+1
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+8
-0
tests/test_optim.py
tests/test_optim.py
+6
-2
No files found.
bitsandbytes/functional.py
View file @
8b3c0f35
...
@@ -19,6 +19,7 @@ str2optimizer32bit = {}
...
@@ -19,6 +19,7 @@ str2optimizer32bit = {}
str2optimizer32bit
[
'adam'
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
'adam'
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
'momentum'
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
'momentum'
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
'rmsprop'
]
=
(
lib
.
crmsprop32bit_g32
,
lib
.
crmsprop32bit_g16
)
str2optimizer32bit
[
'rmsprop'
]
=
(
lib
.
crmsprop32bit_g32
,
lib
.
crmsprop32bit_g16
)
str2optimizer32bit
[
'adagrad'
]
=
(
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g16
)
str2optimizer32bit
[
'lars'
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
'lars'
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
)
str2optimizer32bit
[
'lamb'
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
'lamb'
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
...
@@ -33,6 +34,7 @@ str2optimizer8bit_blockwise = {}
...
@@ -33,6 +34,7 @@ str2optimizer8bit_blockwise = {}
str2optimizer8bit_blockwise
[
'adam'
]
=
(
lib
.
cadam_8bit_blockwise_fp32
,
lib
.
cadam_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
'adam'
]
=
(
lib
.
cadam_8bit_blockwise_fp32
,
lib
.
cadam_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
'momentum'
]
=
(
lib
.
cmomentum_8bit_blockwise_fp32
,
lib
.
cmomentum_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
'momentum'
]
=
(
lib
.
cmomentum_8bit_blockwise_fp32
,
lib
.
cmomentum_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
'rmsprop'
]
=
(
lib
.
crmsprop_8bit_blockwise_fp32
,
lib
.
crmsprop_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
'rmsprop'
]
=
(
lib
.
crmsprop_8bit_blockwise_fp32
,
lib
.
crmsprop_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
'adagrad'
]
=
(
lib
.
cadagrad_8bit_blockwise_fp32
,
lib
.
cadagrad_8bit_blockwise_fp16
)
optimal_normal
=
[
-
0.9939730167388916
,
-
0.8727636337280273
,
-
0.8097418546676636
,
-
0.7660024166107178
,
-
0.7318882346153259
,
-
0.6793879270553589
,
-
0.657649040222168
,
-
0.6385974884033203
,
-
0.6211113333702087
,
-
0.5901028513908386
,
-
0.5762918591499329
,
-
0.5630806684494019
,
-
0.5509274005889893
,
-
0.5394591689109802
,
-
0.5283197164535522
,
-
0.517780065536499
,
-
0.5074946284294128
,
-
0.4980469048023224
,
-
0.48867011070251465
,
-
0.48003149032592773
,
-
0.47125306725502014
,
-
0.4629971981048584
,
-
0.4547359049320221
,
-
0.446626216173172
,
-
0.43902668356895447
,
-
0.43158355355262756
,
-
0.4244747757911682
,
-
0.4173796474933624
,
-
0.41038978099823
,
-
0.4055633544921875
,
-
0.4035947024822235
,
-
0.39701032638549805
,
-
0.39057496190071106
,
-
0.38439232110977173
,
-
0.3782760500907898
,
-
0.3721940815448761
,
-
0.3661896586418152
,
-
0.3604033589363098
,
-
0.354605108499527
,
-
0.34892538189888
,
-
0.34320303797721863
,
-
0.3376772701740265
,
-
0.3323028087615967
,
-
0.3269782066345215
,
-
0.32166096568107605
,
-
0.316457599401474
,
-
0.3112771809101105
,
-
0.3061025142669678
,
-
0.30106794834136963
,
-
0.2961243987083435
,
-
0.2912728488445282
,
-
0.28644347190856934
,
-
0.28165507316589355
,
-
0.2769731283187866
,
-
0.2722635865211487
,
-
0.26779335737228394
,
-
0.26314786076545715
,
-
0.2586647868156433
,
-
0.2541804611682892
,
-
0.2496625930070877
,
-
0.24527113139629364
,
-
0.24097171425819397
,
-
0.23659978806972504
,
-
0.23218469321727753
,
-
0.22799566388130188
,
-
0.22380566596984863
,
-
0.21965542435646057
,
-
0.2154538631439209
,
-
0.2113603949546814
,
-
0.20735277235507965
,
-
0.20334717631340027
,
-
0.19932441413402557
,
-
0.19530178606510162
,
-
0.19136647880077362
,
-
0.18736697733402252
,
-
0.18337111175060272
,
-
0.17951400578022003
,
-
0.1757056713104248
,
-
0.17182783782482147
,
-
0.1680615097284317
,
-
0.16431649029254913
,
-
0.16053077578544617
,
-
0.15685945749282837
,
-
0.15298527479171753
,
-
0.1493264138698578
,
-
0.14566898345947266
,
-
0.14188314974308014
,
-
0.13819937407970428
,
-
0.1344561129808426
,
-
0.1306886374950409
,
-
0.1271020770072937
,
-
0.12346585839986801
,
-
0.11981867253780365
,
-
0.11614970862865448
,
-
0.11256207525730133
,
-
0.10889036953449249
,
-
0.10525048524141312
,
-
0.1016591489315033
,
-
0.09824034571647644
,
-
0.09469068050384521
,
-
0.0911419615149498
,
-
0.08773849159479141
,
-
0.08416644483804703
,
-
0.08071305602788925
,
-
0.07720902562141418
,
-
0.07371306419372559
,
-
0.07019119709730148
,
-
0.06673648208379745
,
-
0.06329209357500076
,
-
0.059800852090120316
,
-
0.0564190037548542
,
-
0.05296570807695389
,
-
0.049522045999765396
,
-
0.04609023034572601
,
-
0.04262964054942131
,
-
0.039246633648872375
,
-
0.03577171266078949
,
-
0.03236335143446922
,
-
0.028855687007308006
,
-
0.02542758360505104
,
-
0.022069433704018593
,
-
0.018754752352833748
,
-
0.015386369079351425
,
-
0.01194947212934494
,
-
0.008439815603196621
,
-
0.004995611496269703
,
-
0.0016682245768606663
,
0.0
,
0.0015510577941313386
,
0.005062474869191647
,
0.008417150937020779
,
0.011741090565919876
,
0.015184164978563786
,
0.018582714721560478
,
0.02204744517803192
,
0.025471193715929985
,
0.02889077737927437
,
0.0323684960603714
,
0.03579240292310715
,
0.039281025528907776
,
0.0427563451230526
,
0.04619763046503067
,
0.04968220740556717
,
0.05326594039797783
,
0.05679265409708023
,
0.060245808213949203
,
0.06372645497322083
,
0.06721872836351395
,
0.0706876739859581
,
0.0742349922657013
,
0.07774098962545395
,
0.08123527467250824
,
0.08468879014253616
,
0.08810535818338394
,
0.09155989438295364
,
0.09498448669910431
,
0.0985206812620163
,
0.10206405073404312
,
0.10563778132200241
,
0.10921968519687653
,
0.11284469068050385
,
0.11653254181146622
,
0.12008969485759735
,
0.12368203699588776
,
0.1272617131471634
,
0.13089501857757568
,
0.134552001953125
,
0.1382799744606018
,
0.14194637537002563
,
0.14563234150409698
,
0.14930322766304016
,
0.15303383767604828
,
0.1567956507205963
,
0.16050070524215698
,
0.16431072354316711
,
0.16813558340072632
,
0.17204202711582184
,
0.1758781224489212
,
0.17973239719867706
,
0.1836014688014984
,
0.18753431737422943
,
0.19138391315937042
,
0.19535475969314575
,
0.19931404292583466
,
0.20333819091320038
,
0.20738255977630615
,
0.21152682602405548
,
0.21568812429904938
,
0.21978361904621124
,
0.22393859922885895
,
0.22814159095287323
,
0.23241068422794342
,
0.23675410449504852
,
0.24123944342136383
,
0.24569889903068542
,
0.2500703036785126
,
0.25904011726379395
,
0.26349544525146484
,
0.2682226300239563
,
0.272907555103302
,
0.2774306833744049
,
0.28220856189727783
,
0.2869136929512024
,
0.2916390895843506
,
0.29649388790130615
,
0.30142995715141296
,
0.3065022826194763
,
0.3114383816719055
,
0.31648796796798706
,
0.3216581642627716
,
0.32700115442276
,
0.3322487473487854
,
0.33778008818626404
,
0.3431521952152252
,
0.3487405776977539
,
0.3543166518211365
,
0.3601346015930176
,
0.36605337262153625
,
0.37217751145362854
,
0.378179669380188
,
0.3843980133533478
,
0.3906566798686981
,
0.39714935421943665
,
0.40357843041419983
,
0.4104187488555908
,
0.4171563684940338
,
0.42418959736824036
,
0.43136918544769287
,
0.4389212429523468
,
0.44673123955726624
,
0.45457619428634644
,
0.4627031683921814
,
0.47130417823791504
,
0.4798591434955597
,
0.48897242546081543
,
0.4979848861694336
,
0.5
,
0.5076631307601929
,
0.5177803635597229
,
0.5282770991325378
,
0.5392990112304688
,
0.5506287813186646
,
0.5632893443107605
,
0.5764452815055847
,
0.5903191566467285
,
0.6051878333091736
,
0.6209936141967773
,
0.6382884979248047
,
0.6573970913887024
,
0.6795773506164551
,
0.7037051916122437
,
0.7327037453651428
,
0.7677436470985413
,
0.8111193776130676
,
0.875165581703186
,
1.0
]
optimal_normal
=
[
-
0.9939730167388916
,
-
0.8727636337280273
,
-
0.8097418546676636
,
-
0.7660024166107178
,
-
0.7318882346153259
,
-
0.6793879270553589
,
-
0.657649040222168
,
-
0.6385974884033203
,
-
0.6211113333702087
,
-
0.5901028513908386
,
-
0.5762918591499329
,
-
0.5630806684494019
,
-
0.5509274005889893
,
-
0.5394591689109802
,
-
0.5283197164535522
,
-
0.517780065536499
,
-
0.5074946284294128
,
-
0.4980469048023224
,
-
0.48867011070251465
,
-
0.48003149032592773
,
-
0.47125306725502014
,
-
0.4629971981048584
,
-
0.4547359049320221
,
-
0.446626216173172
,
-
0.43902668356895447
,
-
0.43158355355262756
,
-
0.4244747757911682
,
-
0.4173796474933624
,
-
0.41038978099823
,
-
0.4055633544921875
,
-
0.4035947024822235
,
-
0.39701032638549805
,
-
0.39057496190071106
,
-
0.38439232110977173
,
-
0.3782760500907898
,
-
0.3721940815448761
,
-
0.3661896586418152
,
-
0.3604033589363098
,
-
0.354605108499527
,
-
0.34892538189888
,
-
0.34320303797721863
,
-
0.3376772701740265
,
-
0.3323028087615967
,
-
0.3269782066345215
,
-
0.32166096568107605
,
-
0.316457599401474
,
-
0.3112771809101105
,
-
0.3061025142669678
,
-
0.30106794834136963
,
-
0.2961243987083435
,
-
0.2912728488445282
,
-
0.28644347190856934
,
-
0.28165507316589355
,
-
0.2769731283187866
,
-
0.2722635865211487
,
-
0.26779335737228394
,
-
0.26314786076545715
,
-
0.2586647868156433
,
-
0.2541804611682892
,
-
0.2496625930070877
,
-
0.24527113139629364
,
-
0.24097171425819397
,
-
0.23659978806972504
,
-
0.23218469321727753
,
-
0.22799566388130188
,
-
0.22380566596984863
,
-
0.21965542435646057
,
-
0.2154538631439209
,
-
0.2113603949546814
,
-
0.20735277235507965
,
-
0.20334717631340027
,
-
0.19932441413402557
,
-
0.19530178606510162
,
-
0.19136647880077362
,
-
0.18736697733402252
,
-
0.18337111175060272
,
-
0.17951400578022003
,
-
0.1757056713104248
,
-
0.17182783782482147
,
-
0.1680615097284317
,
-
0.16431649029254913
,
-
0.16053077578544617
,
-
0.15685945749282837
,
-
0.15298527479171753
,
-
0.1493264138698578
,
-
0.14566898345947266
,
-
0.14188314974308014
,
-
0.13819937407970428
,
-
0.1344561129808426
,
-
0.1306886374950409
,
-
0.1271020770072937
,
-
0.12346585839986801
,
-
0.11981867253780365
,
-
0.11614970862865448
,
-
0.11256207525730133
,
-
0.10889036953449249
,
-
0.10525048524141312
,
-
0.1016591489315033
,
-
0.09824034571647644
,
-
0.09469068050384521
,
-
0.0911419615149498
,
-
0.08773849159479141
,
-
0.08416644483804703
,
-
0.08071305602788925
,
-
0.07720902562141418
,
-
0.07371306419372559
,
-
0.07019119709730148
,
-
0.06673648208379745
,
-
0.06329209357500076
,
-
0.059800852090120316
,
-
0.0564190037548542
,
-
0.05296570807695389
,
-
0.049522045999765396
,
-
0.04609023034572601
,
-
0.04262964054942131
,
-
0.039246633648872375
,
-
0.03577171266078949
,
-
0.03236335143446922
,
-
0.028855687007308006
,
-
0.02542758360505104
,
-
0.022069433704018593
,
-
0.018754752352833748
,
-
0.015386369079351425
,
-
0.01194947212934494
,
-
0.008439815603196621
,
-
0.004995611496269703
,
-
0.0016682245768606663
,
0.0
,
0.0015510577941313386
,
0.005062474869191647
,
0.008417150937020779
,
0.011741090565919876
,
0.015184164978563786
,
0.018582714721560478
,
0.02204744517803192
,
0.025471193715929985
,
0.02889077737927437
,
0.0323684960603714
,
0.03579240292310715
,
0.039281025528907776
,
0.0427563451230526
,
0.04619763046503067
,
0.04968220740556717
,
0.05326594039797783
,
0.05679265409708023
,
0.060245808213949203
,
0.06372645497322083
,
0.06721872836351395
,
0.0706876739859581
,
0.0742349922657013
,
0.07774098962545395
,
0.08123527467250824
,
0.08468879014253616
,
0.08810535818338394
,
0.09155989438295364
,
0.09498448669910431
,
0.0985206812620163
,
0.10206405073404312
,
0.10563778132200241
,
0.10921968519687653
,
0.11284469068050385
,
0.11653254181146622
,
0.12008969485759735
,
0.12368203699588776
,
0.1272617131471634
,
0.13089501857757568
,
0.134552001953125
,
0.1382799744606018
,
0.14194637537002563
,
0.14563234150409698
,
0.14930322766304016
,
0.15303383767604828
,
0.1567956507205963
,
0.16050070524215698
,
0.16431072354316711
,
0.16813558340072632
,
0.17204202711582184
,
0.1758781224489212
,
0.17973239719867706
,
0.1836014688014984
,
0.18753431737422943
,
0.19138391315937042
,
0.19535475969314575
,
0.19931404292583466
,
0.20333819091320038
,
0.20738255977630615
,
0.21152682602405548
,
0.21568812429904938
,
0.21978361904621124
,
0.22393859922885895
,
0.22814159095287323
,
0.23241068422794342
,
0.23675410449504852
,
0.24123944342136383
,
0.24569889903068542
,
0.2500703036785126
,
0.25904011726379395
,
0.26349544525146484
,
0.2682226300239563
,
0.272907555103302
,
0.2774306833744049
,
0.28220856189727783
,
0.2869136929512024
,
0.2916390895843506
,
0.29649388790130615
,
0.30142995715141296
,
0.3065022826194763
,
0.3114383816719055
,
0.31648796796798706
,
0.3216581642627716
,
0.32700115442276
,
0.3322487473487854
,
0.33778008818626404
,
0.3431521952152252
,
0.3487405776977539
,
0.3543166518211365
,
0.3601346015930176
,
0.36605337262153625
,
0.37217751145362854
,
0.378179669380188
,
0.3843980133533478
,
0.3906566798686981
,
0.39714935421943665
,
0.40357843041419983
,
0.4104187488555908
,
0.4171563684940338
,
0.42418959736824036
,
0.43136918544769287
,
0.4389212429523468
,
0.44673123955726624
,
0.45457619428634644
,
0.4627031683921814
,
0.47130417823791504
,
0.4798591434955597
,
0.48897242546081543
,
0.4979848861694336
,
0.5
,
0.5076631307601929
,
0.5177803635597229
,
0.5282770991325378
,
0.5392990112304688
,
0.5506287813186646
,
0.5632893443107605
,
0.5764452815055847
,
0.5903191566467285
,
0.6051878333091736
,
0.6209936141967773
,
0.6382884979248047
,
0.6573970913887024
,
0.6795773506164551
,
0.7037051916122437
,
0.7327037453651428
,
0.7677436470985413
,
0.8111193776130676
,
0.875165581703186
,
1.0
]
...
...
bitsandbytes/optim/__init__.py
View file @
8b3c0f35
...
@@ -7,4 +7,5 @@ from .sgd import SGD, SGD8bit, SGD32bit
...
@@ -7,4 +7,5 @@ from .sgd import SGD, SGD8bit, SGD32bit
from
.lars
import
LARS
,
LARS8bit
,
LARS32bit
,
PytorchLARS
from
.lars
import
LARS
,
LARS8bit
,
LARS32bit
,
PytorchLARS
from
.lamb
import
LAMB
,
LAMB8bit
,
LAMB32bit
from
.lamb
import
LAMB
,
LAMB8bit
,
LAMB32bit
from
.rmsprop
import
RMSprop
,
RMSprop8bit
,
RMSprop32bit
from
.rmsprop
import
RMSprop
,
RMSprop8bit
,
RMSprop32bit
from
.adagrad
import
Adagrad
,
Adagrad8bit
,
Adagrad32bit
from
.optimizer
import
GlobalOptimManager
from
.optimizer
import
GlobalOptimManager
bitsandbytes/optim/adagrad.py
0 → 100644
View file @
8b3c0f35
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
bitsandbytes.optim.optimizer
import
Optimizer1State
torch
.
optim
.
Adagrad
class
Adagrad
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'Initial accumulator value != 0.0 not supported!'
)
if
lr_decay
!=
0.0
:
raise
ValueError
(
'Lr Decay != 0.0 not supported!'
)
super
(
Adagrad
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
class
Adagrad8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
8
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'Initial accumulator value != 0.0 not supported!'
)
if
lr_decay
!=
0.0
:
raise
ValueError
(
'Lr Decay != 0.0 not supported!'
)
assert
block_wise
super
(
Adagrad8bit
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
class
Adagrad32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-2
,
lr_decay
=
0
,
weight_decay
=
0
,
initial_accumulator_value
=
0
,
eps
=
1e-10
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
initial_accumulator_value
!=
0.0
:
raise
ValueError
(
'Initial accumulator value != 0.0 not supported!'
)
if
lr_decay
!=
0.0
:
raise
ValueError
(
'Lr Decay != 0.0 not supported!'
)
super
(
Adagrad32bit
,
self
).
__init__
(
'adagrad'
,
params
,
lr
,
(
0.0
,
0.0
),
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
csrc/kernels.cu
View file @
8b3c0f35
...
@@ -790,6 +790,11 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
...
@@ -790,6 +790,11 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
s1_vals
[
j
]
=
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
(
s1_vals
[
j
])
+
eps
);
// update value
s1_vals
[
j
]
=
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
(
s1_vals
[
j
])
+
eps
);
// update value
s1_vals
[
j
]
=
s1_vals
[
j
]
*
s1_vals
[
j
];
// update norm
s1_vals
[
j
]
=
s1_vals
[
j
]
*
s1_vals
[
j
];
// update norm
break
;
break
;
case
ADAGRAD
:
s1_vals
[
j
]
=
s1_vals
[
j
]
+
((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
]);
// state update
s1_vals
[
j
]
=
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
(
s1_vals
[
j
])
+
eps
);
// update value
s1_vals
[
j
]
=
s1_vals
[
j
]
*
s1_vals
[
j
];
// update norm
break
;
}
}
}
}
...
@@ -884,6 +889,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
...
@@ -884,6 +889,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
]));
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
]));
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
update_scale
*
(
lr
*
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
((
float
)
s1_vals
[
j
])
+
eps
));
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
update_scale
*
(
lr
*
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
((
float
)
s1_vals
[
j
])
+
eps
));
break
;
break
;
case
ADAGRAD
:
s1_vals
[
j
]
=
s1_vals
[
j
]
+
((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
]);
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
((
float
)
s1_vals
[
j
])
+
eps
);
break
;
}
}
}
}
}
}
...
@@ -1653,6 +1662,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
...
@@ -1653,6 +1662,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
case
RMSPROP
:
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
break
;
break
;
case
ADAGRAD
:
s1_vals
[
j
]
=
s1_vals
[
j
]
+
(
g_val
*
g_val
);
break
;
}
}
}
}
...
@@ -1688,6 +1700,10 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
...
@@ -1688,6 +1700,10 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
g_val
=
g_vals
[
j
];
g_val
=
g_vals
[
j
];
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
(
__fdividef
(
g_val
,
sqrtf
(
s1_vals
[
j
])
+
eps
));
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
(
__fdividef
(
g_val
,
sqrtf
(
s1_vals
[
j
])
+
eps
));
break
;
break
;
case
ADAGRAD
:
g_val
=
g_vals
[
j
];
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
(
__fdividef
(
g_val
,
sqrtf
(
s1_vals
[
j
])
+
eps
));
break
;
}
}
}
}
}
}
...
@@ -1738,6 +1754,8 @@ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
...
@@ -1738,6 +1754,8 @@ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State
(
MOMENTUM
,
float
)
MAKE_PreconditionOptimizer32bit1State
(
MOMENTUM
,
float
)
MAKE_PreconditionOptimizer32bit1State
(
RMSPROP
,
half
)
MAKE_PreconditionOptimizer32bit1State
(
RMSPROP
,
half
)
MAKE_PreconditionOptimizer32bit1State
(
RMSPROP
,
float
)
MAKE_PreconditionOptimizer32bit1State
(
RMSPROP
,
float
)
MAKE_PreconditionOptimizer32bit1State
(
ADAGRAD
,
half
)
MAKE_PreconditionOptimizer32bit1State
(
ADAGRAD
,
float
)
#define MAKE_Optimizer32bit1State(oname, gtype) \
#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
...
@@ -1747,6 +1765,8 @@ MAKE_Optimizer32bit1State(MOMENTUM, half)
...
@@ -1747,6 +1765,8 @@ MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State
(
MOMENTUM
,
float
)
MAKE_Optimizer32bit1State
(
MOMENTUM
,
float
)
MAKE_Optimizer32bit1State
(
RMSPROP
,
half
)
MAKE_Optimizer32bit1State
(
RMSPROP
,
half
)
MAKE_Optimizer32bit1State
(
RMSPROP
,
float
)
MAKE_Optimizer32bit1State
(
RMSPROP
,
float
)
MAKE_Optimizer32bit1State
(
ADAGRAD
,
half
)
MAKE_Optimizer32bit1State
(
ADAGRAD
,
float
)
#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
...
@@ -1862,3 +1882,5 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
...
@@ -1862,3 +1882,5 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise
(
MOMENTUM
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
MOMENTUM
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
RMSPROP
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
RMSPROP
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
RMSPROP
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
RMSPROP
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
ADAGRAD
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
ADAGRAD
,
half
,
2048
,
8
)
csrc/ops.cu
View file @
8b3c0f35
...
@@ -199,6 +199,8 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
...
@@ -199,6 +199,8 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
break
;
break
;
case
MOMENTUM
:
case
MOMENTUM
:
case
RMSPROP
:
case
RMSPROP
:
case
ADAGRAD
:
if
(
max_unorm
>
0.0
f
)
if
(
max_unorm
>
0.0
f
)
{
{
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
...
@@ -240,6 +242,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
...
@@ -240,6 +242,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
break
;
break
;
case
MOMENTUM
:
case
MOMENTUM
:
case
RMSPROP
:
case
RMSPROP
:
case
ADAGRAD
:
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
...
@@ -274,6 +277,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
...
@@ -274,6 +277,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
break
;
break
;
case
MOMENTUM
:
case
MOMENTUM
:
case
RMSPROP
:
case
RMSPROP
:
case
ADAGRAD
:
blocks
=
n
/
BLOCKSIZE_1STATE
;
blocks
=
n
/
BLOCKSIZE_1STATE
;
blocks
=
n
%
BLOCKSIZE_1STATE
==
0
?
blocks
:
blocks
+
1
;
blocks
=
n
%
BLOCKSIZE_1STATE
==
0
?
blocks
:
blocks
+
1
;
kOptimizerStatic8bit1StateBlockwise
<
T
,
OPTIMIZER
,
BLOCKSIZE_1STATE
,
NUM_1STATE
><<<
blocks
,
BLOCKSIZE_1STATE
/
NUM_1STATE
>>>
(
p
,
g
,
state1
,
beta1
,
beta2
,
eps
,
step
,
lr
,
kOptimizerStatic8bit1StateBlockwise
<
T
,
OPTIMIZER
,
BLOCKSIZE_1STATE
,
NUM_1STATE
><<<
blocks
,
BLOCKSIZE_1STATE
/
NUM_1STATE
>>>
(
p
,
g
,
state1
,
beta1
,
beta2
,
eps
,
step
,
lr
,
...
@@ -321,6 +325,8 @@ MAKE_optimizer32bit(MOMENTUM, half)
...
@@ -321,6 +325,8 @@ MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit
(
MOMENTUM
,
float
)
MAKE_optimizer32bit
(
MOMENTUM
,
float
)
MAKE_optimizer32bit
(
RMSPROP
,
half
)
MAKE_optimizer32bit
(
RMSPROP
,
half
)
MAKE_optimizer32bit
(
RMSPROP
,
float
)
MAKE_optimizer32bit
(
RMSPROP
,
float
)
MAKE_optimizer32bit
(
ADAGRAD
,
half
)
MAKE_optimizer32bit
(
ADAGRAD
,
float
)
#define MAKE_optimizerStatic8bit(name, gtype) \
#define MAKE_optimizerStatic8bit(name, gtype) \
template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
...
@@ -350,6 +356,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
...
@@ -350,6 +356,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise
(
float
,
MOMENTUM
);
MAKE_optimizerStatic8bitBlockwise
(
float
,
MOMENTUM
);
MAKE_optimizerStatic8bitBlockwise
(
half
,
RMSPROP
);
MAKE_optimizerStatic8bitBlockwise
(
half
,
RMSPROP
);
MAKE_optimizerStatic8bitBlockwise
(
float
,
RMSPROP
);
MAKE_optimizerStatic8bitBlockwise
(
float
,
RMSPROP
);
MAKE_optimizerStatic8bitBlockwise
(
half
,
ADAGRAD
);
MAKE_optimizerStatic8bitBlockwise
(
float
,
ADAGRAD
);
template
void
percentileClipping
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
template
void
percentileClipping
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
template
void
percentileClipping
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
template
void
percentileClipping
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
csrc/ops.cuh
View file @
8b3c0f35
...
@@ -36,6 +36,7 @@ typedef enum Optimizer_t
...
@@ -36,6 +36,7 @@ typedef enum Optimizer_t
MOMENTUM
=
1
,
MOMENTUM
=
1
,
RMSPROP
=
2
,
RMSPROP
=
2
,
LARS
=
3
,
LARS
=
3
,
ADAGRAD
=
4
,
}
Optimizer_t
;
}
Optimizer_t
;
...
...
csrc/pythonInterface.c
View file @
8b3c0f35
...
@@ -29,6 +29,8 @@ MAKE_FUNC32(adam, ADAM, float, 32)
...
@@ -29,6 +29,8 @@ MAKE_FUNC32(adam, ADAM, float, 32)
MAKE_FUNC32
(
adam
,
ADAM
,
half
,
16
)
MAKE_FUNC32
(
adam
,
ADAM
,
half
,
16
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_FUNC32
(
adagrad
,
ADAGRAD
,
float
,
32
)
MAKE_FUNC32
(
adagrad
,
ADAGRAD
,
half
,
16
)
#define MAKE_FUNC8(fname, oname, gtype, gbits) \
#define MAKE_FUNC8(fname, oname, gtype, gbits) \
void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
...
@@ -62,6 +64,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
...
@@ -62,6 +64,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
MAKE_BLOCKWISE8
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_BLOCKWISE8
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_BLOCKWISE8
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_BLOCKWISE8
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_BLOCKWISE8
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_BLOCKWISE8
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_BLOCKWISE8
(
adagrad
,
ADAGRAD
,
half
,
16
)
MAKE_BLOCKWISE8
(
adagrad
,
ADAGRAD
,
float
,
32
)
void
percentileClipping_g32
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
float
>
(
g
,
gnorm_vec
,
step
,
n
);
}
void
percentileClipping_g32
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
float
>
(
g
,
gnorm_vec
,
step
,
n
);
}
...
@@ -102,6 +106,8 @@ extern "C"
...
@@ -102,6 +106,8 @@ extern "C"
MAKE_CFUNC32
(
momentum
,
half
,
16
)
MAKE_CFUNC32
(
momentum
,
half
,
16
)
MAKE_CFUNC32
(
rmsprop
,
float
,
32
)
MAKE_CFUNC32
(
rmsprop
,
float
,
32
)
MAKE_CFUNC32
(
rmsprop
,
half
,
16
)
MAKE_CFUNC32
(
rmsprop
,
half
,
16
)
MAKE_CFUNC32
(
adagrad
,
float
,
32
)
MAKE_CFUNC32
(
adagrad
,
half
,
16
)
#define MAKE_CFUNC8(name, gtype, gbits) \
#define MAKE_CFUNC8(name, gtype, gbits) \
void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
...
@@ -135,6 +141,8 @@ extern "C"
...
@@ -135,6 +141,8 @@ extern "C"
MAKE_CBLOCKWISE8
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_CBLOCKWISE8
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_CBLOCKWISE8
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_CBLOCKWISE8
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_CBLOCKWISE8
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_CBLOCKWISE8
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_CBLOCKWISE8
(
adagrad
,
ADAGRAD
,
half
,
16
)
MAKE_CBLOCKWISE8
(
adagrad
,
ADAGRAD
,
float
,
32
)
void
cpercentile_clipping_g32
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping_g32
(
g
,
gnorm_vec
,
step
,
n
);
}
void
cpercentile_clipping_g32
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping_g32
(
g
,
gnorm_vec
,
step
,
n
);
}
...
...
tests/test_optim.py
View file @
8b3c0f35
...
@@ -39,6 +39,7 @@ str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambd
...
@@ -39,6 +39,7 @@ str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambd
str2optimizers
[
'lars'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'lars'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'lamb'
]
=
(
lambda
pxx
:
apex
.
optimizers
.
FusedLAMB
(
pxx
,
weight_decay
=
0.0
,
max_grad_norm
=
10000.0
,
eps
=
1e-8
,
use_nvlamb
=
True
),
bnb
.
optim
.
LAMB
)
str2optimizers
[
'lamb'
]
=
(
lambda
pxx
:
apex
.
optimizers
.
FusedLAMB
(
pxx
,
weight_decay
=
0.0
,
max_grad_norm
=
10000.0
,
eps
=
1e-8
,
use_nvlamb
=
True
),
bnb
.
optim
.
LAMB
)
str2optimizers
[
'rmsprop'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'rmsprop'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'adagrad'
]
=
(
lambda
pxx
:
torch
.
optim
.
Adagrad
(
pxx
,
0.01
),
lambda
pxx
:
bnb
.
optim
.
Adagrad
(
pxx
,
0.01
,
block_wise
=
False
))
str2optimizers
[
'adam8bit'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
))
str2optimizers
[
'adam8bit'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
))
str2optimizers
[
'momentum8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'momentum8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'rmsprop8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'rmsprop8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
...
@@ -48,6 +49,7 @@ str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
...
@@ -48,6 +49,7 @@ str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
str2optimizers
[
'adam8bit_blockwise'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
'adam8bit_blockwise'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
'momentum8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'momentum8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'rmsprop8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'rmsprop8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'adagrad8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
Adagrad
(
pxx
,
0.01
),
lambda
pxx
:
bnb
.
optim
.
Adagrad8bit
(
pxx
,
0.01
,
block_wise
=
True
))
str2statenames
=
{}
str2statenames
=
{}
str2statenames
[
'adam'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'adam'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
...
@@ -55,6 +57,7 @@ str2statenames['momentum'] = [('momentum_buffer', 'state1')]
...
@@ -55,6 +57,7 @@ str2statenames['momentum'] = [('momentum_buffer', 'state1')]
str2statenames
[
'lars'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lars'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lamb'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'lamb'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'rmsprop'
]
=
[(
'square_avg'
,
'state1'
)]
str2statenames
[
'rmsprop'
]
=
[(
'square_avg'
,
'state1'
)]
str2statenames
[
'adagrad'
]
=
[(
'sum'
,
'state1'
)]
str2statenames
[
'adam8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'adam8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'lamb8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'lamb8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'adam8bit_blockwise'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'absmax2'
)]
str2statenames
[
'adam8bit_blockwise'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'absmax2'
)]
...
@@ -63,11 +66,12 @@ str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1
...
@@ -63,11 +66,12 @@ str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1
str2statenames
[
'lars8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'lars8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'rmsprop8bit'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'rmsprop8bit'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'rmsprop8bit_blockwise'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
'rmsprop8bit_blockwise'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
'adagrad8bit_blockwise'
]
=
[(
'sum'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam'
,
'momentum'
,
'rmsprop'
,
'lars'
,
'lamb'
]
optimizer_names
=
[
'adam'
,
'momentum'
,
'rmsprop'
,
'lars'
,
'lamb'
,
'adagrad'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
@@ -197,7 +201,7 @@ def test_global_config(dim1, dim2, gtype):
...
@@ -197,7 +201,7 @@ def test_global_config(dim1, dim2, gtype):
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam8bit'
,
'momentum8bit'
,
'rmsprop8bit'
,
'adam8bit_blockwise'
,
'lamb8bit'
,
'lars8bit'
,
'momentum8bit_blockwise'
,
'rmsprop8bit_blockwise'
]
optimizer_names
=
[
'adam8bit'
,
'momentum8bit'
,
'rmsprop8bit'
,
'adam8bit_blockwise'
,
'lamb8bit'
,
'lars8bit'
,
'momentum8bit_blockwise'
,
'rmsprop8bit_blockwise'
,
'adagrad8bit_blockwise'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment