Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
2f8083bd
Commit
2f8083bd
authored
Nov 28, 2021
by
Tim Dettmers
Browse files
Added AdamW. #10 #13
parent
ca2078a6
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
54 additions
and
13 deletions
+54
-13
CHANGELOG.md
CHANGELOG.md
+4
-0
Makefile
Makefile
+10
-9
bitsandbytes/optim/__init__.py
bitsandbytes/optim/__init__.py
+1
-0
bitsandbytes/optim/adam.py
bitsandbytes/optim/adam.py
+0
-1
bitsandbytes/optim/adamw.py
bitsandbytes/optim/adamw.py
+29
-0
csrc/kernels.cu
csrc/kernels.cu
+3
-0
tests/test_optim.py
tests/test_optim.py
+7
-3
No files found.
CHANGELOG.md
View file @
2f8083bd
...
@@ -42,3 +42,7 @@ Docs:
...
@@ -42,3 +42,7 @@ Docs:
Features:
Features:
-
Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer
-
Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer
-
Added AdamW (copy of Adam with weight decay init 1e-2)
Bug fixes:
-
Fixed a bug where weight decay was incorrectly applied to 32-bit Adam
Makefile
View file @
2f8083bd
...
@@ -19,15 +19,16 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu
...
@@ -19,15 +19,16 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu
LIB
:=
-L
$(CUDA_HOME)
/lib64
-lcudart
-lcuda
-lcublas
-lcurand
-lcusparse
-L
$(CONDA_PREFIX)
/lib
LIB
:=
-L
$(CUDA_HOME)
/lib64
-lcudart
-lcuda
-lcublas
-lcurand
-lcusparse
-L
$(CONDA_PREFIX)
/lib
# NVIDIA NVCC compilation flags
# NVIDIA NVCC compilation flags
COMPUTE_CAPABILITY
:=
-gencode
arch
=
compute_35,code
=
sm_35
# Kepler
#COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_37,code
=
sm_37
# Kepler
#COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_50,code
=
sm_50
# Maxwell
#COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_52,code
=
sm_52
# Maxwell
#COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_60,code
=
sm_60
# Pascal
#COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_61,code
=
sm_61
# Pascal
#COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_70,code
=
sm_70
# Volta
#COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_72,code
=
sm_72
# Volta
#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_72,code
=
sm_72
# Volta
#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
COMPUTE_CAPABILITY
:=
-gencode
arch
=
compute_75,code
=
sm_75
# Volta
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
CC_CUDA92
:=
-gencode
arch
=
compute_30,code
=
sm_30
CC_CUDA92
:=
-gencode
arch
=
compute_30,code
=
sm_30
...
...
bitsandbytes/optim/__init__.py
View file @
2f8083bd
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.adam
import
Adam
,
Adam8bit
,
Adam32bit
from
.adam
import
Adam
,
Adam8bit
,
Adam32bit
from
.adamw
import
AdamW
,
AdamW8bit
,
AdamW32bit
from
.sgd
import
SGD
,
SGD8bit
,
SGD32bit
from
.sgd
import
SGD
,
SGD8bit
,
SGD32bit
from
.lars
import
LARS
,
LARS8bit
,
LARS32bit
,
PytorchLARS
from
.lars
import
LARS
,
LARS8bit
,
LARS32bit
,
PytorchLARS
from
.lamb
import
LAMB
,
LAMB8bit
,
LAMB32bit
from
.lamb
import
LAMB
,
LAMB8bit
,
LAMB32bit
...
...
bitsandbytes/optim/adam.py
View file @
2f8083bd
...
@@ -28,7 +28,6 @@ class Adam32bit(Optimizer2State):
...
@@ -28,7 +28,6 @@ class Adam32bit(Optimizer2State):
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
class
AnalysisAdam
(
torch
.
optim
.
Optimizer
):
class
AnalysisAdam
(
torch
.
optim
.
Optimizer
):
"""Adam that performs 8-bit vs 32-bit error analysis.
"""Adam that performs 8-bit vs 32-bit error analysis.
...
...
bitsandbytes/optim/adamw.py
0 → 100644
View file @
2f8083bd
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
bitsandbytes.optim.optimizer
import
Optimizer2State
import
bitsandbytes.functional
as
F
class
AdamW
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
AdamW
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
class
AdamW8bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
AdamW8bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
class
AdamW32bit
(
Optimizer2State
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
1e-2
,
amsgrad
=
False
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
(
AdamW32bit
,
self
).
__init__
(
'adam'
,
params
,
lr
,
betas
,
eps
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
)
csrc/kernels.cu
View file @
2f8083bd
...
@@ -720,6 +720,9 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
...
@@ -720,6 +720,9 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
]));
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
]));
s2_vals
[
j
]
=
s2_vals
[
j
]
*
beta2
+
((
1.0
f
-
beta2
)
*
(((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
])));
s2_vals
[
j
]
=
s2_vals
[
j
]
*
beta2
+
((
1.0
f
-
beta2
)
*
(((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
])));
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
+
(
update_scale
*
step_size
*
(
s1_vals
[
j
]
/
(
sqrtf
(
s2_vals
[
j
])
+
(
eps
*
correction2
))));
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
+
(
update_scale
*
step_size
*
(
s1_vals
[
j
]
/
(
sqrtf
(
s2_vals
[
j
])
+
(
eps
*
correction2
))));
if
(
weight_decay
>
0.0
f
)
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
*
(
1.0
f
-
(
lr
*
weight_decay
));
}
}
break
;
break
;
}
}
...
...
tests/test_optim.py
View file @
2f8083bd
...
@@ -34,6 +34,7 @@ str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx,
...
@@ -34,6 +34,7 @@ str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx,
str2optimizers
[
'lars_apex'
]
=
(
None
,
lambda
pxx
:
apex
.
parallel
.
LARC
.
LARC
(
apex
.
optimizers
.
FusedSGD
(
pxx
,
0.01
,
0.9
)),
bnb
.
optim
.
Adam
)
str2optimizers
[
'lars_apex'
]
=
(
None
,
lambda
pxx
:
apex
.
parallel
.
LARC
.
LARC
(
apex
.
optimizers
.
FusedSGD
(
pxx
,
0.01
,
0.9
)),
bnb
.
optim
.
Adam
)
str2optimizers
[
'adam'
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'adam'
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'adamw'
]
=
(
torch
.
optim
.
AdamW
,
bnb
.
optim
.
AdamW
)
str2optimizers
[
'fused_adam'
]
=
(
apex
.
optimizers
.
FusedAdam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'fused_adam'
]
=
(
apex
.
optimizers
.
FusedAdam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'momentum'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'momentum'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'lars'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'lars'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
))
...
@@ -47,12 +48,14 @@ str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_
...
@@ -47,12 +48,14 @@ str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_
str2optimizers
[
'lars8bit'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS8bit
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'lars8bit'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS8bit
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'adam8bit_blockwise'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
'adam8bit_blockwise'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
'adamw8bit_blockwise'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
AdamW8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
'momentum8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'momentum8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'rmsprop8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'rmsprop8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'adagrad8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
Adagrad
(
pxx
,
0.01
),
lambda
pxx
:
bnb
.
optim
.
Adagrad8bit
(
pxx
,
0.01
,
block_wise
=
True
))
str2optimizers
[
'adagrad8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
Adagrad
(
pxx
,
0.01
),
lambda
pxx
:
bnb
.
optim
.
Adagrad8bit
(
pxx
,
0.01
,
block_wise
=
True
))
str2statenames
=
{}
str2statenames
=
{}
str2statenames
[
'adam'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'adam'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'adamw'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'momentum'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'momentum'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lars'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lars'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lamb'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'lamb'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
...
@@ -61,6 +64,7 @@ str2statenames['adagrad'] = [('sum', 'state1')]
...
@@ -61,6 +64,7 @@ str2statenames['adagrad'] = [('sum', 'state1')]
str2statenames
[
'adam8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'adam8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'lamb8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'lamb8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'adam8bit_blockwise'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'absmax2'
)]
str2statenames
[
'adam8bit_blockwise'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'absmax2'
)]
str2statenames
[
'adamw8bit_blockwise'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'absmax2'
)]
str2statenames
[
'momentum8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'momentum8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'momentum8bit_blockwise'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
'momentum8bit_blockwise'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
'lars8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'lars8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
...
@@ -71,7 +75,7 @@ str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')
...
@@ -71,7 +75,7 @@ str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam'
,
'momentum'
,
'rmsprop'
,
'lars'
,
'lamb'
,
'adagrad'
]
optimizer_names
=
[
'adam'
,
'adamw'
,
'momentum'
,
'rmsprop'
,
'lars'
,
'lamb'
,
'adagrad'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
@@ -86,7 +90,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
...
@@ -86,7 +90,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
if
gtype
==
torch
.
float32
:
if
gtype
==
torch
.
float32
:
atol
,
rtol
=
1
e-6
,
1e-5
atol
,
rtol
=
2
e-6
,
1e-5
else
:
else
:
atol
,
rtol
=
1e-4
,
1e-3
atol
,
rtol
=
1e-4
,
1e-3
...
@@ -201,7 +205,7 @@ def test_global_config(dim1, dim2, gtype):
...
@@ -201,7 +205,7 @@ def test_global_config(dim1, dim2, gtype):
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam8bit'
,
'momentum8bit'
,
'rmsprop8bit'
,
'adam8bit_blockwise'
,
'lamb8bit'
,
'lars8bit'
,
'momentum8bit_blockwise'
,
'rmsprop8bit_blockwise'
,
'adagrad8bit_blockwise'
]
optimizer_names
=
[
'adam8bit'
,
'momentum8bit'
,
'rmsprop8bit'
,
'adam8bit_blockwise'
,
'adamw8bit_blockwise'
,
'lamb8bit'
,
'lars8bit'
,
'momentum8bit_blockwise'
,
'rmsprop8bit_blockwise'
,
'adagrad8bit_blockwise'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment